diesel/sqlite/connection/
functions.rs1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3
4#[cfg(all(target_family = "wasm", target_os = "unknown"))]
5use sqlite_wasm_rs as ffi;
6
7use super::raw::RawConnection;
8use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
9use crate::backend::Backend;
10use crate::deserialize::{FromSqlRow, StaticallySizedRow};
11use crate::result::{DatabaseErrorKind, Error, QueryResult};
12use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
13use crate::serialize::{IsNull, Output, ToSql};
14use crate::sql_types::HasSqlType;
15use crate::sqlite::SqliteFunctionBehavior;
16use crate::sqlite::SqliteValue;
17use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
18use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
19use alloc::boxed::Box;
20use alloc::string::ToString;
21
22pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
23 conn: &RawConnection,
24 fn_name: &str,
25 behavior: SqliteFunctionBehavior,
26 mut f: F,
27) -> QueryResult<()>
28where
29 F: FnMut(&RawConnection, Args) -> Ret + core::panic::UnwindSafe + Send + 'static,
30 Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
31 Ret: ToSql<RetSqlType, Sqlite>,
32 Sqlite: HasSqlType<RetSqlType>,
33{
34 let fields_needed = Args::FIELD_COUNT;
35 if fields_needed > 127 {
36 return Err(Error::DatabaseError(
37 DatabaseErrorKind::UnableToSendCommand,
38 Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
39 ));
40 }
41
42 conn.register_sql_function(fn_name, fields_needed, behavior, move |conn, args| {
43 let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
44
45 Ok(f(conn, args))
46 })?;
47 Ok(())
48}
49
50pub(super) fn register_noargs<RetSqlType, Ret, F>(
51 conn: &RawConnection,
52 fn_name: &str,
53 behavior: SqliteFunctionBehavior,
54 mut f: F,
55) -> QueryResult<()>
56where
57 F: FnMut() -> Ret + core::panic::UnwindSafe + Send + 'static,
58 Ret: ToSql<RetSqlType, Sqlite>,
59 Sqlite: HasSqlType<RetSqlType>,
60{
61 conn.register_sql_function(fn_name, 0, behavior, move |_, _| Ok(f()))?;
62 Ok(())
63}
64
65pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
66 conn: &RawConnection,
67 fn_name: &str,
68 behavior: SqliteFunctionBehavior,
69) -> QueryResult<()>
70where
71 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + core::panic::UnwindSafe,
72 Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
73 Ret: ToSql<RetSqlType, Sqlite>,
74 Sqlite: HasSqlType<RetSqlType>,
75{
76 let fields_needed = Args::FIELD_COUNT;
77 if fields_needed > 127 {
78 return Err(Error::DatabaseError(
79 DatabaseErrorKind::UnableToSendCommand,
80 Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
81 ));
82 }
83
84 conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
85 fn_name,
86 fields_needed,
87 behavior,
88 )?;
89
90 Ok(())
91}
92
93pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
94 args: &mut [*mut ffi::sqlite3_value],
95) -> Result<Args, Error>
96where
97 Args: FromSqlRow<ArgsSqlType, Sqlite>,
98{
99 let row = FunctionRow::new(args);
100 Args::build_from_row(&row).map_err(Error::DeserializationError)
101}
102
103#[allow(clippy::let_unit_value)]
106pub(super) fn process_sql_function_result<RetSqlType, Ret>(
107 result: &'_ Ret,
108) -> QueryResult<InternalSqliteBindValue<'_>>
109where
110 Ret: ToSql<RetSqlType, Sqlite>,
111 Sqlite: HasSqlType<RetSqlType>,
112{
113 let mut metadata_lookup = ();
114 let value = SqliteBindValue {
115 inner: InternalSqliteBindValue::Null,
116 };
117 let mut buf = Output::new(value, &mut metadata_lookup);
118 let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
119
120 if let IsNull::Yes = is_null {
121 Ok(InternalSqliteBindValue::Null)
122 } else {
123 Ok(buf.into_inner().inner)
124 }
125}
126
127struct FunctionRow<'a> {
128 args: &'a [Option<OwnedSqliteValue>],
129 field_count: usize,
130}
131
132impl FunctionRow<'_> {
133 #[allow(unsafe_code)] fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
135 let lengths = args.len();
136 let args = unsafe {
137 core::slice::from_raw_parts(
138 args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
151 as *mut Option<OwnedSqliteValue>,
152 lengths,
153 )
154 };
155
156 Self {
157 field_count: lengths,
158 args,
159 }
160 }
161}
162
163impl RowSealed for FunctionRow<'_> {}
164
165impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
166 type Field<'f>
167 = FunctionArgument<'f>
168 where
169 'a: 'f,
170 Self: 'f;
171 type InnerPartialRow = Self;
172
173 fn field_count(&self) -> usize {
174 self.field_count
175 }
176
177 fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
178 where
179 'a: 'b,
180 Self: crate::row::RowIndex<I>,
181 {
182 let col_idx = self.idx(idx)?;
183 Some(FunctionArgument {
184 args: self.args,
185 col_idx,
186 })
187 }
188
189 fn partial_row(&self, range: core::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
190 PartialRow::new(self, range)
191 }
192}
193
194impl RowIndex<usize> for FunctionRow<'_> {
195 fn idx(&self, idx: usize) -> Option<usize> {
196 if idx < self.field_count() {
197 Some(idx)
198 } else {
199 None
200 }
201 }
202}
203
204impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
205 fn idx(&self, _idx: &'a str) -> Option<usize> {
206 None
207 }
208}
209
210struct FunctionArgument<'a> {
211 args: &'a [Option<OwnedSqliteValue>],
212 col_idx: usize,
213}
214
215impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
216 fn field_name(&self) -> Option<&str> {
217 None
218 }
219
220 fn is_null(&self) -> bool {
221 self.value().is_none()
222 }
223
224 fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
225 SqliteValue::from_function_row(self.args, self.col_idx)
226 }
227}