diesel/sqlite/connection/
functions.rs

1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3
4#[cfg(all(target_family = "wasm", target_os = "unknown"))]
5use sqlite_wasm_rs::export as ffi;
6
7use super::raw::RawConnection;
8use super::row::PrivateSqliteRow;
9use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
10use crate::backend::Backend;
11use crate::deserialize::{FromSqlRow, StaticallySizedRow};
12use crate::result::{DatabaseErrorKind, Error, QueryResult};
13use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
14use crate::serialize::{IsNull, Output, ToSql};
15use crate::sql_types::HasSqlType;
16use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
17use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
18use crate::sqlite::SqliteValue;
19use std::cell::{Ref, RefCell};
20use std::marker::PhantomData;
21use std::mem::ManuallyDrop;
22use std::ops::DerefMut;
23use std::rc::Rc;
24
25pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
26    conn: &RawConnection,
27    fn_name: &str,
28    deterministic: bool,
29    mut f: F,
30) -> QueryResult<()>
31where
32    F: FnMut(&RawConnection, Args) -> Ret + std::panic::UnwindSafe + Send + 'static,
33    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
34    Ret: ToSql<RetSqlType, Sqlite>,
35    Sqlite: HasSqlType<RetSqlType>,
36{
37    let fields_needed = Args::FIELD_COUNT;
38    if fields_needed > 127 {
39        return Err(Error::DatabaseError(
40            DatabaseErrorKind::UnableToSendCommand,
41            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
42        ));
43    }
44
45    conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
46        let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
47
48        Ok(f(conn, args))
49    })?;
50    Ok(())
51}
52
53pub(super) fn register_noargs<RetSqlType, Ret, F>(
54    conn: &RawConnection,
55    fn_name: &str,
56    deterministic: bool,
57    mut f: F,
58) -> QueryResult<()>
59where
60    F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
61    Ret: ToSql<RetSqlType, Sqlite>,
62    Sqlite: HasSqlType<RetSqlType>,
63{
64    conn.register_sql_function(fn_name, 0, deterministic, move |_, _| Ok(f()))?;
65    Ok(())
66}
67
68pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
69    conn: &RawConnection,
70    fn_name: &str,
71) -> QueryResult<()>
72where
73    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
74    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
75    Ret: ToSql<RetSqlType, Sqlite>,
76    Sqlite: HasSqlType<RetSqlType>,
77{
78    let fields_needed = Args::FIELD_COUNT;
79    if fields_needed > 127 {
80        return Err(Error::DatabaseError(
81            DatabaseErrorKind::UnableToSendCommand,
82            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
83        ));
84    }
85
86    conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
87        fn_name,
88        fields_needed,
89    )?;
90
91    Ok(())
92}
93
94pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
95    args: &mut [*mut ffi::sqlite3_value],
96) -> Result<Args, Error>
97where
98    Args: FromSqlRow<ArgsSqlType, Sqlite>,
99{
100    let row = FunctionRow::new(args);
101    Args::build_from_row(&row).map_err(Error::DeserializationError)
102}
103
104// clippy is wrong here, the let binding is required
105// for lifetime reasons
106#[allow(clippy::let_unit_value)]
107pub(super) fn process_sql_function_result<RetSqlType, Ret>(
108    result: &'_ Ret,
109) -> QueryResult<InternalSqliteBindValue<'_>>
110where
111    Ret: ToSql<RetSqlType, Sqlite>,
112    Sqlite: HasSqlType<RetSqlType>,
113{
114    let mut metadata_lookup = ();
115    let value = SqliteBindValue {
116        inner: InternalSqliteBindValue::Null,
117    };
118    let mut buf = Output::new(value, &mut metadata_lookup);
119    let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
120
121    if let IsNull::Yes = is_null {
122        Ok(InternalSqliteBindValue::Null)
123    } else {
124        Ok(buf.into_inner().inner)
125    }
126}
127
128struct FunctionRow<'a> {
129    // we use `ManuallyDrop` to prevent dropping the content of the internal vector
130    // as this buffer is owned by sqlite not by diesel
131    args: Rc<RefCell<ManuallyDrop<PrivateSqliteRow<'a, 'static>>>>,
132    field_count: usize,
133    marker: PhantomData<&'a ffi::sqlite3_value>,
134}
135
136impl Drop for FunctionRow<'_> {
137    #[allow(unsafe_code)] // manual drop calls
138    fn drop(&mut self) {
139        if let Some(args) = Rc::get_mut(&mut self.args) {
140            if let PrivateSqliteRow::Duplicated { column_names, .. } =
141                DerefMut::deref_mut(RefCell::get_mut(args))
142            {
143                if Rc::strong_count(column_names) == 1 {
144                    // According the https://doc.rust-lang.org/std/mem/struct.ManuallyDrop.html#method.drop
145                    // it's fine to just drop the values here
146                    unsafe { std::ptr::drop_in_place(column_names as *mut _) }
147                }
148            }
149        }
150    }
151}
152
153impl FunctionRow<'_> {
154    #[allow(unsafe_code)] // complicated ptr cast
155    fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
156        let lengths = args.len();
157        let args = unsafe {
158            Vec::from_raw_parts(
159                // This cast is safe because:
160                // * Casting from a pointer to an array to a pointer to the first array
161                // element is safe
162                // * Casting from a raw pointer to `NonNull<T>` is safe,
163                // because `NonNull` is #[repr(transparent)]
164                // * Casting from `NonNull<T>` to `OwnedSqliteValue` is safe,
165                // as the struct is `#[repr(transparent)]
166                // * Casting from `NonNull<T>` to `Option<NonNull<T>>` as the documentation
167                // states: "This is so that enums may use this forbidden value as a discriminant –
168                // Option<NonNull<T>> has the same size as *mut T"
169                // * The last point remains true for `OwnedSqliteValue` as `#[repr(transparent)]
170                // guarantees the same layout as the inner type
171                // * It's unsafe to drop the vector (and the vector elements)
172                // because of this we wrap the vector (or better the Row)
173                // Into `ManualDrop` to prevent the dropping
174                args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
175                    as *mut Option<OwnedSqliteValue>,
176                lengths,
177                lengths,
178            )
179        };
180
181        Self {
182            field_count: lengths,
183            args: Rc::new(RefCell::new(ManuallyDrop::new(
184                PrivateSqliteRow::Duplicated {
185                    values: args,
186                    column_names: Rc::from(vec![None; lengths]),
187                },
188            ))),
189            marker: PhantomData,
190        }
191    }
192}
193
194impl RowSealed for FunctionRow<'_> {}
195
196impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
197    type Field<'f>
198        = FunctionArgument<'f>
199    where
200        'a: 'f,
201        Self: 'f;
202    type InnerPartialRow = Self;
203
204    fn field_count(&self) -> usize {
205        self.field_count
206    }
207
208    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
209    where
210        'a: 'b,
211        Self: crate::row::RowIndex<I>,
212    {
213        let col_idx = self.idx(idx)?;
214        Some(FunctionArgument {
215            args: self.args.borrow(),
216            col_idx,
217        })
218    }
219
220    fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
221        PartialRow::new(self, range)
222    }
223}
224
225impl RowIndex<usize> for FunctionRow<'_> {
226    fn idx(&self, idx: usize) -> Option<usize> {
227        if idx < self.field_count() {
228            Some(idx)
229        } else {
230            None
231        }
232    }
233}
234
235impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
236    fn idx(&self, _idx: &'a str) -> Option<usize> {
237        None
238    }
239}
240
241struct FunctionArgument<'a> {
242    args: Ref<'a, ManuallyDrop<PrivateSqliteRow<'a, 'static>>>,
243    col_idx: usize,
244}
245
246impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
247    fn field_name(&self) -> Option<&str> {
248        None
249    }
250
251    fn is_null(&self) -> bool {
252        self.value().is_none()
253    }
254
255    fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
256        SqliteValue::new(
257            Ref::map(Ref::clone(&self.args), |drop| std::ops::Deref::deref(drop)),
258            self.col_idx,
259        )
260    }
261}