diesel/pg/connection/
cursor.rs

1use super::raw::RawConnection;
2use super::result::PgResult;
3use super::row::PgRow;
4use crate::pg::Pg;
5use crate::query_builder::QueryFragment;
6use std::rc::Rc;
7
8#[allow(missing_debug_implementations)]
9pub struct Cursor {
10    current_row: usize,
11    db_result: Rc<PgResult>,
12}
13
14impl Cursor {
15    pub(super) fn new(result: PgResult, conn: &mut RawConnection) -> crate::QueryResult<Cursor> {
16        let next_res = conn.get_next_result()?;
17        debug_assert!(next_res.is_none());
18        Ok(Self {
19            current_row: 0,
20            db_result: Rc::new(result),
21        })
22    }
23}
24
25impl ExactSizeIterator for Cursor {
26    fn len(&self) -> usize {
27        self.db_result.num_rows() - self.current_row
28    }
29}
30
31impl Iterator for Cursor {
32    type Item = crate::QueryResult<PgRow>;
33
34    fn next(&mut self) -> Option<Self::Item> {
35        if self.current_row < self.db_result.num_rows() {
36            let row = self.db_result.clone().get_row(self.current_row);
37            self.current_row += 1;
38            Some(Ok(row))
39        } else {
40            None
41        }
42    }
43
44    fn nth(&mut self, n: usize) -> Option<Self::Item> {
45        self.current_row = (self.current_row + n).min(self.db_result.num_rows());
46        self.next()
47    }
48
49    fn size_hint(&self) -> (usize, Option<usize>) {
50        let len = self.len();
51        (len, Some(len))
52    }
53
54    fn count(self) -> usize
55    where
56        Self: Sized,
57    {
58        self.len()
59    }
60}
61
62/// The type returned by various [`Connection`] methods.
63/// Acts as an iterator over `T`.
64#[allow(missing_debug_implementations)]
65pub struct RowByRowCursor<'conn, 'query> {
66    first_row: bool,
67    db_result: Rc<PgResult>,
68    conn: &'conn mut super::ConnectionAndTransactionManager,
69    query: Box<dyn QueryFragment<Pg> + 'query>,
70}
71
72impl<'conn, 'query> RowByRowCursor<'conn, 'query> {
73    pub(super) fn new(
74        db_result: PgResult,
75        conn: &'conn mut super::ConnectionAndTransactionManager,
76        query: Box<dyn QueryFragment<Pg> + 'query>,
77    ) -> Self {
78        RowByRowCursor {
79            first_row: true,
80            db_result: Rc::new(db_result),
81            conn,
82            query,
83        }
84    }
85}
86
87impl Iterator for RowByRowCursor<'_, '_> {
88    type Item = crate::QueryResult<PgRow>;
89
90    fn next(&mut self) -> Option<Self::Item> {
91        if !self.first_row {
92            let get_next_result = super::update_transaction_manager_status(
93                self.conn.raw_connection.get_next_result(),
94                self.conn,
95                &crate::debug_query(&self.query),
96                false,
97            );
98            match get_next_result {
99                Ok(Some(res)) => {
100                    // we try to reuse the existing allocation here
101                    if let Some(old_res) = Rc::get_mut(&mut self.db_result) {
102                        *old_res = res;
103                    } else {
104                        self.db_result = Rc::new(res);
105                    }
106                }
107                Ok(None) => {
108                    return None;
109                }
110                Err(e) => return Some(Err(e)),
111            }
112        }
113        // This contains either 1 (for a row containing data) or 0 (for the last one) rows
114        if self.db_result.num_rows() > 0 {
115            debug_assert_eq!(self.db_result.num_rows(), 1);
116            self.first_row = false;
117            Some(Ok(self.db_result.clone().get_row(0)))
118        } else {
119            None
120        }
121    }
122}
123
124impl Drop for RowByRowCursor<'_, '_> {
125    fn drop(&mut self) {
126        loop {
127            let res = super::update_transaction_manager_status(
128                self.conn.raw_connection.get_next_result(),
129                self.conn,
130                &crate::debug_query(&self.query),
131                false,
132            );
133            if matches!(res, Err(_) | Ok(None)) {
134                // the error case is handled in update_transaction_manager_status
135                if res.is_ok() {
136                    self.conn.instrumentation.on_connection_event(
137                        crate::connection::InstrumentationEvent::FinishQuery {
138                            query: &crate::debug_query(&self.query),
139                            error: None,
140                        },
141                    );
142                }
143                break;
144            }
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use crate::connection::DefaultLoadingMode;
152    use crate::pg::PgRowByRowLoadingMode;
153
154    #[diesel_test_helper::test]
155    fn fun_with_row_iters() {
156        crate::table! {
157            #[allow(unused_parens)]
158            users(id) {
159                id -> Integer,
160                name -> Text,
161            }
162        }
163
164        use crate::connection::LoadConnection;
165        use crate::deserialize::{FromSql, FromSqlRow};
166        use crate::pg::Pg;
167        use crate::prelude::*;
168        use crate::row::{Field, Row};
169        use crate::sql_types;
170
171        let conn = &mut crate::test_helpers::connection();
172
173        crate::sql_query(
174            "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
175        )
176        .execute(conn)
177        .unwrap();
178
179        crate::insert_into(users::table)
180            .values(vec![
181                (users::id.eq(1), users::name.eq("Sean")),
182                (users::id.eq(2), users::name.eq("Tess")),
183            ])
184            .execute(conn)
185            .unwrap();
186
187        let query = users::table.select((users::id, users::name));
188
189        let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
190
191        let row_iter = LoadConnection::<DefaultLoadingMode>::load(conn, query).unwrap();
192        for (row, expected) in row_iter.zip(&expected) {
193            let row = row.unwrap();
194
195            let deserialized = <(i32, String) as FromSqlRow<
196                (sql_types::Integer, sql_types::Text),
197                _,
198            >>::build_from_row(&row)
199            .unwrap();
200
201            assert_eq!(&deserialized, expected);
202        }
203
204        {
205            let collected_rows = LoadConnection::<DefaultLoadingMode>::load(conn, query)
206                .unwrap()
207                .collect::<Vec<_>>();
208
209            for (row, expected) in collected_rows.iter().zip(&expected) {
210                let deserialized = row
211                    .as_ref()
212                    .map(|row| {
213                        <(i32, String) as FromSqlRow<
214                                (sql_types::Integer, sql_types::Text),
215                            _,
216                            >>::build_from_row(row).unwrap()
217                    })
218                    .unwrap();
219
220                assert_eq!(&deserialized, expected);
221            }
222        }
223
224        let mut row_iter = LoadConnection::<DefaultLoadingMode>::load(conn, query).unwrap();
225
226        let first_row = row_iter.next().unwrap().unwrap();
227        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
228        let first_values = (first_fields.0.value(), first_fields.1.value());
229
230        let second_row = row_iter.next().unwrap().unwrap();
231        let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
232        let second_values = (second_fields.0.value(), second_fields.1.value());
233
234        assert!(row_iter.next().is_none());
235
236        assert_eq!(
237            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(first_values.0).unwrap(),
238            expected[0].0
239        );
240        assert_eq!(
241            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(first_values.1).unwrap(),
242            expected[0].1
243        );
244
245        assert_eq!(
246            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(second_values.0).unwrap(),
247            expected[1].0
248        );
249        assert_eq!(
250            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(second_values.1).unwrap(),
251            expected[1].1
252        );
253
254        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
255        let first_values = (first_fields.0.value(), first_fields.1.value());
256
257        assert_eq!(
258            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(first_values.0).unwrap(),
259            expected[0].0
260        );
261        assert_eq!(
262            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(first_values.1).unwrap(),
263            expected[0].1
264        );
265    }
266
267    #[diesel_test_helper::test]
268    fn loading_modes_return_the_same_result() {
269        use crate::prelude::*;
270
271        crate::table! {
272            #[allow(unused_parens)]
273            users(id) {
274                id -> Integer,
275                name -> Text,
276            }
277        }
278
279        let conn = &mut crate::test_helpers::connection();
280
281        crate::sql_query(
282            "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
283        )
284        .execute(conn)
285        .unwrap();
286
287        crate::insert_into(users::table)
288            .values(vec![
289                (users::id.eq(1), users::name.eq("Sean")),
290                (users::id.eq(2), users::name.eq("Tess")),
291            ])
292            .execute(conn)
293            .unwrap();
294
295        let users_by_default_mode = users::table
296            .select(users::name)
297            .load_iter::<String, DefaultLoadingMode>(conn)
298            .unwrap()
299            .collect::<QueryResult<Vec<_>>>()
300            .unwrap();
301        let users_row_by_row = users::table
302            .select(users::name)
303            .load_iter::<String, PgRowByRowLoadingMode>(conn)
304            .unwrap()
305            .collect::<QueryResult<Vec<_>>>()
306            .unwrap();
307        assert_eq!(users_by_default_mode, users_row_by_row);
308        assert_eq!(users_by_default_mode, vec!["Sean", "Tess"]);
309    }
310
311    #[diesel_test_helper::test]
312    fn fun_with_row_iters_row_by_row() {
313        crate::table! {
314            #[allow(unused_parens)]
315            users(id) {
316                id -> Integer,
317                name -> Text,
318            }
319        }
320
321        use crate::connection::LoadConnection;
322        use crate::deserialize::{FromSql, FromSqlRow};
323        use crate::pg::Pg;
324        use crate::prelude::*;
325        use crate::row::{Field, Row};
326        use crate::sql_types;
327
328        let conn = &mut crate::test_helpers::connection();
329
330        crate::sql_query(
331            "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
332        )
333        .execute(conn)
334        .unwrap();
335
336        crate::insert_into(users::table)
337            .values(vec![
338                (users::id.eq(1), users::name.eq("Sean")),
339                (users::id.eq(2), users::name.eq("Tess")),
340            ])
341            .execute(conn)
342            .unwrap();
343
344        let query = users::table.select((users::id, users::name));
345
346        let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
347
348        let row_iter = LoadConnection::<PgRowByRowLoadingMode>::load(conn, query).unwrap();
349        for (row, expected) in row_iter.zip(&expected) {
350            let row = row.unwrap();
351
352            let deserialized = <(i32, String) as FromSqlRow<
353                (sql_types::Integer, sql_types::Text),
354                _,
355            >>::build_from_row(&row)
356            .unwrap();
357
358            assert_eq!(&deserialized, expected);
359        }
360
361        {
362            let collected_rows = LoadConnection::<PgRowByRowLoadingMode>::load(conn, query)
363                .unwrap()
364                .collect::<Vec<_>>();
365
366            for (row, expected) in collected_rows.iter().zip(&expected) {
367                let deserialized = row
368                    .as_ref()
369                    .map(|row| {
370                        <(i32, String) as FromSqlRow<
371                                (sql_types::Integer, sql_types::Text),
372                            _,
373                            >>::build_from_row(row).unwrap()
374                    })
375                    .unwrap();
376
377                assert_eq!(&deserialized, expected);
378            }
379        }
380
381        let mut row_iter = LoadConnection::<PgRowByRowLoadingMode>::load(conn, query).unwrap();
382
383        let first_row = row_iter.next().unwrap().unwrap();
384        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
385        let first_values = (first_fields.0.value(), first_fields.1.value());
386
387        let second_row = row_iter.next().unwrap().unwrap();
388        let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
389        let second_values = (second_fields.0.value(), second_fields.1.value());
390
391        assert!(row_iter.next().is_none());
392
393        assert_eq!(
394            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(first_values.0).unwrap(),
395            expected[0].0
396        );
397        assert_eq!(
398            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(first_values.1).unwrap(),
399            expected[0].1
400        );
401
402        assert_eq!(
403            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(second_values.0).unwrap(),
404            expected[1].0
405        );
406        assert_eq!(
407            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(second_values.1).unwrap(),
408            expected[1].1
409        );
410
411        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
412        let first_values = (first_fields.0.value(), first_fields.1.value());
413
414        assert_eq!(
415            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(first_values.0).unwrap(),
416            expected[0].0
417        );
418        assert_eq!(
419            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(first_values.1).unwrap(),
420            expected[0].1
421        );
422    }
423}