1extern crate percent_encoding;
2extern crate url;
3
4use self::percent_encoding::percent_decode;
5use self::url::{Host, Url};
6use std::collections::HashMap;
7use std::ffi::{CStr, CString};
8
9use crate::result::{ConnectionError, ConnectionResult};
10
11use mysqlclient_sys::mysql_ssl_mode;
12
13bitflags::bitflags! {
14 #[derive(Clone, Copy)]
15 pub struct CapabilityFlags: u32 {
16 const CLIENT_LONG_PASSWORD = 0x00000001;
17 const CLIENT_FOUND_ROWS = 0x00000002;
18 const CLIENT_LONG_FLAG = 0x00000004;
19 const CLIENT_CONNECT_WITH_DB = 0x00000008;
20 const CLIENT_NO_SCHEMA = 0x00000010;
21 const CLIENT_COMPRESS = 0x00000020;
22 const CLIENT_ODBC = 0x00000040;
23 const CLIENT_LOCAL_FILES = 0x00000080;
24 const CLIENT_IGNORE_SPACE = 0x00000100;
25 const CLIENT_PROTOCOL_41 = 0x00000200;
26 const CLIENT_INTERACTIVE = 0x00000400;
27 const CLIENT_SSL = 0x00000800;
28 const CLIENT_IGNORE_SIGPIPE = 0x00001000;
29 const CLIENT_TRANSACTIONS = 0x00002000;
30 const CLIENT_RESERVED = 0x00004000;
31 const CLIENT_SECURE_CONNECTION = 0x00008000;
32 const CLIENT_MULTI_STATEMENTS = 0x00010000;
33 const CLIENT_MULTI_RESULTS = 0x00020000;
34 const CLIENT_PS_MULTI_RESULTS = 0x00040000;
35 const CLIENT_PLUGIN_AUTH = 0x00080000;
36 const CLIENT_CONNECT_ATTRS = 0x00100000;
37 const CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000;
38 const CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = 0x00400000;
39 const CLIENT_SESSION_TRACK = 0x00800000;
40 const CLIENT_DEPRECATE_EOF = 0x01000000;
41 }
42}
43
44pub(super) struct ConnectionOptions {
45 host: Option<CString>,
46 user: CString,
47 password: Option<CString>,
48 database: Option<CString>,
49 port: Option<u16>,
50 unix_socket: Option<CString>,
51 client_flags: CapabilityFlags,
52 ssl_mode: Option<mysql_ssl_mode>,
53 ssl_ca: Option<CString>,
54 ssl_cert: Option<CString>,
55 ssl_key: Option<CString>,
56}
57
58impl ConnectionOptions {
59 pub(super) fn parse(database_url: &str) -> ConnectionResult<Self> {
60 let url = match Url::parse(database_url) {
61 Ok(url) => url,
62 Err(_) => return Err(connection_url_error()),
63 };
64
65 if url.scheme() != "mysql" {
66 return Err(connection_url_error());
67 }
68
69 if url.path_segments().map(Iterator::count).unwrap_or(0) > 1 {
70 return Err(connection_url_error());
71 }
72
73 let query_pairs = url.query_pairs().into_owned().collect::<HashMap<_, _>>();
74 if query_pairs.contains_key("database") {
75 return Err(connection_url_error());
76 }
77
78 let unix_socket = match query_pairs.get("unix_socket") {
79 Some(v) => Some(CString::new(v.as_bytes())?),
80 _ => None,
81 };
82
83 let ssl_ca = match query_pairs.get("ssl_ca") {
84 Some(v) => Some(CString::new(v.as_bytes())?),
85 _ => None,
86 };
87
88 let ssl_cert = match query_pairs.get("ssl_cert") {
89 Some(v) => Some(CString::new(v.as_bytes())?),
90 _ => None,
91 };
92
93 let ssl_key = match query_pairs.get("ssl_key") {
94 Some(v) => Some(CString::new(v.as_bytes())?),
95 _ => None,
96 };
97
98 let ssl_mode = match query_pairs.get("ssl_mode") {
99 Some(v) => {
100 let ssl_mode = match v.to_lowercase().as_str() {
101 "disabled" => mysql_ssl_mode::SSL_MODE_DISABLED,
102 "preferred" => mysql_ssl_mode::SSL_MODE_PREFERRED,
103 "required" => mysql_ssl_mode::SSL_MODE_REQUIRED,
104 "verify_ca" => mysql_ssl_mode::SSL_MODE_VERIFY_CA,
105 "verify_identity" => mysql_ssl_mode::SSL_MODE_VERIFY_IDENTITY,
106 _ => {
107 let msg = "unknown ssl_mode";
108 return Err(ConnectionError::InvalidConnectionUrl(msg.into()));
109 }
110 };
111 Some(ssl_mode)
112 }
113 _ => None,
114 };
115
116 let host = match url.host() {
117 Some(Host::Ipv6(host)) => Some(CString::new(host.to_string())?),
118 Some(host) if host.to_string() == "localhost" && unix_socket.is_some() => None,
119 Some(host) => Some(CString::new(host.to_string())?),
120 None => None,
121 };
122 let user = decode_into_cstring(url.username())?;
123 let password = match url.password() {
124 Some(password) => Some(decode_into_cstring(password)?),
125 None => None,
126 };
127
128 let database = match url.path_segments().and_then(|mut iter| iter.next()) {
129 Some("") | None => None,
130 Some(segment) => Some(CString::new(segment.as_bytes())?),
131 };
132
133 let client_flags = CapabilityFlags::CLIENT_FOUND_ROWS;
135
136 Ok(ConnectionOptions {
137 host,
138 user,
139 password,
140 database,
141 port: url.port(),
142 client_flags,
143 ssl_mode,
144 unix_socket,
145 ssl_ca,
146 ssl_cert,
147 ssl_key,
148 })
149 }
150
151 pub(super) fn host(&self) -> Option<&CStr> {
152 self.host.as_deref()
153 }
154
155 pub(super) fn user(&self) -> &CStr {
156 &self.user
157 }
158
159 pub(super) fn password(&self) -> Option<&CStr> {
160 self.password.as_deref()
161 }
162
163 pub(super) fn database(&self) -> Option<&CStr> {
164 self.database.as_deref()
165 }
166
167 pub(super) fn port(&self) -> Option<u16> {
168 self.port
169 }
170
171 pub(super) fn unix_socket(&self) -> Option<&CStr> {
172 self.unix_socket.as_deref()
173 }
174
175 pub(super) fn ssl_ca(&self) -> Option<&CStr> {
176 self.ssl_ca.as_deref()
177 }
178
179 pub(super) fn ssl_cert(&self) -> Option<&CStr> {
180 self.ssl_cert.as_deref()
181 }
182
183 pub(super) fn ssl_key(&self) -> Option<&CStr> {
184 self.ssl_key.as_deref()
185 }
186
187 pub(super) fn client_flags(&self) -> CapabilityFlags {
188 self.client_flags
189 }
190
191 pub(super) fn ssl_mode(&self) -> Option<mysql_ssl_mode> {
192 self.ssl_mode
193 }
194}
195
196fn decode_into_cstring(s: &str) -> ConnectionResult<CString> {
197 let decoded = percent_decode(s.as_bytes())
198 .decode_utf8()
199 .map_err(|_| connection_url_error())?;
200 CString::new(decoded.as_bytes()).map_err(Into::into)
201}
202
203fn connection_url_error() -> ConnectionError {
204 let msg = "MySQL connection URLs must be in the form \
205 `mysql://[[user]:[password]@]host[:port][/database][?unix_socket=socket-path]`";
206 ConnectionError::InvalidConnectionUrl(msg.into())
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[diesel_test_helper::test]
214 fn urls_with_schemes_other_than_mysql_are_errors() {
215 assert!(ConnectionOptions::parse("postgres://localhost").is_err());
216 assert!(ConnectionOptions::parse("http://localhost").is_err());
217 assert!(ConnectionOptions::parse("file:///tmp/mysql.sock").is_err());
218 assert!(ConnectionOptions::parse("socket:///tmp/mysql.sock").is_err());
219 assert!(ConnectionOptions::parse("mysql://localhost?database=somedb").is_err());
220 assert!(ConnectionOptions::parse("mysql://localhost").is_ok());
221 }
222
223 #[diesel_test_helper::test]
224 fn urls_must_have_zero_or_one_path_segments() {
225 assert!(ConnectionOptions::parse("mysql://localhost/foo/bar").is_err());
226 assert!(ConnectionOptions::parse("mysql://localhost/foo").is_ok());
227 }
228
229 #[diesel_test_helper::test]
230 fn first_path_segment_is_treated_as_database() {
231 let foo_cstr = CString::new("foo").unwrap();
232 let bar_cstr = CString::new("bar").unwrap();
233 assert_eq!(
234 Some(&*foo_cstr),
235 ConnectionOptions::parse("mysql://localhost/foo")
236 .unwrap()
237 .database()
238 );
239 assert_eq!(
240 Some(&*bar_cstr),
241 ConnectionOptions::parse("mysql://localhost/bar")
242 .unwrap()
243 .database()
244 );
245 assert_eq!(
246 None,
247 ConnectionOptions::parse("mysql://localhost")
248 .unwrap()
249 .database()
250 );
251 }
252
253 #[diesel_test_helper::test]
254 fn userinfo_should_be_percent_decode() {
255 use self::percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
256 const USERINFO_ENCODE_SET: &AsciiSet = &CONTROLS
257 .add(b' ')
258 .add(b'"')
259 .add(b'<')
260 .add(b'>')
261 .add(b'`')
262 .add(b'#')
263 .add(b'?')
264 .add(b'{')
265 .add(b'}')
266 .add(b'/')
267 .add(b':')
268 .add(b';')
269 .add(b'=')
270 .add(b'@')
271 .add(b'[')
272 .add(b'\\')
273 .add(b']')
274 .add(b'^')
275 .add(b'|');
276
277 let username = "x#gfuL?4Zuj{n73m}eeJt0";
278 let encoded_username = utf8_percent_encode(username, USERINFO_ENCODE_SET);
279
280 let password = "x/gfuL?4Zuj{n73m}eeJt1";
281 let encoded_password = utf8_percent_encode(password, USERINFO_ENCODE_SET);
282
283 let db_url = format!("mysql://{encoded_username}:{encoded_password}@localhost/bar",);
284 let db_url = Url::parse(&db_url).unwrap();
285
286 let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap();
287 let username = CString::new(username.as_bytes()).unwrap();
288 let password = CString::new(password.as_bytes()).unwrap();
289 assert_eq!(username, conn_opts.user);
290 assert_eq!(password, conn_opts.password.unwrap());
291 }
292
293 #[diesel_test_helper::test]
294 fn ipv6_host_not_wrapped_in_brackets() {
295 let host1 = CString::new("::1").unwrap();
296 let host2 = CString::new("2001:db8:85a3::8a2e:370:7334").unwrap();
297
298 assert_eq!(
299 Some(&*host1),
300 ConnectionOptions::parse("mysql://[::1]").unwrap().host()
301 );
302 assert_eq!(
303 Some(&*host2),
304 ConnectionOptions::parse("mysql://[2001:db8:85a3::8a2e:370:7334]")
305 .unwrap()
306 .host()
307 );
308 }
309
310 #[diesel_test_helper::test]
311 fn unix_socket_tests() {
312 let unix_socket = "/var/run/mysqld.sock";
313 let username = "foo";
314 let password = "bar";
315 let db_url = format!("mysql://{username}:{password}@localhost?unix_socket={unix_socket}",);
316 let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap();
317 let cstring = |s| CString::new(s).unwrap();
318 assert_eq!(None, conn_opts.host);
319 assert_eq!(None, conn_opts.port);
320 assert_eq!(cstring(username), conn_opts.user);
321 assert_eq!(cstring(password), conn_opts.password.unwrap());
322 assert_eq!(
323 CString::new(unix_socket).unwrap(),
324 conn_opts.unix_socket.unwrap()
325 );
326 }
327
328 #[diesel_test_helper::test]
329 fn ssl_ca_tests() {
330 let ssl_ca = "/etc/ssl/certs/ca-certificates.crt";
331 let username = "foo";
332 let password = "bar";
333 let db_url = format!("mysql://{username}:{password}@localhost?ssl_ca={ssl_ca}",);
334 let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap();
335 let cstring = |s| CString::new(s).unwrap();
336 assert_eq!(Some(cstring("localhost")), conn_opts.host);
337 assert_eq!(None, conn_opts.port);
338 assert_eq!(cstring(username), conn_opts.user);
339 assert_eq!(cstring(password), conn_opts.password.unwrap());
340 assert_eq!(CString::new(ssl_ca).unwrap(), conn_opts.ssl_ca.unwrap());
341
342 let url_with_unix_str_and_ssl_ca = format!(
343 "mysql://{username}:{password}@localhost?unix_socket=/var/run/mysqld.sock&ssl_ca={ssl_ca}"
344 );
345
346 let conn_opts2 = ConnectionOptions::parse(url_with_unix_str_and_ssl_ca.as_str()).unwrap();
347 assert_eq!(None, conn_opts2.host);
348 assert_eq!(None, conn_opts2.port);
349 assert_eq!(CString::new(ssl_ca).unwrap(), conn_opts2.ssl_ca.unwrap());
350 }
351
352 #[diesel_test_helper::test]
353 fn ssl_cert_tests() {
354 let ssl_cert = "/etc/ssl/certs/client-cert.crt";
355 let username = "foo";
356 let password = "bar";
357 let db_url = format!("mysql://{username}:{password}@localhost?ssl_cert={ssl_cert}");
358 let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap();
359 let cstring = |s| CString::new(s).unwrap();
360 assert_eq!(Some(cstring("localhost")), conn_opts.host);
361 assert_eq!(None, conn_opts.port);
362 assert_eq!(cstring(username), conn_opts.user);
363 assert_eq!(cstring(password), conn_opts.password.unwrap());
364 assert_eq!(CString::new(ssl_cert).unwrap(), conn_opts.ssl_cert.unwrap());
365
366 let url_with_unix_str_and_ssl_cert = format!(
367 "mysql://{username}:{password}@localhost?unix_socket=/var/run/mysqld.sock&ssl_cert={ssl_cert}"
368 );
369
370 let conn_opts2 = ConnectionOptions::parse(url_with_unix_str_and_ssl_cert.as_str()).unwrap();
371 assert_eq!(None, conn_opts2.host);
372 assert_eq!(None, conn_opts2.port);
373 assert_eq!(
374 CString::new(ssl_cert).unwrap(),
375 conn_opts2.ssl_cert.unwrap()
376 );
377 }
378
379 #[diesel_test_helper::test]
380 fn ssl_key_tests() {
381 let ssl_key = "/etc/ssl/certs/client-key.crt";
382 let username = "foo";
383 let password = "bar";
384 let db_url = format!("mysql://{username}:{password}@localhost?ssl_key={ssl_key}");
385 let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap();
386 let cstring = |s| CString::new(s).unwrap();
387 assert_eq!(Some(cstring("localhost")), conn_opts.host);
388 assert_eq!(None, conn_opts.port);
389 assert_eq!(cstring(username), conn_opts.user);
390 assert_eq!(cstring(password), conn_opts.password.unwrap());
391 assert_eq!(CString::new(ssl_key).unwrap(), conn_opts.ssl_key.unwrap());
392
393 let url_with_unix_str_and_ssl_key = format!(
394 "mysql://{username}:{password}@localhost?unix_socket=/var/run/mysqld.sock&ssl_key={ssl_key}"
395 );
396
397 let conn_opts2 = ConnectionOptions::parse(url_with_unix_str_and_ssl_key.as_str()).unwrap();
398 assert_eq!(None, conn_opts2.host);
399 assert_eq!(None, conn_opts2.port);
400 assert_eq!(CString::new(ssl_key).unwrap(), conn_opts2.ssl_key.unwrap());
401 }
402
403 #[diesel_test_helper::test]
404 fn ssl_mode() {
405 let ssl_mode = |url| ConnectionOptions::parse(url).unwrap().ssl_mode();
406 assert_eq!(ssl_mode("mysql://localhost"), None);
407 assert_eq!(
408 ssl_mode("mysql://localhost?ssl_mode=disabled"),
409 Some(mysql_ssl_mode::SSL_MODE_DISABLED)
410 );
411 assert_eq!(
412 ssl_mode("mysql://localhost?ssl_mode=PREFERRED"),
413 Some(mysql_ssl_mode::SSL_MODE_PREFERRED)
414 );
415 assert_eq!(
416 ssl_mode("mysql://localhost?ssl_mode=required"),
417 Some(mysql_ssl_mode::SSL_MODE_REQUIRED)
418 );
419 assert_eq!(
420 ssl_mode("mysql://localhost?ssl_mode=VERIFY_CA"),
421 Some(mysql_ssl_mode::SSL_MODE_VERIFY_CA)
422 );
423 assert_eq!(
424 ssl_mode("mysql://localhost?ssl_mode=verify_identity"),
425 Some(mysql_ssl_mode::SSL_MODE_VERIFY_IDENTITY)
426 );
427 }
428}