1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3#[cfg(all(target_family = "wasm", target_os = "unknown"))]
4use sqlite_wasm_rs as ffi;
56use crate::result::Error::DatabaseError;
7use crate::result::*;
8use crate::sqlite::SqliteConnection;
9use alloc::boxed::Box;
10use alloc::string::{String, ToString};
11use core::ffi::{c_char, c_int};
1213/// The fn type `libsqlite3-sys` expects for `sqlite3_auto_extension`, used only
14/// to name the private trampoline below.
15type RawAutoExtension = unsafe extern "C" fn(
16 db: *mut ffi::sqlite3,
17 pz_err_msg: *mut *mut c_char,
18 p_api: *const ffi::sqlite3_api_routines,
19) -> c_int;
2021/// Registers an auto-extension that runs for every SQLite connection opened in
22/// this process, including non-Diesel ones.
23///
24/// This is a safe wrapper around [`sqlite3_auto_extension`][docs]. The callback
25/// receives the [`SqliteConnection`] being opened and returns `Ok(())` to
26/// continue or an error to fail the open. Use it to register SQL functions,
27/// collations, or aggregates through the usual connection API, or to initialize
28/// a statically linked C extension such as Spatialite or sqlite-vec via
29/// [`SqliteConnection::with_raw_connection`].
30///
31/// Call this before opening any connection. Extensions run in registration
32/// order, and the first error aborts the open. The callback must be a `fn` item
33/// or a closure that captures only zero-sized values (enforced at compile
34/// time), and registering the same `fn` twice is a no-op. It may run on several
35/// threads at once and must not open another connection (which would re-enter
36/// the auto-extensions and recurse) or call [`register_auto_extension`],
37/// [`cancel_auto_extension`], or [`reset_auto_extension`]. Panics are caught and
38/// turned into a failed open.
39///
40/// [docs]: https://www.sqlite.org/c3ref/auto_extension.html
41///
42/// # Example
43///
44/// ```rust
45/// use diesel::dsl::sql;
46/// use diesel::prelude::*;
47/// use diesel::sql_types::Integer;
48/// use diesel::sqlite::{register_auto_extension, reset_auto_extension, SqliteConnection};
49///
50/// // Registers a case-insensitive collation on every new connection.
51/// fn my_ext(conn: &mut SqliteConnection) -> QueryResult<()> {
52/// conn.register_collation("RUSTNOCASE", |a, b| a.to_lowercase().cmp(&b.to_lowercase()))
53/// }
54///
55/// register_auto_extension(my_ext).unwrap();
56///
57/// // Every future connection now has the collation.
58/// let mut conn = SqliteConnection::establish(":memory:").unwrap();
59/// let equal: i32 = sql::<Integer>("SELECT 'a' = 'A' COLLATE RUSTNOCASE")
60/// .get_result(&mut conn)
61/// .unwrap();
62/// assert_eq!(equal, 1);
63/// # reset_auto_extension();
64/// ```
65#[allow(unsafe_code)]
66pub fn register_auto_extension<F>(extension: F) -> QueryResult<()>
67where
68F: Fn(&mut SqliteConnection) -> QueryResult<()> + Sync + 'static,
69{
70let result = unsafe { ffi::sqlite3_auto_extension(Some(entry_point(extension))) };
71if result == ffi::SQLITE_OK {
72Ok(())
73 } else {
74Err(DatabaseError(
75 DatabaseErrorKind::Unknown,
76Box::new(ffi::code_to_str(result).to_string()),
77 ))
78 }
79}
8081/// Removes a previously registered auto-extension, returning `true` if it was
82/// found ([docs][cancel_docs]).
83///
84/// Pass the same `fn` item given to [`register_auto_extension`]. A closure
85/// cannot be cancelled this way, because its type cannot be named again. Use
86/// [`reset_auto_extension`] to clear everything instead.
87///
88/// [cancel_docs]: https://www.sqlite.org/c3ref/cancel_auto_extension.html
89#[allow(unsafe_code)]
90pub fn cancel_auto_extension<F>(extension: F) -> bool91where
92F: Fn(&mut SqliteConnection) -> QueryResult<()> + Sync + 'static,
93{
94unsafe { ffi::sqlite3_cancel_auto_extension(Some(entry_point(extension))) != 0 }
95}
9697/// Clears **all** registered auto-extensions ([docs][reset_docs]).
98///
99/// After this call, no auto-extensions will run for newly opened connections.
100///
101/// [reset_docs]: https://www.sqlite.org/c3ref/reset_auto_extension.html
102#[allow(unsafe_code)]
103pub fn reset_auto_extension() {
104unsafe { ffi::sqlite3_reset_auto_extension() }
105}
106107/// Returns the trampoline for `F`. `extension` is taken by value to infer `F`,
108/// then `forget`-ten so the zero-sized callback stays conceptually alive for the
109/// process and its destructor never runs.
110fn entry_point<F>(extension: F) -> RawAutoExtension111where
112F: Fn(&mut SqliteConnection) -> QueryResult<()> + Sync + 'static,
113{
114 core::mem::forget(extension);
115trampoline::<F>
116}
117118/// The C entry point handed to SQLite, monomorphized per callback type `F` so
119/// each distinct callback maps to a distinct, stable address. SQLite's
120/// pointer-based deduplication and [`cancel_auto_extension`] rely on that.
121#[allow(unsafe_code)]
122unsafe extern "C" fn trampoline<F>(
123 db: *mut ffi::sqlite3,
124 pz_err_msg: *mut *mut c_char,
125 _p_api: *const ffi::sqlite3_api_routines,
126) -> c_int127where
128F: Fn(&mut SqliteConnection) -> QueryResult<()> + Sync + 'static,
129{
130const {
131if !(core::mem::size_of::<F>() == 0) {
{
::core::panicking::panic_fmt(format_args!("an auto-extension callback must not capture non-zero-sized state. Use a `fn` item or a closure that captures only zero-sized values"));
}
};assert!(
132 core::mem::size_of::<F>() == 0,
133"an auto-extension callback must not capture non-zero-sized state. \
134 Use a `fn` item or a closure that captures only zero-sized values"
135);
136 }
137138// `_p_api` matters only for runtime-loaded shared libraries. Statically
139 // linked extensions link the SQLite symbols directly, so we ignore it.
140let result: Result<(), String> =
141crate::util::std_compat::catch_unwind(core::panic::AssertUnwindSafe(|| {
142let Some(db) = core::ptr::NonNull::new(db) else {
143return Err(String::from(
144"auto-extension received a null database handle",
145 ));
146 };
147// Reconstruct a *reference* to the zero-sized callback, never an
148 // owned value, so its destructor never runs. `&F: Fn` because `F: Fn`.
149 // SAFETY: `F` is zero-sized (asserted above), so a dangling, aligned,
150 // non-null pointer is a valid `&F`, which `NonNull::dangling` provides.
151let extension: &F = unsafe { core::ptr::NonNull::<F>::dangling().as_ref() };
152// SAFETY: `db` is a valid handle for the duration of this call, and
153 // the borrowed connection does not take ownership of it.
154unsafe { SqliteConnection::with_borrowed_connection(db, extension) }
155 .map_err(|e| e.to_string())
156 }))
157 .unwrap_or_else(|panic| {
158Err(match panic_detail(panic) {
159Some(message) => ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("auto-extension panicked: {0}",
message))
})alloc::format!("auto-extension panicked: {message}"),
160None => String::from("auto-extension panicked"),
161 })
162 });
163164match result {
165Ok(()) => ffi::SQLITE_OK,
166Err(message) => {
167set_error_message(pz_err_msg, &message);
168 ffi::SQLITE_ERROR169 }
170 }
171}
172173/// Best-effort message from a caught panic payload. The no_std `catch_unwind`
174/// carries no payload, so only the `std` variant can recover the text.
175#[cfg(feature = "std")]
176fn panic_detail(panic: alloc::boxed::Box<dyn core::any::Any + Send>) -> Option<String> {
177panic178 .downcast_ref::<&str>()
179 .map(|s| (*s).to_owned())
180 .or_else(|| panic.downcast_ref::<String>().cloned())
181}
182183#[cfg(not(feature = "std"))]
184fn panic_detail(_panic: ()) -> Option<String> {
185None
186}
187188/// Writes `message` into `*pz_err_msg` with `sqlite3_malloc`, which is the
189/// allocator SQLite later frees it with. The message is truncated at the first
190/// NUL byte to form a valid C string, and allocation failure is ignored.
191#[allow(unsafe_code)]
192fn set_error_message(pz_err_msg: *mut *mut c_char, message: &str) {
193if pz_err_msg.is_null() {
194return;
195 }
196197let bytes = message.as_bytes();
198let len = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());
199200// SQLite sizes allocations with a C `int`. A message that does not fit is
201 // dropped rather than truncated to a bogus length.
202let Ok(size) = c_int::try_from(len + 1) else {
203return;
204 };
205let buffer = unsafe { ffi::sqlite3_malloc(size) } as *mut u8;
206if buffer.is_null() {
207return;
208 }
209210unsafe {
211 core::ptr::copy_nonoverlapping(bytes.as_ptr(), buffer, len);
212*buffer.add(len) = 0;
213*pz_err_msg = bufferas *mut c_char;
214 }
215}
216217#[cfg(test)]
218mod tests {
219use super::*;
220use crate::dsl::sql;
221use crate::prelude::*;
222use crate::sql_types::Integer;
223use std::sync::Mutex;
224225// `sqlite3_auto_extension` is process-global, so these tests serialize on
226 // this lock and register only benign (never-failing) extensions, leaving
227 // connections opened by other tests unaffected. The failing path is covered
228 // by `trampoline_maps_result_to_return_code`, which calls the trampoline
229 // directly without touching the global registry.
230static AUTO_EXT_TEST_LOCK: Mutex<()> = Mutex::new(());
231232// A benign auto-extension: registers a `TESTCOLL` collation through the
233 // normal connection API.
234fn test_ext_init(conn: &mut SqliteConnection) -> QueryResult<()> {
235 conn.register_collation("TESTCOLL", |a, b| a.cmp(b))
236 }
237238fn open_memory_connection() -> SqliteConnection {
239 SqliteConnection::establish(":memory:").expect("Failed to open :memory: connection")
240 }
241242// Errors out if `TESTCOLL` is not registered on a freshly opened connection.
243fn probe_collation() -> QueryResult<i32> {
244let mut conn = open_memory_connection();
245 sql::<Integer>("SELECT 'a' = 'a' COLLATE TESTCOLL").get_result(&mut conn)
246 }
247248/// RAII guard that calls `reset_auto_extension()` on drop, ensuring global
249 /// state is cleaned up even if a test panics.
250struct TestResetGuard;
251252impl Drop for TestResetGuard {
253fn drop(&mut self) {
254 reset_auto_extension();
255 }
256 }
257258#[test]
259fn auto_extension_lifecycle() {
260let _lock = AUTO_EXT_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
261let _guard = TestResetGuard;
262 reset_auto_extension();
263264// -- 1. register + new connection has the collation --
265register_auto_extension(test_ext_init).unwrap();
266assert_eq!(probe_collation().unwrap(), 1);
267268// -- 2. cancel + new connection does NOT have the collation --
269let removed = cancel_auto_extension(test_ext_init);
270assert!(
271 removed,
272"cancel should return true for registered extension"
273);
274assert!(
275 probe_collation().is_err(),
276"collation should not be available after cancel"
277);
278279// -- 3. cancel returns false for unregistered --
280let removed = cancel_auto_extension(test_ext_init);
281assert!(
282 !removed,
283"cancel should return false for unregistered extension"
284);
285286// -- 4. reset clears all --
287register_auto_extension(test_ext_init).unwrap();
288 reset_auto_extension();
289assert!(
290 probe_collation().is_err(),
291"collation should not be available after reset"
292);
293294// -- 5. duplicate registration is idempotent --
295register_auto_extension(test_ext_init).unwrap();
296 register_auto_extension(test_ext_init).unwrap();
297assert_eq!(probe_collation().unwrap(), 1);
298// _guard drops here, ensuring reset even on panic.
299}
300301// Drives the trampoline directly (via `entry_point`), without registering
302 // it in SQLite's global list, so the Ok/Err/null paths can be checked
303 // deterministically without affecting connections opened by other tests.
304#[test]
305 #[allow(unsafe_code)]
306fn trampoline_maps_result_to_return_code() {
307fn ok_ext(_conn: &mut SqliteConnection) -> QueryResult<()> {
308Ok(())
309 }
310fn err_ext(_conn: &mut SqliteConnection) -> QueryResult<()> {
311Err(Error::QueryBuilderError("boom".into()))
312 }
313314let ok_tramp = entry_point(ok_ext);
315let err_tramp = entry_point(err_ext);
316317let mut conn = open_memory_connection();
318// SAFETY: the pointer is only used for the duration of the closure,
319 // while `conn` is alive.
320unsafe {
321 conn.with_raw_connection(|db| {
322let mut err: *mut c_char = core::ptr::null_mut();
323324// Ok -> SQLITE_OK, no error message allocated.
325let rc = ok_tramp(db, &mut err, core::ptr::null());
326assert_eq!(rc, ffi::SQLITE_OK);
327assert!(err.is_null());
328329// Err -> SQLITE_ERROR, message written via sqlite3_malloc.
330let rc = err_tramp(db, &mut err, core::ptr::null());
331assert_eq!(rc, ffi::SQLITE_ERROR);
332assert!(!err.is_null());
333let message = core::ffi::CStr::from_ptr(err)
334 .to_string_lossy()
335 .into_owned();
336assert_eq!(message, "boom");
337 ffi::sqlite3_free(err as *mut core::ffi::c_void);
338339// Null db handle -> SQLITE_ERROR, never dereferenced.
340let mut err: *mut c_char = core::ptr::null_mut();
341let rc = ok_tramp(core::ptr::null_mut(), &mut err, core::ptr::null());
342assert_eq!(rc, ffi::SQLITE_ERROR);
343if !err.is_null() {
344 ffi::sqlite3_free(err as *mut core::ffi::c_void);
345 }
346 })
347 }
348 }
349350// Regression test for the closure-reconstruction soundness fix. The callback
351 // captures a zero-sized guard with a load-bearing `Drop` that must never run,
352 // because the trampoline only reproduces the callback behind a reference and
353 // `forget`s the registered value (the old `mem::zeroed()` ran it repeatedly).
354#[test]
355fn callback_zero_sized_capture_is_never_dropped() {
356use std::sync::atomic::{AtomicUsize, Ordering};
357358static DROPS: AtomicUsize = AtomicUsize::new(0);
359360// Zero-sized, `Sync`, with a load-bearing `Drop`.
361struct Guard;
362impl Drop for Guard {
363fn drop(&mut self) {
364 DROPS.fetch_add(1, Ordering::SeqCst);
365 }
366 }
367368let _lock = AUTO_EXT_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
369let _reset = TestResetGuard;
370 reset_auto_extension();
371372let guard = Guard;
373// `move` captures the zero-sized `guard` by value, so the closure is a
374 // zero-sized type with a non-trivial destructor.
375register_auto_extension(move |conn: &mut SqliteConnection| {
376let _ = &guard;
377 conn.register_collation("TESTCOLL", |a, b| a.cmp(b))
378 })
379 .unwrap();
380381for _ in 0..5 {
382assert_eq!(probe_collation().unwrap(), 1);
383 }
384 reset_auto_extension();
385386assert_eq!(
387 DROPS.load(Ordering::SeqCst),
3880,
389"the captured guard's destructor must never run"
390);
391 }
392393// A panicking callback becomes `SQLITE_ERROR` with the payload recovered
394 // into the message (`&str` and `String` payloads, generic fallback
395 // otherwise), and drives the panic-unwind path through the drop guard. Gated
396 // off WASM, where `catch_unwind` aborts because `panic = "abort"`.
397#[test]
398 #[allow(unsafe_code)]
399 #[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
400fn trampoline_reports_panic_message() {
401fn panic_str(_conn: &mut SqliteConnection) -> QueryResult<()> {
402panic!("boom-str");
403 }
404fn panic_string(_conn: &mut SqliteConnection) -> QueryResult<()> {
405panic!("boom-{}", 7);
406 }
407fn panic_other(_conn: &mut SqliteConnection) -> QueryResult<()> {
408 std::panic::panic_any(7_u8);
409 }
410411let cases: [(RawAutoExtension, &str); 3] = [
412 (entry_point(panic_str), "auto-extension panicked: boom-str"),
413 (entry_point(panic_string), "auto-extension panicked: boom-7"),
414 (entry_point(panic_other), "auto-extension panicked"),
415 ];
416417let mut conn = open_memory_connection();
418// SAFETY: the pointer is only used while `conn` is alive.
419unsafe {
420 conn.with_raw_connection(|db| {
421for (tramp, expected) in cases {
422let mut err: *mut c_char = core::ptr::null_mut();
423let rc = tramp(db, &mut err, core::ptr::null());
424assert_eq!(rc, ffi::SQLITE_ERROR);
425assert!(!err.is_null());
426let message = core::ffi::CStr::from_ptr(err)
427 .to_string_lossy()
428 .into_owned();
429assert_eq!(message, expected);
430 ffi::sqlite3_free(err as *mut core::ffi::c_void);
431 }
432 })
433 }
434 }
435436#[test]
437 #[allow(unsafe_code)]
438fn error_message_truncates_at_interior_nul() {
439let mut err: *mut c_char = core::ptr::null_mut();
440 set_error_message(&mut err, "before\0after");
441assert!(!err.is_null());
442// SAFETY: `err` is a sqlite-allocated C string we own until we free it.
443unsafe {
444let truncated = core::ffi::CStr::from_ptr(err).to_str().unwrap();
445assert_eq!(truncated, "before");
446 ffi::sqlite3_free(err as *mut core::ffi::c_void);
447 }
448449// A null out-pointer is a no-op (must not write through it or crash).
450set_error_message(core::ptr::null_mut(), "ignored");
451 }
452}