Skip to main content

diesel/sqlite/connection/
raw.rs

1#![allow(unsafe_code)] // ffi calls
2#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
3extern crate libsqlite3_sys as ffi;
4
5#[cfg(all(target_family = "wasm", target_os = "unknown"))]
6use sqlite_wasm_rs as ffi;
7
8use std::ffi::{CString, NulError};
9use std::io::{stderr, Write};
10use std::os::raw as libc;
11use std::ptr::NonNull;
12use std::{mem, ptr, slice, str};
13
14use super::functions::{build_sql_function_args, process_sql_function_result};
15use super::serialized_database::SerializedDatabase;
16use super::stmt::ensure_sqlite_ok;
17use super::{Sqlite, SqliteAggregateFunction};
18use crate::deserialize::FromSqlRow;
19use crate::result::Error::DatabaseError;
20use crate::result::*;
21use crate::serialize::ToSql;
22use crate::sql_types::HasSqlType;
23
24/// For use in FFI function, which cannot unwind.
25/// Print the message, ask to open an issue at Github and [`abort`](std::process::abort).
26macro_rules! assert_fail {
27    ($fmt:expr $(,$args:tt)*) => {
28        eprint!(concat!(
29            $fmt,
30            "If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\n",
31            "Source location: {}:{}\n",
32        ), $($args,)* file!(), line!());
33        std::process::abort()
34    };
35}
36
37#[allow(missing_debug_implementations, missing_copy_implementations)]
38pub(super) struct RawConnection {
39    pub(super) internal_connection: NonNull<ffi::sqlite3>,
40}
41
42impl RawConnection {
43    pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
44        let mut conn_pointer = ptr::null_mut();
45
46        let database_url = if database_url.starts_with("sqlite://") {
47            CString::new(database_url.replacen("sqlite://", "file:", 1))?
48        } else {
49            CString::new(database_url)?
50        };
51        let flags = ffi::SQLITE_OPEN_READWRITE | ffi::SQLITE_OPEN_CREATE | ffi::SQLITE_OPEN_URI;
52        let connection_status = unsafe {
53            ffi::sqlite3_open_v2(database_url.as_ptr(), &mut conn_pointer, flags, ptr::null())
54        };
55
56        match connection_status {
57            ffi::SQLITE_OK => {
58                let conn_pointer = unsafe { NonNull::new_unchecked(conn_pointer) };
59                Ok(RawConnection {
60                    internal_connection: conn_pointer,
61                })
62            }
63            err_code => {
64                let message = super::error_message(err_code);
65                // sqlite3_open_v2() may allocate a database connection handle
66                // even on failure. To avoid a resource leak, it must be released
67                // with sqlite3_close(). Passing a null pointer to sqlite3_close()
68                // is a harmless no-op, so no null check is needed.
69                // See: https://www.sqlite.org/c3ref/open.html
70                unsafe { ffi::sqlite3_close(conn_pointer) };
71                Err(ConnectionError::BadConnection(message.into()))
72            }
73        }
74    }
75
76    pub(super) fn exec(&self, query: &str) -> QueryResult<()> {
77        let query = CString::new(query)?;
78        let callback_fn = None;
79        let callback_arg = ptr::null_mut();
80        let result = unsafe {
81            ffi::sqlite3_exec(
82                self.internal_connection.as_ptr(),
83                query.as_ptr(),
84                callback_fn,
85                callback_arg,
86                ptr::null_mut(),
87            )
88        };
89
90        ensure_sqlite_ok(result, self.internal_connection.as_ptr())
91    }
92
93    pub(super) fn rows_affected_by_last_query(
94        &self,
95    ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
96        let r = unsafe { ffi::sqlite3_changes(self.internal_connection.as_ptr()) };
97
98        Ok(r.try_into()?)
99    }
100
101    pub(super) fn register_sql_function<F, Ret, RetSqlType>(
102        &self,
103        fn_name: &str,
104        num_args: usize,
105        deterministic: bool,
106        f: F,
107    ) -> QueryResult<()>
108    where
109        F: FnMut(&Self, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
110            + std::panic::UnwindSafe
111            + Send
112            + 'static,
113        Ret: ToSql<RetSqlType, Sqlite>,
114        Sqlite: HasSqlType<RetSqlType>,
115    {
116        let c_fn_name = Self::get_fn_name(fn_name)?;
117        let flags = Self::get_flags(deterministic);
118        let num_args = num_args
119            .try_into()
120            .map_err(|e| Error::SerializationError(Box::new(e)))?;
121        // only create the pointer as last step here
122        // as we can otherwise leak memory
123        let callback_fn = Box::into_raw(Box::new(CustomFunctionUserPtr {
124            callback: f,
125            function_name: fn_name.to_owned(),
126        }));
127
128        let result = unsafe {
129            ffi::sqlite3_create_function_v2(
130                self.internal_connection.as_ptr(),
131                c_fn_name.as_ptr(),
132                num_args,
133                flags,
134                callback_fn as *mut _,
135                Some(run_custom_function::<F, Ret, RetSqlType>),
136                None,
137                None,
138                Some(destroy_boxed::<CustomFunctionUserPtr<F>>),
139            )
140        };
141
142        Self::process_sql_function_result(result)
143    }
144
145    pub(super) fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
146        &self,
147        fn_name: &str,
148        num_args: usize,
149    ) -> QueryResult<()>
150    where
151        A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
152        Args: FromSqlRow<ArgsSqlType, Sqlite>,
153        Ret: ToSql<RetSqlType, Sqlite>,
154        Sqlite: HasSqlType<RetSqlType>,
155    {
156        let fn_name = Self::get_fn_name(fn_name)?;
157        let flags = Self::get_flags(false);
158        let num_args = num_args
159            .try_into()
160            .map_err(|e| Error::SerializationError(Box::new(e)))?;
161
162        // we pass in the aggregate instance as a user pointer
163        // instead of relying on `sqlite3_aggregate_context` to have
164        // control about the allocation. Specifically allocating
165        // on the rust side will make sure the alignment of the
166        // aggregate instance is correct
167        // Any potential failing code is happening above so we don't leak this memory
168        let ctx_ptr = Box::into_raw(Box::new(None::<A>));
169
170        let result = unsafe {
171            ffi::sqlite3_create_function_v2(
172                self.internal_connection.as_ptr(),
173                fn_name.as_ptr(),
174                num_args,
175                flags,
176                ctx_ptr.cast(),
177                None,
178                Some(run_aggregator_step_function::<_, _, _, _, A>),
179                Some(run_aggregator_final_function::<_, _, _, _, A>),
180                Some(destroy_boxed::<Option<A>>),
181            )
182        };
183
184        Self::process_sql_function_result(result)
185    }
186
187    pub(super) fn register_collation_function<F>(
188        &self,
189        collation_name: &str,
190        collation: F,
191    ) -> QueryResult<()>
192    where
193        F: Fn(&str, &str) -> std::cmp::Ordering + std::panic::UnwindSafe + Send + 'static,
194    {
195        let c_collation_name = Self::get_fn_name(collation_name)?;
196        // only create the pointer as last step here as we otherwise could leak memory
197        let callback_fn = Box::into_raw(Box::new(CollationUserPtr {
198            callback: collation,
199            collation_name: collation_name.to_owned(),
200        }));
201
202        let result = unsafe {
203            ffi::sqlite3_create_collation_v2(
204                self.internal_connection.as_ptr(),
205                c_collation_name.as_ptr(),
206                ffi::SQLITE_UTF8,
207                callback_fn as *mut _,
208                Some(run_collation_function::<F>),
209                Some(destroy_boxed::<CollationUserPtr<F>>),
210            )
211        };
212
213        let result = Self::process_sql_function_result(result);
214        if result.is_err() {
215            destroy_boxed::<CollationUserPtr<F>>(callback_fn as *mut _);
216        }
217        result
218    }
219
220    pub(super) fn serialize(&mut self) -> SerializedDatabase {
221        unsafe {
222            let mut size: ffi::sqlite3_int64 = 0;
223            let data_ptr = ffi::sqlite3_serialize(
224                self.internal_connection.as_ptr(),
225                std::ptr::null(),
226                &mut size as *mut _,
227                0,
228            );
229            SerializedDatabase::new(
230                data_ptr,
231                size.try_into()
232                    .expect("Cannot fit the serialized database into memory"),
233            )
234        }
235    }
236
237    pub(super) fn deserialize(&mut self, data: &[u8]) -> QueryResult<()> {
238        let db_size = data
239            .len()
240            .try_into()
241            .map_err(|e| Error::DeserializationError(Box::new(e)))?;
242        // the cast for `ffi::SQLITE_DESERIALIZE_READONLY` is required for old libsqlite3-sys versions
243        #[allow(clippy::unnecessary_cast)]
244        unsafe {
245            let result = ffi::sqlite3_deserialize(
246                self.internal_connection.as_ptr(),
247                std::ptr::null(),
248                data.as_ptr() as *mut u8,
249                db_size,
250                db_size,
251                ffi::SQLITE_DESERIALIZE_READONLY as u32,
252            );
253
254            ensure_sqlite_ok(result, self.internal_connection.as_ptr())
255        }
256    }
257
258    fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
259        CString::new(fn_name)
260    }
261
262    fn get_flags(deterministic: bool) -> i32 {
263        let mut flags = ffi::SQLITE_UTF8;
264        if deterministic {
265            flags |= ffi::SQLITE_DETERMINISTIC;
266        }
267        flags
268    }
269
270    fn process_sql_function_result(result: i32) -> Result<(), Error> {
271        if result == ffi::SQLITE_OK {
272            Ok(())
273        } else {
274            let error_message = super::error_message(result);
275            Err(DatabaseError(
276                DatabaseErrorKind::Unknown,
277                Box::new(error_message.to_string()),
278            ))
279        }
280    }
281}
282
283impl Drop for RawConnection {
284    fn drop(&mut self) {
285        use std::thread::panicking;
286
287        let close_result = unsafe { ffi::sqlite3_close(self.internal_connection.as_ptr()) };
288        if close_result != ffi::SQLITE_OK {
289            let error_message = super::error_message(close_result);
290            if panicking() {
291                stderr().write_fmt(format_args!("Error closing SQLite connection: {0}",
        error_message))write!(stderr(), "Error closing SQLite connection: {error_message}")
292                    .expect("Error writing to `stderr`");
293            } else {
294                {
    ::core::panicking::panic_fmt(format_args!("Error closing SQLite connection: {0}",
            error_message));
};panic!("Error closing SQLite connection: {error_message}");
295            }
296        }
297    }
298}
299
300enum SqliteCallbackError {
301    Abort(&'static str),
302    DieselError(crate::result::Error),
303    Panic(String),
304}
305
306impl SqliteCallbackError {
307    fn emit(&self, ctx: *mut ffi::sqlite3_context) {
308        let s;
309        let msg = match self {
310            SqliteCallbackError::Abort(msg) => *msg,
311            SqliteCallbackError::DieselError(e) => {
312                s = e.to_string();
313                &s
314            }
315            SqliteCallbackError::Panic(msg) => msg,
316        };
317        unsafe {
318            context_error_str(ctx, msg);
319        }
320    }
321}
322
323impl From<crate::result::Error> for SqliteCallbackError {
324    fn from(e: crate::result::Error) -> Self {
325        Self::DieselError(e)
326    }
327}
328
329struct CustomFunctionUserPtr<F> {
330    callback: F,
331    function_name: String,
332}
333
334#[allow(warnings)]
335extern "C" fn run_custom_function<F, Ret, RetSqlType>(
336    ctx: *mut ffi::sqlite3_context,
337    num_args: libc::c_int,
338    value_ptr: *mut *mut ffi::sqlite3_value,
339) where
340    F: FnMut(&RawConnection, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
341        + std::panic::UnwindSafe
342        + Send
343        + 'static,
344    Ret: ToSql<RetSqlType, Sqlite>,
345    Sqlite: HasSqlType<RetSqlType>,
346{
347    use std::ops::Deref;
348    static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen.";
349    static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen.";
350
351    let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } {
352        // We use `ManuallyDrop` here because we do not want to run the
353        // Drop impl of `RawConnection` as this would close the connection
354        Some(conn) => mem::ManuallyDrop::new(RawConnection {
355            internal_connection: conn,
356        }),
357        None => {
358            unsafe { context_error_str(ctx, NULL_CONN_ERR) };
359            return;
360        }
361    };
362
363    let data_ptr = unsafe { ffi::sqlite3_user_data(ctx) };
364
365    let mut data_ptr = match NonNull::new(data_ptr as *mut CustomFunctionUserPtr<F>) {
366        None => unsafe {
367            context_error_str(ctx, NULL_DATA_ERR);
368            return;
369        },
370        Some(mut f) => f,
371    };
372    let data_ptr = unsafe { data_ptr.as_mut() };
373
374    // We need this to move the reference into the catch_unwind part
375    // this is sound as `F` itself and the stored string is `UnwindSafe`
376    let callback = std::panic::AssertUnwindSafe(&mut data_ptr.callback);
377
378    let result = std::panic::catch_unwind(move || {
379        let _ = &callback;
380        let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
381        let res = (callback.0)(&*conn, args)?;
382        let value = process_sql_function_result(&res)?;
383        // We've checked already that ctx is not null
384        unsafe {
385            value.result_of(&mut *ctx);
386        }
387        Ok(())
388    })
389    .unwrap_or_else(|p| Err(SqliteCallbackError::Panic(data_ptr.function_name.clone())));
390    if let Err(e) = result {
391        e.emit(ctx);
392    }
393}
394
395#[allow(warnings)]
396extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
397    ctx: *mut ffi::sqlite3_context,
398    num_args: libc::c_int,
399    value_ptr: *mut *mut ffi::sqlite3_value,
400) where
401    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
402    Args: FromSqlRow<ArgsSqlType, Sqlite>,
403    Ret: ToSql<RetSqlType, Sqlite>,
404    Sqlite: HasSqlType<RetSqlType>,
405{
406    let result = std::panic::catch_unwind(move || {
407        let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
408        run_aggregator_step::<A, Args, ArgsSqlType>(ctx, args)
409    })
410    .unwrap_or_else(|e| {
411        Err(SqliteCallbackError::Panic(::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("{0}::step() panicked",
                std::any::type_name::<A>()))
    })format!(
412            "{}::step() panicked",
413            std::any::type_name::<A>()
414        )))
415    });
416
417    match result {
418        Ok(()) => {}
419        Err(e) => e.emit(ctx),
420    }
421}
422
423fn run_aggregator_step<A, Args, ArgsSqlType>(
424    ctx: *mut ffi::sqlite3_context,
425    args: &mut [*mut ffi::sqlite3_value],
426) -> Result<(), SqliteCallbackError>
427where
428    A: SqliteAggregateFunction<Args>,
429    Args: FromSqlRow<ArgsSqlType, Sqlite>,
430{
431    static NULL_CTX_ERR: &str =
432        "We've written the aggregator to the user data, but it could not be retrieved.";
433
434    let aggregator = unsafe {
435        // SAFETY:
436        // * We set the corresponding data in `register_aggregate_function` above
437        // * It has the correct type
438        let aggregate_context = ffi::sqlite3_user_data(ctx) as *mut Option<A>;
439        // SAFETY:
440        // * The layout is correct as we allocated the value on the rust side
441        // * The alignment is correct as we allocated the value on the rust side
442        match aggregate_context.as_mut() {
443            Some(Some(agg)) => agg,
444            Some(r) => {
445                *r = Some(A::default());
446                r.as_mut().expect("Initialised literally above")
447            }
448            None => return Err(SqliteCallbackError::Abort(NULL_CTX_ERR)),
449        }
450    };
451
452    let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
453
454    aggregator.step(args);
455    Ok(())
456}
457
458extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
459    ctx: *mut ffi::sqlite3_context,
460) where
461    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
462    Args: FromSqlRow<ArgsSqlType, Sqlite>,
463    Ret: ToSql<RetSqlType, Sqlite>,
464    Sqlite: HasSqlType<RetSqlType>,
465{
466    let result = std::panic::catch_unwind(|| {
467        let aggregator = unsafe {
468            // SAFETY:
469            // * We set the corresponding data in `register_aggregate_function` above
470            // * It has the correct type
471            let aggregate_context = ffi::sqlite3_user_data(ctx) as *mut Option<A>;
472            // SAFETY:
473            // * The value has the correct layout as we initialised it on the rust side
474            // * The value has the correct alignment as we initialised it on the rust side
475            aggregate_context.as_mut()
476        }
477        .and_then(|a| a.take());
478
479        let res = A::finalize(aggregator);
480        let value = process_sql_function_result(&res)?;
481        // We've checked already that ctx is not null
482        let r = unsafe { value.result_of(&mut *ctx) };
483        r.map_err(|e| {
484            SqliteCallbackError::DieselError(crate::result::Error::SerializationError(Box::new(e)))
485        })?;
486        Ok(())
487    })
488    .unwrap_or_else(|_e| {
489        Err(SqliteCallbackError::Panic(::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("{0}::finalize() panicked",
                std::any::type_name::<A>()))
    })format!(
490            "{}::finalize() panicked",
491            std::any::type_name::<A>()
492        )))
493    });
494    if let Err(e) = result {
495        e.emit(ctx);
496    }
497}
498
499unsafe fn context_error_str(ctx: *mut ffi::sqlite3_context, error: &str) {
500    let len: i32 = error.len().try_into().unwrap_or(i32::MAX);
501    unsafe {
502        ffi::sqlite3_result_error(ctx, error.as_ptr() as *const _, len);
503    }
504}
505
506struct CollationUserPtr<F> {
507    callback: F,
508    collation_name: String,
509}
510
511#[allow(warnings)]
512extern "C" fn run_collation_function<F>(
513    user_ptr: *mut libc::c_void,
514    lhs_len: libc::c_int,
515    lhs_ptr: *const libc::c_void,
516    rhs_len: libc::c_int,
517    rhs_ptr: *const libc::c_void,
518) -> libc::c_int
519where
520    F: Fn(&str, &str) -> std::cmp::Ordering + Send + std::panic::UnwindSafe + 'static,
521{
522    let user_ptr = user_ptr as *const CollationUserPtr<F>;
523    let user_ptr = std::panic::AssertUnwindSafe(unsafe { user_ptr.as_ref() });
524
525    let result = std::panic::catch_unwind(|| {
526        let user_ptr = user_ptr.ok_or_else(|| {
527            SqliteCallbackError::Abort(
528                "Got a null pointer as data pointer. This should never happen",
529            )
530        })?;
531        for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
532            if *len < 0 {
533                {
    ::std::io::_eprint(format_args!("An unknown error occurred. {0}_len is negative. This should never happen.If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\nSource location: {1}:{2}\n",
            side, "diesel/src/sqlite/connection/raw.rs", 533u32));
};
std::process::abort();assert_fail!(
534                    "An unknown error occurred. {}_len is negative. This should never happen.",
535                    side
536                );
537            }
538            if ptr.is_null() {
539                {
    ::std::io::_eprint(format_args!("An unknown error occurred. {0}_ptr is a null pointer. This should never happen.If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\nSource location: {1}:{2}\n",
            side, "diesel/src/sqlite/connection/raw.rs", 539u32));
};
std::process::abort();assert_fail!(
540                "An unknown error occurred. {}_ptr is a null pointer. This should never happen.",
541                side
542            );
543            }
544        }
545
546        let (rhs, lhs) = unsafe {
547            // Depending on the eTextRep-parameter to sqlite3_create_collation_v2() the strings can
548            // have various encodings. register_collation_function() always selects SQLITE_UTF8, so the
549            // pointers point to valid UTF-8 strings (assuming correct behavior of libsqlite3).
550            (
551                str::from_utf8(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
552                str::from_utf8(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
553            )
554        };
555
556        let rhs =
557            rhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for rhs"))?;
558        let lhs =
559            lhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for lhs"))?;
560
561        Ok((user_ptr.callback)(rhs, lhs))
562    })
563    .unwrap_or_else(|p| {
564        Err(SqliteCallbackError::Panic(
565            user_ptr
566                .map(|u| u.collation_name.clone())
567                .unwrap_or_default(),
568        ))
569    });
570
571    match result {
572        Ok(std::cmp::Ordering::Less) => -1,
573        Ok(std::cmp::Ordering::Equal) => 0,
574        Ok(std::cmp::Ordering::Greater) => 1,
575        Err(SqliteCallbackError::Abort(a)) => {
576            {
    ::std::io::_eprint(format_args!("Collation function {0} failed with: {1}\n",
            user_ptr.map(|c| &c.collation_name as &str).unwrap_or_default(),
            a));
};eprintln!(
577                "Collation function {} failed with: {}",
578                user_ptr
579                    .map(|c| &c.collation_name as &str)
580                    .unwrap_or_default(),
581                a
582            );
583            std::process::abort()
584        }
585        Err(SqliteCallbackError::DieselError(e)) => {
586            {
    ::std::io::_eprint(format_args!("Collation function {0} failed with: {1}\n",
            user_ptr.map(|c| &c.collation_name as &str).unwrap_or_default(),
            e));
};eprintln!(
587                "Collation function {} failed with: {}",
588                user_ptr
589                    .map(|c| &c.collation_name as &str)
590                    .unwrap_or_default(),
591                e
592            );
593            std::process::abort()
594        }
595        Err(SqliteCallbackError::Panic(msg)) => {
596            {
    ::std::io::_eprint(format_args!("Collation function {0} panicked\n",
            msg));
};eprintln!("Collation function {} panicked", msg);
597            std::process::abort()
598        }
599    }
600}
601
602extern "C" fn destroy_boxed<F>(data: *mut libc::c_void) {
603    let ptr = data as *mut F;
604    unsafe { std::mem::drop(Box::from_raw(ptr)) };
605}