Skip to main content

diesel/sqlite/connection/
functions.rs

1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3
4#[cfg(all(target_family = "wasm", target_os = "unknown"))]
5use sqlite_wasm_rs as ffi;
6
7use super::raw::RawConnection;
8use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
9use crate::backend::Backend;
10use crate::deserialize::{FromSqlRow, StaticallySizedRow};
11use crate::result::{DatabaseErrorKind, Error, QueryResult};
12use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
13use crate::serialize::{IsNull, Output, ToSql};
14use crate::sql_types::HasSqlType;
15use crate::sqlite::SqliteFunctionBehavior;
16use crate::sqlite::SqliteValue;
17use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
18use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
19use alloc::boxed::Box;
20use alloc::string::ToString;
21
22pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
23    conn: &RawConnection,
24    fn_name: &str,
25    behavior: SqliteFunctionBehavior,
26    mut f: F,
27) -> QueryResult<()>
28where
29    F: FnMut(&RawConnection, Args) -> Ret + core::panic::UnwindSafe + Send + 'static,
30    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
31    Ret: ToSql<RetSqlType, Sqlite>,
32    Sqlite: HasSqlType<RetSqlType>,
33{
34    let fields_needed = Args::FIELD_COUNT;
35    if fields_needed > 127 {
36        return Err(Error::DatabaseError(
37            DatabaseErrorKind::UnableToSendCommand,
38            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
39        ));
40    }
41
42    conn.register_sql_function(fn_name, fields_needed, behavior, move |conn, args| {
43        let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
44
45        Ok(f(conn, args))
46    })?;
47    Ok(())
48}
49
50pub(super) fn register_noargs<RetSqlType, Ret, F>(
51    conn: &RawConnection,
52    fn_name: &str,
53    behavior: SqliteFunctionBehavior,
54    mut f: F,
55) -> QueryResult<()>
56where
57    F: FnMut() -> Ret + core::panic::UnwindSafe + Send + 'static,
58    Ret: ToSql<RetSqlType, Sqlite>,
59    Sqlite: HasSqlType<RetSqlType>,
60{
61    conn.register_sql_function(fn_name, 0, behavior, move |_, _| Ok(f()))?;
62    Ok(())
63}
64
65pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
66    conn: &RawConnection,
67    fn_name: &str,
68    behavior: SqliteFunctionBehavior,
69) -> QueryResult<()>
70where
71    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + core::panic::UnwindSafe,
72    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
73    Ret: ToSql<RetSqlType, Sqlite>,
74    Sqlite: HasSqlType<RetSqlType>,
75{
76    let fields_needed = Args::FIELD_COUNT;
77    if fields_needed > 127 {
78        return Err(Error::DatabaseError(
79            DatabaseErrorKind::UnableToSendCommand,
80            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
81        ));
82    }
83
84    conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
85        fn_name,
86        fields_needed,
87        behavior,
88    )?;
89
90    Ok(())
91}
92
93pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
94    args: &mut [*mut ffi::sqlite3_value],
95) -> Result<Args, Error>
96where
97    Args: FromSqlRow<ArgsSqlType, Sqlite>,
98{
99    let row = FunctionRow::new(args);
100    Args::build_from_row(&row).map_err(Error::DeserializationError)
101}
102
103// clippy is wrong here, the let binding is required
104// for lifetime reasons
105#[allow(clippy::let_unit_value)]
106pub(super) fn process_sql_function_result<RetSqlType, Ret>(
107    result: &'_ Ret,
108) -> QueryResult<InternalSqliteBindValue<'_>>
109where
110    Ret: ToSql<RetSqlType, Sqlite>,
111    Sqlite: HasSqlType<RetSqlType>,
112{
113    let mut metadata_lookup = ();
114    let value = SqliteBindValue {
115        inner: InternalSqliteBindValue::Null,
116    };
117    let mut buf = Output::new(value, &mut metadata_lookup);
118    let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
119
120    if let IsNull::Yes = is_null {
121        Ok(InternalSqliteBindValue::Null)
122    } else {
123        Ok(buf.into_inner().inner)
124    }
125}
126
127struct FunctionRow<'a> {
128    args: &'a [Option<OwnedSqliteValue>],
129    field_count: usize,
130}
131
132impl FunctionRow<'_> {
133    #[allow(unsafe_code)] // complicated ptr cast
134    fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
135        let lengths = args.len();
136        let args = unsafe {
137            core::slice::from_raw_parts(
138                // This cast is safe because:
139                // * Casting from a pointer to an array to a pointer to the first array
140                // element is safe
141                // * Casting from a raw pointer to `NonNull<T>` is safe,
142                // because `NonNull` is #[repr(transparent)]
143                // * Casting from `NonNull<T>` to `OwnedSqliteValue` is safe,
144                // as the struct is `#[repr(transparent)]
145                // * Casting from `NonNull<T>` to `Option<NonNull<T>>` as the documentation
146                // states: "This is so that enums may use this forbidden value as a discriminant –
147                // Option<NonNull<T>> has the same size as *mut T"
148                // * The last point remains true for `OwnedSqliteValue` as `#[repr(transparent)]
149                // guarantees the same layout as the inner type
150                args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
151                    as *mut Option<OwnedSqliteValue>,
152                lengths,
153            )
154        };
155
156        Self {
157            field_count: lengths,
158            args,
159        }
160    }
161}
162
163impl RowSealed for FunctionRow<'_> {}
164
165impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
166    type Field<'f>
167        = FunctionArgument<'f>
168    where
169        'a: 'f,
170        Self: 'f;
171    type InnerPartialRow = Self;
172
173    fn field_count(&self) -> usize {
174        self.field_count
175    }
176
177    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
178    where
179        'a: 'b,
180        Self: crate::row::RowIndex<I>,
181    {
182        let col_idx = self.idx(idx)?;
183        Some(FunctionArgument {
184            args: self.args,
185            col_idx,
186        })
187    }
188
189    fn partial_row(&self, range: core::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
190        PartialRow::new(self, range)
191    }
192}
193
194impl RowIndex<usize> for FunctionRow<'_> {
195    fn idx(&self, idx: usize) -> Option<usize> {
196        if idx < self.field_count() {
197            Some(idx)
198        } else {
199            None
200        }
201    }
202}
203
204impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
205    fn idx(&self, _idx: &'a str) -> Option<usize> {
206        None
207    }
208}
209
210struct FunctionArgument<'a> {
211    args: &'a [Option<OwnedSqliteValue>],
212    col_idx: usize,
213}
214
215impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
216    fn field_name(&self) -> Option<&str> {
217        None
218    }
219
220    fn is_null(&self) -> bool {
221        self.value().is_none()
222    }
223
224    fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
225        SqliteValue::from_function_row(self.args, self.col_idx)
226    }
227}