diesel/sqlite/connection/
functions.rs1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3
4#[cfg(all(target_family = "wasm", target_os = "unknown"))]
5use sqlite_wasm_rs as ffi;
6
7use super::raw::RawConnection;
8use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
9use crate::backend::Backend;
10use crate::deserialize::{FromSqlRow, StaticallySizedRow};
11use crate::result::{DatabaseErrorKind, Error, QueryResult};
12use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
13use crate::serialize::{IsNull, Output, ToSql};
14use crate::sql_types::HasSqlType;
15use crate::sqlite::SqliteValue;
16use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
17use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
18use alloc::boxed::Box;
19use alloc::string::ToString;
20
21pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
22 conn: &RawConnection,
23 fn_name: &str,
24 deterministic: bool,
25 mut f: F,
26) -> QueryResult<()>
27where
28 F: FnMut(&RawConnection, Args) -> Ret + core::panic::UnwindSafe + Send + 'static,
29 Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
30 Ret: ToSql<RetSqlType, Sqlite>,
31 Sqlite: HasSqlType<RetSqlType>,
32{
33 let fields_needed = Args::FIELD_COUNT;
34 if fields_needed > 127 {
35 return Err(Error::DatabaseError(
36 DatabaseErrorKind::UnableToSendCommand,
37 Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
38 ));
39 }
40
41 conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
42 let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
43
44 Ok(f(conn, args))
45 })?;
46 Ok(())
47}
48
49pub(super) fn register_noargs<RetSqlType, Ret, F>(
50 conn: &RawConnection,
51 fn_name: &str,
52 deterministic: bool,
53 mut f: F,
54) -> QueryResult<()>
55where
56 F: FnMut() -> Ret + core::panic::UnwindSafe + Send + 'static,
57 Ret: ToSql<RetSqlType, Sqlite>,
58 Sqlite: HasSqlType<RetSqlType>,
59{
60 conn.register_sql_function(fn_name, 0, deterministic, move |_, _| Ok(f()))?;
61 Ok(())
62}
63
64pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
65 conn: &RawConnection,
66 fn_name: &str,
67) -> QueryResult<()>
68where
69 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + core::panic::UnwindSafe,
70 Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
71 Ret: ToSql<RetSqlType, Sqlite>,
72 Sqlite: HasSqlType<RetSqlType>,
73{
74 let fields_needed = Args::FIELD_COUNT;
75 if fields_needed > 127 {
76 return Err(Error::DatabaseError(
77 DatabaseErrorKind::UnableToSendCommand,
78 Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
79 ));
80 }
81
82 conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
83 fn_name,
84 fields_needed,
85 )?;
86
87 Ok(())
88}
89
90pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
91 args: &mut [*mut ffi::sqlite3_value],
92) -> Result<Args, Error>
93where
94 Args: FromSqlRow<ArgsSqlType, Sqlite>,
95{
96 let row = FunctionRow::new(args);
97 Args::build_from_row(&row).map_err(Error::DeserializationError)
98}
99
100#[allow(clippy::let_unit_value)]
103pub(super) fn process_sql_function_result<RetSqlType, Ret>(
104 result: &'_ Ret,
105) -> QueryResult<InternalSqliteBindValue<'_>>
106where
107 Ret: ToSql<RetSqlType, Sqlite>,
108 Sqlite: HasSqlType<RetSqlType>,
109{
110 let mut metadata_lookup = ();
111 let value = SqliteBindValue {
112 inner: InternalSqliteBindValue::Null,
113 };
114 let mut buf = Output::new(value, &mut metadata_lookup);
115 let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
116
117 if let IsNull::Yes = is_null {
118 Ok(InternalSqliteBindValue::Null)
119 } else {
120 Ok(buf.into_inner().inner)
121 }
122}
123
124struct FunctionRow<'a> {
125 args: &'a [Option<OwnedSqliteValue>],
126 field_count: usize,
127}
128
129impl FunctionRow<'_> {
130 #[allow(unsafe_code)] fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
132 let lengths = args.len();
133 let args = unsafe {
134 core::slice::from_raw_parts(
135 args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
148 as *mut Option<OwnedSqliteValue>,
149 lengths,
150 )
151 };
152
153 Self {
154 field_count: lengths,
155 args,
156 }
157 }
158}
159
160impl RowSealed for FunctionRow<'_> {}
161
162impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
163 type Field<'f>
164 = FunctionArgument<'f>
165 where
166 'a: 'f,
167 Self: 'f;
168 type InnerPartialRow = Self;
169
170 fn field_count(&self) -> usize {
171 self.field_count
172 }
173
174 fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
175 where
176 'a: 'b,
177 Self: crate::row::RowIndex<I>,
178 {
179 let col_idx = self.idx(idx)?;
180 Some(FunctionArgument {
181 args: self.args,
182 col_idx,
183 })
184 }
185
186 fn partial_row(&self, range: core::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
187 PartialRow::new(self, range)
188 }
189}
190
191impl RowIndex<usize> for FunctionRow<'_> {
192 fn idx(&self, idx: usize) -> Option<usize> {
193 if idx < self.field_count() {
194 Some(idx)
195 } else {
196 None
197 }
198 }
199}
200
201impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
202 fn idx(&self, _idx: &'a str) -> Option<usize> {
203 None
204 }
205}
206
207struct FunctionArgument<'a> {
208 args: &'a [Option<OwnedSqliteValue>],
209 col_idx: usize,
210}
211
212impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
213 fn field_name(&self) -> Option<&str> {
214 None
215 }
216
217 fn is_null(&self) -> bool {
218 self.value().is_none()
219 }
220
221 fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
222 SqliteValue::from_function_row(self.args, self.col_idx)
223 }
224}