Skip to main content

diesel/sqlite/connection/
row.rs

1use super::owned_row::OwnedSqliteRow;
2use super::sqlite_value::{OwnedSqliteValue, SqliteValue};
3use super::stmt::StatementUse;
4use crate::backend::Backend;
5use crate::row::{Field, IntoOwnedRow, PartialRow, Row, RowIndex, RowSealed};
6use crate::sqlite::Sqlite;
7use alloc::borrow::ToOwned;
8use alloc::rc::Rc;
9use alloc::string::String;
10use alloc::sync::Arc;
11use alloc::vec::Vec;
12use core::cell::{Ref, RefCell};
13
14#[allow(missing_debug_implementations)]
15pub struct SqliteRow<'stmt, 'query> {
16    pub(super) inner: Rc<RefCell<PrivateSqliteRow<'stmt, 'query>>>,
17    pub(super) field_count: usize,
18}
19
20pub(super) enum PrivateSqliteRow<'stmt, 'query> {
21    Direct(StatementUse<'stmt, 'query>),
22    Duplicated {
23        values: Vec<Option<OwnedSqliteValue>>,
24        column_names: Rc<[Option<String>]>,
25    },
26}
27
28impl<'stmt> IntoOwnedRow<'stmt, Sqlite> for SqliteRow<'stmt, '_> {
29    type OwnedRow = OwnedSqliteRow;
30
31    type Cache = Option<Arc<[Option<String>]>>;
32
33    fn into_owned(self, column_name_cache: &mut Self::Cache) -> Self::OwnedRow {
34        self.inner.borrow().moveable(column_name_cache)
35    }
36}
37
38impl<'stmt, 'query> PrivateSqliteRow<'stmt, 'query> {
39    pub(super) fn duplicate(
40        &mut self,
41        column_names: &mut Option<Rc<[Option<String>]>>,
42    ) -> PrivateSqliteRow<'stmt, 'query> {
43        match self {
44            PrivateSqliteRow::Direct(stmt) => {
45                let column_names = if let Some(column_names) = column_names {
46                    column_names.clone()
47                } else {
48                    let c: Rc<[Option<String>]> = Rc::from(
49                        (0..stmt.column_count())
50                            .map(|idx| stmt.field_name(idx).map(|s| s.to_owned()))
51                            .collect::<Vec<_>>(),
52                    );
53                    *column_names = Some(c.clone());
54                    c
55                };
56                PrivateSqliteRow::Duplicated {
57                    values: (0..stmt.column_count())
58                        .map(|idx| stmt.copy_value(idx))
59                        .collect(),
60                    column_names,
61                }
62            }
63            PrivateSqliteRow::Duplicated {
64                values,
65                column_names,
66            } => PrivateSqliteRow::Duplicated {
67                values: values
68                    .iter()
69                    .map(|v| v.as_ref().map(|v| v.duplicate()))
70                    .collect(),
71                column_names: column_names.clone(),
72            },
73        }
74    }
75
76    pub(super) fn moveable(
77        &self,
78        column_name_cache: &mut Option<Arc<[Option<String>]>>,
79    ) -> OwnedSqliteRow {
80        match self {
81            PrivateSqliteRow::Direct(stmt) => {
82                if column_name_cache.is_none() {
83                    *column_name_cache = Some(
84                        (0..stmt.column_count())
85                            .map(|idx| stmt.field_name(idx).map(|s| s.to_owned()))
86                            .collect::<Vec<_>>()
87                            .into(),
88                    );
89                }
90                let column_names = Arc::clone(
91                    column_name_cache
92                        .as_ref()
93                        .expect("This is initialized above"),
94                );
95                OwnedSqliteRow::new(
96                    (0..stmt.column_count())
97                        .map(|idx| stmt.copy_value(idx))
98                        .collect(),
99                    column_names,
100                )
101            }
102            PrivateSqliteRow::Duplicated {
103                values,
104                column_names,
105            } => {
106                if column_name_cache.is_none() {
107                    *column_name_cache = Some(
108                        (*column_names)
109                            .iter()
110                            .map(|s| s.to_owned())
111                            .collect::<Vec<_>>()
112                            .into(),
113                    );
114                }
115                let column_names = Arc::clone(
116                    column_name_cache
117                        .as_ref()
118                        .expect("This is initialized above"),
119                );
120                OwnedSqliteRow::new(
121                    values
122                        .iter()
123                        .map(|v| v.as_ref().map(|v| v.duplicate()))
124                        .collect(),
125                    column_names,
126                )
127            }
128        }
129    }
130}
131
132impl RowSealed for SqliteRow<'_, '_> {}
133
134impl<'stmt> Row<'stmt, Sqlite> for SqliteRow<'stmt, '_> {
135    type Field<'field>
136        = SqliteField<'field, 'field>
137    where
138        'stmt: 'field,
139        Self: 'field;
140    type InnerPartialRow = Self;
141
142    fn field_count(&self) -> usize {
143        self.field_count
144    }
145
146    fn get<'field, I>(&'field self, idx: I) -> Option<Self::Field<'field>>
147    where
148        'stmt: 'field,
149        Self: RowIndex<I>,
150    {
151        let idx = self.idx(idx)?;
152        Some(SqliteField {
153            row: self.inner.borrow(),
154            col_idx: idx,
155        })
156    }
157
158    fn partial_row(&self, range: core::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
159        PartialRow::new(self, range)
160    }
161}
162
163impl RowIndex<usize> for SqliteRow<'_, '_> {
164    fn idx(&self, idx: usize) -> Option<usize> {
165        if idx < self.field_count {
166            Some(idx)
167        } else {
168            None
169        }
170    }
171}
172
173impl<'idx> RowIndex<&'idx str> for SqliteRow<'_, '_> {
174    fn idx(&self, field_name: &'idx str) -> Option<usize> {
175        match &mut *self.inner.borrow_mut() {
176            PrivateSqliteRow::Direct(stmt) => stmt.index_for_column_name(field_name),
177            PrivateSqliteRow::Duplicated { column_names, .. } => column_names
178                .iter()
179                .position(|n| n.as_ref().map(|s| s as &str) == Some(field_name)),
180        }
181    }
182}
183
184#[allow(missing_debug_implementations)]
185pub struct SqliteField<'stmt, 'query> {
186    pub(super) row: Ref<'stmt, PrivateSqliteRow<'stmt, 'query>>,
187    pub(super) col_idx: usize,
188}
189
190impl<'stmt> Field<'stmt, Sqlite> for SqliteField<'stmt, '_> {
191    fn field_name(&self) -> Option<&str> {
192        match &*self.row {
193            PrivateSqliteRow::Direct(stmt) => stmt.field_name(
194                self.col_idx
195                    .try_into()
196                    .expect("Diesel expects to run at least on a 32 bit platform"),
197            ),
198            PrivateSqliteRow::Duplicated { column_names, .. } => column_names
199                .get(self.col_idx)
200                .and_then(|t| t.as_ref().map(|n| n as &str)),
201        }
202    }
203
204    fn is_null(&self) -> bool {
205        self.value().is_none()
206    }
207
208    fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
209        SqliteValue::new(Ref::clone(&self.row), self.col_idx)
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[diesel_test_helper::test]
218    fn fun_with_row_iters() {
219        crate::table! {
220            #[allow(unused_parens)]
221            users(id) {
222                id -> Integer,
223                name -> Text,
224            }
225        }
226
227        use crate::connection::LoadConnection;
228        use crate::deserialize::{FromSql, FromSqlRow};
229        use crate::prelude::*;
230        use crate::row::{Field, Row};
231        use crate::sql_types;
232
233        let conn = &mut crate::test_helpers::connection();
234
235        crate::sql_query("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);")
236            .execute(conn)
237            .unwrap();
238
239        crate::insert_into(users::table)
240            .values(vec![
241                (users::id.eq(1), users::name.eq("Sean")),
242                (users::id.eq(2), users::name.eq("Tess")),
243            ])
244            .execute(conn)
245            .unwrap();
246
247        let query = users::table.select((users::id, users::name));
248
249        let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
250
251        let row_iter = conn.load(query).unwrap();
252        for (row, expected) in row_iter.zip(&expected) {
253            let row = row.unwrap();
254
255            let deserialized = <(i32, String) as FromSqlRow<
256                (sql_types::Integer, sql_types::Text),
257                _,
258            >>::build_from_row(&row)
259            .unwrap();
260
261            assert_eq!(&deserialized, expected);
262        }
263
264        {
265            let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();
266
267            for (row, expected) in collected_rows.iter().zip(&expected) {
268                let deserialized = row
269                    .as_ref()
270                    .map(|row| {
271                        <(i32, String) as FromSqlRow<
272                            (sql_types::Integer, sql_types::Text),
273                        _,
274                        >>::build_from_row(row).unwrap()
275                    })
276                    .unwrap();
277
278                assert_eq!(&deserialized, expected);
279            }
280        }
281
282        let mut row_iter = conn.load(query).unwrap();
283
284        let first_row = row_iter.next().unwrap().unwrap();
285        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
286        let first_values = (first_fields.0.value(), first_fields.1.value());
287
288        assert!(row_iter.next().unwrap().is_err());
289        std::mem::drop(first_values);
290        assert!(row_iter.next().unwrap().is_err());
291        std::mem::drop(first_fields);
292
293        let second_row = row_iter.next().unwrap().unwrap();
294        let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
295        let second_values = (second_fields.0.value(), second_fields.1.value());
296
297        assert!(row_iter.next().unwrap().is_err());
298        std::mem::drop(second_values);
299        assert!(row_iter.next().unwrap().is_err());
300        std::mem::drop(second_fields);
301
302        assert!(row_iter.next().is_none());
303
304        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
305        let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
306
307        let first_values = (first_fields.0.value(), first_fields.1.value());
308        let second_values = (second_fields.0.value(), second_fields.1.value());
309
310        assert_eq!(
311            <i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0)
312                .unwrap(),
313            expected[0].0
314        );
315        assert_eq!(
316            <String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1)
317                .unwrap(),
318            expected[0].1
319        );
320
321        assert_eq!(
322            <i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(second_values.0)
323                .unwrap(),
324            expected[1].0
325        );
326        assert_eq!(
327            <String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(second_values.1)
328                .unwrap(),
329            expected[1].1
330        );
331
332        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
333        let first_values = (first_fields.0.value(), first_fields.1.value());
334
335        assert_eq!(
336            <i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0)
337                .unwrap(),
338            expected[0].0
339        );
340        assert_eq!(
341            <String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1)
342                .unwrap(),
343            expected[0].1
344        );
345    }
346
347    #[cfg(feature = "returning_clauses_for_sqlite_3_35")]
348    #[crate::declare_sql_function]
349    extern "SQL" {
350        fn sleep(a: diesel::sql_types::Integer) -> diesel::sql_types::Integer;
351    }
352
353    #[diesel_test_helper::test]
354    #[cfg(feature = "returning_clauses_for_sqlite_3_35")]
355    #[allow(clippy::cast_sign_loss)]
356    fn parallel_iter_with_error() {
357        use crate::SqliteConnection;
358        use crate::connection::Connection;
359        use crate::connection::LoadConnection;
360        use crate::connection::SimpleConnection;
361        use crate::expression_methods::ExpressionMethods;
362        use std::sync::{Arc, Barrier};
363        use std::time::Duration;
364
365        let temp_dir = tempfile::tempdir().unwrap();
366        let db_path = format!("{}/test.db", temp_dir.path().display());
367        let mut conn1 = SqliteConnection::establish(&db_path).unwrap();
368        let mut conn2 = SqliteConnection::establish(&db_path).unwrap();
369
370        crate::table! {
371            users {
372                id -> Integer,
373                name -> Text,
374            }
375        }
376
377        conn1
378            .batch_execute("CREATE TABLE users(id INTEGER NOT NULL PRIMARY KEY, name TEXT)")
379            .unwrap();
380
381        let barrier = Arc::new(Barrier::new(2));
382        let barrier2 = barrier.clone();
383
384        // we unblock the main thread from the sleep function
385        sleep_utils::register_impl(&mut conn2, move |a: i32| {
386            barrier.wait();
387            std::thread::sleep(Duration::from_secs(a as u64));
388            a
389        })
390        .unwrap();
391
392        // spawn a background thread that locks the database file
393        let handle = std::thread::spawn(move || {
394            use crate::query_dsl::RunQueryDsl;
395
396            conn2
397                .immediate_transaction(|conn| diesel::select(sleep(1)).execute(conn))
398                .unwrap();
399        });
400        barrier2.wait();
401
402        // execute some action that also requires a lock
403        let mut iter = conn1
404            .load(
405                diesel::insert_into(users::table)
406                    .values((users::id.eq(1), users::name.eq("John")))
407                    .returning(users::id),
408            )
409            .unwrap();
410
411        // get the first iterator result, that should return the lock error
412        let n = iter.next().unwrap();
413        assert!(n.is_err());
414
415        // check that the iterator is now empty
416        let n = iter.next();
417        assert!(n.is_none());
418
419        // join the background thread
420        handle.join().unwrap();
421    }
422}