1#![allow(unsafe_code)] #[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
3extern crate libsqlite3_sys as ffi;
4
5#[cfg(all(target_family = "wasm", target_os = "unknown"))]
6use sqlite_wasm_rs as ffi;
7
8use std::ffi::{CString, NulError};
9use std::io::{stderr, Write};
10use std::os::raw as libc;
11use std::ptr::NonNull;
12use std::{mem, ptr, slice, str};
13
14use super::functions::{build_sql_function_args, process_sql_function_result};
15use super::serialized_database::SerializedDatabase;
16use super::stmt::ensure_sqlite_ok;
17use super::{Sqlite, SqliteAggregateFunction};
18use crate::deserialize::FromSqlRow;
19use crate::result::Error::DatabaseError;
20use crate::result::*;
21use crate::serialize::ToSql;
22use crate::sql_types::HasSqlType;
23
24macro_rules! assert_fail {
27 ($fmt:expr $(,$args:tt)*) => {
28 eprint!(concat!(
29 $fmt,
30 "If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\n",
31 "Source location: {}:{}\n",
32 ), $($args,)* file!(), line!());
33 std::process::abort()
34 };
35}
36
37#[allow(missing_debug_implementations, missing_copy_implementations)]
38pub(super) struct RawConnection {
39 pub(super) internal_connection: NonNull<ffi::sqlite3>,
40}
41
42impl RawConnection {
43 pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
44 let mut conn_pointer = ptr::null_mut();
45
46 let database_url = if database_url.starts_with("sqlite://") {
47 CString::new(database_url.replacen("sqlite://", "file:", 1))?
48 } else {
49 CString::new(database_url)?
50 };
51 let flags = ffi::SQLITE_OPEN_READWRITE | ffi::SQLITE_OPEN_CREATE | ffi::SQLITE_OPEN_URI;
52 let connection_status = unsafe {
53 ffi::sqlite3_open_v2(database_url.as_ptr(), &mut conn_pointer, flags, ptr::null())
54 };
55
56 match connection_status {
57 ffi::SQLITE_OK => {
58 let conn_pointer = unsafe { NonNull::new_unchecked(conn_pointer) };
59 Ok(RawConnection {
60 internal_connection: conn_pointer,
61 })
62 }
63 err_code => {
64 let message = super::error_message(err_code);
65 unsafe { ffi::sqlite3_close(conn_pointer) };
71 Err(ConnectionError::BadConnection(message.into()))
72 }
73 }
74 }
75
76 pub(super) fn exec(&self, query: &str) -> QueryResult<()> {
77 let query = CString::new(query)?;
78 let callback_fn = None;
79 let callback_arg = ptr::null_mut();
80 let result = unsafe {
81 ffi::sqlite3_exec(
82 self.internal_connection.as_ptr(),
83 query.as_ptr(),
84 callback_fn,
85 callback_arg,
86 ptr::null_mut(),
87 )
88 };
89
90 ensure_sqlite_ok(result, self.internal_connection.as_ptr())
91 }
92
93 pub(super) fn rows_affected_by_last_query(
94 &self,
95 ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
96 let r = unsafe { ffi::sqlite3_changes(self.internal_connection.as_ptr()) };
97
98 Ok(r.try_into()?)
99 }
100
101 pub(super) fn register_sql_function<F, Ret, RetSqlType>(
102 &self,
103 fn_name: &str,
104 num_args: usize,
105 deterministic: bool,
106 f: F,
107 ) -> QueryResult<()>
108 where
109 F: FnMut(&Self, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
110 + std::panic::UnwindSafe
111 + Send
112 + 'static,
113 Ret: ToSql<RetSqlType, Sqlite>,
114 Sqlite: HasSqlType<RetSqlType>,
115 {
116 let c_fn_name = Self::get_fn_name(fn_name)?;
117 let flags = Self::get_flags(deterministic);
118 let num_args = num_args
119 .try_into()
120 .map_err(|e| Error::SerializationError(Box::new(e)))?;
121 let callback_fn = Box::into_raw(Box::new(CustomFunctionUserPtr {
124 callback: f,
125 function_name: fn_name.to_owned(),
126 }));
127
128 let result = unsafe {
129 ffi::sqlite3_create_function_v2(
130 self.internal_connection.as_ptr(),
131 c_fn_name.as_ptr(),
132 num_args,
133 flags,
134 callback_fn as *mut _,
135 Some(run_custom_function::<F, Ret, RetSqlType>),
136 None,
137 None,
138 Some(destroy_boxed::<CustomFunctionUserPtr<F>>),
139 )
140 };
141
142 Self::process_sql_function_result(result)
143 }
144
145 pub(super) fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
146 &self,
147 fn_name: &str,
148 num_args: usize,
149 ) -> QueryResult<()>
150 where
151 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
152 Args: FromSqlRow<ArgsSqlType, Sqlite>,
153 Ret: ToSql<RetSqlType, Sqlite>,
154 Sqlite: HasSqlType<RetSqlType>,
155 {
156 let fn_name = Self::get_fn_name(fn_name)?;
157 let flags = Self::get_flags(false);
158 let num_args = num_args
159 .try_into()
160 .map_err(|e| Error::SerializationError(Box::new(e)))?;
161
162 let ctx_ptr = Box::into_raw(Box::new(None::<A>));
169
170 let result = unsafe {
171 ffi::sqlite3_create_function_v2(
172 self.internal_connection.as_ptr(),
173 fn_name.as_ptr(),
174 num_args,
175 flags,
176 ctx_ptr.cast(),
177 None,
178 Some(run_aggregator_step_function::<_, _, _, _, A>),
179 Some(run_aggregator_final_function::<_, _, _, _, A>),
180 Some(destroy_boxed::<Option<A>>),
181 )
182 };
183
184 Self::process_sql_function_result(result)
185 }
186
187 pub(super) fn register_collation_function<F>(
188 &self,
189 collation_name: &str,
190 collation: F,
191 ) -> QueryResult<()>
192 where
193 F: Fn(&str, &str) -> std::cmp::Ordering + std::panic::UnwindSafe + Send + 'static,
194 {
195 let c_collation_name = Self::get_fn_name(collation_name)?;
196 let callback_fn = Box::into_raw(Box::new(CollationUserPtr {
198 callback: collation,
199 collation_name: collation_name.to_owned(),
200 }));
201
202 let result = unsafe {
203 ffi::sqlite3_create_collation_v2(
204 self.internal_connection.as_ptr(),
205 c_collation_name.as_ptr(),
206 ffi::SQLITE_UTF8,
207 callback_fn as *mut _,
208 Some(run_collation_function::<F>),
209 Some(destroy_boxed::<CollationUserPtr<F>>),
210 )
211 };
212
213 let result = Self::process_sql_function_result(result);
214 if result.is_err() {
215 destroy_boxed::<CollationUserPtr<F>>(callback_fn as *mut _);
216 }
217 result
218 }
219
220 pub(super) fn serialize(&mut self) -> SerializedDatabase {
221 unsafe {
222 let mut size: ffi::sqlite3_int64 = 0;
223 let data_ptr = ffi::sqlite3_serialize(
224 self.internal_connection.as_ptr(),
225 std::ptr::null(),
226 &mut size as *mut _,
227 0,
228 );
229 SerializedDatabase::new(
230 data_ptr,
231 size.try_into()
232 .expect("Cannot fit the serialized database into memory"),
233 )
234 }
235 }
236
237 pub(super) fn deserialize(&mut self, data: &[u8]) -> QueryResult<()> {
238 let db_size = data
239 .len()
240 .try_into()
241 .map_err(|e| Error::DeserializationError(Box::new(e)))?;
242 #[allow(clippy::unnecessary_cast)]
244 unsafe {
245 let result = ffi::sqlite3_deserialize(
246 self.internal_connection.as_ptr(),
247 std::ptr::null(),
248 data.as_ptr() as *mut u8,
249 db_size,
250 db_size,
251 ffi::SQLITE_DESERIALIZE_READONLY as u32,
252 );
253
254 ensure_sqlite_ok(result, self.internal_connection.as_ptr())
255 }
256 }
257
258 fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
259 CString::new(fn_name)
260 }
261
262 fn get_flags(deterministic: bool) -> i32 {
263 let mut flags = ffi::SQLITE_UTF8;
264 if deterministic {
265 flags |= ffi::SQLITE_DETERMINISTIC;
266 }
267 flags
268 }
269
270 fn process_sql_function_result(result: i32) -> Result<(), Error> {
271 if result == ffi::SQLITE_OK {
272 Ok(())
273 } else {
274 let error_message = super::error_message(result);
275 Err(DatabaseError(
276 DatabaseErrorKind::Unknown,
277 Box::new(error_message.to_string()),
278 ))
279 }
280 }
281}
282
283impl Drop for RawConnection {
284 fn drop(&mut self) {
285 use std::thread::panicking;
286
287 let close_result = unsafe { ffi::sqlite3_close(self.internal_connection.as_ptr()) };
288 if close_result != ffi::SQLITE_OK {
289 let error_message = super::error_message(close_result);
290 if panicking() {
291 stderr().write_fmt(format_args!("Error closing SQLite connection: {0}",
error_message))write!(stderr(), "Error closing SQLite connection: {error_message}")
292 .expect("Error writing to `stderr`");
293 } else {
294 {
::core::panicking::panic_fmt(format_args!("Error closing SQLite connection: {0}",
error_message));
};panic!("Error closing SQLite connection: {error_message}");
295 }
296 }
297 }
298}
299
300enum SqliteCallbackError {
301 Abort(&'static str),
302 DieselError(crate::result::Error),
303 Panic(String),
304}
305
306impl SqliteCallbackError {
307 fn emit(&self, ctx: *mut ffi::sqlite3_context) {
308 let s;
309 let msg = match self {
310 SqliteCallbackError::Abort(msg) => *msg,
311 SqliteCallbackError::DieselError(e) => {
312 s = e.to_string();
313 &s
314 }
315 SqliteCallbackError::Panic(msg) => msg,
316 };
317 unsafe {
318 context_error_str(ctx, msg);
319 }
320 }
321}
322
323impl From<crate::result::Error> for SqliteCallbackError {
324 fn from(e: crate::result::Error) -> Self {
325 Self::DieselError(e)
326 }
327}
328
329struct CustomFunctionUserPtr<F> {
330 callback: F,
331 function_name: String,
332}
333
334#[allow(warnings)]
335extern "C" fn run_custom_function<F, Ret, RetSqlType>(
336 ctx: *mut ffi::sqlite3_context,
337 num_args: libc::c_int,
338 value_ptr: *mut *mut ffi::sqlite3_value,
339) where
340 F: FnMut(&RawConnection, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
341 + std::panic::UnwindSafe
342 + Send
343 + 'static,
344 Ret: ToSql<RetSqlType, Sqlite>,
345 Sqlite: HasSqlType<RetSqlType>,
346{
347 use std::ops::Deref;
348 static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen.";
349 static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen.";
350
351 let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } {
352 Some(conn) => mem::ManuallyDrop::new(RawConnection {
355 internal_connection: conn,
356 }),
357 None => {
358 unsafe { context_error_str(ctx, NULL_CONN_ERR) };
359 return;
360 }
361 };
362
363 let data_ptr = unsafe { ffi::sqlite3_user_data(ctx) };
364
365 let mut data_ptr = match NonNull::new(data_ptr as *mut CustomFunctionUserPtr<F>) {
366 None => unsafe {
367 context_error_str(ctx, NULL_DATA_ERR);
368 return;
369 },
370 Some(mut f) => f,
371 };
372 let data_ptr = unsafe { data_ptr.as_mut() };
373
374 let callback = std::panic::AssertUnwindSafe(&mut data_ptr.callback);
377
378 let result = std::panic::catch_unwind(move || {
379 let _ = &callback;
380 let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
381 let res = (callback.0)(&*conn, args)?;
382 let value = process_sql_function_result(&res)?;
383 unsafe {
385 value.result_of(&mut *ctx);
386 }
387 Ok(())
388 })
389 .unwrap_or_else(|p| Err(SqliteCallbackError::Panic(data_ptr.function_name.clone())));
390 if let Err(e) = result {
391 e.emit(ctx);
392 }
393}
394
395#[allow(warnings)]
396extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
397 ctx: *mut ffi::sqlite3_context,
398 num_args: libc::c_int,
399 value_ptr: *mut *mut ffi::sqlite3_value,
400) where
401 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
402 Args: FromSqlRow<ArgsSqlType, Sqlite>,
403 Ret: ToSql<RetSqlType, Sqlite>,
404 Sqlite: HasSqlType<RetSqlType>,
405{
406 let result = std::panic::catch_unwind(move || {
407 let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
408 run_aggregator_step::<A, Args, ArgsSqlType>(ctx, args)
409 })
410 .unwrap_or_else(|e| {
411 Err(SqliteCallbackError::Panic(::alloc::__export::must_use({
::alloc::fmt::format(format_args!("{0}::step() panicked",
std::any::type_name::<A>()))
})format!(
412 "{}::step() panicked",
413 std::any::type_name::<A>()
414 )))
415 });
416
417 match result {
418 Ok(()) => {}
419 Err(e) => e.emit(ctx),
420 }
421}
422
423fn run_aggregator_step<A, Args, ArgsSqlType>(
424 ctx: *mut ffi::sqlite3_context,
425 args: &mut [*mut ffi::sqlite3_value],
426) -> Result<(), SqliteCallbackError>
427where
428 A: SqliteAggregateFunction<Args>,
429 Args: FromSqlRow<ArgsSqlType, Sqlite>,
430{
431 static NULL_CTX_ERR: &str =
432 "We've written the aggregator to the user data, but it could not be retrieved.";
433
434 let aggregator = unsafe {
435 let aggregate_context = ffi::sqlite3_user_data(ctx) as *mut Option<A>;
439 match aggregate_context.as_mut() {
443 Some(Some(agg)) => agg,
444 Some(r) => {
445 *r = Some(A::default());
446 r.as_mut().expect("Initialised literally above")
447 }
448 None => return Err(SqliteCallbackError::Abort(NULL_CTX_ERR)),
449 }
450 };
451
452 let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
453
454 aggregator.step(args);
455 Ok(())
456}
457
458extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
459 ctx: *mut ffi::sqlite3_context,
460) where
461 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
462 Args: FromSqlRow<ArgsSqlType, Sqlite>,
463 Ret: ToSql<RetSqlType, Sqlite>,
464 Sqlite: HasSqlType<RetSqlType>,
465{
466 let result = std::panic::catch_unwind(|| {
467 let aggregator = unsafe {
468 let aggregate_context = ffi::sqlite3_user_data(ctx) as *mut Option<A>;
472 aggregate_context.as_mut()
476 }
477 .and_then(|a| a.take());
478
479 let res = A::finalize(aggregator);
480 let value = process_sql_function_result(&res)?;
481 let r = unsafe { value.result_of(&mut *ctx) };
483 r.map_err(|e| {
484 SqliteCallbackError::DieselError(crate::result::Error::SerializationError(Box::new(e)))
485 })?;
486 Ok(())
487 })
488 .unwrap_or_else(|_e| {
489 Err(SqliteCallbackError::Panic(::alloc::__export::must_use({
::alloc::fmt::format(format_args!("{0}::finalize() panicked",
std::any::type_name::<A>()))
})format!(
490 "{}::finalize() panicked",
491 std::any::type_name::<A>()
492 )))
493 });
494 if let Err(e) = result {
495 e.emit(ctx);
496 }
497}
498
499unsafe fn context_error_str(ctx: *mut ffi::sqlite3_context, error: &str) {
500 let len: i32 = error.len().try_into().unwrap_or(i32::MAX);
501 unsafe {
502 ffi::sqlite3_result_error(ctx, error.as_ptr() as *const _, len);
503 }
504}
505
506struct CollationUserPtr<F> {
507 callback: F,
508 collation_name: String,
509}
510
511#[allow(warnings)]
512extern "C" fn run_collation_function<F>(
513 user_ptr: *mut libc::c_void,
514 lhs_len: libc::c_int,
515 lhs_ptr: *const libc::c_void,
516 rhs_len: libc::c_int,
517 rhs_ptr: *const libc::c_void,
518) -> libc::c_int
519where
520 F: Fn(&str, &str) -> std::cmp::Ordering + Send + std::panic::UnwindSafe + 'static,
521{
522 let user_ptr = user_ptr as *const CollationUserPtr<F>;
523 let user_ptr = std::panic::AssertUnwindSafe(unsafe { user_ptr.as_ref() });
524
525 let result = std::panic::catch_unwind(|| {
526 let user_ptr = user_ptr.ok_or_else(|| {
527 SqliteCallbackError::Abort(
528 "Got a null pointer as data pointer. This should never happen",
529 )
530 })?;
531 for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
532 if *len < 0 {
533 {
::std::io::_eprint(format_args!("An unknown error occurred. {0}_len is negative. This should never happen.If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\nSource location: {1}:{2}\n",
side, "diesel/src/sqlite/connection/raw.rs", 533u32));
};
std::process::abort();assert_fail!(
534 "An unknown error occurred. {}_len is negative. This should never happen.",
535 side
536 );
537 }
538 if ptr.is_null() {
539 {
::std::io::_eprint(format_args!("An unknown error occurred. {0}_ptr is a null pointer. This should never happen.If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\nSource location: {1}:{2}\n",
side, "diesel/src/sqlite/connection/raw.rs", 539u32));
};
std::process::abort();assert_fail!(
540 "An unknown error occurred. {}_ptr is a null pointer. This should never happen.",
541 side
542 );
543 }
544 }
545
546 let (rhs, lhs) = unsafe {
547 (
551 str::from_utf8(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
552 str::from_utf8(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
553 )
554 };
555
556 let rhs =
557 rhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for rhs"))?;
558 let lhs =
559 lhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for lhs"))?;
560
561 Ok((user_ptr.callback)(rhs, lhs))
562 })
563 .unwrap_or_else(|p| {
564 Err(SqliteCallbackError::Panic(
565 user_ptr
566 .map(|u| u.collation_name.clone())
567 .unwrap_or_default(),
568 ))
569 });
570
571 match result {
572 Ok(std::cmp::Ordering::Less) => -1,
573 Ok(std::cmp::Ordering::Equal) => 0,
574 Ok(std::cmp::Ordering::Greater) => 1,
575 Err(SqliteCallbackError::Abort(a)) => {
576 {
::std::io::_eprint(format_args!("Collation function {0} failed with: {1}\n",
user_ptr.map(|c| &c.collation_name as &str).unwrap_or_default(),
a));
};eprintln!(
577 "Collation function {} failed with: {}",
578 user_ptr
579 .map(|c| &c.collation_name as &str)
580 .unwrap_or_default(),
581 a
582 );
583 std::process::abort()
584 }
585 Err(SqliteCallbackError::DieselError(e)) => {
586 {
::std::io::_eprint(format_args!("Collation function {0} failed with: {1}\n",
user_ptr.map(|c| &c.collation_name as &str).unwrap_or_default(),
e));
};eprintln!(
587 "Collation function {} failed with: {}",
588 user_ptr
589 .map(|c| &c.collation_name as &str)
590 .unwrap_or_default(),
591 e
592 );
593 std::process::abort()
594 }
595 Err(SqliteCallbackError::Panic(msg)) => {
596 {
::std::io::_eprint(format_args!("Collation function {0} panicked\n",
msg));
};eprintln!("Collation function {} panicked", msg);
597 std::process::abort()
598 }
599 }
600}
601
602extern "C" fn destroy_boxed<F>(data: *mut libc::c_void) {
603 let ptr = data as *mut F;
604 unsafe { std::mem::drop(Box::from_raw(ptr)) };
605}