diesel/sqlite/connection/
row.rs

1use std::cell::{Ref, RefCell};
2use std::rc::Rc;
3use std::sync::Arc;
4
5use super::owned_row::OwnedSqliteRow;
6use super::sqlite_value::{OwnedSqliteValue, SqliteValue};
7use super::stmt::StatementUse;
8use crate::backend::Backend;
9use crate::row::{Field, IntoOwnedRow, PartialRow, Row, RowIndex, RowSealed};
10use crate::sqlite::Sqlite;
11
12#[allow(missing_debug_implementations)]
13pub struct SqliteRow<'stmt, 'query> {
14    pub(super) inner: Rc<RefCell<PrivateSqliteRow<'stmt, 'query>>>,
15    pub(super) field_count: usize,
16}
17
18pub(super) enum PrivateSqliteRow<'stmt, 'query> {
19    Direct(StatementUse<'stmt, 'query>),
20    Duplicated {
21        values: Vec<Option<OwnedSqliteValue>>,
22        column_names: Rc<[Option<String>]>,
23    },
24}
25
26impl<'stmt> IntoOwnedRow<'stmt, Sqlite> for SqliteRow<'stmt, '_> {
27    type OwnedRow = OwnedSqliteRow;
28
29    type Cache = Option<Arc<[Option<String>]>>;
30
31    fn into_owned(self, column_name_cache: &mut Self::Cache) -> Self::OwnedRow {
32        self.inner.borrow().moveable(column_name_cache)
33    }
34}
35
36impl<'stmt, 'query> PrivateSqliteRow<'stmt, 'query> {
37    pub(super) fn duplicate(
38        &mut self,
39        column_names: &mut Option<Rc<[Option<String>]>>,
40    ) -> PrivateSqliteRow<'stmt, 'query> {
41        match self {
42            PrivateSqliteRow::Direct(stmt) => {
43                let column_names = if let Some(column_names) = column_names {
44                    column_names.clone()
45                } else {
46                    let c: Rc<[Option<String>]> = Rc::from(
47                        (0..stmt.column_count())
48                            .map(|idx| stmt.field_name(idx).map(|s| s.to_owned()))
49                            .collect::<Vec<_>>(),
50                    );
51                    *column_names = Some(c.clone());
52                    c
53                };
54                PrivateSqliteRow::Duplicated {
55                    values: (0..stmt.column_count())
56                        .map(|idx| stmt.copy_value(idx))
57                        .collect(),
58                    column_names,
59                }
60            }
61            PrivateSqliteRow::Duplicated {
62                values,
63                column_names,
64            } => PrivateSqliteRow::Duplicated {
65                values: values
66                    .iter()
67                    .map(|v| v.as_ref().map(|v| v.duplicate()))
68                    .collect(),
69                column_names: column_names.clone(),
70            },
71        }
72    }
73
74    pub(super) fn moveable(
75        &self,
76        column_name_cache: &mut Option<Arc<[Option<String>]>>,
77    ) -> OwnedSqliteRow {
78        match self {
79            PrivateSqliteRow::Direct(stmt) => {
80                if column_name_cache.is_none() {
81                    *column_name_cache = Some(
82                        (0..stmt.column_count())
83                            .map(|idx| stmt.field_name(idx).map(|s| s.to_owned()))
84                            .collect::<Vec<_>>()
85                            .into(),
86                    );
87                }
88                let column_names = Arc::clone(
89                    column_name_cache
90                        .as_ref()
91                        .expect("This is initialized above"),
92                );
93                OwnedSqliteRow::new(
94                    (0..stmt.column_count())
95                        .map(|idx| stmt.copy_value(idx))
96                        .collect(),
97                    column_names,
98                )
99            }
100            PrivateSqliteRow::Duplicated {
101                values,
102                column_names,
103            } => {
104                if column_name_cache.is_none() {
105                    *column_name_cache = Some(
106                        (*column_names)
107                            .iter()
108                            .map(|s| s.to_owned())
109                            .collect::<Vec<_>>()
110                            .into(),
111                    );
112                }
113                let column_names = Arc::clone(
114                    column_name_cache
115                        .as_ref()
116                        .expect("This is initialized above"),
117                );
118                OwnedSqliteRow::new(
119                    values
120                        .iter()
121                        .map(|v| v.as_ref().map(|v| v.duplicate()))
122                        .collect(),
123                    column_names,
124                )
125            }
126        }
127    }
128}
129
130impl RowSealed for SqliteRow<'_, '_> {}
131
132impl<'stmt> Row<'stmt, Sqlite> for SqliteRow<'stmt, '_> {
133    type Field<'field>
134        = SqliteField<'field, 'field>
135    where
136        'stmt: 'field,
137        Self: 'field;
138    type InnerPartialRow = Self;
139
140    fn field_count(&self) -> usize {
141        self.field_count
142    }
143
144    fn get<'field, I>(&'field self, idx: I) -> Option<Self::Field<'field>>
145    where
146        'stmt: 'field,
147        Self: RowIndex<I>,
148    {
149        let idx = self.idx(idx)?;
150        Some(SqliteField {
151            row: self.inner.borrow(),
152            col_idx: idx,
153        })
154    }
155
156    fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
157        PartialRow::new(self, range)
158    }
159}
160
161impl RowIndex<usize> for SqliteRow<'_, '_> {
162    fn idx(&self, idx: usize) -> Option<usize> {
163        if idx < self.field_count {
164            Some(idx)
165        } else {
166            None
167        }
168    }
169}
170
171impl<'idx> RowIndex<&'idx str> for SqliteRow<'_, '_> {
172    fn idx(&self, field_name: &'idx str) -> Option<usize> {
173        match &mut *self.inner.borrow_mut() {
174            PrivateSqliteRow::Direct(stmt) => stmt.index_for_column_name(field_name),
175            PrivateSqliteRow::Duplicated { column_names, .. } => column_names
176                .iter()
177                .position(|n| n.as_ref().map(|s| s as &str) == Some(field_name)),
178        }
179    }
180}
181
182#[allow(missing_debug_implementations)]
183pub struct SqliteField<'stmt, 'query> {
184    pub(super) row: Ref<'stmt, PrivateSqliteRow<'stmt, 'query>>,
185    pub(super) col_idx: usize,
186}
187
188impl<'stmt> Field<'stmt, Sqlite> for SqliteField<'stmt, '_> {
189    fn field_name(&self) -> Option<&str> {
190        match &*self.row {
191            PrivateSqliteRow::Direct(stmt) => stmt.field_name(
192                self.col_idx
193                    .try_into()
194                    .expect("Diesel expects to run at least on a 32 bit platform"),
195            ),
196            PrivateSqliteRow::Duplicated { column_names, .. } => column_names
197                .get(self.col_idx)
198                .and_then(|t| t.as_ref().map(|n| n as &str)),
199        }
200    }
201
202    fn is_null(&self) -> bool {
203        self.value().is_none()
204    }
205
206    fn value(&self) -> Option<<Sqlite as Backend>::RawValue<'_>> {
207        SqliteValue::new(Ref::clone(&self.row), self.col_idx)
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn fun_with_row_iters() {
217        crate::table! {
218            #[allow(unused_parens)]
219            users(id) {
220                id -> Integer,
221                name -> Text,
222            }
223        }
224
225        use crate::connection::LoadConnection;
226        use crate::deserialize::{FromSql, FromSqlRow};
227        use crate::prelude::*;
228        use crate::row::{Field, Row};
229        use crate::sql_types;
230
231        let conn = &mut crate::test_helpers::connection();
232
233        crate::sql_query("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);")
234            .execute(conn)
235            .unwrap();
236
237        crate::insert_into(users::table)
238            .values(vec![
239                (users::id.eq(1), users::name.eq("Sean")),
240                (users::id.eq(2), users::name.eq("Tess")),
241            ])
242            .execute(conn)
243            .unwrap();
244
245        let query = users::table.select((users::id, users::name));
246
247        let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
248
249        let row_iter = conn.load(query).unwrap();
250        for (row, expected) in row_iter.zip(&expected) {
251            let row = row.unwrap();
252
253            let deserialized = <(i32, String) as FromSqlRow<
254                (sql_types::Integer, sql_types::Text),
255                _,
256            >>::build_from_row(&row)
257            .unwrap();
258
259            assert_eq!(&deserialized, expected);
260        }
261
262        {
263            let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();
264
265            for (row, expected) in collected_rows.iter().zip(&expected) {
266                let deserialized = row
267                    .as_ref()
268                    .map(|row| {
269                        <(i32, String) as FromSqlRow<
270                            (sql_types::Integer, sql_types::Text),
271                        _,
272                        >>::build_from_row(row).unwrap()
273                    })
274                    .unwrap();
275
276                assert_eq!(&deserialized, expected);
277            }
278        }
279
280        let mut row_iter = conn.load(query).unwrap();
281
282        let first_row = row_iter.next().unwrap().unwrap();
283        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
284        let first_values = (first_fields.0.value(), first_fields.1.value());
285
286        assert!(row_iter.next().unwrap().is_err());
287        std::mem::drop(first_values);
288        assert!(row_iter.next().unwrap().is_err());
289        std::mem::drop(first_fields);
290
291        let second_row = row_iter.next().unwrap().unwrap();
292        let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
293        let second_values = (second_fields.0.value(), second_fields.1.value());
294
295        assert!(row_iter.next().unwrap().is_err());
296        std::mem::drop(second_values);
297        assert!(row_iter.next().unwrap().is_err());
298        std::mem::drop(second_fields);
299
300        assert!(row_iter.next().is_none());
301
302        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
303        let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
304
305        let first_values = (first_fields.0.value(), first_fields.1.value());
306        let second_values = (second_fields.0.value(), second_fields.1.value());
307
308        assert_eq!(
309            <i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0)
310                .unwrap(),
311            expected[0].0
312        );
313        assert_eq!(
314            <String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1)
315                .unwrap(),
316            expected[0].1
317        );
318
319        assert_eq!(
320            <i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(second_values.0)
321                .unwrap(),
322            expected[1].0
323        );
324        assert_eq!(
325            <String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(second_values.1)
326                .unwrap(),
327            expected[1].1
328        );
329
330        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
331        let first_values = (first_fields.0.value(), first_fields.1.value());
332
333        assert_eq!(
334            <i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0)
335                .unwrap(),
336            expected[0].0
337        );
338        assert_eq!(
339            <String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1)
340                .unwrap(),
341            expected[0].1
342        );
343    }
344
345    #[cfg(feature = "returning_clauses_for_sqlite_3_35")]
346    crate::define_sql_function! {fn sleep(a: diesel::sql_types::Integer) -> diesel::sql_types::Integer}
347
348    #[test]
349    #[cfg(feature = "returning_clauses_for_sqlite_3_35")]
350    #[allow(clippy::cast_sign_loss)]
351    fn parallel_iter_with_error() {
352        use crate::connection::Connection;
353        use crate::connection::LoadConnection;
354        use crate::connection::SimpleConnection;
355        use crate::expression_methods::ExpressionMethods;
356        use crate::SqliteConnection;
357        use std::sync::{Arc, Barrier};
358        use std::time::Duration;
359
360        let temp_dir = tempfile::tempdir().unwrap();
361        let db_path = format!("{}/test.db", temp_dir.path().display());
362        let mut conn1 = SqliteConnection::establish(&db_path).unwrap();
363        let mut conn2 = SqliteConnection::establish(&db_path).unwrap();
364
365        crate::table! {
366            users {
367                id -> Integer,
368                name -> Text,
369            }
370        }
371
372        conn1
373            .batch_execute("CREATE TABLE users(id INTEGER NOT NULL PRIMARY KEY, name TEXT)")
374            .unwrap();
375
376        let barrier = Arc::new(Barrier::new(2));
377        let barrier2 = barrier.clone();
378
379        // we unblock the main thread from the sleep function
380        sleep_utils::register_impl(&mut conn2, move |a: i32| {
381            barrier.wait();
382            std::thread::sleep(Duration::from_secs(a as u64));
383            a
384        })
385        .unwrap();
386
387        // spawn a background thread that locks the database file
388        let handle = std::thread::spawn(move || {
389            use crate::query_dsl::RunQueryDsl;
390
391            conn2
392                .immediate_transaction(|conn| diesel::select(sleep(1)).execute(conn))
393                .unwrap();
394        });
395        barrier2.wait();
396
397        // execute some action that also requires a lock
398        let mut iter = conn1
399            .load(
400                diesel::insert_into(users::table)
401                    .values((users::id.eq(1), users::name.eq("John")))
402                    .returning(users::id),
403            )
404            .unwrap();
405
406        // get the first iterator result, that should return the lock error
407        let n = iter.next().unwrap();
408        assert!(n.is_err());
409
410        // check that the iterator is now empty
411        let n = iter.next();
412        assert!(n.is_none());
413
414        // join the background thread
415        handle.join().unwrap();
416    }
417}