Skip to main content

diesel/sqlite/connection/
functions.rs

1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3
4#[cfg(all(target_family = "wasm", target_os = "unknown"))]
5use sqlite_wasm_rs as ffi;
6
7use super::raw::RawConnection;
8use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
9use crate::backend::Backend;
10use crate::deserialize::{FromSqlRow, StaticallySizedRow};
11use crate::result::{DatabaseErrorKind, Error, QueryResult};
12use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
13use crate::serialize::{IsNull, Output, ToSql};
14use crate::sql_types::HasSqlType;
15use crate::sqlite::SqliteValue;
16use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
17use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
18use alloc::boxed::Box;
19use alloc::string::ToString;
20
21pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
22    conn: &RawConnection,
23    fn_name: &str,
24    deterministic: bool,
25    mut f: F,
26) -> QueryResult<()>
27where
28    F: FnMut(&RawConnection, Args) -> Ret + core::panic::UnwindSafe + Send + 'static,
29    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
30    Ret: ToSql<RetSqlType, Sqlite>,
31    Sqlite: HasSqlType<RetSqlType>,
32{
33    let fields_needed = Args::FIELD_COUNT;
34    if fields_needed > 127 {
35        return Err(Error::DatabaseError(
36            DatabaseErrorKind::UnableToSendCommand,
37            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
38        ));
39    }
40
41    conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
42        let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
43
44        Ok(f(conn, args))
45    })?;
46    Ok(())
47}
48
49pub(super) fn register_noargs<RetSqlType, Ret, F>(
50    conn: &RawConnection,
51    fn_name: &str,
52    deterministic: bool,
53    mut f: F,
54) -> QueryResult<()>
55where
56    F: FnMut() -> Ret + core::panic::UnwindSafe + Send + 'static,
57    Ret: ToSql<RetSqlType, Sqlite>,
58    Sqlite: HasSqlType<RetSqlType>,
59{
60    conn.register_sql_function(fn_name, 0, deterministic, move |_, _| Ok(f()))?;
61    Ok(())
62}
63
64pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
65    conn: &RawConnection,
66    fn_name: &str,
67) -> QueryResult<()>
68where
69    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + core::panic::UnwindSafe,
70    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
71    Ret: ToSql<RetSqlType, Sqlite>,
72    Sqlite: HasSqlType<RetSqlType>,
73{
74    let fields_needed = Args::FIELD_COUNT;
75    if fields_needed > 127 {
76        return Err(Error::DatabaseError(
77            DatabaseErrorKind::UnableToSendCommand,
78            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
79        ));
80    }
81
82    conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
83        fn_name,
84        fields_needed,
85    )?;
86
87    Ok(())
88}
89
90pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
91    args: &mut [*mut ffi::sqlite3_value],
92) -> Result<Args, Error>
93where
94    Args: FromSqlRow<ArgsSqlType, Sqlite>,
95{
96    let row = FunctionRow::new(args);
97    Args::build_from_row(&row).map_err(Error::DeserializationError)
98}
99
100// clippy is wrong here, the let binding is required
101// for lifetime reasons
102#[allow(clippy::let_unit_value)]
103pub(super) fn process_sql_function_result<RetSqlType, Ret>(
104    result: &'_ Ret,
105) -> QueryResult<InternalSqliteBindValue<'_>>
106where
107    Ret: ToSql<RetSqlType, Sqlite>,
108    Sqlite: HasSqlType<RetSqlType>,
109{
110    let mut metadata_lookup = ();
111    let value = SqliteBindValue {
112        inner: InternalSqliteBindValue::Null,
113    };
114    let mut buf = Output::new(value, &mut metadata_lookup);
115    let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
116
117    if let IsNull::Yes = is_null {
118        Ok(InternalSqliteBindValue::Null)
119    } else {
120        Ok(buf.into_inner().inner)
121    }
122}
123
124struct FunctionRow<'a> {
125    args: &'a [Option<OwnedSqliteValue>],
126    field_count: usize,
127}
128
129impl FunctionRow<'_> {
130    #[allow(unsafe_code)] // complicated ptr cast
131    fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
132        let lengths = args.len();
133        let args = unsafe {
134            core::slice::from_raw_parts(
135                // This cast is safe because:
136                // * Casting from a pointer to an array to a pointer to the first array
137                // element is safe
138                // * Casting from a raw pointer to `NonNull<T>` is safe,
139                // because `NonNull` is #[repr(transparent)]
140                // * Casting from `NonNull<T>` to `OwnedSqliteValue` is safe,
141                // as the struct is `#[repr(transparent)]
142                // * Casting from `NonNull<T>` to `Option<NonNull<T>>` as the documentation
143                // states: "This is so that enums may use this forbidden value as a discriminant –
144                // Option<NonNull<T>> has the same size as *mut T"
145                // * The last point remains true for `OwnedSqliteValue` as `#[repr(transparent)]
146                // guarantees the same layout as the inner type
147                args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
148                    as *mut Option<OwnedSqliteValue>,
149                lengths,
150            )
151        };
152
153        Self {
154            field_count: lengths,
155            args,
156        }
157    }
158}
159
160impl RowSealed for FunctionRow<'_> {}
161
162impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
163    type Field<'f>
164        = FunctionArgument<'f>
165    where
166        'a: 'f,
167        Self: 'f;
168    type InnerPartialRow = Self;
169
170    fn field_count(&self) -> usize {
171        self.field_count
172    }
173
174    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
175    where
176        'a: 'b,
177        Self: crate::row::RowIndex<I>,
178    {
179        let col_idx = self.idx(idx)?;
180        Some(FunctionArgument {
181            args: self.args,
182            col_idx,
183        })
184    }
185
186    fn partial_row(&self, range: core::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
187        PartialRow::new(self, range)
188    }
189}
190
191impl RowIndex<usize> for FunctionRow<'_> {
192    fn idx(&self, idx: usize) -> Option<usize> {
193        if idx < self.field_count() {
194            Some(idx)
195        } else {
196            None
197        }
198    }
199}
200
201impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
202    fn idx(&self, _idx: &'a str) -> Option<usize> {
203        None
204    }
205}
206
207struct FunctionArgument<'a> {
208    args: &'a [Option<OwnedSqliteValue>],
209    col_idx: usize,
210}
211
212impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
213    fn field_name(&self) -> Option<&str> {
214        None
215    }
216
217    fn is_null(&self) -> bool {
218        self.value().is_none()
219    }
220
221    fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
222        SqliteValue::from_function_row(self.args, self.col_idx)
223    }
224}