diesel/sqlite/connection/
functions.rs
1extern crate libsqlite3_sys as ffi;
2
3use super::raw::RawConnection;
4use super::row::PrivateSqliteRow;
5use super::{Sqlite, SqliteAggregateFunction, SqliteBindValue};
6use crate::backend::Backend;
7use crate::deserialize::{FromSqlRow, StaticallySizedRow};
8use crate::result::{DatabaseErrorKind, Error, QueryResult};
9use crate::row::{Field, PartialRow, Row, RowIndex, RowSealed};
10use crate::serialize::{IsNull, Output, ToSql};
11use crate::sql_types::HasSqlType;
12use crate::sqlite::connection::bind_collector::InternalSqliteBindValue;
13use crate::sqlite::connection::sqlite_value::OwnedSqliteValue;
14use crate::sqlite::SqliteValue;
15use std::cell::{Ref, RefCell};
16use std::marker::PhantomData;
17use std::mem::ManuallyDrop;
18use std::ops::DerefMut;
19use std::rc::Rc;
20
21pub(super) fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
22 conn: &RawConnection,
23 fn_name: &str,
24 deterministic: bool,
25 mut f: F,
26) -> QueryResult<()>
27where
28 F: FnMut(&RawConnection, Args) -> Ret + std::panic::UnwindSafe + Send + 'static,
29 Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
30 Ret: ToSql<RetSqlType, Sqlite>,
31 Sqlite: HasSqlType<RetSqlType>,
32{
33 let fields_needed = Args::FIELD_COUNT;
34 if fields_needed > 127 {
35 return Err(Error::DatabaseError(
36 DatabaseErrorKind::UnableToSendCommand,
37 Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
38 ));
39 }
40
41 conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
42 let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
43
44 Ok(f(conn, args))
45 })?;
46 Ok(())
47}
48
49pub(super) fn register_noargs<RetSqlType, Ret, F>(
50 conn: &RawConnection,
51 fn_name: &str,
52 deterministic: bool,
53 mut f: F,
54) -> QueryResult<()>
55where
56 F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
57 Ret: ToSql<RetSqlType, Sqlite>,
58 Sqlite: HasSqlType<RetSqlType>,
59{
60 conn.register_sql_function(fn_name, 0, deterministic, move |_, _| Ok(f()))?;
61 Ok(())
62}
63
64pub(super) fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
65 conn: &RawConnection,
66 fn_name: &str,
67) -> QueryResult<()>
68where
69 A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
70 Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
71 Ret: ToSql<RetSqlType, Sqlite>,
72 Sqlite: HasSqlType<RetSqlType>,
73{
74 let fields_needed = Args::FIELD_COUNT;
75 if fields_needed > 127 {
76 return Err(Error::DatabaseError(
77 DatabaseErrorKind::UnableToSendCommand,
78 Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
79 ));
80 }
81
82 conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
83 fn_name,
84 fields_needed,
85 )?;
86
87 Ok(())
88}
89
90pub(super) fn build_sql_function_args<ArgsSqlType, Args>(
91 args: &mut [*mut ffi::sqlite3_value],
92) -> Result<Args, Error>
93where
94 Args: FromSqlRow<ArgsSqlType, Sqlite>,
95{
96 let row = FunctionRow::new(args);
97 Args::build_from_row(&row).map_err(Error::DeserializationError)
98}
99
100#[allow(clippy::let_unit_value)]
103pub(super) fn process_sql_function_result<RetSqlType, Ret>(
104 result: &'_ Ret,
105) -> QueryResult<InternalSqliteBindValue<'_>>
106where
107 Ret: ToSql<RetSqlType, Sqlite>,
108 Sqlite: HasSqlType<RetSqlType>,
109{
110 let mut metadata_lookup = ();
111 let value = SqliteBindValue {
112 inner: InternalSqliteBindValue::Null,
113 };
114 let mut buf = Output::new(value, &mut metadata_lookup);
115 let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
116
117 if let IsNull::Yes = is_null {
118 Ok(InternalSqliteBindValue::Null)
119 } else {
120 Ok(buf.into_inner().inner)
121 }
122}
123
124struct FunctionRow<'a> {
125 args: Rc<RefCell<ManuallyDrop<PrivateSqliteRow<'a, 'static>>>>,
128 field_count: usize,
129 marker: PhantomData<&'a ffi::sqlite3_value>,
130}
131
132impl Drop for FunctionRow<'_> {
133 #[allow(unsafe_code)] fn drop(&mut self) {
135 if let Some(args) = Rc::get_mut(&mut self.args) {
136 if let PrivateSqliteRow::Duplicated { column_names, .. } =
137 DerefMut::deref_mut(RefCell::get_mut(args))
138 {
139 if Rc::strong_count(column_names) == 1 {
140 unsafe { std::ptr::drop_in_place(column_names as *mut _) }
143 }
144 }
145 }
146 }
147}
148
149impl FunctionRow<'_> {
150 #[allow(unsafe_code)] fn new(args: &mut [*mut ffi::sqlite3_value]) -> Self {
152 let lengths = args.len();
153 let args = unsafe {
154 Vec::from_raw_parts(
155 args as *mut [*mut ffi::sqlite3_value] as *mut ffi::sqlite3_value
171 as *mut Option<OwnedSqliteValue>,
172 lengths,
173 lengths,
174 )
175 };
176
177 Self {
178 field_count: lengths,
179 args: Rc::new(RefCell::new(ManuallyDrop::new(
180 PrivateSqliteRow::Duplicated {
181 values: args,
182 column_names: Rc::from(vec![None; lengths]),
183 },
184 ))),
185 marker: PhantomData,
186 }
187 }
188}
189
190impl RowSealed for FunctionRow<'_> {}
191
192impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {
193 type Field<'f>
194 = FunctionArgument<'f>
195 where
196 'a: 'f,
197 Self: 'f;
198 type InnerPartialRow = Self;
199
200 fn field_count(&self) -> usize {
201 self.field_count
202 }
203
204 fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
205 where
206 'a: 'b,
207 Self: crate::row::RowIndex<I>,
208 {
209 let col_idx = self.idx(idx)?;
210 Some(FunctionArgument {
211 args: self.args.borrow(),
212 col_idx,
213 })
214 }
215
216 fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
217 PartialRow::new(self, range)
218 }
219}
220
221impl RowIndex<usize> for FunctionRow<'_> {
222 fn idx(&self, idx: usize) -> Option<usize> {
223 if idx < self.field_count() {
224 Some(idx)
225 } else {
226 None
227 }
228 }
229}
230
231impl<'a> RowIndex<&'a str> for FunctionRow<'_> {
232 fn idx(&self, _idx: &'a str) -> Option<usize> {
233 None
234 }
235}
236
237struct FunctionArgument<'a> {
238 args: Ref<'a, ManuallyDrop<PrivateSqliteRow<'a, 'static>>>,
239 col_idx: usize,
240}
241
242impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
243 fn field_name(&self) -> Option<&str> {
244 None
245 }
246
247 fn is_null(&self) -> bool {
248 self.value().is_none()
249 }
250
251 fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
252 SqliteValue::new(
253 Ref::map(Ref::clone(&self.args), |drop| std::ops::Deref::deref(drop)),
254 self.col_idx,
255 )
256 }
257}