1#![allow(unsafe_code)] #[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
3extern crate libsqlite3_sys as ffi;
4
5#[cfg(all(target_family = "wasm", target_os = "unknown"))]
6use sqlite_wasm_rs::export as ffi;
7
8use std::ffi::{CString, NulError};
9use std::io::{stderr, Write};
10use std::os::raw as libc;
11use std::ptr::NonNull;
12use std::{mem, ptr, slice, str};
13
14use super::functions::{build_sql_function_args, process_sql_function_result};
15use super::serialized_database::SerializedDatabase;
16use super::stmt::ensure_sqlite_ok;
17use super::{Sqlite, SqliteAggregateFunction};
18use crate::deserialize::FromSqlRow;
19use crate::result::Error::DatabaseError;
20use crate::result::*;
21use crate::serialize::ToSql;
22use crate::sql_types::HasSqlType;
23
24macro_rules! assert_fail {
27 ($fmt:expr $(,$args:tt)*) => {
28 eprint!(concat!(
29 $fmt,
30 "If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\n",
31 "Source location: {}:{}\n",
32 ), $($args,)* file!(), line!());
33 std::process::abort()
34 };
35}
36
37#[allow(missing_debug_implementations, missing_copy_implementations)]
38pub(super) struct RawConnection {
39 pub(super) internal_connection: NonNull<ffi::sqlite3>,
40}
41
42impl RawConnection {
43 pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
44 let mut conn_pointer = ptr::null_mut();
45
46 let database_url = if database_url.starts_with("sqlite://") {
47 CString::new(database_url.replacen("sqlite://", "file:", 1))?
48 } else {
49 CString::new(database_url)?
50 };
51 let flags = ffi::SQLITE_OPEN_READWRITE | ffi::SQLITE_OPEN_CREATE | ffi::SQLITE_OPEN_URI;
52 let connection_status = unsafe {
53 ffi::sqlite3_open_v2(database_url.as_ptr(), &mut conn_pointer, flags, ptr::null())
54 };
55
56 match connection_status {
57 ffi::SQLITE_OK => {
58 let conn_pointer = unsafe { NonNull::new_unchecked(conn_pointer) };
59 Ok(RawConnection {
60 internal_connection: conn_pointer,
61 })
62 }
63 err_code => {
64 let message = super::error_message(err_code);
65 Err(ConnectionError::BadConnection(message.into()))
66 }
67 }
68 }
69
70 pub(super) fn exec(&self, query: &str) -> QueryResult<()> {
71 let query = CString::new(query)?;
72 let callback_fn = None;
73 let callback_arg = ptr::null_mut();
74 let result = unsafe {
75 ffi::sqlite3_exec(
76 self.internal_connection.as_ptr(),
77 query.as_ptr(),
78 callback_fn,
79 callback_arg,
80 ptr::null_mut(),
81 )
82 };
83
84 ensure_sqlite_ok(result, self.internal_connection.as_ptr())
85 }
86
87 pub(super) fn rows_affected_by_last_query(
88 &self,
89 ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
90 let r = unsafe { ffi::sqlite3_changes(self.internal_connection.as_ptr()) };
91
92 Ok(r.try_into()?)
93 }
94
95 pub(super) fn register_sql_function<F, Ret, RetSqlType>(
96 &self,
97 fn_name: &str,
98 num_args: usize,
99 deterministic: bool,
100 f: F,
101 ) -> QueryResult<()>
102 where
103 F: FnMut(&Self, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
104 + std::panic::UnwindSafe
105 + Send
106 + 'static,
107 Ret: ToSql<RetSqlType, Sqlite>,
108 Sqlite: HasSqlType<RetSqlType>,
109 {
110 let callback_fn = Box::into_raw(Box::new(CustomFunctionUserPtr {
111 callback: f,
112 function_name: fn_name.to_owned(),
113 }));
114 let fn_name = Self::get_fn_name(fn_name)?;
115 let flags = Self::get_flags(deterministic);
116 let num_args = num_args
117 .try_into()
118 .map_err(|e| Error::SerializationError(Box::new(e)))?;
119
120 let result = unsafe {
121 ffi::sqlite3_create_function_v2(
122 self.internal_connection.as_ptr(),
123 fn_name.as_ptr(),
124 num_args,
125 flags,
126 callback_fn as *mut _,
127 Some(run_custom_function::<F, Ret, RetSqlType>),
128 None,
129 None,
130 Some(destroy_boxed::<CustomFunctionUserPtr<F>>),
131 )
132 };
133
134 Self::process_sql_function_result(result)
135 }
136
137 pub(super) fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
138 &self,
139 fn_name: &str,
140 num_args: usize,
141 ) -> QueryResult<()>
142 where
143 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
144 Args: FromSqlRow<ArgsSqlType, Sqlite>,
145 Ret: ToSql<RetSqlType, Sqlite>,
146 Sqlite: HasSqlType<RetSqlType>,
147 {
148 let fn_name = Self::get_fn_name(fn_name)?;
149 let flags = Self::get_flags(false);
150 let num_args = num_args
151 .try_into()
152 .map_err(|e| Error::SerializationError(Box::new(e)))?;
153
154 let result = unsafe {
155 ffi::sqlite3_create_function_v2(
156 self.internal_connection.as_ptr(),
157 fn_name.as_ptr(),
158 num_args,
159 flags,
160 ptr::null_mut(),
161 None,
162 Some(run_aggregator_step_function::<_, _, _, _, A>),
163 Some(run_aggregator_final_function::<_, _, _, _, A>),
164 None,
165 )
166 };
167
168 Self::process_sql_function_result(result)
169 }
170
171 pub(super) fn register_collation_function<F>(
172 &self,
173 collation_name: &str,
174 collation: F,
175 ) -> QueryResult<()>
176 where
177 F: Fn(&str, &str) -> std::cmp::Ordering + std::panic::UnwindSafe + Send + 'static,
178 {
179 let callback_fn = Box::into_raw(Box::new(CollationUserPtr {
180 callback: collation,
181 collation_name: collation_name.to_owned(),
182 }));
183 let collation_name = Self::get_fn_name(collation_name)?;
184
185 let result = unsafe {
186 ffi::sqlite3_create_collation_v2(
187 self.internal_connection.as_ptr(),
188 collation_name.as_ptr(),
189 ffi::SQLITE_UTF8,
190 callback_fn as *mut _,
191 Some(run_collation_function::<F>),
192 Some(destroy_boxed::<CollationUserPtr<F>>),
193 )
194 };
195
196 let result = Self::process_sql_function_result(result);
197 if result.is_err() {
198 destroy_boxed::<CollationUserPtr<F>>(callback_fn as *mut _);
199 }
200 result
201 }
202
203 pub(super) fn serialize(&mut self) -> SerializedDatabase {
204 unsafe {
205 let mut size: ffi::sqlite3_int64 = 0;
206 let data_ptr = ffi::sqlite3_serialize(
207 self.internal_connection.as_ptr(),
208 std::ptr::null(),
209 &mut size as *mut _,
210 0,
211 );
212 SerializedDatabase::new(
213 data_ptr,
214 size.try_into()
215 .expect("Cannot fit the serialized database into memory"),
216 )
217 }
218 }
219
220 pub(super) fn deserialize(&mut self, data: &[u8]) -> QueryResult<()> {
221 let db_size = data
222 .len()
223 .try_into()
224 .map_err(|e| Error::DeserializationError(Box::new(e)))?;
225 #[allow(clippy::unnecessary_cast)]
227 unsafe {
228 let result = ffi::sqlite3_deserialize(
229 self.internal_connection.as_ptr(),
230 std::ptr::null(),
231 data.as_ptr() as *mut u8,
232 db_size,
233 db_size,
234 ffi::SQLITE_DESERIALIZE_READONLY as u32,
235 );
236
237 ensure_sqlite_ok(result, self.internal_connection.as_ptr())
238 }
239 }
240
241 fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
242 CString::new(fn_name)
243 }
244
245 fn get_flags(deterministic: bool) -> i32 {
246 let mut flags = ffi::SQLITE_UTF8;
247 if deterministic {
248 flags |= ffi::SQLITE_DETERMINISTIC;
249 }
250 flags
251 }
252
253 fn process_sql_function_result(result: i32) -> Result<(), Error> {
254 if result == ffi::SQLITE_OK {
255 Ok(())
256 } else {
257 let error_message = super::error_message(result);
258 Err(DatabaseError(
259 DatabaseErrorKind::Unknown,
260 Box::new(error_message.to_string()),
261 ))
262 }
263 }
264}
265
266impl Drop for RawConnection {
267 fn drop(&mut self) {
268 use std::thread::panicking;
269
270 let close_result = unsafe { ffi::sqlite3_close(self.internal_connection.as_ptr()) };
271 if close_result != ffi::SQLITE_OK {
272 let error_message = super::error_message(close_result);
273 if panicking() {
274 write!(stderr(), "Error closing SQLite connection: {error_message}")
275 .expect("Error writing to `stderr`");
276 } else {
277 panic!("Error closing SQLite connection: {}", error_message);
278 }
279 }
280 }
281}
282
283enum SqliteCallbackError {
284 Abort(&'static str),
285 DieselError(crate::result::Error),
286 Panic(String),
287}
288
289impl SqliteCallbackError {
290 fn emit(&self, ctx: *mut ffi::sqlite3_context) {
291 let s;
292 let msg = match self {
293 SqliteCallbackError::Abort(msg) => *msg,
294 SqliteCallbackError::DieselError(e) => {
295 s = e.to_string();
296 &s
297 }
298 SqliteCallbackError::Panic(msg) => msg,
299 };
300 unsafe {
301 context_error_str(ctx, msg);
302 }
303 }
304}
305
306impl From<crate::result::Error> for SqliteCallbackError {
307 fn from(e: crate::result::Error) -> Self {
308 Self::DieselError(e)
309 }
310}
311
312struct CustomFunctionUserPtr<F> {
313 callback: F,
314 function_name: String,
315}
316
317#[allow(warnings)]
318extern "C" fn run_custom_function<F, Ret, RetSqlType>(
319 ctx: *mut ffi::sqlite3_context,
320 num_args: libc::c_int,
321 value_ptr: *mut *mut ffi::sqlite3_value,
322) where
323 F: FnMut(&RawConnection, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
324 + std::panic::UnwindSafe
325 + Send
326 + 'static,
327 Ret: ToSql<RetSqlType, Sqlite>,
328 Sqlite: HasSqlType<RetSqlType>,
329{
330 use std::ops::Deref;
331 static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen.";
332 static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen.";
333
334 let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } {
335 Some(conn) => mem::ManuallyDrop::new(RawConnection {
338 internal_connection: conn,
339 }),
340 None => {
341 unsafe { context_error_str(ctx, NULL_CONN_ERR) };
342 return;
343 }
344 };
345
346 let data_ptr = unsafe { ffi::sqlite3_user_data(ctx) };
347
348 let mut data_ptr = match NonNull::new(data_ptr as *mut CustomFunctionUserPtr<F>) {
349 None => unsafe {
350 context_error_str(ctx, NULL_DATA_ERR);
351 return;
352 },
353 Some(mut f) => f,
354 };
355 let data_ptr = unsafe { data_ptr.as_mut() };
356
357 let callback = std::panic::AssertUnwindSafe(&mut data_ptr.callback);
360
361 let result = std::panic::catch_unwind(move || {
362 let _ = &callback;
363 let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
364 let res = (callback.0)(&*conn, args)?;
365 let value = process_sql_function_result(&res)?;
366 unsafe {
368 value.result_of(&mut *ctx);
369 }
370 Ok(())
371 })
372 .unwrap_or_else(|p| Err(SqliteCallbackError::Panic(data_ptr.function_name.clone())));
373 if let Err(e) = result {
374 e.emit(ctx);
375 }
376}
377
378#[repr(u8)]
381enum OptionalAggregator<A> {
382 None,
384 Some(A),
385}
386
387#[allow(warnings)]
388extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
389 ctx: *mut ffi::sqlite3_context,
390 num_args: libc::c_int,
391 value_ptr: *mut *mut ffi::sqlite3_value,
392) where
393 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
394 Args: FromSqlRow<ArgsSqlType, Sqlite>,
395 Ret: ToSql<RetSqlType, Sqlite>,
396 Sqlite: HasSqlType<RetSqlType>,
397{
398 let result = std::panic::catch_unwind(move || {
399 let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
400 run_aggregator_step::<A, Args, ArgsSqlType>(ctx, args)
401 })
402 .unwrap_or_else(|e| {
403 Err(SqliteCallbackError::Panic(format!(
404 "{}::step() panicked",
405 std::any::type_name::<A>()
406 )))
407 });
408
409 match result {
410 Ok(()) => {}
411 Err(e) => e.emit(ctx),
412 }
413}
414
415fn run_aggregator_step<A, Args, ArgsSqlType>(
416 ctx: *mut ffi::sqlite3_context,
417 args: &mut [*mut ffi::sqlite3_value],
418) -> Result<(), SqliteCallbackError>
419where
420 A: SqliteAggregateFunction<Args>,
421 Args: FromSqlRow<ArgsSqlType, Sqlite>,
422{
423 static NULL_AG_CTX_ERR: &str = "An unknown error occurred. sqlite3_aggregate_context returned a null pointer. This should never happen.";
424 static NULL_CTX_ERR: &str =
425 "We've written the aggregator to the aggregate context, but it could not be retrieved.";
426
427 let n_bytes: i32 = std::mem::size_of::<OptionalAggregator<A>>()
428 .try_into()
429 .expect("Aggregate context should be larger than 2^32");
430 let aggregate_context = unsafe {
431 ffi::sqlite3_aggregate_context(ctx, n_bytes)
453 };
454 let aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
455 let aggregator = unsafe {
456 match aggregate_context.map(|a| &mut *a.as_ptr()) {
457 Some(&mut OptionalAggregator::Some(ref mut agg)) => agg,
458 Some(a_ptr @ &mut OptionalAggregator::None) => {
459 ptr::write_unaligned(a_ptr as *mut _, OptionalAggregator::Some(A::default()));
460 if let OptionalAggregator::Some(ref mut agg) = a_ptr {
461 agg
462 } else {
463 return Err(SqliteCallbackError::Abort(NULL_CTX_ERR));
464 }
465 }
466 None => {
467 return Err(SqliteCallbackError::Abort(NULL_AG_CTX_ERR));
468 }
469 }
470 };
471 let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
472
473 aggregator.step(args);
474 Ok(())
475}
476
477extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
478 ctx: *mut ffi::sqlite3_context,
479) where
480 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
481 Args: FromSqlRow<ArgsSqlType, Sqlite>,
482 Ret: ToSql<RetSqlType, Sqlite>,
483 Sqlite: HasSqlType<RetSqlType>,
484{
485 static NO_AGGREGATOR_FOUND: &str = "We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer.";
486 let aggregate_context = unsafe {
487 ffi::sqlite3_aggregate_context(ctx, 0)
494 };
495
496 let result = std::panic::catch_unwind(|| {
497 let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
498
499 let aggregator = if let Some(a) = aggregate_context.as_mut() {
500 let a = unsafe { a.as_mut() };
501 match std::mem::replace(a, OptionalAggregator::None) {
502 OptionalAggregator::None => {
503 return Err(SqliteCallbackError::Abort(NO_AGGREGATOR_FOUND));
504 }
505 OptionalAggregator::Some(a) => Some(a),
506 }
507 } else {
508 None
509 };
510
511 let res = A::finalize(aggregator);
512 let value = process_sql_function_result(&res)?;
513 let r = unsafe { value.result_of(&mut *ctx) };
515 r.map_err(|e| {
516 SqliteCallbackError::DieselError(crate::result::Error::SerializationError(Box::new(e)))
517 })?;
518 Ok(())
519 })
520 .unwrap_or_else(|_e| {
521 Err(SqliteCallbackError::Panic(format!(
522 "{}::finalize() panicked",
523 std::any::type_name::<A>()
524 )))
525 });
526 if let Err(e) = result {
527 e.emit(ctx);
528 }
529}
530
531unsafe fn context_error_str(ctx: *mut ffi::sqlite3_context, error: &str) {
532 let len: i32 = error
533 .len()
534 .try_into()
535 .expect("Trying to set a error message with more than 2^32 byte is not supported");
536 ffi::sqlite3_result_error(ctx, error.as_ptr() as *const _, len);
537}
538
539struct CollationUserPtr<F> {
540 callback: F,
541 collation_name: String,
542}
543
544#[allow(warnings)]
545extern "C" fn run_collation_function<F>(
546 user_ptr: *mut libc::c_void,
547 lhs_len: libc::c_int,
548 lhs_ptr: *const libc::c_void,
549 rhs_len: libc::c_int,
550 rhs_ptr: *const libc::c_void,
551) -> libc::c_int
552where
553 F: Fn(&str, &str) -> std::cmp::Ordering + Send + std::panic::UnwindSafe + 'static,
554{
555 let user_ptr = user_ptr as *const CollationUserPtr<F>;
556 let user_ptr = std::panic::AssertUnwindSafe(unsafe { user_ptr.as_ref() });
557
558 let result = std::panic::catch_unwind(|| {
559 let user_ptr = user_ptr.ok_or_else(|| {
560 SqliteCallbackError::Abort(
561 "Got a null pointer as data pointer. This should never happen",
562 )
563 })?;
564 for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
565 if *len < 0 {
566 assert_fail!(
567 "An unknown error occurred. {}_len is negative. This should never happen.",
568 side
569 );
570 }
571 if ptr.is_null() {
572 assert_fail!(
573 "An unknown error occurred. {}_ptr is a null pointer. This should never happen.",
574 side
575 );
576 }
577 }
578
579 let (rhs, lhs) = unsafe {
580 (
584 str::from_utf8(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
585 str::from_utf8(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
586 )
587 };
588
589 let rhs =
590 rhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for rhs"))?;
591 let lhs =
592 lhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for lhs"))?;
593
594 Ok((user_ptr.callback)(rhs, lhs))
595 })
596 .unwrap_or_else(|p| {
597 Err(SqliteCallbackError::Panic(
598 user_ptr
599 .map(|u| u.collation_name.clone())
600 .unwrap_or_default(),
601 ))
602 });
603
604 match result {
605 Ok(std::cmp::Ordering::Less) => -1,
606 Ok(std::cmp::Ordering::Equal) => 0,
607 Ok(std::cmp::Ordering::Greater) => 1,
608 Err(SqliteCallbackError::Abort(a)) => {
609 eprintln!(
610 "Collation function {} failed with: {}",
611 user_ptr
612 .map(|c| &c.collation_name as &str)
613 .unwrap_or_default(),
614 a
615 );
616 std::process::abort()
617 }
618 Err(SqliteCallbackError::DieselError(e)) => {
619 eprintln!(
620 "Collation function {} failed with: {}",
621 user_ptr
622 .map(|c| &c.collation_name as &str)
623 .unwrap_or_default(),
624 e
625 );
626 std::process::abort()
627 }
628 Err(SqliteCallbackError::Panic(msg)) => {
629 eprintln!("Collation function {} panicked", msg);
630 std::process::abort()
631 }
632 }
633}
634
635extern "C" fn destroy_boxed<F>(data: *mut libc::c_void) {
636 let ptr = data as *mut F;
637 unsafe { std::mem::drop(Box::from_raw(ptr)) };
638}