Skip to main content

diesel/sqlite/
auto_extension.rs

1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3#[cfg(all(target_family = "wasm", target_os = "unknown"))]
4use sqlite_wasm_rs as ffi;
5
6use crate::result::Error::DatabaseError;
7use crate::result::*;
8use crate::sqlite::SqliteConnection;
9use alloc::boxed::Box;
10use alloc::string::{String, ToString};
11use core::ffi::{c_char, c_int};
12
13/// The fn type `libsqlite3-sys` expects for `sqlite3_auto_extension`, used only
14/// to name the private trampoline below.
15type RawAutoExtension = unsafe extern "C" fn(
16    db: *mut ffi::sqlite3,
17    pz_err_msg: *mut *mut c_char,
18    p_api: *const ffi::sqlite3_api_routines,
19) -> c_int;
20
21/// Registers an auto-extension that runs for every SQLite connection opened in
22/// this process, including non-Diesel ones.
23///
24/// This is a safe wrapper around [`sqlite3_auto_extension`][docs]. The callback
25/// receives the [`SqliteConnection`] being opened and returns `Ok(())` to
26/// continue or an error to fail the open. Use it to register SQL functions,
27/// collations, or aggregates through the usual connection API, or to initialize
28/// a statically linked C extension such as Spatialite or sqlite-vec via
29/// [`SqliteConnection::with_raw_connection`].
30///
31/// Call this before opening any connection. Extensions run in registration
32/// order, and the first error aborts the open. The callback must be a `fn` item
33/// or a closure that captures only zero-sized values (enforced at compile
34/// time), and registering the same `fn` twice is a no-op. It may run on several
35/// threads at once and must not open another connection (which would re-enter
36/// the auto-extensions and recurse) or call [`register_auto_extension`],
37/// [`cancel_auto_extension`], or [`reset_auto_extension`]. Panics are caught and
38/// turned into a failed open.
39///
40/// [docs]: https://www.sqlite.org/c3ref/auto_extension.html
41///
42/// # Example
43///
44/// ```rust
45/// use diesel::dsl::sql;
46/// use diesel::prelude::*;
47/// use diesel::sql_types::Integer;
48/// use diesel::sqlite::{register_auto_extension, reset_auto_extension, SqliteConnection};
49///
50/// // Registers a case-insensitive collation on every new connection.
51/// fn my_ext(conn: &mut SqliteConnection) -> QueryResult<()> {
52///     conn.register_collation("RUSTNOCASE", |a, b| a.to_lowercase().cmp(&b.to_lowercase()))
53/// }
54///
55/// register_auto_extension(my_ext).unwrap();
56///
57/// // Every future connection now has the collation.
58/// let mut conn = SqliteConnection::establish(":memory:").unwrap();
59/// let equal: i32 = sql::<Integer>("SELECT 'a' = 'A' COLLATE RUSTNOCASE")
60///     .get_result(&mut conn)
61///     .unwrap();
62/// assert_eq!(equal, 1);
63/// # reset_auto_extension();
64/// ```
65#[allow(unsafe_code)]
66pub fn register_auto_extension<F>(extension: F) -> QueryResult<()>
67where
68    F: Fn(&mut SqliteConnection) -> QueryResult<()> + Sync + 'static,
69{
70    let result = unsafe { ffi::sqlite3_auto_extension(Some(entry_point(extension))) };
71    if result == ffi::SQLITE_OK {
72        Ok(())
73    } else {
74        Err(DatabaseError(
75            DatabaseErrorKind::Unknown,
76            Box::new(ffi::code_to_str(result).to_string()),
77        ))
78    }
79}
80
81/// Removes a previously registered auto-extension, returning `true` if it was
82/// found ([docs][cancel_docs]).
83///
84/// Pass the same `fn` item given to [`register_auto_extension`]. A closure
85/// cannot be cancelled this way, because its type cannot be named again. Use
86/// [`reset_auto_extension`] to clear everything instead.
87///
88/// [cancel_docs]: https://www.sqlite.org/c3ref/cancel_auto_extension.html
89#[allow(unsafe_code)]
90pub fn cancel_auto_extension<F>(extension: F) -> bool
91where
92    F: Fn(&mut SqliteConnection) -> QueryResult<()> + Sync + 'static,
93{
94    unsafe { ffi::sqlite3_cancel_auto_extension(Some(entry_point(extension))) != 0 }
95}
96
97/// Clears **all** registered auto-extensions ([docs][reset_docs]).
98///
99/// After this call, no auto-extensions will run for newly opened connections.
100///
101/// [reset_docs]: https://www.sqlite.org/c3ref/reset_auto_extension.html
102#[allow(unsafe_code)]
103pub fn reset_auto_extension() {
104    unsafe { ffi::sqlite3_reset_auto_extension() }
105}
106
107/// Returns the trampoline for `F`. `extension` is taken by value to infer `F`,
108/// then `forget`-ten so the zero-sized callback stays conceptually alive for the
109/// process and its destructor never runs.
110fn entry_point<F>(extension: F) -> RawAutoExtension
111where
112    F: Fn(&mut SqliteConnection) -> QueryResult<()> + Sync + 'static,
113{
114    core::mem::forget(extension);
115    trampoline::<F>
116}
117
118/// The C entry point handed to SQLite, monomorphized per callback type `F` so
119/// each distinct callback maps to a distinct, stable address. SQLite's
120/// pointer-based deduplication and [`cancel_auto_extension`] rely on that.
121#[allow(unsafe_code)]
122unsafe extern "C" fn trampoline<F>(
123    db: *mut ffi::sqlite3,
124    pz_err_msg: *mut *mut c_char,
125    _p_api: *const ffi::sqlite3_api_routines,
126) -> c_int
127where
128    F: Fn(&mut SqliteConnection) -> QueryResult<()> + Sync + 'static,
129{
130    const {
131        if !(core::mem::size_of::<F>() == 0) {
    {
        ::core::panicking::panic_fmt(format_args!("an auto-extension callback must not capture non-zero-sized state. Use a `fn` item or a closure that captures only zero-sized values"));
    }
};assert!(
132            core::mem::size_of::<F>() == 0,
133            "an auto-extension callback must not capture non-zero-sized state. \
134             Use a `fn` item or a closure that captures only zero-sized values"
135        );
136    }
137
138    // `_p_api` matters only for runtime-loaded shared libraries. Statically
139    // linked extensions link the SQLite symbols directly, so we ignore it.
140    let result: Result<(), String> =
141        crate::util::std_compat::catch_unwind(core::panic::AssertUnwindSafe(|| {
142            let Some(db) = core::ptr::NonNull::new(db) else {
143                return Err(String::from(
144                    "auto-extension received a null database handle",
145                ));
146            };
147            // Reconstruct a *reference* to the zero-sized callback, never an
148            // owned value, so its destructor never runs. `&F: Fn` because `F: Fn`.
149            // SAFETY: `F` is zero-sized (asserted above), so a dangling, aligned,
150            // non-null pointer is a valid `&F`, which `NonNull::dangling` provides.
151            let extension: &F = unsafe { core::ptr::NonNull::<F>::dangling().as_ref() };
152            // SAFETY: `db` is a valid handle for the duration of this call, and
153            // the borrowed connection does not take ownership of it.
154            unsafe { SqliteConnection::with_borrowed_connection(db, extension) }
155                .map_err(|e| e.to_string())
156        }))
157        .unwrap_or_else(|panic| {
158            Err(match panic_detail(panic) {
159                Some(message) => ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("auto-extension panicked: {0}",
                message))
    })alloc::format!("auto-extension panicked: {message}"),
160                None => String::from("auto-extension panicked"),
161            })
162        });
163
164    match result {
165        Ok(()) => ffi::SQLITE_OK,
166        Err(message) => {
167            set_error_message(pz_err_msg, &message);
168            ffi::SQLITE_ERROR
169        }
170    }
171}
172
173/// Best-effort message from a caught panic payload. The no_std `catch_unwind`
174/// carries no payload, so only the `std` variant can recover the text.
175#[cfg(feature = "std")]
176fn panic_detail(panic: alloc::boxed::Box<dyn core::any::Any + Send>) -> Option<String> {
177    panic
178        .downcast_ref::<&str>()
179        .map(|s| (*s).to_owned())
180        .or_else(|| panic.downcast_ref::<String>().cloned())
181}
182
183#[cfg(not(feature = "std"))]
184fn panic_detail(_panic: ()) -> Option<String> {
185    None
186}
187
188/// Writes `message` into `*pz_err_msg` with `sqlite3_malloc`, which is the
189/// allocator SQLite later frees it with. The message is truncated at the first
190/// NUL byte to form a valid C string, and allocation failure is ignored.
191#[allow(unsafe_code)]
192fn set_error_message(pz_err_msg: *mut *mut c_char, message: &str) {
193    if pz_err_msg.is_null() {
194        return;
195    }
196
197    let bytes = message.as_bytes();
198    let len = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());
199
200    // SQLite sizes allocations with a C `int`. A message that does not fit is
201    // dropped rather than truncated to a bogus length.
202    let Ok(size) = c_int::try_from(len + 1) else {
203        return;
204    };
205    let buffer = unsafe { ffi::sqlite3_malloc(size) } as *mut u8;
206    if buffer.is_null() {
207        return;
208    }
209
210    unsafe {
211        core::ptr::copy_nonoverlapping(bytes.as_ptr(), buffer, len);
212        *buffer.add(len) = 0;
213        *pz_err_msg = buffer as *mut c_char;
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::dsl::sql;
221    use crate::prelude::*;
222    use crate::sql_types::Integer;
223    use std::sync::Mutex;
224
225    // `sqlite3_auto_extension` is process-global, so these tests serialize on
226    // this lock and register only benign (never-failing) extensions, leaving
227    // connections opened by other tests unaffected. The failing path is covered
228    // by `trampoline_maps_result_to_return_code`, which calls the trampoline
229    // directly without touching the global registry.
230    static AUTO_EXT_TEST_LOCK: Mutex<()> = Mutex::new(());
231
232    // A benign auto-extension: registers a `TESTCOLL` collation through the
233    // normal connection API.
234    fn test_ext_init(conn: &mut SqliteConnection) -> QueryResult<()> {
235        conn.register_collation("TESTCOLL", |a, b| a.cmp(b))
236    }
237
238    fn open_memory_connection() -> SqliteConnection {
239        SqliteConnection::establish(":memory:").expect("Failed to open :memory: connection")
240    }
241
242    // Errors out if `TESTCOLL` is not registered on a freshly opened connection.
243    fn probe_collation() -> QueryResult<i32> {
244        let mut conn = open_memory_connection();
245        sql::<Integer>("SELECT 'a' = 'a' COLLATE TESTCOLL").get_result(&mut conn)
246    }
247
248    /// RAII guard that calls `reset_auto_extension()` on drop, ensuring global
249    /// state is cleaned up even if a test panics.
250    struct TestResetGuard;
251
252    impl Drop for TestResetGuard {
253        fn drop(&mut self) {
254            reset_auto_extension();
255        }
256    }
257
258    #[test]
259    fn auto_extension_lifecycle() {
260        let _lock = AUTO_EXT_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
261        let _guard = TestResetGuard;
262        reset_auto_extension();
263
264        // -- 1. register + new connection has the collation --
265        register_auto_extension(test_ext_init).unwrap();
266        assert_eq!(probe_collation().unwrap(), 1);
267
268        // -- 2. cancel + new connection does NOT have the collation --
269        let removed = cancel_auto_extension(test_ext_init);
270        assert!(
271            removed,
272            "cancel should return true for registered extension"
273        );
274        assert!(
275            probe_collation().is_err(),
276            "collation should not be available after cancel"
277        );
278
279        // -- 3. cancel returns false for unregistered --
280        let removed = cancel_auto_extension(test_ext_init);
281        assert!(
282            !removed,
283            "cancel should return false for unregistered extension"
284        );
285
286        // -- 4. reset clears all --
287        register_auto_extension(test_ext_init).unwrap();
288        reset_auto_extension();
289        assert!(
290            probe_collation().is_err(),
291            "collation should not be available after reset"
292        );
293
294        // -- 5. duplicate registration is idempotent --
295        register_auto_extension(test_ext_init).unwrap();
296        register_auto_extension(test_ext_init).unwrap();
297        assert_eq!(probe_collation().unwrap(), 1);
298        // _guard drops here, ensuring reset even on panic.
299    }
300
301    // Drives the trampoline directly (via `entry_point`), without registering
302    // it in SQLite's global list, so the Ok/Err/null paths can be checked
303    // deterministically without affecting connections opened by other tests.
304    #[test]
305    #[allow(unsafe_code)]
306    fn trampoline_maps_result_to_return_code() {
307        fn ok_ext(_conn: &mut SqliteConnection) -> QueryResult<()> {
308            Ok(())
309        }
310        fn err_ext(_conn: &mut SqliteConnection) -> QueryResult<()> {
311            Err(Error::QueryBuilderError("boom".into()))
312        }
313
314        let ok_tramp = entry_point(ok_ext);
315        let err_tramp = entry_point(err_ext);
316
317        let mut conn = open_memory_connection();
318        // SAFETY: the pointer is only used for the duration of the closure,
319        // while `conn` is alive.
320        unsafe {
321            conn.with_raw_connection(|db| {
322                let mut err: *mut c_char = core::ptr::null_mut();
323
324                // Ok -> SQLITE_OK, no error message allocated.
325                let rc = ok_tramp(db, &mut err, core::ptr::null());
326                assert_eq!(rc, ffi::SQLITE_OK);
327                assert!(err.is_null());
328
329                // Err -> SQLITE_ERROR, message written via sqlite3_malloc.
330                let rc = err_tramp(db, &mut err, core::ptr::null());
331                assert_eq!(rc, ffi::SQLITE_ERROR);
332                assert!(!err.is_null());
333                let message = core::ffi::CStr::from_ptr(err)
334                    .to_string_lossy()
335                    .into_owned();
336                assert_eq!(message, "boom");
337                ffi::sqlite3_free(err as *mut core::ffi::c_void);
338
339                // Null db handle -> SQLITE_ERROR, never dereferenced.
340                let mut err: *mut c_char = core::ptr::null_mut();
341                let rc = ok_tramp(core::ptr::null_mut(), &mut err, core::ptr::null());
342                assert_eq!(rc, ffi::SQLITE_ERROR);
343                if !err.is_null() {
344                    ffi::sqlite3_free(err as *mut core::ffi::c_void);
345                }
346            })
347        }
348    }
349
350    // Regression test for the closure-reconstruction soundness fix. The callback
351    // captures a zero-sized guard with a load-bearing `Drop` that must never run,
352    // because the trampoline only reproduces the callback behind a reference and
353    // `forget`s the registered value (the old `mem::zeroed()` ran it repeatedly).
354    #[test]
355    fn callback_zero_sized_capture_is_never_dropped() {
356        use std::sync::atomic::{AtomicUsize, Ordering};
357
358        static DROPS: AtomicUsize = AtomicUsize::new(0);
359
360        // Zero-sized, `Sync`, with a load-bearing `Drop`.
361        struct Guard;
362        impl Drop for Guard {
363            fn drop(&mut self) {
364                DROPS.fetch_add(1, Ordering::SeqCst);
365            }
366        }
367
368        let _lock = AUTO_EXT_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
369        let _reset = TestResetGuard;
370        reset_auto_extension();
371
372        let guard = Guard;
373        // `move` captures the zero-sized `guard` by value, so the closure is a
374        // zero-sized type with a non-trivial destructor.
375        register_auto_extension(move |conn: &mut SqliteConnection| {
376            let _ = &guard;
377            conn.register_collation("TESTCOLL", |a, b| a.cmp(b))
378        })
379        .unwrap();
380
381        for _ in 0..5 {
382            assert_eq!(probe_collation().unwrap(), 1);
383        }
384        reset_auto_extension();
385
386        assert_eq!(
387            DROPS.load(Ordering::SeqCst),
388            0,
389            "the captured guard's destructor must never run"
390        );
391    }
392
393    // A panicking callback becomes `SQLITE_ERROR` with the payload recovered
394    // into the message (`&str` and `String` payloads, generic fallback
395    // otherwise), and drives the panic-unwind path through the drop guard. Gated
396    // off WASM, where `catch_unwind` aborts because `panic = "abort"`.
397    #[test]
398    #[allow(unsafe_code)]
399    #[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
400    fn trampoline_reports_panic_message() {
401        fn panic_str(_conn: &mut SqliteConnection) -> QueryResult<()> {
402            panic!("boom-str");
403        }
404        fn panic_string(_conn: &mut SqliteConnection) -> QueryResult<()> {
405            panic!("boom-{}", 7);
406        }
407        fn panic_other(_conn: &mut SqliteConnection) -> QueryResult<()> {
408            std::panic::panic_any(7_u8);
409        }
410
411        let cases: [(RawAutoExtension, &str); 3] = [
412            (entry_point(panic_str), "auto-extension panicked: boom-str"),
413            (entry_point(panic_string), "auto-extension panicked: boom-7"),
414            (entry_point(panic_other), "auto-extension panicked"),
415        ];
416
417        let mut conn = open_memory_connection();
418        // SAFETY: the pointer is only used while `conn` is alive.
419        unsafe {
420            conn.with_raw_connection(|db| {
421                for (tramp, expected) in cases {
422                    let mut err: *mut c_char = core::ptr::null_mut();
423                    let rc = tramp(db, &mut err, core::ptr::null());
424                    assert_eq!(rc, ffi::SQLITE_ERROR);
425                    assert!(!err.is_null());
426                    let message = core::ffi::CStr::from_ptr(err)
427                        .to_string_lossy()
428                        .into_owned();
429                    assert_eq!(message, expected);
430                    ffi::sqlite3_free(err as *mut core::ffi::c_void);
431                }
432            })
433        }
434    }
435
436    #[test]
437    #[allow(unsafe_code)]
438    fn error_message_truncates_at_interior_nul() {
439        let mut err: *mut c_char = core::ptr::null_mut();
440        set_error_message(&mut err, "before\0after");
441        assert!(!err.is_null());
442        // SAFETY: `err` is a sqlite-allocated C string we own until we free it.
443        unsafe {
444            let truncated = core::ffi::CStr::from_ptr(err).to_str().unwrap();
445            assert_eq!(truncated, "before");
446            ffi::sqlite3_free(err as *mut core::ffi::c_void);
447        }
448
449        // A null out-pointer is a no-op (must not write through it or crash).
450        set_error_message(core::ptr::null_mut(), "ignored");
451    }
452}