Skip to main content

diesel/sqlite/connection/
functions.rs

1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3
4#[cfg(all(target_family = "wasm", target_os = "unknown"))]
5use sqlite_wasm_rs as ffi;
6
7use super::raw::RawConnection;
8use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
9use crate::backend::Backend;
10use crate::deserialize::{FromSqlRow, StaticallySizedRow};
11use crate::result::{DatabaseErrorKind, Error, QueryResult};
12use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
13use crate::serialize::{IsNull, Output, ToSql};
14use crate::sql_types::HasSqlType;
15use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
16use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
17use crate::sqlite::SqliteValue;
18
19pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
20    conn: &RawConnection,
21    fn_name: &str,
22    deterministic: bool,
23    mut f: F,
24) -> QueryResult<()>
25where
26    F: FnMut(&RawConnection, Args) -> Ret + std::panic::UnwindSafe + Send + 'static,
27    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
28    Ret: ToSql<RetSqlType, Sqlite>,
29    Sqlite: HasSqlType<RetSqlType>,
30{
31    let fields_needed = Args::FIELD_COUNT;
32    if fields_needed > 127 {
33        return Err(Error::DatabaseError(
34            DatabaseErrorKind::UnableToSendCommand,
35            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
36        ));
37    }
38
39    conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
40        let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
41
42        Ok(f(conn, args))
43    })?;
44    Ok(())
45}
46
47pub(super) fn register_noargs<RetSqlType, Ret, F>(
48    conn: &RawConnection,
49    fn_name: &str,
50    deterministic: bool,
51    mut f: F,
52) -> QueryResult<()>
53where
54    F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
55    Ret: ToSql<RetSqlType, Sqlite>,
56    Sqlite: HasSqlType<RetSqlType>,
57{
58    conn.register_sql_function(fn_name, 0, deterministic, move |_, _| Ok(f()))?;
59    Ok(())
60}
61
62pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
63    conn: &RawConnection,
64    fn_name: &str,
65) -> QueryResult<()>
66where
67    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
68    Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
69    Ret: ToSql<RetSqlType, Sqlite>,
70    Sqlite: HasSqlType<RetSqlType>,
71{
72    let fields_needed = Args::FIELD_COUNT;
73    if fields_needed > 127 {
74        return Err(Error::DatabaseError(
75            DatabaseErrorKind::UnableToSendCommand,
76            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
77        ));
78    }
79
80    conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
81        fn_name,
82        fields_needed,
83    )?;
84
85    Ok(())
86}
87
88pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
89    args: &mut [*mut ffi::sqlite3_value],
90) -> Result<Args, Error>
91where
92    Args: FromSqlRow<ArgsSqlType, Sqlite>,
93{
94    let row = FunctionRow::new(args);
95    Args::build_from_row(&row).map_err(Error::DeserializationError)
96}
97
98// clippy is wrong here, the let binding is required
99// for lifetime reasons
100#[allow(clippy::let_unit_value)]
101pub(super) fn process_sql_function_result<RetSqlType, Ret>(
102    result: &'_ Ret,
103) -> QueryResult<InternalSqliteBindValue<'_>>
104where
105    Ret: ToSql<RetSqlType, Sqlite>,
106    Sqlite: HasSqlType<RetSqlType>,
107{
108    let mut metadata_lookup = ();
109    let value = SqliteBindValue {
110        inner: InternalSqliteBindValue::Null,
111    };
112    let mut buf = Output::new(value, &mut metadata_lookup);
113    let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
114
115    if let IsNull::Yes = is_null {
116        Ok(InternalSqliteBindValue::Null)
117    } else {
118        Ok(buf.into_inner().inner)
119    }
120}
121
122struct FunctionRow<'a> {
123    args: &'a [Option<OwnedSqliteValue>],
124    field_count: usize,
125}
126
127impl FunctionRow<'_> {
128    #[allow(unsafe_code)] // complicated ptr cast
129    fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
130        let lengths = args.len();
131        let args = unsafe {
132            core::slice::from_raw_parts(
133                // This cast is safe because:
134                // * Casting from a pointer to an array to a pointer to the first array
135                // element is safe
136                // * Casting from a raw pointer to `NonNull<T>` is safe,
137                // because `NonNull` is #[repr(transparent)]
138                // * Casting from `NonNull<T>` to `OwnedSqliteValue` is safe,
139                // as the struct is `#[repr(transparent)]
140                // * Casting from `NonNull<T>` to `Option<NonNull<T>>` as the documentation
141                // states: "This is so that enums may use this forbidden value as a discriminant –
142                // Option<NonNull<T>> has the same size as *mut T"
143                // * The last point remains true for `OwnedSqliteValue` as `#[repr(transparent)]
144                // guarantees the same layout as the inner type
145                args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
146                    as *mut Option<OwnedSqliteValue>,
147                lengths,
148            )
149        };
150
151        Self {
152            field_count: lengths,
153            args,
154        }
155    }
156}
157
158impl RowSealed for FunctionRow<'_> {}
159
160impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
161    type Field<'f>
162        = FunctionArgument<'f>
163    where
164        'a: 'f,
165        Self: 'f;
166    type InnerPartialRow = Self;
167
168    fn field_count(&self) -> usize {
169        self.field_count
170    }
171
172    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
173    where
174        'a: 'b,
175        Self: crate::row::RowIndex<I>,
176    {
177        let col_idx = self.idx(idx)?;
178        Some(FunctionArgument {
179            args: self.args,
180            col_idx,
181        })
182    }
183
184    fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
185        PartialRow::new(self, range)
186    }
187}
188
189impl RowIndex<usize> for FunctionRow<'_> {
190    fn idx(&self, idx: usize) -> Option<usize> {
191        if idx < self.field_count() {
192            Some(idx)
193        } else {
194            None
195        }
196    }
197}
198
199impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
200    fn idx(&self, _idx: &'a str) -> Option<usize> {
201        None
202    }
203}
204
205struct FunctionArgument<'a> {
206    args: &'a [Option<OwnedSqliteValue>],
207    col_idx: usize,
208}
209
210impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
211    fn field_name(&self) -> Option<&str> {
212        None
213    }
214
215    fn is_null(&self) -> bool {
216        self.value().is_none()
217    }
218
219    fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
220        SqliteValue::from_function_row(self.args, self.col_idx)
221    }
222}