diesel/mysql/connection/
url.rs

1extern crate percent_encoding;
2extern crate url;
3
4use self::percent_encoding::percent_decode;
5use self::url::{Host, Url};
6use std::collections::HashMap;
7use std::ffi::{CStr, CString};
8
9use crate::result::{ConnectionError, ConnectionResult};
10
11use mysqlclient_sys::mysql_ssl_mode;
12
13bitflags::bitflags! {
14    #[derive(Clone, Copy)]
15    pub struct CapabilityFlags: u32 {
16        const CLIENT_LONG_PASSWORD = 0x00000001;
17        const CLIENT_FOUND_ROWS = 0x00000002;
18        const CLIENT_LONG_FLAG = 0x00000004;
19        const CLIENT_CONNECT_WITH_DB = 0x00000008;
20        const CLIENT_NO_SCHEMA = 0x00000010;
21        const CLIENT_COMPRESS = 0x00000020;
22        const CLIENT_ODBC = 0x00000040;
23        const CLIENT_LOCAL_FILES = 0x00000080;
24        const CLIENT_IGNORE_SPACE = 0x00000100;
25        const CLIENT_PROTOCOL_41 = 0x00000200;
26        const CLIENT_INTERACTIVE = 0x00000400;
27        const CLIENT_SSL = 0x00000800;
28        const CLIENT_IGNORE_SIGPIPE = 0x00001000;
29        const CLIENT_TRANSACTIONS = 0x00002000;
30        const CLIENT_RESERVED = 0x00004000;
31        const CLIENT_SECURE_CONNECTION = 0x00008000;
32        const CLIENT_MULTI_STATEMENTS = 0x00010000;
33        const CLIENT_MULTI_RESULTS = 0x00020000;
34        const CLIENT_PS_MULTI_RESULTS = 0x00040000;
35        const CLIENT_PLUGIN_AUTH = 0x00080000;
36        const CLIENT_CONNECT_ATTRS = 0x00100000;
37        const CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000;
38        const CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = 0x00400000;
39        const CLIENT_SESSION_TRACK = 0x00800000;
40        const CLIENT_DEPRECATE_EOF = 0x01000000;
41    }
42}
43
44pub(super) struct ConnectionOptions {
45    host: Option<CString>,
46    user: CString,
47    password: Option<CString>,
48    database: Option<CString>,
49    port: Option<u16>,
50    unix_socket: Option<CString>,
51    client_flags: CapabilityFlags,
52    ssl_mode: Option<mysql_ssl_mode>,
53    ssl_ca: Option<CString>,
54    ssl_cert: Option<CString>,
55    ssl_key: Option<CString>,
56}
57
58impl ConnectionOptions {
59    pub(super) fn parse(database_url: &str) -> ConnectionResult<Self> {
60        let url = match Url::parse(database_url) {
61            Ok(url) => url,
62            Err(_) => return Err(connection_url_error()),
63        };
64
65        if url.scheme() != "mysql" {
66            return Err(connection_url_error());
67        }
68
69        if url.path_segments().map(Iterator::count).unwrap_or(0) > 1 {
70            return Err(connection_url_error());
71        }
72
73        let query_pairs = url.query_pairs().into_owned().collect::<HashMap<_, _>>();
74        if query_pairs.contains_key("database") {
75            return Err(connection_url_error());
76        }
77
78        let unix_socket = match query_pairs.get("unix_socket") {
79            Some(v) => Some(CString::new(v.as_bytes())?),
80            _ => None,
81        };
82
83        let ssl_ca = match query_pairs.get("ssl_ca") {
84            Some(v) => Some(CString::new(v.as_bytes())?),
85            _ => None,
86        };
87
88        let ssl_cert = match query_pairs.get("ssl_cert") {
89            Some(v) => Some(CString::new(v.as_bytes())?),
90            _ => None,
91        };
92
93        let ssl_key = match query_pairs.get("ssl_key") {
94            Some(v) => Some(CString::new(v.as_bytes())?),
95            _ => None,
96        };
97
98        let ssl_mode = match query_pairs.get("ssl_mode") {
99            Some(v) => {
100                let ssl_mode = match v.to_lowercase().as_str() {
101                    "disabled" => mysql_ssl_mode::SSL_MODE_DISABLED,
102                    "preferred" => mysql_ssl_mode::SSL_MODE_PREFERRED,
103                    "required" => mysql_ssl_mode::SSL_MODE_REQUIRED,
104                    "verify_ca" => mysql_ssl_mode::SSL_MODE_VERIFY_CA,
105                    "verify_identity" => mysql_ssl_mode::SSL_MODE_VERIFY_IDENTITY,
106                    _ => {
107                        let msg = "unknown ssl_mode";
108                        return Err(ConnectionError::InvalidConnectionUrl(msg.into()));
109                    }
110                };
111                Some(ssl_mode)
112            }
113            _ => None,
114        };
115
116        let host = match url.host() {
117            Some(Host::Ipv6(host)) => Some(CString::new(host.to_string())?),
118            Some(host) if host.to_string() == "localhost" && unix_socket.is_some() => None,
119            Some(host) => Some(CString::new(host.to_string())?),
120            None => None,
121        };
122        let user = decode_into_cstring(url.username())?;
123        let password = match url.password() {
124            Some(password) => Some(decode_into_cstring(password)?),
125            None => None,
126        };
127
128        let database = match url.path_segments().and_then(|mut iter| iter.next()) {
129            Some("") | None => None,
130            Some(segment) => Some(CString::new(segment.as_bytes())?),
131        };
132
133        // this is not present in the database_url, using a default value
134        let client_flags = CapabilityFlags::CLIENT_FOUND_ROWS;
135
136        Ok(ConnectionOptions {
137            host,
138            user,
139            password,
140            database,
141            port: url.port(),
142            client_flags,
143            ssl_mode,
144            unix_socket,
145            ssl_ca,
146            ssl_cert,
147            ssl_key,
148        })
149    }
150
151    pub(super) fn host(&self) -> Option<&CStr> {
152        self.host.as_deref()
153    }
154
155    pub(super) fn user(&self) -> &CStr {
156        &self.user
157    }
158
159    pub(super) fn password(&self) -> Option<&CStr> {
160        self.password.as_deref()
161    }
162
163    pub(super) fn database(&self) -> Option<&CStr> {
164        self.database.as_deref()
165    }
166
167    pub(super) fn port(&self) -> Option<u16> {
168        self.port
169    }
170
171    pub(super) fn unix_socket(&self) -> Option<&CStr> {
172        self.unix_socket.as_deref()
173    }
174
175    pub(super) fn ssl_ca(&self) -> Option<&CStr> {
176        self.ssl_ca.as_deref()
177    }
178
179    pub(super) fn ssl_cert(&self) -> Option<&CStr> {
180        self.ssl_cert.as_deref()
181    }
182
183    pub(super) fn ssl_key(&self) -> Option<&CStr> {
184        self.ssl_key.as_deref()
185    }
186
187    pub(super) fn client_flags(&self) -> CapabilityFlags {
188        self.client_flags
189    }
190
191    pub(super) fn ssl_mode(&self) -> Option<mysql_ssl_mode> {
192        self.ssl_mode
193    }
194}
195
196fn decode_into_cstring(s: &str) -> ConnectionResult<CString> {
197    let decoded = percent_decode(s.as_bytes())
198        .decode_utf8()
199        .map_err(|_| connection_url_error())?;
200    CString::new(decoded.as_bytes()).map_err(Into::into)
201}
202
203fn connection_url_error() -> ConnectionError {
204    let msg = "MySQL connection URLs must be in the form \
205               `mysql://[[user]:[password]@]host[:port][/database][?unix_socket=socket-path]`";
206    ConnectionError::InvalidConnectionUrl(msg.into())
207}
208
209#[test]
210fn urls_with_schemes_other_than_mysql_are_errors() {
211    assert!(ConnectionOptions::parse("postgres://localhost").is_err());
212    assert!(ConnectionOptions::parse("http://localhost").is_err());
213    assert!(ConnectionOptions::parse("file:///tmp/mysql.sock").is_err());
214    assert!(ConnectionOptions::parse("socket:///tmp/mysql.sock").is_err());
215    assert!(ConnectionOptions::parse("mysql://localhost?database=somedb").is_err());
216    assert!(ConnectionOptions::parse("mysql://localhost").is_ok());
217}
218
219#[test]
220fn urls_must_have_zero_or_one_path_segments() {
221    assert!(ConnectionOptions::parse("mysql://localhost/foo/bar").is_err());
222    assert!(ConnectionOptions::parse("mysql://localhost/foo").is_ok());
223}
224
225#[test]
226fn first_path_segment_is_treated_as_database() {
227    let foo_cstr = CString::new("foo").unwrap();
228    let bar_cstr = CString::new("bar").unwrap();
229    assert_eq!(
230        Some(&*foo_cstr),
231        ConnectionOptions::parse("mysql://localhost/foo")
232            .unwrap()
233            .database()
234    );
235    assert_eq!(
236        Some(&*bar_cstr),
237        ConnectionOptions::parse("mysql://localhost/bar")
238            .unwrap()
239            .database()
240    );
241    assert_eq!(
242        None,
243        ConnectionOptions::parse("mysql://localhost")
244            .unwrap()
245            .database()
246    );
247}
248
249#[test]
250fn userinfo_should_be_percent_decode() {
251    use self::percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
252    const USERINFO_ENCODE_SET: &AsciiSet = &CONTROLS
253        .add(b' ')
254        .add(b'"')
255        .add(b'<')
256        .add(b'>')
257        .add(b'`')
258        .add(b'#')
259        .add(b'?')
260        .add(b'{')
261        .add(b'}')
262        .add(b'/')
263        .add(b':')
264        .add(b';')
265        .add(b'=')
266        .add(b'@')
267        .add(b'[')
268        .add(b'\\')
269        .add(b']')
270        .add(b'^')
271        .add(b'|');
272
273    let username = "x#gfuL?4Zuj{n73m}eeJt0";
274    let encoded_username = utf8_percent_encode(username, USERINFO_ENCODE_SET);
275
276    let password = "x/gfuL?4Zuj{n73m}eeJt1";
277    let encoded_password = utf8_percent_encode(password, USERINFO_ENCODE_SET);
278
279    let db_url = format!("mysql://{encoded_username}:{encoded_password}@localhost/bar",);
280    let db_url = Url::parse(&db_url).unwrap();
281
282    let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap();
283    let username = CString::new(username.as_bytes()).unwrap();
284    let password = CString::new(password.as_bytes()).unwrap();
285    assert_eq!(username, conn_opts.user);
286    assert_eq!(password, conn_opts.password.unwrap());
287}
288
289#[test]
290fn ipv6_host_not_wrapped_in_brackets() {
291    let host1 = CString::new("::1").unwrap();
292    let host2 = CString::new("2001:db8:85a3::8a2e:370:7334").unwrap();
293
294    assert_eq!(
295        Some(&*host1),
296        ConnectionOptions::parse("mysql://[::1]").unwrap().host()
297    );
298    assert_eq!(
299        Some(&*host2),
300        ConnectionOptions::parse("mysql://[2001:db8:85a3::8a2e:370:7334]")
301            .unwrap()
302            .host()
303    );
304}
305
306#[test]
307fn unix_socket_tests() {
308    let unix_socket = "/var/run/mysqld.sock";
309    let username = "foo";
310    let password = "bar";
311    let db_url = format!("mysql://{username}:{password}@localhost?unix_socket={unix_socket}",);
312    let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap();
313    let cstring = |s| CString::new(s).unwrap();
314    assert_eq!(None, conn_opts.host);
315    assert_eq!(None, conn_opts.port);
316    assert_eq!(cstring(username), conn_opts.user);
317    assert_eq!(cstring(password), conn_opts.password.unwrap());
318    assert_eq!(
319        CString::new(unix_socket).unwrap(),
320        conn_opts.unix_socket.unwrap()
321    );
322}
323
324#[test]
325fn ssl_ca_tests() {
326    let ssl_ca = "/etc/ssl/certs/ca-certificates.crt";
327    let username = "foo";
328    let password = "bar";
329    let db_url = format!("mysql://{username}:{password}@localhost?ssl_ca={ssl_ca}",);
330    let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap();
331    let cstring = |s| CString::new(s).unwrap();
332    assert_eq!(Some(cstring("localhost")), conn_opts.host);
333    assert_eq!(None, conn_opts.port);
334    assert_eq!(cstring(username), conn_opts.user);
335    assert_eq!(cstring(password), conn_opts.password.unwrap());
336    assert_eq!(CString::new(ssl_ca).unwrap(), conn_opts.ssl_ca.unwrap());
337
338    let url_with_unix_str_and_ssl_ca = format!(
339        "mysql://{username}:{password}@localhost?unix_socket=/var/run/mysqld.sock&ssl_ca={ssl_ca}"
340    );
341
342    let conn_opts2 = ConnectionOptions::parse(url_with_unix_str_and_ssl_ca.as_str()).unwrap();
343    assert_eq!(None, conn_opts2.host);
344    assert_eq!(None, conn_opts2.port);
345    assert_eq!(CString::new(ssl_ca).unwrap(), conn_opts2.ssl_ca.unwrap());
346}
347
348#[test]
349fn ssl_cert_tests() {
350    let ssl_cert = "/etc/ssl/certs/client-cert.crt";
351    let username = "foo";
352    let password = "bar";
353    let db_url = format!("mysql://{username}:{password}@localhost?ssl_cert={ssl_cert}");
354    let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap();
355    let cstring = |s| CString::new(s).unwrap();
356    assert_eq!(Some(cstring("localhost")), conn_opts.host);
357    assert_eq!(None, conn_opts.port);
358    assert_eq!(cstring(username), conn_opts.user);
359    assert_eq!(cstring(password), conn_opts.password.unwrap());
360    assert_eq!(CString::new(ssl_cert).unwrap(), conn_opts.ssl_cert.unwrap());
361
362    let url_with_unix_str_and_ssl_cert = format!(
363        "mysql://{username}:{password}@localhost?unix_socket=/var/run/mysqld.sock&ssl_cert={ssl_cert}"
364    );
365
366    let conn_opts2 = ConnectionOptions::parse(url_with_unix_str_and_ssl_cert.as_str()).unwrap();
367    assert_eq!(None, conn_opts2.host);
368    assert_eq!(None, conn_opts2.port);
369    assert_eq!(
370        CString::new(ssl_cert).unwrap(),
371        conn_opts2.ssl_cert.unwrap()
372    );
373}
374
375#[test]
376fn ssl_key_tests() {
377    let ssl_key = "/etc/ssl/certs/client-key.crt";
378    let username = "foo";
379    let password = "bar";
380    let db_url = format!("mysql://{username}:{password}@localhost?ssl_key={ssl_key}");
381    let conn_opts = ConnectionOptions::parse(db_url.as_str()).unwrap();
382    let cstring = |s| CString::new(s).unwrap();
383    assert_eq!(Some(cstring("localhost")), conn_opts.host);
384    assert_eq!(None, conn_opts.port);
385    assert_eq!(cstring(username), conn_opts.user);
386    assert_eq!(cstring(password), conn_opts.password.unwrap());
387    assert_eq!(CString::new(ssl_key).unwrap(), conn_opts.ssl_key.unwrap());
388
389    let url_with_unix_str_and_ssl_key = format!(
390        "mysql://{username}:{password}@localhost?unix_socket=/var/run/mysqld.sock&ssl_key={ssl_key}"
391    );
392
393    let conn_opts2 = ConnectionOptions::parse(url_with_unix_str_and_ssl_key.as_str()).unwrap();
394    assert_eq!(None, conn_opts2.host);
395    assert_eq!(None, conn_opts2.port);
396    assert_eq!(CString::new(ssl_key).unwrap(), conn_opts2.ssl_key.unwrap());
397}
398
399#[test]
400fn ssl_mode() {
401    let ssl_mode = |url| ConnectionOptions::parse(url).unwrap().ssl_mode();
402    assert_eq!(ssl_mode("mysql://localhost"), None);
403    assert_eq!(
404        ssl_mode("mysql://localhost?ssl_mode=disabled"),
405        Some(mysql_ssl_mode::SSL_MODE_DISABLED)
406    );
407    assert_eq!(
408        ssl_mode("mysql://localhost?ssl_mode=PREFERRED"),
409        Some(mysql_ssl_mode::SSL_MODE_PREFERRED)
410    );
411    assert_eq!(
412        ssl_mode("mysql://localhost?ssl_mode=required"),
413        Some(mysql_ssl_mode::SSL_MODE_REQUIRED)
414    );
415    assert_eq!(
416        ssl_mode("mysql://localhost?ssl_mode=VERIFY_CA"),
417        Some(mysql_ssl_mode::SSL_MODE_VERIFY_CA)
418    );
419    assert_eq!(
420        ssl_mode("mysql://localhost?ssl_mode=verify_identity"),
421        Some(mysql_ssl_mode::SSL_MODE_VERIFY_IDENTITY)
422    );
423}