diesel/sqlite/connection/
raw.rs

1#![allow(unsafe_code)] // ffi calls
2#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
3extern crate libsqlite3_sys as ffi;
4
5#[cfg(all(target_family = "wasm", target_os = "unknown"))]
6use sqlite_wasm_rs::export as ffi;
7
8use std::ffi::{CString, NulError};
9use std::io::{stderr, Write};
10use std::os::raw as libc;
11use std::ptr::NonNull;
12use std::{mem, ptr, slice, str};
13
14use super::functions::{build_sql_function_args, process_sql_function_result};
15use super::serialized_database::SerializedDatabase;
16use super::stmt::ensure_sqlite_ok;
17use super::{Sqlite, SqliteAggregateFunction};
18use crate::deserialize::FromSqlRow;
19use crate::result::Error::DatabaseError;
20use crate::result::*;
21use crate::serialize::ToSql;
22use crate::sql_types::HasSqlType;
23
24/// For use in FFI function, which cannot unwind.
25/// Print the message, ask to open an issue at Github and [`abort`](std::process::abort).
26macro_rules! assert_fail {
27    ($fmt:expr $(,$args:tt)*) => {
28        eprint!(concat!(
29            $fmt,
30            "If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\n",
31            "Source location: {}:{}\n",
32        ), $($args,)* file!(), line!());
33        std::process::abort()
34    };
35}
36
37#[allow(missing_debug_implementations, missing_copy_implementations)]
38pub(super) struct RawConnection {
39    pub(super) internal_connection: NonNull<ffi::sqlite3>,
40}
41
42impl RawConnection {
43    pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
44        let mut conn_pointer = ptr::null_mut();
45
46        let database_url = if database_url.starts_with("sqlite://") {
47            CString::new(database_url.replacen("sqlite://", "file:", 1))?
48        } else {
49            CString::new(database_url)?
50        };
51        let flags = ffi::SQLITE_OPEN_READWRITE | ffi::SQLITE_OPEN_CREATE | ffi::SQLITE_OPEN_URI;
52        let connection_status = unsafe {
53            ffi::sqlite3_open_v2(database_url.as_ptr(), &mut conn_pointer, flags, ptr::null())
54        };
55
56        match connection_status {
57            ffi::SQLITE_OK => {
58                let conn_pointer = unsafe { NonNull::new_unchecked(conn_pointer) };
59                Ok(RawConnection {
60                    internal_connection: conn_pointer,
61                })
62            }
63            err_code => {
64                let message = super::error_message(err_code);
65                Err(ConnectionError::BadConnection(message.into()))
66            }
67        }
68    }
69
70    pub(super) fn exec(&self, query: &str) -> QueryResult<()> {
71        let query = CString::new(query)?;
72        let callback_fn = None;
73        let callback_arg = ptr::null_mut();
74        let result = unsafe {
75            ffi::sqlite3_exec(
76                self.internal_connection.as_ptr(),
77                query.as_ptr(),
78                callback_fn,
79                callback_arg,
80                ptr::null_mut(),
81            )
82        };
83
84        ensure_sqlite_ok(result, self.internal_connection.as_ptr())
85    }
86
87    pub(super) fn rows_affected_by_last_query(
88        &self,
89    ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
90        let r = unsafe { ffi::sqlite3_changes(self.internal_connection.as_ptr()) };
91
92        Ok(r.try_into()?)
93    }
94
95    pub(super) fn register_sql_function<F, Ret, RetSqlType>(
96        &self,
97        fn_name: &str,
98        num_args: usize,
99        deterministic: bool,
100        f: F,
101    ) -> QueryResult<()>
102    where
103        F: FnMut(&Self, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
104            + std::panic::UnwindSafe
105            + Send
106            + 'static,
107        Ret: ToSql<RetSqlType, Sqlite>,
108        Sqlite: HasSqlType<RetSqlType>,
109    {
110        let callback_fn = Box::into_raw(Box::new(CustomFunctionUserPtr {
111            callback: f,
112            function_name: fn_name.to_owned(),
113        }));
114        let fn_name = Self::get_fn_name(fn_name)?;
115        let flags = Self::get_flags(deterministic);
116        let num_args = num_args
117            .try_into()
118            .map_err(|e| Error::SerializationError(Box::new(e)))?;
119
120        let result = unsafe {
121            ffi::sqlite3_create_function_v2(
122                self.internal_connection.as_ptr(),
123                fn_name.as_ptr(),
124                num_args,
125                flags,
126                callback_fn as *mut _,
127                Some(run_custom_function::<F, Ret, RetSqlType>),
128                None,
129                None,
130                Some(destroy_boxed::<CustomFunctionUserPtr<F>>),
131            )
132        };
133
134        Self::process_sql_function_result(result)
135    }
136
137    pub(super) fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
138        &self,
139        fn_name: &str,
140        num_args: usize,
141    ) -> QueryResult<()>
142    where
143        A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
144        Args: FromSqlRow<ArgsSqlType, Sqlite>,
145        Ret: ToSql<RetSqlType, Sqlite>,
146        Sqlite: HasSqlType<RetSqlType>,
147    {
148        let fn_name = Self::get_fn_name(fn_name)?;
149        let flags = Self::get_flags(false);
150        let num_args = num_args
151            .try_into()
152            .map_err(|e| Error::SerializationError(Box::new(e)))?;
153
154        let result = unsafe {
155            ffi::sqlite3_create_function_v2(
156                self.internal_connection.as_ptr(),
157                fn_name.as_ptr(),
158                num_args,
159                flags,
160                ptr::null_mut(),
161                None,
162                Some(run_aggregator_step_function::<_, _, _, _, A>),
163                Some(run_aggregator_final_function::<_, _, _, _, A>),
164                None,
165            )
166        };
167
168        Self::process_sql_function_result(result)
169    }
170
171    pub(super) fn register_collation_function<F>(
172        &self,
173        collation_name: &str,
174        collation: F,
175    ) -> QueryResult<()>
176    where
177        F: Fn(&str, &str) -> std::cmp::Ordering + std::panic::UnwindSafe + Send + 'static,
178    {
179        let callback_fn = Box::into_raw(Box::new(CollationUserPtr {
180            callback: collation,
181            collation_name: collation_name.to_owned(),
182        }));
183        let collation_name = Self::get_fn_name(collation_name)?;
184
185        let result = unsafe {
186            ffi::sqlite3_create_collation_v2(
187                self.internal_connection.as_ptr(),
188                collation_name.as_ptr(),
189                ffi::SQLITE_UTF8,
190                callback_fn as *mut _,
191                Some(run_collation_function::<F>),
192                Some(destroy_boxed::<CollationUserPtr<F>>),
193            )
194        };
195
196        let result = Self::process_sql_function_result(result);
197        if result.is_err() {
198            destroy_boxed::<CollationUserPtr<F>>(callback_fn as *mut _);
199        }
200        result
201    }
202
203    pub(super) fn serialize(&mut self) -> SerializedDatabase {
204        unsafe {
205            let mut size: ffi::sqlite3_int64 = 0;
206            let data_ptr = ffi::sqlite3_serialize(
207                self.internal_connection.as_ptr(),
208                std::ptr::null(),
209                &mut size as *mut _,
210                0,
211            );
212            SerializedDatabase::new(
213                data_ptr,
214                size.try_into()
215                    .expect("Cannot fit the serialized database into memory"),
216            )
217        }
218    }
219
220    pub(super) fn deserialize(&mut self, data: &[u8]) -> QueryResult<()> {
221        let db_size = data
222            .len()
223            .try_into()
224            .map_err(|e| Error::DeserializationError(Box::new(e)))?;
225        // the cast for `ffi::SQLITE_DESERIALIZE_READONLY` is required for old libsqlite3-sys versions
226        #[allow(clippy::unnecessary_cast)]
227        unsafe {
228            let result = ffi::sqlite3_deserialize(
229                self.internal_connection.as_ptr(),
230                std::ptr::null(),
231                data.as_ptr() as *mut u8,
232                db_size,
233                db_size,
234                ffi::SQLITE_DESERIALIZE_READONLY as u32,
235            );
236
237            ensure_sqlite_ok(result, self.internal_connection.as_ptr())
238        }
239    }
240
241    fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
242        CString::new(fn_name)
243    }
244
245    fn get_flags(deterministic: bool) -> i32 {
246        let mut flags = ffi::SQLITE_UTF8;
247        if deterministic {
248            flags |= ffi::SQLITE_DETERMINISTIC;
249        }
250        flags
251    }
252
253    fn process_sql_function_result(result: i32) -> Result<(), Error> {
254        if result == ffi::SQLITE_OK {
255            Ok(())
256        } else {
257            let error_message = super::error_message(result);
258            Err(DatabaseError(
259                DatabaseErrorKind::Unknown,
260                Box::new(error_message.to_string()),
261            ))
262        }
263    }
264}
265
266impl Drop for RawConnection {
267    fn drop(&mut self) {
268        use std::thread::panicking;
269
270        let close_result = unsafe { ffi::sqlite3_close(self.internal_connection.as_ptr()) };
271        if close_result != ffi::SQLITE_OK {
272            let error_message = super::error_message(close_result);
273            if panicking() {
274                write!(stderr(), "Error closing SQLite connection: {error_message}")
275                    .expect("Error writing to `stderr`");
276            } else {
277                panic!("Error closing SQLite connection: {}", error_message);
278            }
279        }
280    }
281}
282
283enum SqliteCallbackError {
284    Abort(&'static str),
285    DieselError(crate::result::Error),
286    Panic(String),
287}
288
289impl SqliteCallbackError {
290    fn emit(&self, ctx: *mut ffi::sqlite3_context) {
291        let s;
292        let msg = match self {
293            SqliteCallbackError::Abort(msg) => *msg,
294            SqliteCallbackError::DieselError(e) => {
295                s = e.to_string();
296                &s
297            }
298            SqliteCallbackError::Panic(msg) => msg,
299        };
300        unsafe {
301            context_error_str(ctx, msg);
302        }
303    }
304}
305
306impl From<crate::result::Error> for SqliteCallbackError {
307    fn from(e: crate::result::Error) -> Self {
308        Self::DieselError(e)
309    }
310}
311
312struct CustomFunctionUserPtr<F> {
313    callback: F,
314    function_name: String,
315}
316
317#[allow(warnings)]
318extern "C" fn run_custom_function<F, Ret, RetSqlType>(
319    ctx: *mut ffi::sqlite3_context,
320    num_args: libc::c_int,
321    value_ptr: *mut *mut ffi::sqlite3_value,
322) where
323    F: FnMut(&RawConnection, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
324        + std::panic::UnwindSafe
325        + Send
326        + 'static,
327    Ret: ToSql<RetSqlType, Sqlite>,
328    Sqlite: HasSqlType<RetSqlType>,
329{
330    use std::ops::Deref;
331    static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen.";
332    static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen.";
333
334    let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } {
335        // We use `ManuallyDrop` here because we do not want to run the
336        // Drop impl of `RawConnection` as this would close the connection
337        Some(conn) => mem::ManuallyDrop::new(RawConnection {
338            internal_connection: conn,
339        }),
340        None => {
341            unsafe { context_error_str(ctx, NULL_CONN_ERR) };
342            return;
343        }
344    };
345
346    let data_ptr = unsafe { ffi::sqlite3_user_data(ctx) };
347
348    let mut data_ptr = match NonNull::new(data_ptr as *mut CustomFunctionUserPtr<F>) {
349        None => unsafe {
350            context_error_str(ctx, NULL_DATA_ERR);
351            return;
352        },
353        Some(mut f) => f,
354    };
355    let data_ptr = unsafe { data_ptr.as_mut() };
356
357    // We need this to move the reference into the catch_unwind part
358    // this is sound as `F` itself and the stored string is `UnwindSafe`
359    let callback = std::panic::AssertUnwindSafe(&mut data_ptr.callback);
360
361    let result = std::panic::catch_unwind(move || {
362        let _ = &callback;
363        let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
364        let res = (callback.0)(&*conn, args)?;
365        let value = process_sql_function_result(&res)?;
366        // We've checked already that ctx is not null
367        unsafe {
368            value.result_of(&mut *ctx);
369        }
370        Ok(())
371    })
372    .unwrap_or_else(|p| Err(SqliteCallbackError::Panic(data_ptr.function_name.clone())));
373    if let Err(e) = result {
374        e.emit(ctx);
375    }
376}
377
378// Need a custom option type here, because the std lib one does not have guarantees about the discriminate values
379// See: https://github.com/rust-lang/rfcs/blob/master/text/2195-really-tagged-unions.md#opaque-tags
380#[repr(u8)]
381enum OptionalAggregator<A> {
382    // Discriminant is 0
383    None,
384    Some(A),
385}
386
387#[allow(warnings)]
388extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
389    ctx: *mut ffi::sqlite3_context,
390    num_args: libc::c_int,
391    value_ptr: *mut *mut ffi::sqlite3_value,
392) where
393    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
394    Args: FromSqlRow<ArgsSqlType, Sqlite>,
395    Ret: ToSql<RetSqlType, Sqlite>,
396    Sqlite: HasSqlType<RetSqlType>,
397{
398    let result = std::panic::catch_unwind(move || {
399        let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
400        run_aggregator_step::<A, Args, ArgsSqlType>(ctx, args)
401    })
402    .unwrap_or_else(|e| {
403        Err(SqliteCallbackError::Panic(format!(
404            "{}::step() panicked",
405            std::any::type_name::<A>()
406        )))
407    });
408
409    match result {
410        Ok(()) => {}
411        Err(e) => e.emit(ctx),
412    }
413}
414
415fn run_aggregator_step<A, Args, ArgsSqlType>(
416    ctx: *mut ffi::sqlite3_context,
417    args: &mut [*mut ffi::sqlite3_value],
418) -> Result<(), SqliteCallbackError>
419where
420    A: SqliteAggregateFunction<Args>,
421    Args: FromSqlRow<ArgsSqlType, Sqlite>,
422{
423    static NULL_AG_CTX_ERR: &str = "An unknown error occurred. sqlite3_aggregate_context returned a null pointer. This should never happen.";
424    static NULL_CTX_ERR: &str =
425        "We've written the aggregator to the aggregate context, but it could not be retrieved.";
426
427    let n_bytes: i32 = std::mem::size_of::<OptionalAggregator<A>>()
428        .try_into()
429        .expect("Aggregate context should be larger than 2^32");
430    let aggregate_context = unsafe {
431        // This block of unsafe code makes the following assumptions:
432        //
433        // * sqlite3_aggregate_context allocates sizeof::<OptionalAggregator<A>>
434        //   bytes of zeroed memory as documented here:
435        //   https://www.sqlite.org/c3ref/aggregate_context.html
436        //   A null pointer is returned for negative or zero sized types,
437        //   which should be impossible in theory. We check that nevertheless
438        //
439        // * OptionalAggregator::None has a discriminant of 0 as specified by
440        //   #[repr(u8)] + RFC 2195
441        //
442        // * If all bytes are zero, the discriminant is also zero, so we can
443        //   assume that we get OptionalAggregator::None in this case. This is
444        //   not UB as we only access the discriminant here, so we do not try
445        //   to read any other zeroed memory. After that we initialize our enum
446        //   by writing a correct value at this location via ptr::write_unaligned
447        //
448        // * We use ptr::write_unaligned as we did not found any guarantees that
449        //   the memory will have a correct alignment.
450        //   (Note I(weiznich): would assume that it is aligned correctly, but we
451        //    we cannot guarantee it, so better be safe than sorry)
452        ffi::sqlite3_aggregate_context(ctx, n_bytes)
453    };
454    let aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
455    let aggregator = unsafe {
456        match aggregate_context.map(|a| &mut *a.as_ptr()) {
457            Some(&mut OptionalAggregator::Some(ref mut agg)) => agg,
458            Some(a_ptr @ &mut OptionalAggregator::None) => {
459                ptr::write_unaligned(a_ptr as *mut _, OptionalAggregator::Some(A::default()));
460                if let OptionalAggregator::Some(ref mut agg) = a_ptr {
461                    agg
462                } else {
463                    return Err(SqliteCallbackError::Abort(NULL_CTX_ERR));
464                }
465            }
466            None => {
467                return Err(SqliteCallbackError::Abort(NULL_AG_CTX_ERR));
468            }
469        }
470    };
471    let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
472
473    aggregator.step(args);
474    Ok(())
475}
476
477extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
478    ctx: *mut ffi::sqlite3_context,
479) where
480    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
481    Args: FromSqlRow<ArgsSqlType, Sqlite>,
482    Ret: ToSql<RetSqlType, Sqlite>,
483    Sqlite: HasSqlType<RetSqlType>,
484{
485    static NO_AGGREGATOR_FOUND: &str = "We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer.";
486    let aggregate_context = unsafe {
487        // Within the xFinal callback, it is customary to set nBytes to 0 so no pointless memory
488        // allocations occur, a null pointer is returned in this case
489        // See: https://www.sqlite.org/c3ref/aggregate_context.html
490        //
491        // For the reasoning about the safety of the OptionalAggregator handling
492        // see the comment in run_aggregator_step_function.
493        ffi::sqlite3_aggregate_context(ctx, 0)
494    };
495
496    let result = std::panic::catch_unwind(|| {
497        let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
498
499        let aggregator = if let Some(a) = aggregate_context.as_mut() {
500            let a = unsafe { a.as_mut() };
501            match std::mem::replace(a, OptionalAggregator::None) {
502                OptionalAggregator::None => {
503                    return Err(SqliteCallbackError::Abort(NO_AGGREGATOR_FOUND));
504                }
505                OptionalAggregator::Some(a) => Some(a),
506            }
507        } else {
508            None
509        };
510
511        let res = A::finalize(aggregator);
512        let value = process_sql_function_result(&res)?;
513        // We've checked already that ctx is not null
514        let r = unsafe { value.result_of(&mut *ctx) };
515        r.map_err(|e| {
516            SqliteCallbackError::DieselError(crate::result::Error::SerializationError(Box::new(e)))
517        })?;
518        Ok(())
519    })
520    .unwrap_or_else(|_e| {
521        Err(SqliteCallbackError::Panic(format!(
522            "{}::finalize() panicked",
523            std::any::type_name::<A>()
524        )))
525    });
526    if let Err(e) = result {
527        e.emit(ctx);
528    }
529}
530
531unsafe fn context_error_str(ctx: *mut ffi::sqlite3_context, error: &str) {
532    let len: i32 = error
533        .len()
534        .try_into()
535        .expect("Trying to set a error message with more than 2^32 byte is not supported");
536    ffi::sqlite3_result_error(ctx, error.as_ptr() as *const _, len);
537}
538
539struct CollationUserPtr<F> {
540    callback: F,
541    collation_name: String,
542}
543
544#[allow(warnings)]
545extern "C" fn run_collation_function<F>(
546    user_ptr: *mut libc::c_void,
547    lhs_len: libc::c_int,
548    lhs_ptr: *const libc::c_void,
549    rhs_len: libc::c_int,
550    rhs_ptr: *const libc::c_void,
551) -> libc::c_int
552where
553    F: Fn(&str, &str) -> std::cmp::Ordering + Send + std::panic::UnwindSafe + 'static,
554{
555    let user_ptr = user_ptr as *const CollationUserPtr<F>;
556    let user_ptr = std::panic::AssertUnwindSafe(unsafe { user_ptr.as_ref() });
557
558    let result = std::panic::catch_unwind(|| {
559        let user_ptr = user_ptr.ok_or_else(|| {
560            SqliteCallbackError::Abort(
561                "Got a null pointer as data pointer. This should never happen",
562            )
563        })?;
564        for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
565            if *len < 0 {
566                assert_fail!(
567                    "An unknown error occurred. {}_len is negative. This should never happen.",
568                    side
569                );
570            }
571            if ptr.is_null() {
572                assert_fail!(
573                "An unknown error occurred. {}_ptr is a null pointer. This should never happen.",
574                side
575            );
576            }
577        }
578
579        let (rhs, lhs) = unsafe {
580            // Depending on the eTextRep-parameter to sqlite3_create_collation_v2() the strings can
581            // have various encodings. register_collation_function() always selects SQLITE_UTF8, so the
582            // pointers point to valid UTF-8 strings (assuming correct behavior of libsqlite3).
583            (
584                str::from_utf8(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
585                str::from_utf8(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
586            )
587        };
588
589        let rhs =
590            rhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for rhs"))?;
591        let lhs =
592            lhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for lhs"))?;
593
594        Ok((user_ptr.callback)(rhs, lhs))
595    })
596    .unwrap_or_else(|p| {
597        Err(SqliteCallbackError::Panic(
598            user_ptr
599                .map(|u| u.collation_name.clone())
600                .unwrap_or_default(),
601        ))
602    });
603
604    match result {
605        Ok(std::cmp::Ordering::Less) => -1,
606        Ok(std::cmp::Ordering::Equal) => 0,
607        Ok(std::cmp::Ordering::Greater) => 1,
608        Err(SqliteCallbackError::Abort(a)) => {
609            eprintln!(
610                "Collation function {} failed with: {}",
611                user_ptr
612                    .map(|c| &c.collation_name as &str)
613                    .unwrap_or_default(),
614                a
615            );
616            std::process::abort()
617        }
618        Err(SqliteCallbackError::DieselError(e)) => {
619            eprintln!(
620                "Collation function {} failed with: {}",
621                user_ptr
622                    .map(|c| &c.collation_name as &str)
623                    .unwrap_or_default(),
624                e
625            );
626            std::process::abort()
627        }
628        Err(SqliteCallbackError::Panic(msg)) => {
629            eprintln!("Collation function {} panicked", msg);
630            std::process::abort()
631        }
632    }
633}
634
635extern "C" fn destroy_boxed<F>(data: *mut libc::c_void) {
636    let ptr = data as *mut F;
637    unsafe { std::mem::drop(Box::from_raw(ptr)) };
638}