diesel/sqlite/connection/
raw.rs

1#![allow(unsafe_code)] // ffi calls
2extern crate libsqlite3_sys as ffi;
3
4use std::ffi::{CString, NulError};
5use std::io::{stderr, Write};
6use std::os::raw as libc;
7use std::ptr::NonNull;
8use std::{mem, ptr, slice, str};
9
10use super::functions::{build_sql_function_args, process_sql_function_result};
11use super::serialized_database::SerializedDatabase;
12use super::stmt::ensure_sqlite_ok;
13use super::{Sqlite, SqliteAggregateFunction};
14use crate::deserialize::FromSqlRow;
15use crate::result::Error::DatabaseError;
16use crate::result::*;
17use crate::serialize::ToSql;
18use crate::sql_types::HasSqlType;
19
20/// For use in FFI function, which cannot unwind.
21/// Print the message, ask to open an issue at Github and [`abort`](std::process::abort).
22macro_rules! assert_fail {
23    ($fmt:expr $(,$args:tt)*) => {
24        eprint!(concat!(
25            $fmt,
26            "If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\n",
27            "Source location: {}:{}\n",
28        ), $($args,)* file!(), line!());
29        std::process::abort()
30    };
31}
32
33#[allow(missing_debug_implementations, missing_copy_implementations)]
34pub(super) struct RawConnection {
35    pub(super) internal_connection: NonNull<ffi::sqlite3>,
36}
37
38impl RawConnection {
39    pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
40        let mut conn_pointer = ptr::null_mut();
41
42        let database_url = if database_url.starts_with("sqlite://") {
43            CString::new(database_url.replacen("sqlite://", "file:", 1))?
44        } else {
45            CString::new(database_url)?
46        };
47        let flags = ffi::SQLITE_OPEN_READWRITE | ffi::SQLITE_OPEN_CREATE | ffi::SQLITE_OPEN_URI;
48        let connection_status = unsafe {
49            ffi::sqlite3_open_v2(database_url.as_ptr(), &mut conn_pointer, flags, ptr::null())
50        };
51
52        match connection_status {
53            ffi::SQLITE_OK => {
54                let conn_pointer = unsafe { NonNull::new_unchecked(conn_pointer) };
55                Ok(RawConnection {
56                    internal_connection: conn_pointer,
57                })
58            }
59            err_code => {
60                let message = super::error_message(err_code);
61                Err(ConnectionError::BadConnection(message.into()))
62            }
63        }
64    }
65
66    pub(super) fn exec(&self, query: &str) -> QueryResult<()> {
67        let query = CString::new(query)?;
68        let callback_fn = None;
69        let callback_arg = ptr::null_mut();
70        let result = unsafe {
71            ffi::sqlite3_exec(
72                self.internal_connection.as_ptr(),
73                query.as_ptr(),
74                callback_fn,
75                callback_arg,
76                ptr::null_mut(),
77            )
78        };
79
80        ensure_sqlite_ok(result, self.internal_connection.as_ptr())
81    }
82
83    pub(super) fn rows_affected_by_last_query(
84        &self,
85    ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
86        let r = unsafe { ffi::sqlite3_changes(self.internal_connection.as_ptr()) };
87
88        Ok(r.try_into()?)
89    }
90
91    pub(super) fn register_sql_function<F, Ret, RetSqlType>(
92        &self,
93        fn_name: &str,
94        num_args: usize,
95        deterministic: bool,
96        f: F,
97    ) -> QueryResult<()>
98    where
99        F: FnMut(&Self, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
100            + std::panic::UnwindSafe
101            + Send
102            + 'static,
103        Ret: ToSql<RetSqlType, Sqlite>,
104        Sqlite: HasSqlType<RetSqlType>,
105    {
106        let callback_fn = Box::into_raw(Box::new(CustomFunctionUserPtr {
107            callback: f,
108            function_name: fn_name.to_owned(),
109        }));
110        let fn_name = Self::get_fn_name(fn_name)?;
111        let flags = Self::get_flags(deterministic);
112        let num_args = num_args
113            .try_into()
114            .map_err(|e| Error::SerializationError(Box::new(e)))?;
115
116        let result = unsafe {
117            ffi::sqlite3_create_function_v2(
118                self.internal_connection.as_ptr(),
119                fn_name.as_ptr(),
120                num_args,
121                flags,
122                callback_fn as *mut _,
123                Some(run_custom_function::<F, Ret, RetSqlType>),
124                None,
125                None,
126                Some(destroy_boxed::<CustomFunctionUserPtr<F>>),
127            )
128        };
129
130        Self::process_sql_function_result(result)
131    }
132
133    pub(super) fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
134        &self,
135        fn_name: &str,
136        num_args: usize,
137    ) -> QueryResult<()>
138    where
139        A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
140        Args: FromSqlRow<ArgsSqlType, Sqlite>,
141        Ret: ToSql<RetSqlType, Sqlite>,
142        Sqlite: HasSqlType<RetSqlType>,
143    {
144        let fn_name = Self::get_fn_name(fn_name)?;
145        let flags = Self::get_flags(false);
146        let num_args = num_args
147            .try_into()
148            .map_err(|e| Error::SerializationError(Box::new(e)))?;
149
150        let result = unsafe {
151            ffi::sqlite3_create_function_v2(
152                self.internal_connection.as_ptr(),
153                fn_name.as_ptr(),
154                num_args,
155                flags,
156                ptr::null_mut(),
157                None,
158                Some(run_aggregator_step_function::<_, _, _, _, A>),
159                Some(run_aggregator_final_function::<_, _, _, _, A>),
160                None,
161            )
162        };
163
164        Self::process_sql_function_result(result)
165    }
166
167    pub(super) fn register_collation_function<F>(
168        &self,
169        collation_name: &str,
170        collation: F,
171    ) -> QueryResult<()>
172    where
173        F: Fn(&str, &str) -> std::cmp::Ordering + std::panic::UnwindSafe + Send + 'static,
174    {
175        let callback_fn = Box::into_raw(Box::new(CollationUserPtr {
176            callback: collation,
177            collation_name: collation_name.to_owned(),
178        }));
179        let collation_name = Self::get_fn_name(collation_name)?;
180
181        let result = unsafe {
182            ffi::sqlite3_create_collation_v2(
183                self.internal_connection.as_ptr(),
184                collation_name.as_ptr(),
185                ffi::SQLITE_UTF8,
186                callback_fn as *mut _,
187                Some(run_collation_function::<F>),
188                Some(destroy_boxed::<CollationUserPtr<F>>),
189            )
190        };
191
192        let result = Self::process_sql_function_result(result);
193        if result.is_err() {
194            destroy_boxed::<CollationUserPtr<F>>(callback_fn as *mut _);
195        }
196        result
197    }
198
199    pub(super) fn serialize(&mut self) -> SerializedDatabase {
200        unsafe {
201            let mut size: ffi::sqlite3_int64 = 0;
202            let data_ptr = ffi::sqlite3_serialize(
203                self.internal_connection.as_ptr(),
204                std::ptr::null(),
205                &mut size as *mut _,
206                0,
207            );
208            SerializedDatabase::new(
209                data_ptr,
210                size.try_into()
211                    .expect("Cannot fit the serialized database into memory"),
212            )
213        }
214    }
215
216    pub(super) fn deserialize(&mut self, data: &[u8]) -> QueryResult<()> {
217        let db_size = data
218            .len()
219            .try_into()
220            .map_err(|e| Error::DeserializationError(Box::new(e)))?;
221        // the cast for `ffi::SQLITE_DESERIALIZE_READONLY` is required for old libsqlite3-sys versions
222        #[allow(clippy::unnecessary_cast)]
223        unsafe {
224            let result = ffi::sqlite3_deserialize(
225                self.internal_connection.as_ptr(),
226                std::ptr::null(),
227                data.as_ptr() as *mut u8,
228                db_size,
229                db_size,
230                ffi::SQLITE_DESERIALIZE_READONLY as u32,
231            );
232
233            ensure_sqlite_ok(result, self.internal_connection.as_ptr())
234        }
235    }
236
237    fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
238        CString::new(fn_name)
239    }
240
241    fn get_flags(deterministic: bool) -> i32 {
242        let mut flags = ffi::SQLITE_UTF8;
243        if deterministic {
244            flags |= ffi::SQLITE_DETERMINISTIC;
245        }
246        flags
247    }
248
249    fn process_sql_function_result(result: i32) -> Result<(), Error> {
250        if result == ffi::SQLITE_OK {
251            Ok(())
252        } else {
253            let error_message = super::error_message(result);
254            Err(DatabaseError(
255                DatabaseErrorKind::Unknown,
256                Box::new(error_message.to_string()),
257            ))
258        }
259    }
260}
261
262impl Drop for RawConnection {
263    fn drop(&mut self) {
264        use std::thread::panicking;
265
266        let close_result = unsafe { ffi::sqlite3_close(self.internal_connection.as_ptr()) };
267        if close_result != ffi::SQLITE_OK {
268            let error_message = super::error_message(close_result);
269            if panicking() {
270                write!(stderr(), "Error closing SQLite connection: {error_message}")
271                    .expect("Error writing to `stderr`");
272            } else {
273                panic!("Error closing SQLite connection: {}", error_message);
274            }
275        }
276    }
277}
278
279enum SqliteCallbackError {
280    Abort(&'static str),
281    DieselError(crate::result::Error),
282    Panic(String),
283}
284
285impl SqliteCallbackError {
286    fn emit(&self, ctx: *mut ffi::sqlite3_context) {
287        let s;
288        let msg = match self {
289            SqliteCallbackError::Abort(msg) => *msg,
290            SqliteCallbackError::DieselError(e) => {
291                s = e.to_string();
292                &s
293            }
294            SqliteCallbackError::Panic(msg) => msg,
295        };
296        unsafe {
297            context_error_str(ctx, msg);
298        }
299    }
300}
301
302impl From<crate::result::Error> for SqliteCallbackError {
303    fn from(e: crate::result::Error) -> Self {
304        Self::DieselError(e)
305    }
306}
307
308struct CustomFunctionUserPtr<F> {
309    callback: F,
310    function_name: String,
311}
312
313#[allow(warnings)]
314extern "C" fn run_custom_function<F, Ret, RetSqlType>(
315    ctx: *mut ffi::sqlite3_context,
316    num_args: libc::c_int,
317    value_ptr: *mut *mut ffi::sqlite3_value,
318) where
319    F: FnMut(&RawConnection, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
320        + std::panic::UnwindSafe
321        + Send
322        + 'static,
323    Ret: ToSql<RetSqlType, Sqlite>,
324    Sqlite: HasSqlType<RetSqlType>,
325{
326    use std::ops::Deref;
327    static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen.";
328    static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen.";
329
330    let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } {
331        // We use `ManuallyDrop` here because we do not want to run the
332        // Drop impl of `RawConnection` as this would close the connection
333        Some(conn) => mem::ManuallyDrop::new(RawConnection {
334            internal_connection: conn,
335        }),
336        None => {
337            unsafe { context_error_str(ctx, NULL_CONN_ERR) };
338            return;
339        }
340    };
341
342    let data_ptr = unsafe { ffi::sqlite3_user_data(ctx) };
343
344    let mut data_ptr = match NonNull::new(data_ptr as *mut CustomFunctionUserPtr<F>) {
345        None => unsafe {
346            context_error_str(ctx, NULL_DATA_ERR);
347            return;
348        },
349        Some(mut f) => f,
350    };
351    let data_ptr = unsafe { data_ptr.as_mut() };
352
353    // We need this to move the reference into the catch_unwind part
354    // this is sound as `F` itself and the stored string is `UnwindSafe`
355    let callback = std::panic::AssertUnwindSafe(&mut data_ptr.callback);
356
357    let result = std::panic::catch_unwind(move || {
358        let _ = &callback;
359        let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
360        let res = (callback.0)(&*conn, args)?;
361        let value = process_sql_function_result(&res)?;
362        // We've checked already that ctx is not null
363        unsafe {
364            value.result_of(&mut *ctx);
365        }
366        Ok(())
367    })
368    .unwrap_or_else(|p| Err(SqliteCallbackError::Panic(data_ptr.function_name.clone())));
369    if let Err(e) = result {
370        e.emit(ctx);
371    }
372}
373
374// Need a custom option type here, because the std lib one does not have guarantees about the discriminate values
375// See: https://github.com/rust-lang/rfcs/blob/master/text/2195-really-tagged-unions.md#opaque-tags
376#[repr(u8)]
377enum OptionalAggregator<A> {
378    // Discriminant is 0
379    None,
380    Some(A),
381}
382
383#[allow(warnings)]
384extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
385    ctx: *mut ffi::sqlite3_context,
386    num_args: libc::c_int,
387    value_ptr: *mut *mut ffi::sqlite3_value,
388) where
389    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
390    Args: FromSqlRow<ArgsSqlType, Sqlite>,
391    Ret: ToSql<RetSqlType, Sqlite>,
392    Sqlite: HasSqlType<RetSqlType>,
393{
394    let result = std::panic::catch_unwind(move || {
395        let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
396        run_aggregator_step::<A, Args, ArgsSqlType>(ctx, args)
397    })
398    .unwrap_or_else(|e| {
399        Err(SqliteCallbackError::Panic(format!(
400            "{}::step() panicked",
401            std::any::type_name::<A>()
402        )))
403    });
404
405    match result {
406        Ok(()) => {}
407        Err(e) => e.emit(ctx),
408    }
409}
410
411fn run_aggregator_step<A, Args, ArgsSqlType>(
412    ctx: *mut ffi::sqlite3_context,
413    args: &mut [*mut ffi::sqlite3_value],
414) -> Result<(), SqliteCallbackError>
415where
416    A: SqliteAggregateFunction<Args>,
417    Args: FromSqlRow<ArgsSqlType, Sqlite>,
418{
419    static NULL_AG_CTX_ERR: &str = "An unknown error occurred. sqlite3_aggregate_context returned a null pointer. This should never happen.";
420    static NULL_CTX_ERR: &str =
421        "We've written the aggregator to the aggregate context, but it could not be retrieved.";
422
423    let n_bytes: i32 = std::mem::size_of::<OptionalAggregator<A>>()
424        .try_into()
425        .expect("Aggregate context should be larger than 2^32");
426    let aggregate_context = unsafe {
427        // This block of unsafe code makes the following assumptions:
428        //
429        // * sqlite3_aggregate_context allocates sizeof::<OptionalAggregator<A>>
430        //   bytes of zeroed memory as documented here:
431        //   https://www.sqlite.org/c3ref/aggregate_context.html
432        //   A null pointer is returned for negative or zero sized types,
433        //   which should be impossible in theory. We check that nevertheless
434        //
435        // * OptionalAggregator::None has a discriminant of 0 as specified by
436        //   #[repr(u8)] + RFC 2195
437        //
438        // * If all bytes are zero, the discriminant is also zero, so we can
439        //   assume that we get OptionalAggregator::None in this case. This is
440        //   not UB as we only access the discriminant here, so we do not try
441        //   to read any other zeroed memory. After that we initialize our enum
442        //   by writing a correct value at this location via ptr::write_unaligned
443        //
444        // * We use ptr::write_unaligned as we did not found any guarantees that
445        //   the memory will have a correct alignment.
446        //   (Note I(weiznich): would assume that it is aligned correctly, but we
447        //    we cannot guarantee it, so better be safe than sorry)
448        ffi::sqlite3_aggregate_context(ctx, n_bytes)
449    };
450    let aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
451    let aggregator = unsafe {
452        match aggregate_context.map(|a| &mut *a.as_ptr()) {
453            Some(&mut OptionalAggregator::Some(ref mut agg)) => agg,
454            Some(a_ptr @ &mut OptionalAggregator::None) => {
455                ptr::write_unaligned(a_ptr as *mut _, OptionalAggregator::Some(A::default()));
456                if let OptionalAggregator::Some(ref mut agg) = a_ptr {
457                    agg
458                } else {
459                    return Err(SqliteCallbackError::Abort(NULL_CTX_ERR));
460                }
461            }
462            None => {
463                return Err(SqliteCallbackError::Abort(NULL_AG_CTX_ERR));
464            }
465        }
466    };
467    let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
468
469    aggregator.step(args);
470    Ok(())
471}
472
473extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
474    ctx: *mut ffi::sqlite3_context,
475) where
476    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
477    Args: FromSqlRow<ArgsSqlType, Sqlite>,
478    Ret: ToSql<RetSqlType, Sqlite>,
479    Sqlite: HasSqlType<RetSqlType>,
480{
481    static NO_AGGREGATOR_FOUND: &str = "We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer.";
482    let aggregate_context = unsafe {
483        // Within the xFinal callback, it is customary to set nBytes to 0 so no pointless memory
484        // allocations occur, a null pointer is returned in this case
485        // See: https://www.sqlite.org/c3ref/aggregate_context.html
486        //
487        // For the reasoning about the safety of the OptionalAggregator handling
488        // see the comment in run_aggregator_step_function.
489        ffi::sqlite3_aggregate_context(ctx, 0)
490    };
491
492    let result = std::panic::catch_unwind(|| {
493        let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
494
495        let aggregator = if let Some(a) = aggregate_context.as_mut() {
496            let a = unsafe { a.as_mut() };
497            match std::mem::replace(a, OptionalAggregator::None) {
498                OptionalAggregator::None => {
499                    return Err(SqliteCallbackError::Abort(NO_AGGREGATOR_FOUND));
500                }
501                OptionalAggregator::Some(a) => Some(a),
502            }
503        } else {
504            None
505        };
506
507        let res = A::finalize(aggregator);
508        let value = process_sql_function_result(&res)?;
509        // We've checked already that ctx is not null
510        let r = unsafe { value.result_of(&mut *ctx) };
511        r.map_err(|e| {
512            SqliteCallbackError::DieselError(crate::result::Error::SerializationError(Box::new(e)))
513        })?;
514        Ok(())
515    })
516    .unwrap_or_else(|_e| {
517        Err(SqliteCallbackError::Panic(format!(
518            "{}::finalize() panicked",
519            std::any::type_name::<A>()
520        )))
521    });
522    if let Err(e) = result {
523        e.emit(ctx);
524    }
525}
526
527unsafe fn context_error_str(ctx: *mut ffi::sqlite3_context, error: &str) {
528    let len: i32 = error
529        .len()
530        .try_into()
531        .expect("Trying to set a error message with more than 2^32 byte is not supported");
532    ffi::sqlite3_result_error(ctx, error.as_ptr() as *const _, len);
533}
534
535struct CollationUserPtr<F> {
536    callback: F,
537    collation_name: String,
538}
539
540#[allow(warnings)]
541extern "C" fn run_collation_function<F>(
542    user_ptr: *mut libc::c_void,
543    lhs_len: libc::c_int,
544    lhs_ptr: *const libc::c_void,
545    rhs_len: libc::c_int,
546    rhs_ptr: *const libc::c_void,
547) -> libc::c_int
548where
549    F: Fn(&str, &str) -> std::cmp::Ordering + Send + std::panic::UnwindSafe + 'static,
550{
551    let user_ptr = user_ptr as *const CollationUserPtr<F>;
552    let user_ptr = std::panic::AssertUnwindSafe(unsafe { user_ptr.as_ref() });
553
554    let result = std::panic::catch_unwind(|| {
555        let user_ptr = user_ptr.ok_or_else(|| {
556            SqliteCallbackError::Abort(
557                "Got a null pointer as data pointer. This should never happen",
558            )
559        })?;
560        for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
561            if *len < 0 {
562                assert_fail!(
563                    "An unknown error occurred. {}_len is negative. This should never happen.",
564                    side
565                );
566            }
567            if ptr.is_null() {
568                assert_fail!(
569                "An unknown error occurred. {}_ptr is a null pointer. This should never happen.",
570                side
571            );
572            }
573        }
574
575        let (rhs, lhs) = unsafe {
576            // Depending on the eTextRep-parameter to sqlite3_create_collation_v2() the strings can
577            // have various encodings. register_collation_function() always selects SQLITE_UTF8, so the
578            // pointers point to valid UTF-8 strings (assuming correct behavior of libsqlite3).
579            (
580                str::from_utf8(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
581                str::from_utf8(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
582            )
583        };
584
585        let rhs =
586            rhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for rhs"))?;
587        let lhs =
588            lhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for lhs"))?;
589
590        Ok((user_ptr.callback)(rhs, lhs))
591    })
592    .unwrap_or_else(|p| {
593        Err(SqliteCallbackError::Panic(
594            user_ptr
595                .map(|u| u.collation_name.clone())
596                .unwrap_or_default(),
597        ))
598    });
599
600    match result {
601        Ok(std::cmp::Ordering::Less) => -1,
602        Ok(std::cmp::Ordering::Equal) => 0,
603        Ok(std::cmp::Ordering::Greater) => 1,
604        Err(SqliteCallbackError::Abort(a)) => {
605            eprintln!(
606                "Collation function {} failed with: {}",
607                user_ptr
608                    .map(|c| &c.collation_name as &str)
609                    .unwrap_or_default(),
610                a
611            );
612            std::process::abort()
613        }
614        Err(SqliteCallbackError::DieselError(e)) => {
615            eprintln!(
616                "Collation function {} failed with: {}",
617                user_ptr
618                    .map(|c| &c.collation_name as &str)
619                    .unwrap_or_default(),
620                e
621            );
622            std::process::abort()
623        }
624        Err(SqliteCallbackError::Panic(msg)) => {
625            eprintln!("Collation function {} panicked", msg);
626            std::process::abort()
627        }
628    }
629}
630
631extern "C" fn destroy_boxed<F>(data: *mut libc::c_void) {
632    let ptr = data as *mut F;
633    unsafe { std::mem::drop(Box::from_raw(ptr)) };
634}