Skip to main content

diesel/sqlite/connection/
functions.rs

1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3
4#[cfg(all(target_family = "wasm", target_os = "unknown"))]
5use sqlite_wasm_rs as ffi;
6
7use super::raw::RawConnection;
8use super::row::PrivateSqliteRow;
9use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
10use crate::backend::Backend;
11use crate::deserialize::{FromSqlRow, StaticallySizedRow};
12use crate::result::{DatabaseErrorKind, Error, QueryResult};
13use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
14use crate::serialize::{IsNull, Output, ToSql};
15use crate::sql_types::HasSqlType;
16use crate::sqlite::SqliteValue;
17use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
18use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
19use alloc::boxed::Box;
20use alloc::rc::Rc;
21use alloc::string::ToString;
22use alloc::vec::Vec;
23use core::cell::{Ref, RefCell};
24use core::marker::PhantomData;
25use core::mem::ManuallyDrop;
26use core::ops::DerefMut;
27
28pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
29    conn: &RawConnection,
30    fn_name: &str,
31    deterministic: bool,
32    mut f: F,
33) -> QueryResult<()>
34where
35    F: FnMut(&RawConnection, Args) -> Ret + core::panic::UnwindSafe + Send + 'static,
36    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
37    Ret: ToSql<RetSqlType, Sqlite>,
38    Sqlite: HasSqlType<RetSqlType>,
39{
40    let fields_needed = Args::FIELD_COUNT;
41    if fields_needed > 127 {
42        return Err(Error::DatabaseError(
43            DatabaseErrorKind::UnableToSendCommand,
44            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
45        ));
46    }
47
48    conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
49        let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
50
51        Ok(f(conn, args))
52    })?;
53    Ok(())
54}
55
56pub(super) fn register_noargs<RetSqlType, Ret, F>(
57    conn: &RawConnection,
58    fn_name: &str,
59    deterministic: bool,
60    mut f: F,
61) -> QueryResult<()>
62where
63    F: FnMut() -> Ret + core::panic::UnwindSafe + Send + 'static,
64    Ret: ToSql<RetSqlType, Sqlite>,
65    Sqlite: HasSqlType<RetSqlType>,
66{
67    conn.register_sql_function(fn_name, 0, deterministic, move |_, _| Ok(f()))?;
68    Ok(())
69}
70
71pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
72    conn: &RawConnection,
73    fn_name: &str,
74) -> QueryResult<()>
75where
76    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + core::panic::UnwindSafe,
77    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
78    Ret: ToSql<RetSqlType, Sqlite>,
79    Sqlite: HasSqlType<RetSqlType>,
80{
81    let fields_needed = Args::FIELD_COUNT;
82    if fields_needed > 127 {
83        return Err(Error::DatabaseError(
84            DatabaseErrorKind::UnableToSendCommand,
85            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
86        ));
87    }
88
89    conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
90        fn_name,
91        fields_needed,
92    )?;
93
94    Ok(())
95}
96
97pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
98    args: &mut [*mut ffi::sqlite3_value],
99) -> Result<Args, Error>
100where
101    Args: FromSqlRow<ArgsSqlType, Sqlite>,
102{
103    let row = FunctionRow::new(args);
104    Args::build_from_row(&row).map_err(Error::DeserializationError)
105}
106
107// clippy is wrong here, the let binding is required
108// for lifetime reasons
109#[allow(clippy::let_unit_value)]
110pub(super) fn process_sql_function_result<RetSqlType, Ret>(
111    result: &'_ Ret,
112) -> QueryResult<InternalSqliteBindValue<'_>>
113where
114    Ret: ToSql<RetSqlType, Sqlite>,
115    Sqlite: HasSqlType<RetSqlType>,
116{
117    let mut metadata_lookup = ();
118    let value = SqliteBindValue {
119        inner: InternalSqliteBindValue::Null,
120    };
121    let mut buf = Output::new(value, &mut metadata_lookup);
122    let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
123
124    if let IsNull::Yes = is_null {
125        Ok(InternalSqliteBindValue::Null)
126    } else {
127        Ok(buf.into_inner().inner)
128    }
129}
130
131struct FunctionRow<'a> {
132    // we use `ManuallyDrop` to prevent dropping the content of the internal vector
133    // as this buffer is owned by sqlite not by diesel
134    args: Rc<RefCell<ManuallyDrop<PrivateSqliteRow<'a, 'static>>>>,
135    field_count: usize,
136    marker: PhantomData<&'a ffi::sqlite3_value>,
137}
138
139impl Drop for FunctionRow<'_> {
140    #[allow(unsafe_code)] // manual drop calls
141    fn drop(&mut self) {
142        if let Some(args) = Rc::get_mut(&mut self.args)
143            && let PrivateSqliteRow::Duplicated { column_names, .. } =
144                DerefMut::deref_mut(RefCell::get_mut(args))
145            && Rc::strong_count(column_names) == 1
146        {
147            // According the https://doc.rust-lang.org/std/mem/struct.ManuallyDrop.html#method.drop
148            // it's fine to just drop the values here
149            unsafe { core::ptr::drop_in_place(column_names as *mut _) }
150        }
151    }
152}
153
154impl FunctionRow<'_> {
155    #[allow(unsafe_code)] // complicated ptr cast
156    fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
157        let lengths = args.len();
158        let args = unsafe {
159            Vec::from_raw_parts(
160                // This cast is safe because:
161                // * Casting from a pointer to an array to a pointer to the first array
162                // element is safe
163                // * Casting from a raw pointer to `NonNull<T>` is safe,
164                // because `NonNull` is #[repr(transparent)]
165                // * Casting from `NonNull<T>` to `OwnedSqliteValue` is safe,
166                // as the struct is `#[repr(transparent)]
167                // * Casting from `NonNull<T>` to `Option<NonNull<T>>` as the documentation
168                // states: "This is so that enums may use this forbidden value as a discriminant –
169                // Option<NonNull<T>> has the same size as *mut T"
170                // * The last point remains true for `OwnedSqliteValue` as `#[repr(transparent)]
171                // guarantees the same layout as the inner type
172                // * It's unsafe to drop the vector (and the vector elements)
173                // because of this we wrap the vector (or better the Row)
174                // Into `ManualDrop` to prevent the dropping
175                args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
176                    as *mut Option<OwnedSqliteValue>,
177                lengths,
178                lengths,
179            )
180        };
181
182        Self {
183            field_count: lengths,
184            args: Rc::new(RefCell::new(ManuallyDrop::new(
185                PrivateSqliteRow::Duplicated {
186                    values: args,
187                    column_names: Rc::from(::alloc::vec::from_elem(None, lengths)alloc::vec![None; lengths]),
188                },
189            ))),
190            marker: PhantomData,
191        }
192    }
193}
194
195impl RowSealed for FunctionRow<'_> {}
196
197impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
198    type Field<'f>
199        = FunctionArgument<'f>
200    where
201        'a: 'f,
202        Self: 'f;
203    type InnerPartialRow = Self;
204
205    fn field_count(&self) -> usize {
206        self.field_count
207    }
208
209    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
210    where
211        'a: 'b,
212        Self: crate::row::RowIndex<I>,
213    {
214        let col_idx = self.idx(idx)?;
215        Some(FunctionArgument {
216            args: self.args.borrow(),
217            col_idx,
218        })
219    }
220
221    fn partial_row(&self, range: core::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
222        PartialRow::new(self, range)
223    }
224}
225
226impl RowIndex<usize> for FunctionRow<'_> {
227    fn idx(&self, idx: usize) -> Option<usize> {
228        if idx < self.field_count() {
229            Some(idx)
230        } else {
231            None
232        }
233    }
234}
235
236impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
237    fn idx(&self, _idx: &'a str) -> Option<usize> {
238        None
239    }
240}
241
242struct FunctionArgument<'a> {
243    args: Ref<'a, ManuallyDrop<PrivateSqliteRow<'a, 'static>>>,
244    col_idx: usize,
245}
246
247impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
248    fn field_name(&self) -> Option<&str> {
249        None
250    }
251
252    fn is_null(&self) -> bool {
253        self.value().is_none()
254    }
255
256    fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
257        SqliteValue::new(
258            Ref::map(Ref::clone(&self.args), |drop| core::ops::Deref::deref(drop)),
259            self.col_idx,
260        )
261    }
262}