1#![allow(unsafe_code)] #[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
3extern crate libsqlite3_sys as ffi;
4
5#[cfg(all(target_family = "wasm", target_os = "unknown"))]
6use sqlite_wasm_rs as ffi;
7
8use super::functions::{build_sql_function_args, process_sql_function_result};
9use super::serialized_database::SerializedDatabase;
10use super::stmt::ensure_sqlite_ok;
11use super::{Sqlite, SqliteAggregateFunction};
12use crate::deserialize::FromSqlRow;
13use crate::result::Error::DatabaseError;
14use crate::result::*;
15use crate::serialize::ToSql;
16use crate::sql_types::HasSqlType;
17use alloc::borrow::ToOwned;
18use alloc::boxed::Box;
19use alloc::ffi::{CString, NulError};
20use alloc::string::{String, ToString};
21use core::ffi as libc;
22use core::ptr::NonNull;
23use core::{mem, ptr, slice, str};
24
25macro_rules! assert_fail {
28 ($fmt:expr_2021 $(,$args:tt)*) => {
29 #[cfg(feature = "std")]
30 eprint!(concat!(
31 $fmt,
32 "If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\n",
33 "Source location: {}:{}\n",
34 ), $($args,)* file!(), line!());
35 crate::util::std_compat::abort()
36 };
37}
38
39#[allow(missing_debug_implementations, missing_copy_implementations)]
40pub(super) struct RawConnection {
41 pub(super) internal_connection: NonNull<ffi::sqlite3>,
42}
43
44impl RawConnection {
45 pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
46 let mut conn_pointer = ptr::null_mut();
47
48 let database_url = if database_url.starts_with("sqlite://") {
49 CString::new(database_url.replacen("sqlite://", "file:", 1))?
50 } else {
51 CString::new(database_url)?
52 };
53 let flags = ffi::SQLITE_OPEN_READWRITE | ffi::SQLITE_OPEN_CREATE | ffi::SQLITE_OPEN_URI;
54 let connection_status = unsafe {
55 ffi::sqlite3_open_v2(database_url.as_ptr(), &mut conn_pointer, flags, ptr::null())
56 };
57
58 match connection_status {
59 ffi::SQLITE_OK => {
60 let conn_pointer = unsafe { NonNull::new_unchecked(conn_pointer) };
61 Ok(RawConnection {
62 internal_connection: conn_pointer,
63 })
64 }
65 err_code => {
66 let message = super::error_message(err_code);
67 Err(ConnectionError::BadConnection(message.into()))
68 }
69 }
70 }
71
72 pub(super) fn exec(&self, query: &str) -> QueryResult<()> {
73 let query = CString::new(query)?;
74 let callback_fn = None;
75 let callback_arg = ptr::null_mut();
76 let result = unsafe {
77 ffi::sqlite3_exec(
78 self.internal_connection.as_ptr(),
79 query.as_ptr(),
80 callback_fn,
81 callback_arg,
82 ptr::null_mut(),
83 )
84 };
85
86 ensure_sqlite_ok(result, self.internal_connection.as_ptr())
87 }
88
89 pub(super) fn rows_affected_by_last_query(
90 &self,
91 ) -> Result<usize, Box<dyn core::error::Error + Send + Sync>> {
92 let r = unsafe { ffi::sqlite3_changes(self.internal_connection.as_ptr()) };
93
94 Ok(r.try_into()?)
95 }
96
97 pub(super) fn last_insert_rowid(&self) -> i64 {
98 unsafe { ffi::sqlite3_last_insert_rowid(self.internal_connection.as_ptr()) }
99 }
100
101 pub(super) fn register_sql_function<F, Ret, RetSqlType>(
102 &self,
103 fn_name: &str,
104 num_args: usize,
105 deterministic: bool,
106 f: F,
107 ) -> QueryResult<()>
108 where
109 F: FnMut(&Self, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
110 + core::panic::UnwindSafe
111 + Send
112 + 'static,
113 Ret: ToSql<RetSqlType, Sqlite>,
114 Sqlite: HasSqlType<RetSqlType>,
115 {
116 let callback_fn = Box::into_raw(Box::new(CustomFunctionUserPtr {
117 callback: f,
118 function_name: fn_name.to_owned(),
119 }));
120 let fn_name = Self::get_fn_name(fn_name)?;
121 let flags = Self::get_flags(deterministic);
122 let num_args = num_args
123 .try_into()
124 .map_err(|e| Error::SerializationError(Box::new(e)))?;
125
126 let result = unsafe {
127 ffi::sqlite3_create_function_v2(
128 self.internal_connection.as_ptr(),
129 fn_name.as_ptr(),
130 num_args,
131 flags,
132 callback_fn as *mut _,
133 Some(run_custom_function::<F, Ret, RetSqlType>),
134 None,
135 None,
136 Some(destroy_boxed::<CustomFunctionUserPtr<F>>),
137 )
138 };
139
140 Self::process_sql_function_result(result)
141 }
142
143 pub(super) fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
144 &self,
145 fn_name: &str,
146 num_args: usize,
147 ) -> QueryResult<()>
148 where
149 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + core::panic::UnwindSafe,
150 Args: FromSqlRow<ArgsSqlType, Sqlite>,
151 Ret: ToSql<RetSqlType, Sqlite>,
152 Sqlite: HasSqlType<RetSqlType>,
153 {
154 let fn_name = Self::get_fn_name(fn_name)?;
155 let flags = Self::get_flags(false);
156 let num_args = num_args
157 .try_into()
158 .map_err(|e| Error::SerializationError(Box::new(e)))?;
159
160 let result = unsafe {
161 ffi::sqlite3_create_function_v2(
162 self.internal_connection.as_ptr(),
163 fn_name.as_ptr(),
164 num_args,
165 flags,
166 ptr::null_mut(),
167 None,
168 Some(run_aggregator_step_function::<_, _, _, _, A>),
169 Some(run_aggregator_final_function::<_, _, _, _, A>),
170 None,
171 )
172 };
173
174 Self::process_sql_function_result(result)
175 }
176
177 pub(super) fn register_collation_function<F>(
178 &self,
179 collation_name: &str,
180 collation: F,
181 ) -> QueryResult<()>
182 where
183 F: Fn(&str, &str) -> core::cmp::Ordering + core::panic::UnwindSafe + Send + 'static,
184 {
185 let callback_fn = Box::into_raw(Box::new(CollationUserPtr {
186 callback: collation,
187 collation_name: collation_name.to_owned(),
188 }));
189 let collation_name = Self::get_fn_name(collation_name)?;
190
191 let result = unsafe {
192 ffi::sqlite3_create_collation_v2(
193 self.internal_connection.as_ptr(),
194 collation_name.as_ptr(),
195 ffi::SQLITE_UTF8,
196 callback_fn as *mut _,
197 Some(run_collation_function::<F>),
198 Some(destroy_boxed::<CollationUserPtr<F>>),
199 )
200 };
201
202 let result = Self::process_sql_function_result(result);
203 if result.is_err() {
204 destroy_boxed::<CollationUserPtr<F>>(callback_fn as *mut _);
205 }
206 result
207 }
208
209 pub(super) fn serialize(&mut self) -> SerializedDatabase {
210 unsafe {
211 let mut size: ffi::sqlite3_int64 = 0;
212 let data_ptr = ffi::sqlite3_serialize(
213 self.internal_connection.as_ptr(),
214 core::ptr::null(),
215 &mut size as *mut _,
216 0,
217 );
218 SerializedDatabase::new(
219 data_ptr,
220 size.try_into()
221 .expect("Cannot fit the serialized database into memory"),
222 )
223 }
224 }
225
226 pub(super) fn deserialize(&mut self, data: &[u8]) -> QueryResult<()> {
227 let db_size = data
228 .len()
229 .try_into()
230 .map_err(|e| Error::DeserializationError(Box::new(e)))?;
231 #[allow(clippy::unnecessary_cast)]
233 unsafe {
234 let result = ffi::sqlite3_deserialize(
235 self.internal_connection.as_ptr(),
236 core::ptr::null(),
237 data.as_ptr() as *mut u8,
238 db_size,
239 db_size,
240 ffi::SQLITE_DESERIALIZE_READONLY as u32,
241 );
242
243 ensure_sqlite_ok(result, self.internal_connection.as_ptr())
244 }
245 }
246
247 fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
248 CString::new(fn_name)
249 }
250
251 fn get_flags(deterministic: bool) -> i32 {
252 let mut flags = ffi::SQLITE_UTF8;
253 if deterministic {
254 flags |= ffi::SQLITE_DETERMINISTIC;
255 }
256 flags
257 }
258
259 fn process_sql_function_result(result: i32) -> Result<(), Error> {
260 if result == ffi::SQLITE_OK {
261 Ok(())
262 } else {
263 let error_message = super::error_message(result);
264 Err(DatabaseError(
265 DatabaseErrorKind::Unknown,
266 Box::new(error_message.to_string()),
267 ))
268 }
269 }
270}
271
272impl Drop for RawConnection {
273 fn drop(&mut self) {
274 use crate::util::std_compat::panicking;
275
276 let close_result = unsafe { ffi::sqlite3_close(self.internal_connection.as_ptr()) };
277 if close_result != ffi::SQLITE_OK {
278 let error_message = super::error_message(close_result);
279 if panicking() {
280 #[cfg(feature = "std")]
281 {
::std::io::_eprint(format_args!("Error closing SQLite connection: {0}\n",
error_message));
};eprintln!("Error closing SQLite connection: {error_message}");
282 } else {
283 {
::core::panicking::panic_fmt(format_args!("Error closing SQLite connection: {0}",
error_message));
};panic!("Error closing SQLite connection: {error_message}");
284 }
285 }
286 }
287}
288
289enum SqliteCallbackError {
290 Abort(&'static str),
291 DieselError(crate::result::Error),
292 Panic(String),
293}
294
295impl SqliteCallbackError {
296 fn emit(&self, ctx: *mut ffi::sqlite3_context) {
297 let s;
298 let msg = match self {
299 SqliteCallbackError::Abort(msg) => *msg,
300 SqliteCallbackError::DieselError(e) => {
301 s = e.to_string();
302 &s
303 }
304 SqliteCallbackError::Panic(msg) => msg,
305 };
306 unsafe {
307 context_error_str(ctx, msg);
308 }
309 }
310}
311
312impl From<crate::result::Error> for SqliteCallbackError {
313 fn from(e: crate::result::Error) -> Self {
314 Self::DieselError(e)
315 }
316}
317
318struct CustomFunctionUserPtr<F> {
319 callback: F,
320 function_name: String,
321}
322
323#[allow(warnings)]
324extern "C" fn run_custom_function<F, Ret, RetSqlType>(
325 ctx: *mut ffi::sqlite3_context,
326 num_args: libc::c_int,
327 value_ptr: *mut *mut ffi::sqlite3_value,
328) where
329 F: FnMut(&RawConnection, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
330 + core::panic::UnwindSafe
331 + Send
332 + 'static,
333 Ret: ToSql<RetSqlType, Sqlite>,
334 Sqlite: HasSqlType<RetSqlType>,
335{
336 use core::ops::Deref;
337 static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen.";
338 static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen.";
339
340 let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } {
341 Some(conn) => mem::ManuallyDrop::new(RawConnection {
344 internal_connection: conn,
345 }),
346 None => {
347 unsafe { context_error_str(ctx, NULL_CONN_ERR) };
348 return;
349 }
350 };
351
352 let data_ptr = unsafe { ffi::sqlite3_user_data(ctx) };
353
354 let mut data_ptr = match NonNull::new(data_ptr as *mut CustomFunctionUserPtr<F>) {
355 None => unsafe {
356 context_error_str(ctx, NULL_DATA_ERR);
357 return;
358 },
359 Some(mut f) => f,
360 };
361 let data_ptr = unsafe { data_ptr.as_mut() };
362
363 let callback = core::panic::AssertUnwindSafe(&mut data_ptr.callback);
366
367 let result = crate::util::std_compat::catch_unwind(move || {
368 let _ = &callback;
369 let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
370 let res = (callback.0)(&*conn, args)?;
371 let value = process_sql_function_result(&res)?;
372 unsafe {
374 value.result_of(&mut *ctx);
375 }
376 Ok(())
377 })
378 .unwrap_or_else(|p| Err(SqliteCallbackError::Panic(data_ptr.function_name.clone())));
379 if let Err(e) = result {
380 e.emit(ctx);
381 }
382}
383
384#[repr(u8)]
387enum OptionalAggregator<A> {
388 None,
390 Some(A),
391}
392
393#[allow(warnings)]
394extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
395 ctx: *mut ffi::sqlite3_context,
396 num_args: libc::c_int,
397 value_ptr: *mut *mut ffi::sqlite3_value,
398) where
399 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + core::panic::UnwindSafe,
400 Args: FromSqlRow<ArgsSqlType, Sqlite>,
401 Ret: ToSql<RetSqlType, Sqlite>,
402 Sqlite: HasSqlType<RetSqlType>,
403{
404 let result = crate::util::std_compat::catch_unwind(move || {
405 let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
406 run_aggregator_step::<A, Args, ArgsSqlType>(ctx, args)
407 })
408 .unwrap_or_else(|e| {
409 Err(SqliteCallbackError::Panic(::alloc::__export::must_use({
::alloc::fmt::format(format_args!("{0}::step() panicked",
core::any::type_name::<A>()))
})alloc::format!(
410 "{}::step() panicked",
411 core::any::type_name::<A>()
412 )))
413 });
414
415 match result {
416 Ok(()) => {}
417 Err(e) => e.emit(ctx),
418 }
419}
420
421fn run_aggregator_step<A, Args, ArgsSqlType>(
422 ctx: *mut ffi::sqlite3_context,
423 args: &mut [*mut ffi::sqlite3_value],
424) -> Result<(), SqliteCallbackError>
425where
426 A: SqliteAggregateFunction<Args>,
427 Args: FromSqlRow<ArgsSqlType, Sqlite>,
428{
429 static NULL_AG_CTX_ERR: &str = "An unknown error occurred. sqlite3_aggregate_context returned a null pointer. This should never happen.";
430 static NULL_CTX_ERR: &str =
431 "We've written the aggregator to the aggregate context, but it could not be retrieved.";
432
433 let n_bytes: i32 = core::mem::size_of::<OptionalAggregator<A>>()
434 .try_into()
435 .expect("Aggregate context should be larger than 2^32");
436 let aggregate_context = unsafe {
437 ffi::sqlite3_aggregate_context(ctx, n_bytes)
459 };
460 let aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
461 let aggregator = unsafe {
462 match aggregate_context.map(|a| &mut *a.as_ptr()) {
463 Some(&mut OptionalAggregator::Some(ref mut agg)) => agg,
464 Some(a_ptr @ &mut OptionalAggregator::None) => {
465 ptr::write_unaligned(a_ptr as *mut _, OptionalAggregator::Some(A::default()));
466 if let OptionalAggregator::Some(agg) = a_ptr {
467 agg
468 } else {
469 return Err(SqliteCallbackError::Abort(NULL_CTX_ERR));
470 }
471 }
472 None => {
473 return Err(SqliteCallbackError::Abort(NULL_AG_CTX_ERR));
474 }
475 }
476 };
477 let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
478
479 aggregator.step(args);
480 Ok(())
481}
482
483extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
484 ctx: *mut ffi::sqlite3_context,
485) where
486 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
487 Args: FromSqlRow<ArgsSqlType, Sqlite>,
488 Ret: ToSql<RetSqlType, Sqlite>,
489 Sqlite: HasSqlType<RetSqlType>,
490{
491 static NO_AGGREGATOR_FOUND: &str = "We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer.";
492 let aggregate_context = unsafe {
493 ffi::sqlite3_aggregate_context(ctx, 0)
500 };
501
502 let result = crate::util::std_compat::catch_unwind(|| {
503 let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
504
505 let aggregator = if let Some(a) = aggregate_context.as_mut() {
506 let a = unsafe { a.as_mut() };
507 match core::mem::replace(a, OptionalAggregator::None) {
508 OptionalAggregator::None => {
509 return Err(SqliteCallbackError::Abort(NO_AGGREGATOR_FOUND));
510 }
511 OptionalAggregator::Some(a) => Some(a),
512 }
513 } else {
514 None
515 };
516
517 let res = A::finalize(aggregator);
518 let value = process_sql_function_result(&res)?;
519 let r = unsafe { value.result_of(&mut *ctx) };
521 r.map_err(|e| {
522 SqliteCallbackError::DieselError(crate::result::Error::SerializationError(Box::new(e)))
523 })?;
524 Ok(())
525 })
526 .unwrap_or_else(|_e| {
527 Err(SqliteCallbackError::Panic(::alloc::__export::must_use({
::alloc::fmt::format(format_args!("{0}::finalize() panicked",
core::any::type_name::<A>()))
})alloc::format!(
528 "{}::finalize() panicked",
529 core::any::type_name::<A>()
530 )))
531 });
532 if let Err(e) = result {
533 e.emit(ctx);
534 }
535}
536
537unsafe fn context_error_str(ctx: *mut ffi::sqlite3_context, error: &str) {
538 let len: i32 = error
539 .len()
540 .try_into()
541 .expect("Trying to set a error message with more than 2^32 byte is not supported");
542 unsafe {
543 ffi::sqlite3_result_error(ctx, error.as_ptr() as *const _, len);
544 }
545}
546
547struct CollationUserPtr<F> {
548 callback: F,
549 collation_name: String,
550}
551
552#[allow(warnings)]
553extern "C" fn run_collation_function<F>(
554 user_ptr: *mut libc::c_void,
555 lhs_len: libc::c_int,
556 lhs_ptr: *const libc::c_void,
557 rhs_len: libc::c_int,
558 rhs_ptr: *const libc::c_void,
559) -> libc::c_int
560where
561 F: Fn(&str, &str) -> core::cmp::Ordering + Send + core::panic::UnwindSafe + 'static,
562{
563 let user_ptr = user_ptr as *const CollationUserPtr<F>;
564 let user_ptr = core::panic::AssertUnwindSafe(unsafe { user_ptr.as_ref() });
565
566 let result = crate::util::std_compat::catch_unwind(|| {
567 let user_ptr = user_ptr.ok_or_else(|| {
568 SqliteCallbackError::Abort(
569 "Got a null pointer as data pointer. This should never happen",
570 )
571 })?;
572 for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
573 if *len < 0 {
574 {
::std::io::_eprint(format_args!("An unknown error occurred. {0}_len is negative. This should never happen.If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\nSource location: {1}:{2}\n",
side, "diesel/src/sqlite/connection/raw.rs", 574u32));
};
crate::util::std_compat::abort();assert_fail!(
575 "An unknown error occurred. {}_len is negative. This should never happen.",
576 side
577 );
578 }
579 if ptr.is_null() {
580 {
::std::io::_eprint(format_args!("An unknown error occurred. {0}_ptr is a null pointer. This should never happen.If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\nSource location: {1}:{2}\n",
side, "diesel/src/sqlite/connection/raw.rs", 580u32));
};
crate::util::std_compat::abort();assert_fail!(
581 "An unknown error occurred. {}_ptr is a null pointer. This should never happen.",
582 side
583 );
584 }
585 }
586
587 let (rhs, lhs) = unsafe {
588 (
592 str::from_utf8(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
593 str::from_utf8(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
594 )
595 };
596
597 let rhs =
598 rhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for rhs"))?;
599 let lhs =
600 lhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for lhs"))?;
601
602 Ok((user_ptr.callback)(rhs, lhs))
603 })
604 .unwrap_or_else(|p| {
605 Err(SqliteCallbackError::Panic(
606 user_ptr
607 .map(|u| u.collation_name.clone())
608 .unwrap_or_default(),
609 ))
610 });
611
612 match result {
613 Ok(core::cmp::Ordering::Less) => -1,
614 Ok(core::cmp::Ordering::Equal) => 0,
615 Ok(core::cmp::Ordering::Greater) => 1,
616 Err(SqliteCallbackError::Abort(a)) => {
617 #[cfg(feature = "std")]
618 {
::std::io::_eprint(format_args!("Collation function {0} failed with: {1}\n",
user_ptr.map(|c| &c.collation_name as &str).unwrap_or_default(),
a));
};eprintln!(
619 "Collation function {} failed with: {}",
620 user_ptr
621 .map(|c| &c.collation_name as &str)
622 .unwrap_or_default(),
623 a
624 );
625 crate::util::std_compat::abort()
626 }
627 Err(SqliteCallbackError::DieselError(e)) => {
628 #[cfg(feature = "std")]
629 {
::std::io::_eprint(format_args!("Collation function {0} failed with: {1}\n",
user_ptr.map(|c| &c.collation_name as &str).unwrap_or_default(),
e));
};eprintln!(
630 "Collation function {} failed with: {}",
631 user_ptr
632 .map(|c| &c.collation_name as &str)
633 .unwrap_or_default(),
634 e
635 );
636 crate::util::std_compat::abort()
637 }
638 Err(SqliteCallbackError::Panic(msg)) => {
639 #[cfg(feature = "std")]
640 {
::std::io::_eprint(format_args!("Collation function {0} panicked\n",
msg));
};eprintln!("Collation function {} panicked", msg);
641 crate::util::std_compat::abort()
642 }
643 }
644}
645
646extern "C" fn destroy_boxed<F>(data: *mut libc::c_void) {
647 let ptr = data as *mut F;
648 unsafe { core::mem::drop(Box::from_raw(ptr)) };
649}