1#![allow(unsafe_code)] extern crate libsqlite3_sys as ffi;
3
4use std::ffi::{CString, NulError};
5use std::io::{stderr, Write};
6use std::os::raw as libc;
7use std::ptr::NonNull;
8use std::{mem, ptr, slice, str};
9
10use super::functions::{build_sql_function_args, process_sql_function_result};
11use super::serialized_database::SerializedDatabase;
12use super::stmt::ensure_sqlite_ok;
13use super::{Sqlite, SqliteAggregateFunction};
14use crate::deserialize::FromSqlRow;
15use crate::result::Error::DatabaseError;
16use crate::result::*;
17use crate::serialize::ToSql;
18use crate::sql_types::HasSqlType;
19
20macro_rules! assert_fail {
23 ($fmt:expr $(,$args:tt)*) => {
24 eprint!(concat!(
25 $fmt,
26 "If you see this message, please open an issue at https://github.com/diesel-rs/diesel/issues/new.\n",
27 "Source location: {}:{}\n",
28 ), $($args,)* file!(), line!());
29 std::process::abort()
30 };
31}
32
33#[allow(missing_debug_implementations, missing_copy_implementations)]
34pub(super) struct RawConnection {
35 pub(super) internal_connection: NonNull<ffi::sqlite3>,
36}
37
38impl RawConnection {
39 pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
40 let mut conn_pointer = ptr::null_mut();
41
42 let database_url = if database_url.starts_with("sqlite://") {
43 CString::new(database_url.replacen("sqlite://", "file:", 1))?
44 } else {
45 CString::new(database_url)?
46 };
47 let flags = ffi::SQLITE_OPEN_READWRITE | ffi::SQLITE_OPEN_CREATE | ffi::SQLITE_OPEN_URI;
48 let connection_status = unsafe {
49 ffi::sqlite3_open_v2(database_url.as_ptr(), &mut conn_pointer, flags, ptr::null())
50 };
51
52 match connection_status {
53 ffi::SQLITE_OK => {
54 let conn_pointer = unsafe { NonNull::new_unchecked(conn_pointer) };
55 Ok(RawConnection {
56 internal_connection: conn_pointer,
57 })
58 }
59 err_code => {
60 let message = super::error_message(err_code);
61 Err(ConnectionError::BadConnection(message.into()))
62 }
63 }
64 }
65
66 pub(super) fn exec(&self, query: &str) -> QueryResult<()> {
67 let query = CString::new(query)?;
68 let callback_fn = None;
69 let callback_arg = ptr::null_mut();
70 let result = unsafe {
71 ffi::sqlite3_exec(
72 self.internal_connection.as_ptr(),
73 query.as_ptr(),
74 callback_fn,
75 callback_arg,
76 ptr::null_mut(),
77 )
78 };
79
80 ensure_sqlite_ok(result, self.internal_connection.as_ptr())
81 }
82
83 pub(super) fn rows_affected_by_last_query(
84 &self,
85 ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
86 let r = unsafe { ffi::sqlite3_changes(self.internal_connection.as_ptr()) };
87
88 Ok(r.try_into()?)
89 }
90
91 pub(super) fn register_sql_function<F, Ret, RetSqlType>(
92 &self,
93 fn_name: &str,
94 num_args: usize,
95 deterministic: bool,
96 f: F,
97 ) -> QueryResult<()>
98 where
99 F: FnMut(&Self, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
100 + std::panic::UnwindSafe
101 + Send
102 + 'static,
103 Ret: ToSql<RetSqlType, Sqlite>,
104 Sqlite: HasSqlType<RetSqlType>,
105 {
106 let callback_fn = Box::into_raw(Box::new(CustomFunctionUserPtr {
107 callback: f,
108 function_name: fn_name.to_owned(),
109 }));
110 let fn_name = Self::get_fn_name(fn_name)?;
111 let flags = Self::get_flags(deterministic);
112 let num_args = num_args
113 .try_into()
114 .map_err(|e| Error::SerializationError(Box::new(e)))?;
115
116 let result = unsafe {
117 ffi::sqlite3_create_function_v2(
118 self.internal_connection.as_ptr(),
119 fn_name.as_ptr(),
120 num_args,
121 flags,
122 callback_fn as *mut _,
123 Some(run_custom_function::<F, Ret, RetSqlType>),
124 None,
125 None,
126 Some(destroy_boxed::<CustomFunctionUserPtr<F>>),
127 )
128 };
129
130 Self::process_sql_function_result(result)
131 }
132
133 pub(super) fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
134 &self,
135 fn_name: &str,
136 num_args: usize,
137 ) -> QueryResult<()>
138 where
139 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
140 Args: FromSqlRow<ArgsSqlType, Sqlite>,
141 Ret: ToSql<RetSqlType, Sqlite>,
142 Sqlite: HasSqlType<RetSqlType>,
143 {
144 let fn_name = Self::get_fn_name(fn_name)?;
145 let flags = Self::get_flags(false);
146 let num_args = num_args
147 .try_into()
148 .map_err(|e| Error::SerializationError(Box::new(e)))?;
149
150 let result = unsafe {
151 ffi::sqlite3_create_function_v2(
152 self.internal_connection.as_ptr(),
153 fn_name.as_ptr(),
154 num_args,
155 flags,
156 ptr::null_mut(),
157 None,
158 Some(run_aggregator_step_function::<_, _, _, _, A>),
159 Some(run_aggregator_final_function::<_, _, _, _, A>),
160 None,
161 )
162 };
163
164 Self::process_sql_function_result(result)
165 }
166
167 pub(super) fn register_collation_function<F>(
168 &self,
169 collation_name: &str,
170 collation: F,
171 ) -> QueryResult<()>
172 where
173 F: Fn(&str, &str) -> std::cmp::Ordering + std::panic::UnwindSafe + Send + 'static,
174 {
175 let callback_fn = Box::into_raw(Box::new(CollationUserPtr {
176 callback: collation,
177 collation_name: collation_name.to_owned(),
178 }));
179 let collation_name = Self::get_fn_name(collation_name)?;
180
181 let result = unsafe {
182 ffi::sqlite3_create_collation_v2(
183 self.internal_connection.as_ptr(),
184 collation_name.as_ptr(),
185 ffi::SQLITE_UTF8,
186 callback_fn as *mut _,
187 Some(run_collation_function::<F>),
188 Some(destroy_boxed::<CollationUserPtr<F>>),
189 )
190 };
191
192 let result = Self::process_sql_function_result(result);
193 if result.is_err() {
194 destroy_boxed::<CollationUserPtr<F>>(callback_fn as *mut _);
195 }
196 result
197 }
198
199 pub(super) fn serialize(&mut self) -> SerializedDatabase {
200 unsafe {
201 let mut size: ffi::sqlite3_int64 = 0;
202 let data_ptr = ffi::sqlite3_serialize(
203 self.internal_connection.as_ptr(),
204 std::ptr::null(),
205 &mut size as *mut _,
206 0,
207 );
208 SerializedDatabase::new(
209 data_ptr,
210 size.try_into()
211 .expect("Cannot fit the serialized database into memory"),
212 )
213 }
214 }
215
216 pub(super) fn deserialize(&mut self, data: &[u8]) -> QueryResult<()> {
217 let db_size = data
218 .len()
219 .try_into()
220 .map_err(|e| Error::DeserializationError(Box::new(e)))?;
221 #[allow(clippy::unnecessary_cast)]
223 unsafe {
224 let result = ffi::sqlite3_deserialize(
225 self.internal_connection.as_ptr(),
226 std::ptr::null(),
227 data.as_ptr() as *mut u8,
228 db_size,
229 db_size,
230 ffi::SQLITE_DESERIALIZE_READONLY as u32,
231 );
232
233 ensure_sqlite_ok(result, self.internal_connection.as_ptr())
234 }
235 }
236
237 fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
238 CString::new(fn_name)
239 }
240
241 fn get_flags(deterministic: bool) -> i32 {
242 let mut flags = ffi::SQLITE_UTF8;
243 if deterministic {
244 flags |= ffi::SQLITE_DETERMINISTIC;
245 }
246 flags
247 }
248
249 fn process_sql_function_result(result: i32) -> Result<(), Error> {
250 if result == ffi::SQLITE_OK {
251 Ok(())
252 } else {
253 let error_message = super::error_message(result);
254 Err(DatabaseError(
255 DatabaseErrorKind::Unknown,
256 Box::new(error_message.to_string()),
257 ))
258 }
259 }
260}
261
262impl Drop for RawConnection {
263 fn drop(&mut self) {
264 use std::thread::panicking;
265
266 let close_result = unsafe { ffi::sqlite3_close(self.internal_connection.as_ptr()) };
267 if close_result != ffi::SQLITE_OK {
268 let error_message = super::error_message(close_result);
269 if panicking() {
270 write!(stderr(), "Error closing SQLite connection: {error_message}")
271 .expect("Error writing to `stderr`");
272 } else {
273 panic!("Error closing SQLite connection: {}", error_message);
274 }
275 }
276 }
277}
278
279enum SqliteCallbackError {
280 Abort(&'static str),
281 DieselError(crate::result::Error),
282 Panic(String),
283}
284
285impl SqliteCallbackError {
286 fn emit(&self, ctx: *mut ffi::sqlite3_context) {
287 let s;
288 let msg = match self {
289 SqliteCallbackError::Abort(msg) => *msg,
290 SqliteCallbackError::DieselError(e) => {
291 s = e.to_string();
292 &s
293 }
294 SqliteCallbackError::Panic(msg) => msg,
295 };
296 unsafe {
297 context_error_str(ctx, msg);
298 }
299 }
300}
301
302impl From<crate::result::Error> for SqliteCallbackError {
303 fn from(e: crate::result::Error) -> Self {
304 Self::DieselError(e)
305 }
306}
307
308struct CustomFunctionUserPtr<F> {
309 callback: F,
310 function_name: String,
311}
312
313#[allow(warnings)]
314extern "C" fn run_custom_function<F, Ret, RetSqlType>(
315 ctx: *mut ffi::sqlite3_context,
316 num_args: libc::c_int,
317 value_ptr: *mut *mut ffi::sqlite3_value,
318) where
319 F: FnMut(&RawConnection, &mut [*mut ffi::sqlite3_value]) -> QueryResult<Ret>
320 + std::panic::UnwindSafe
321 + Send
322 + 'static,
323 Ret: ToSql<RetSqlType, Sqlite>,
324 Sqlite: HasSqlType<RetSqlType>,
325{
326 use std::ops::Deref;
327 static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen.";
328 static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen.";
329
330 let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } {
331 Some(conn) => mem::ManuallyDrop::new(RawConnection {
334 internal_connection: conn,
335 }),
336 None => {
337 unsafe { context_error_str(ctx, NULL_CONN_ERR) };
338 return;
339 }
340 };
341
342 let data_ptr = unsafe { ffi::sqlite3_user_data(ctx) };
343
344 let mut data_ptr = match NonNull::new(data_ptr as *mut CustomFunctionUserPtr<F>) {
345 None => unsafe {
346 context_error_str(ctx, NULL_DATA_ERR);
347 return;
348 },
349 Some(mut f) => f,
350 };
351 let data_ptr = unsafe { data_ptr.as_mut() };
352
353 let callback = std::panic::AssertUnwindSafe(&mut data_ptr.callback);
356
357 let result = std::panic::catch_unwind(move || {
358 let _ = &callback;
359 let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
360 let res = (callback.0)(&*conn, args)?;
361 let value = process_sql_function_result(&res)?;
362 unsafe {
364 value.result_of(&mut *ctx);
365 }
366 Ok(())
367 })
368 .unwrap_or_else(|p| Err(SqliteCallbackError::Panic(data_ptr.function_name.clone())));
369 if let Err(e) = result {
370 e.emit(ctx);
371 }
372}
373
374#[repr(u8)]
377enum OptionalAggregator<A> {
378 None,
380 Some(A),
381}
382
383#[allow(warnings)]
384extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
385 ctx: *mut ffi::sqlite3_context,
386 num_args: libc::c_int,
387 value_ptr: *mut *mut ffi::sqlite3_value,
388) where
389 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
390 Args: FromSqlRow<ArgsSqlType, Sqlite>,
391 Ret: ToSql<RetSqlType, Sqlite>,
392 Sqlite: HasSqlType<RetSqlType>,
393{
394 let result = std::panic::catch_unwind(move || {
395 let args = unsafe { slice::from_raw_parts_mut(value_ptr, num_args as _) };
396 run_aggregator_step::<A, Args, ArgsSqlType>(ctx, args)
397 })
398 .unwrap_or_else(|e| {
399 Err(SqliteCallbackError::Panic(format!(
400 "{}::step() panicked",
401 std::any::type_name::<A>()
402 )))
403 });
404
405 match result {
406 Ok(()) => {}
407 Err(e) => e.emit(ctx),
408 }
409}
410
411fn run_aggregator_step<A, Args, ArgsSqlType>(
412 ctx: *mut ffi::sqlite3_context,
413 args: &mut [*mut ffi::sqlite3_value],
414) -> Result<(), SqliteCallbackError>
415where
416 A: SqliteAggregateFunction<Args>,
417 Args: FromSqlRow<ArgsSqlType, Sqlite>,
418{
419 static NULL_AG_CTX_ERR: &str = "An unknown error occurred. sqlite3_aggregate_context returned a null pointer. This should never happen.";
420 static NULL_CTX_ERR: &str =
421 "We've written the aggregator to the aggregate context, but it could not be retrieved.";
422
423 let n_bytes: i32 = std::mem::size_of::<OptionalAggregator<A>>()
424 .try_into()
425 .expect("Aggregate context should be larger than 2^32");
426 let aggregate_context = unsafe {
427 ffi::sqlite3_aggregate_context(ctx, n_bytes)
449 };
450 let aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
451 let aggregator = unsafe {
452 match aggregate_context.map(|a| &mut *a.as_ptr()) {
453 Some(&mut OptionalAggregator::Some(ref mut agg)) => agg,
454 Some(a_ptr @ &mut OptionalAggregator::None) => {
455 ptr::write_unaligned(a_ptr as *mut _, OptionalAggregator::Some(A::default()));
456 if let OptionalAggregator::Some(ref mut agg) = a_ptr {
457 agg
458 } else {
459 return Err(SqliteCallbackError::Abort(NULL_CTX_ERR));
460 }
461 }
462 None => {
463 return Err(SqliteCallbackError::Abort(NULL_AG_CTX_ERR));
464 }
465 }
466 };
467 let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
468
469 aggregator.step(args);
470 Ok(())
471}
472
473extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
474 ctx: *mut ffi::sqlite3_context,
475) where
476 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
477 Args: FromSqlRow<ArgsSqlType, Sqlite>,
478 Ret: ToSql<RetSqlType, Sqlite>,
479 Sqlite: HasSqlType<RetSqlType>,
480{
481 static NO_AGGREGATOR_FOUND: &str = "We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer.";
482 let aggregate_context = unsafe {
483 ffi::sqlite3_aggregate_context(ctx, 0)
490 };
491
492 let result = std::panic::catch_unwind(|| {
493 let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
494
495 let aggregator = if let Some(a) = aggregate_context.as_mut() {
496 let a = unsafe { a.as_mut() };
497 match std::mem::replace(a, OptionalAggregator::None) {
498 OptionalAggregator::None => {
499 return Err(SqliteCallbackError::Abort(NO_AGGREGATOR_FOUND));
500 }
501 OptionalAggregator::Some(a) => Some(a),
502 }
503 } else {
504 None
505 };
506
507 let res = A::finalize(aggregator);
508 let value = process_sql_function_result(&res)?;
509 let r = unsafe { value.result_of(&mut *ctx) };
511 r.map_err(|e| {
512 SqliteCallbackError::DieselError(crate::result::Error::SerializationError(Box::new(e)))
513 })?;
514 Ok(())
515 })
516 .unwrap_or_else(|_e| {
517 Err(SqliteCallbackError::Panic(format!(
518 "{}::finalize() panicked",
519 std::any::type_name::<A>()
520 )))
521 });
522 if let Err(e) = result {
523 e.emit(ctx);
524 }
525}
526
527unsafe fn context_error_str(ctx: *mut ffi::sqlite3_context, error: &str) {
528 let len: i32 = error
529 .len()
530 .try_into()
531 .expect("Trying to set a error message with more than 2^32 byte is not supported");
532 ffi::sqlite3_result_error(ctx, error.as_ptr() as *const _, len);
533}
534
535struct CollationUserPtr<F> {
536 callback: F,
537 collation_name: String,
538}
539
540#[allow(warnings)]
541extern "C" fn run_collation_function<F>(
542 user_ptr: *mut libc::c_void,
543 lhs_len: libc::c_int,
544 lhs_ptr: *const libc::c_void,
545 rhs_len: libc::c_int,
546 rhs_ptr: *const libc::c_void,
547) -> libc::c_int
548where
549 F: Fn(&str, &str) -> std::cmp::Ordering + Send + std::panic::UnwindSafe + 'static,
550{
551 let user_ptr = user_ptr as *const CollationUserPtr<F>;
552 let user_ptr = std::panic::AssertUnwindSafe(unsafe { user_ptr.as_ref() });
553
554 let result = std::panic::catch_unwind(|| {
555 let user_ptr = user_ptr.ok_or_else(|| {
556 SqliteCallbackError::Abort(
557 "Got a null pointer as data pointer. This should never happen",
558 )
559 })?;
560 for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
561 if *len < 0 {
562 assert_fail!(
563 "An unknown error occurred. {}_len is negative. This should never happen.",
564 side
565 );
566 }
567 if ptr.is_null() {
568 assert_fail!(
569 "An unknown error occurred. {}_ptr is a null pointer. This should never happen.",
570 side
571 );
572 }
573 }
574
575 let (rhs, lhs) = unsafe {
576 (
580 str::from_utf8(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
581 str::from_utf8(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
582 )
583 };
584
585 let rhs =
586 rhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for rhs"))?;
587 let lhs =
588 lhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for lhs"))?;
589
590 Ok((user_ptr.callback)(rhs, lhs))
591 })
592 .unwrap_or_else(|p| {
593 Err(SqliteCallbackError::Panic(
594 user_ptr
595 .map(|u| u.collation_name.clone())
596 .unwrap_or_default(),
597 ))
598 });
599
600 match result {
601 Ok(std::cmp::Ordering::Less) => -1,
602 Ok(std::cmp::Ordering::Equal) => 0,
603 Ok(std::cmp::Ordering::Greater) => 1,
604 Err(SqliteCallbackError::Abort(a)) => {
605 eprintln!(
606 "Collation function {} failed with: {}",
607 user_ptr
608 .map(|c| &c.collation_name as &str)
609 .unwrap_or_default(),
610 a
611 );
612 std::process::abort()
613 }
614 Err(SqliteCallbackError::DieselError(e)) => {
615 eprintln!(
616 "Collation function {} failed with: {}",
617 user_ptr
618 .map(|c| &c.collation_name as &str)
619 .unwrap_or_default(),
620 e
621 );
622 std::process::abort()
623 }
624 Err(SqliteCallbackError::Panic(msg)) => {
625 eprintln!("Collation function {} panicked", msg);
626 std::process::abort()
627 }
628 }
629}
630
631extern "C" fn destroy_boxed<F>(data: *mut libc::c_void) {
632 let ptr = data as *mut F;
633 unsafe { std::mem::drop(Box::from_raw(ptr)) };
634}