diesel/sqlite/connection/
functions.rs

1extern crate libsqlite3_sys as ffi;
2
3use super::raw::RawConnection;
4use super::row::PrivateSqliteRow;
5use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
6use crate::backend::Backend;
7use crate::deserialize::{FromSqlRow, StaticallySizedRow};
8use crate::result::{DatabaseErrorKind, Error, QueryResult};
9use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
10use crate::serialize::{IsNull, Output, ToSql};
11use crate::sql_types::HasSqlType;
12use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
13use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
14use crate::sqlite::SqliteValue;
15use std::cell::{Ref, RefCell};
16use std::marker::PhantomData;
17use std::mem::ManuallyDrop;
18use std::ops::DerefMut;
19use std::rc::Rc;
20
21pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
22    conn: &RawConnection,
23    fn_name: &str,
24    deterministic: bool,
25    mut f: F,
26) -> QueryResult<()>
27where
28    F: FnMut(&RawConnection, Args) -> Ret + std::panic::UnwindSafe + Send + 'static,
29    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
30    Ret: ToSql<RetSqlType, Sqlite>,
31    Sqlite: HasSqlType<RetSqlType>,
32{
33    let fields_needed = Args::FIELD_COUNT;
34    if fields_needed > 127 {
35        return Err(Error::DatabaseError(
36            DatabaseErrorKind::UnableToSendCommand,
37            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
38        ));
39    }
40
41    conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
42        let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
43
44        Ok(f(conn, args))
45    })?;
46    Ok(())
47}
48
49pub(super) fn register_noargs<RetSqlType, Ret, F>(
50    conn: &RawConnection,
51    fn_name: &str,
52    deterministic: bool,
53    mut f: F,
54) -> QueryResult<()>
55where
56    F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
57    Ret: ToSql<RetSqlType, Sqlite>,
58    Sqlite: HasSqlType<RetSqlType>,
59{
60    conn.register_sql_function(fn_name, 0, deterministic, move |_, _| Ok(f()))?;
61    Ok(())
62}
63
64pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
65    conn: &RawConnection,
66    fn_name: &str,
67) -> QueryResult<()>
68where
69    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
70    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
71    Ret: ToSql<RetSqlType, Sqlite>,
72    Sqlite: HasSqlType<RetSqlType>,
73{
74    let fields_needed = Args::FIELD_COUNT;
75    if fields_needed > 127 {
76        return Err(Error::DatabaseError(
77            DatabaseErrorKind::UnableToSendCommand,
78            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
79        ));
80    }
81
82    conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
83        fn_name,
84        fields_needed,
85    )?;
86
87    Ok(())
88}
89
90pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
91    args: &mut [*mut ffi::sqlite3_value],
92) -> Result<Args, Error>
93where
94    Args: FromSqlRow<ArgsSqlType, Sqlite>,
95{
96    let row = FunctionRow::new(args);
97    Args::build_from_row(&row).map_err(Error::DeserializationError)
98}
99
100// clippy is wrong here, the let binding is required
101// for lifetime reasons
102#[allow(clippy::let_unit_value)]
103pub(super) fn process_sql_function_result<RetSqlType, Ret>(
104    result: &'_ Ret,
105) -> QueryResult<InternalSqliteBindValue<'_>>
106where
107    Ret: ToSql<RetSqlType, Sqlite>,
108    Sqlite: HasSqlType<RetSqlType>,
109{
110    let mut metadata_lookup = ();
111    let value = SqliteBindValue {
112        inner: InternalSqliteBindValue::Null,
113    };
114    let mut buf = Output::new(value, &mut metadata_lookup);
115    let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
116
117    if let IsNull::Yes = is_null {
118        Ok(InternalSqliteBindValue::Null)
119    } else {
120        Ok(buf.into_inner().inner)
121    }
122}
123
124struct FunctionRow<'a> {
125    // we use `ManuallyDrop` to prevent dropping the content of the internal vector
126    // as this buffer is owned by sqlite not by diesel
127    args: Rc<RefCell<ManuallyDrop<PrivateSqliteRow<'a, 'static>>>>,
128    field_count: usize,
129    marker: PhantomData<&'a ffi::sqlite3_value>,
130}
131
132impl Drop for FunctionRow<'_> {
133    #[allow(unsafe_code)] // manual drop calls
134    fn drop(&mut self) {
135        if let Some(args) = Rc::get_mut(&mut self.args) {
136            if let PrivateSqliteRow::Duplicated { column_names, .. } =
137                DerefMut::deref_mut(RefCell::get_mut(args))
138            {
139                if Rc::strong_count(column_names) == 1 {
140                    // According the https://doc.rust-lang.org/std/mem/struct.ManuallyDrop.html#method.drop
141                    // it's fine to just drop the values here
142                    unsafe { std::ptr::drop_in_place(column_names as *mut _) }
143                }
144            }
145        }
146    }
147}
148
149impl FunctionRow<'_> {
150    #[allow(unsafe_code)] // complicated ptr cast
151    fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
152        let lengths = args.len();
153        let args = unsafe {
154            Vec::from_raw_parts(
155                // This cast is safe because:
156                // * Casting from a pointer to an array to a pointer to the first array
157                // element is safe
158                // * Casting from a raw pointer to `NonNull<T>` is safe,
159                // because `NonNull` is #[repr(transparent)]
160                // * Casting from `NonNull<T>` to `OwnedSqliteValue` is safe,
161                // as the struct is `#[repr(transparent)]
162                // * Casting from `NonNull<T>` to `Option<NonNull<T>>` as the documentation
163                // states: "This is so that enums may use this forbidden value as a discriminant –
164                // Option<NonNull<T>> has the same size as *mut T"
165                // * The last point remains true for `OwnedSqliteValue` as `#[repr(transparent)]
166                // guarantees the same layout as the inner type
167                // * It's unsafe to drop the vector (and the vector elements)
168                // because of this we wrap the vector (or better the Row)
169                // Into `ManualDrop` to prevent the dropping
170                args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
171                    as *mut Option<OwnedSqliteValue>,
172                lengths,
173                lengths,
174            )
175        };
176
177        Self {
178            field_count: lengths,
179            args: Rc::new(RefCell::new(ManuallyDrop::new(
180                PrivateSqliteRow::Duplicated {
181                    values: args,
182                    column_names: Rc::from(vec![None; lengths]),
183                },
184            ))),
185            marker: PhantomData,
186        }
187    }
188}
189
190impl RowSealed for FunctionRow<'_> {}
191
192impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
193    type Field<'f>
194        = FunctionArgument<'f>
195    where
196        'a: 'f,
197        Self: 'f;
198    type InnerPartialRow = Self;
199
200    fn field_count(&self) -> usize {
201        self.field_count
202    }
203
204    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
205    where
206        'a: 'b,
207        Self: crate::row::RowIndex<I>,
208    {
209        let col_idx = self.idx(idx)?;
210        Some(FunctionArgument {
211            args: self.args.borrow(),
212            col_idx,
213        })
214    }
215
216    fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
217        PartialRow::new(self, range)
218    }
219}
220
221impl RowIndex<usize> for FunctionRow<'_> {
222    fn idx(&self, idx: usize) -> Option<usize> {
223        if idx < self.field_count() {
224            Some(idx)
225        } else {
226            None
227        }
228    }
229}
230
231impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
232    fn idx(&self, _idx: &'a str) -> Option<usize> {
233        None
234    }
235}
236
237struct FunctionArgument<'a> {
238    args: Ref<'a, ManuallyDrop<PrivateSqliteRow<'a, 'static>>>,
239    col_idx: usize,
240}
241
242impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
243    fn field_name(&self) -> Option<&str> {
244        None
245    }
246
247    fn is_null(&self) -> bool {
248        self.value().is_none()
249    }
250
251    fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
252        SqliteValue::new(
253            Ref::map(Ref::clone(&self.args), |drop| std::ops::Deref::deref(drop)),
254            self.col_idx,
255        )
256    }
257}