diesel/sqlite/connection/
functions.rs1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3
4#[cfg(all(target_family = "wasm", target_os = "unknown"))]
5use sqlite_wasm_rs as ffi;
6
7use super::raw::RawConnection;
8use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
9use crate::backend::Backend;
10use crate::deserialize::{FromSqlRow, StaticallySizedRow};
11use crate::result::{DatabaseErrorKind, Error, QueryResult};
12use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
13use crate::serialize::{IsNull, Output, ToSql};
14use crate::sql_types::HasSqlType;
15use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
16use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
17use crate::sqlite::SqliteValue;
18
19pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
20 conn: &RawConnection,
21 fn_name: &str,
22 deterministic: bool,
23 mut f: F,
24) -> QueryResult<()>
25where
26 F: FnMut(&RawConnection, Args) -> Ret + std::panic::UnwindSafe + Send + 'static,
27 Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
28 Ret: ToSql<RetSqlType, Sqlite>,
29 Sqlite: HasSqlType<RetSqlType>,
30{
31 let fields_needed = Args::FIELD_COUNT;
32 if fields_needed > 127 {
33 return Err(Error::DatabaseError(
34 DatabaseErrorKind::UnableToSendCommand,
35 Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
36 ));
37 }
38
39 conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
40 let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
41
42 Ok(f(conn, args))
43 })?;
44 Ok(())
45}
46
47pub(super) fn register_noargs<RetSqlType, Ret, F>(
48 conn: &RawConnection,
49 fn_name: &str,
50 deterministic: bool,
51 mut f: F,
52) -> QueryResult<()>
53where
54 F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
55 Ret: ToSql<RetSqlType, Sqlite>,
56 Sqlite: HasSqlType<RetSqlType>,
57{
58 conn.register_sql_function(fn_name, 0, deterministic, move |_, _| Ok(f()))?;
59 Ok(())
60}
61
62pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
63 conn: &RawConnection,
64 fn_name: &str,
65) -> QueryResult<()>
66where
67 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
68 Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
69 Ret: ToSql<RetSqlType, Sqlite>,
70 Sqlite: HasSqlType<RetSqlType>,
71{
72 let fields_needed = Args::FIELD_COUNT;
73 if fields_needed > 127 {
74 return Err(Error::DatabaseError(
75 DatabaseErrorKind::UnableToSendCommand,
76 Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
77 ));
78 }
79
80 conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
81 fn_name,
82 fields_needed,
83 )?;
84
85 Ok(())
86}
87
88pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
89 args: &mut [*mut ffi::sqlite3_value],
90) -> Result<Args, Error>
91where
92 Args: FromSqlRow<ArgsSqlType, Sqlite>,
93{
94 let row = FunctionRow::new(args);
95 Args::build_from_row(&row).map_err(Error::DeserializationError)
96}
97
98#[allow(clippy::let_unit_value)]
101pub(super) fn process_sql_function_result<RetSqlType, Ret>(
102 result: &'_ Ret,
103) -> QueryResult<InternalSqliteBindValue<'_>>
104where
105 Ret: ToSql<RetSqlType, Sqlite>,
106 Sqlite: HasSqlType<RetSqlType>,
107{
108 let mut metadata_lookup = ();
109 let value = SqliteBindValue {
110 inner: InternalSqliteBindValue::Null,
111 };
112 let mut buf = Output::new(value, &mut metadata_lookup);
113 let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
114
115 if let IsNull::Yes = is_null {
116 Ok(InternalSqliteBindValue::Null)
117 } else {
118 Ok(buf.into_inner().inner)
119 }
120}
121
122struct FunctionRow<'a> {
123 args: &'a [Option<OwnedSqliteValue>],
124 field_count: usize,
125}
126
127impl FunctionRow<'_> {
128 #[allow(unsafe_code)] fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
130 let lengths = args.len();
131 let args = unsafe {
132 core::slice::from_raw_parts(
133 args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
146 as *mut Option<OwnedSqliteValue>,
147 lengths,
148 )
149 };
150
151 Self {
152 field_count: lengths,
153 args,
154 }
155 }
156}
157
158impl RowSealed for FunctionRow<'_> {}
159
160impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
161 type Field<'f>
162 = FunctionArgument<'f>
163 where
164 'a: 'f,
165 Self: 'f;
166 type InnerPartialRow = Self;
167
168 fn field_count(&self) -> usize {
169 self.field_count
170 }
171
172 fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
173 where
174 'a: 'b,
175 Self: crate::row::RowIndex<I>,
176 {
177 let col_idx = self.idx(idx)?;
178 Some(FunctionArgument {
179 args: self.args,
180 col_idx,
181 })
182 }
183
184 fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
185 PartialRow::new(self, range)
186 }
187}
188
189impl RowIndex<usize> for FunctionRow<'_> {
190 fn idx(&self, idx: usize) -> Option<usize> {
191 if idx < self.field_count() {
192 Some(idx)
193 } else {
194 None
195 }
196 }
197}
198
199impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
200 fn idx(&self, _idx: &'a str) -> Option<usize> {
201 None
202 }
203}
204
205struct FunctionArgument<'a> {
206 args: &'a [Option<OwnedSqliteValue>],
207 col_idx: usize,
208}
209
210impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
211 fn field_name(&self) -> Option<&str> {
212 None
213 }
214
215 fn is_null(&self) -> bool {
216 self.value().is_none()
217 }
218
219 fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
220 SqliteValue::from_function_row(self.args, self.col_idx)
221 }
222}