diesel/sqlite/connection/
functions.rs
1#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
2extern crate libsqlite3_sys as ffi;
3
4#[cfg(all(target_family = "wasm", target_os = "unknown"))]
5use sqlite_wasm_rs::export as ffi;
6
7use super::raw::RawConnection;
8use super::row::PrivateSqliteRow;
9use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
10use crate::backend::Backend;
11use crate::deserialize::{FromSqlRow, StaticallySizedRow};
12use crate::result::{DatabaseErrorKind, Error, QueryResult};
13use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
14use crate::serialize::{IsNull, Output, ToSql};
15use crate::sql_types::HasSqlType;
16use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
17use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
18use crate::sqlite::SqliteValue;
19use std::cell::{Ref, RefCell};
20use std::marker::PhantomData;
21use std::mem::ManuallyDrop;
22use std::ops::DerefMut;
23use std::rc::Rc;
24
25pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
26 conn: &RawConnection,
27 fn_name: &str,
28 deterministic: bool,
29 mut f: F,
30) -> QueryResult<()>
31where
32 F: FnMut(&RawConnection, Args) -> Ret + std::panic::UnwindSafe + Send + 'static,
33 Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
34 Ret: ToSql<RetSqlType, Sqlite>,
35 Sqlite: HasSqlType<RetSqlType>,
36{
37 let fields_needed = Args::FIELD_COUNT;
38 if fields_needed > 127 {
39 return Err(Error::DatabaseError(
40 DatabaseErrorKind::UnableToSendCommand,
41 Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
42 ));
43 }
44
45 conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
46 let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
47
48 Ok(f(conn, args))
49 })?;
50 Ok(())
51}
52
53pub(super) fn register_noargs<RetSqlType, Ret, F>(
54 conn: &RawConnection,
55 fn_name: &str,
56 deterministic: bool,
57 mut f: F,
58) -> QueryResult<()>
59where
60 F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
61 Ret: ToSql<RetSqlType, Sqlite>,
62 Sqlite: HasSqlType<RetSqlType>,
63{
64 conn.register_sql_function(fn_name, 0, deterministic, move |_, _| Ok(f()))?;
65 Ok(())
66}
67
68pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
69 conn: &RawConnection,
70 fn_name: &str,
71) -> QueryResult<()>
72where
73 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
74 Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
75 Ret: ToSql<RetSqlType, Sqlite>,
76 Sqlite: HasSqlType<RetSqlType>,
77{
78 let fields_needed = Args::FIELD_COUNT;
79 if fields_needed > 127 {
80 return Err(Error::DatabaseError(
81 DatabaseErrorKind::UnableToSendCommand,
82 Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
83 ));
84 }
85
86 conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
87 fn_name,
88 fields_needed,
89 )?;
90
91 Ok(())
92}
93
94pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
95 args: &mut [*mut ffi::sqlite3_value],
96) -> Result<Args, Error>
97where
98 Args: FromSqlRow<ArgsSqlType, Sqlite>,
99{
100 let row = FunctionRow::new(args);
101 Args::build_from_row(&row).map_err(Error::DeserializationError)
102}
103
104#[allow(clippy::let_unit_value)]
107pub(super) fn process_sql_function_result<RetSqlType, Ret>(
108 result: &'_ Ret,
109) -> QueryResult<InternalSqliteBindValue<'_>>
110where
111 Ret: ToSql<RetSqlType, Sqlite>,
112 Sqlite: HasSqlType<RetSqlType>,
113{
114 let mut metadata_lookup = ();
115 let value = SqliteBindValue {
116 inner: InternalSqliteBindValue::Null,
117 };
118 let mut buf = Output::new(value, &mut metadata_lookup);
119 let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
120
121 if let IsNull::Yes = is_null {
122 Ok(InternalSqliteBindValue::Null)
123 } else {
124 Ok(buf.into_inner().inner)
125 }
126}
127
128struct FunctionRow<'a> {
129 args: Rc<RefCell<ManuallyDrop<PrivateSqliteRow<'a, 'static>>>>,
132 field_count: usize,
133 marker: PhantomData<&'a ffi::sqlite3_value>,
134}
135
136impl Drop for FunctionRow<'_> {
137 #[allow(unsafe_code)] fn drop(&mut self) {
139 if let Some(args) = Rc::get_mut(&mut self.args) {
140 if let PrivateSqliteRow::Duplicated { column_names, .. } =
141 DerefMut::deref_mut(RefCell::get_mut(args))
142 {
143 if Rc::strong_count(column_names) == 1 {
144 unsafe { std::ptr::drop_in_place(column_names as *mut _) }
147 }
148 }
149 }
150 }
151}
152
153impl FunctionRow<'_> {
154 #[allow(unsafe_code)] fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
156 let lengths = args.len();
157 let args = unsafe {
158 Vec::from_raw_parts(
159 args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
175 as *mut Option<OwnedSqliteValue>,
176 lengths,
177 lengths,
178 )
179 };
180
181 Self {
182 field_count: lengths,
183 args: Rc::new(RefCell::new(ManuallyDrop::new(
184 PrivateSqliteRow::Duplicated {
185 values: args,
186 column_names: Rc::from(vec![None; lengths]),
187 },
188 ))),
189 marker: PhantomData,
190 }
191 }
192}
193
194impl RowSealed for FunctionRow<'_> {}
195
196impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
197 type Field<'f>
198 = FunctionArgument<'f>
199 where
200 'a: 'f,
201 Self: 'f;
202 type InnerPartialRow = Self;
203
204 fn field_count(&self) -> usize {
205 self.field_count
206 }
207
208 fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
209 where
210 'a: 'b,
211 Self: crate::row::RowIndex<I>,
212 {
213 let col_idx = self.idx(idx)?;
214 Some(FunctionArgument {
215 args: self.args.borrow(),
216 col_idx,
217 })
218 }
219
220 fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
221 PartialRow::new(self, range)
222 }
223}
224
225impl RowIndex<usize> for FunctionRow<'_> {
226 fn idx(&self, idx: usize) -> Option<usize> {
227 if idx < self.field_count() {
228 Some(idx)
229 } else {
230 None
231 }
232 }
233}
234
235impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
236 fn idx(&self, _idx: &'a str) -> Option<usize> {
237 None
238 }
239}
240
241struct FunctionArgument<'a> {
242 args: Ref<'a, ManuallyDrop<PrivateSqliteRow<'a, 'static>>>,
243 col_idx: usize,
244}
245
246impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
247 fn field_name(&self) -> Option<&str> {
248 None
249 }
250
251 fn is_null(&self) -> bool {
252 self.value().is_none()
253 }
254
255 fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
256 SqliteValue::new(
257 Ref::map(Ref::clone(&self.args), |drop| std::ops::Deref::deref(drop)),
258 self.col_idx,
259 )
260 }
261}