diesel/pg/connection/
cursor.rs

1use super::raw::RawConnection;
2use super::result::PgResult;
3use super::row::PgRow;
4use crate::connection::Instrumentation;
5use crate::pg::Pg;
6use crate::query_builder::QueryFragment;
7use std::rc::Rc;
8
9#[allow(missing_debug_implementations)]
10pub struct Cursor {
11    current_row: usize,
12    db_result: Rc<PgResult>,
13}
14
15impl Cursor {
16    pub(super) fn new(result: PgResult, conn: &mut RawConnection) -> crate::QueryResult<Cursor> {
17        let next_res = conn.get_next_result()?;
18        debug_assert!(next_res.is_none());
19        Ok(Self {
20            current_row: 0,
21            db_result: Rc::new(result),
22        })
23    }
24}
25
26impl ExactSizeIterator for Cursor {
27    fn len(&self) -> usize {
28        self.db_result.num_rows() - self.current_row
29    }
30}
31
32impl Iterator for Cursor {
33    type Item = crate::QueryResult<PgRow>;
34
35    fn next(&mut self) -> Option<Self::Item> {
36        if self.current_row < self.db_result.num_rows() {
37            let row = self.db_result.clone().get_row(self.current_row);
38            self.current_row += 1;
39            Some(Ok(row))
40        } else {
41            None
42        }
43    }
44
45    fn nth(&mut self, n: usize) -> Option<Self::Item> {
46        self.current_row = (self.current_row + n).min(self.db_result.num_rows());
47        self.next()
48    }
49
50    fn size_hint(&self) -> (usize, Option<usize>) {
51        let len = self.len();
52        (len, Some(len))
53    }
54
55    fn count(self) -> usize
56    where
57        Self: Sized,
58    {
59        self.len()
60    }
61}
62
63/// The type returned by various [`Connection`] methods.
64/// Acts as an iterator over `T`.
65#[allow(missing_debug_implementations)]
66pub struct RowByRowCursor<'conn, 'query> {
67    first_row: bool,
68    db_result: Rc<PgResult>,
69    conn: &'conn mut super::ConnectionAndTransactionManager,
70    query: Box<dyn QueryFragment<Pg> + 'query>,
71}
72
73impl<'conn, 'query> RowByRowCursor<'conn, 'query> {
74    pub(super) fn new(
75        db_result: PgResult,
76        conn: &'conn mut super::ConnectionAndTransactionManager,
77        query: Box<dyn QueryFragment<Pg> + 'query>,
78    ) -> Self {
79        RowByRowCursor {
80            first_row: true,
81            db_result: Rc::new(db_result),
82            conn,
83            query,
84        }
85    }
86}
87
88impl Iterator for RowByRowCursor<'_, '_> {
89    type Item = crate::QueryResult<PgRow>;
90
91    fn next(&mut self) -> Option<Self::Item> {
92        if !self.first_row {
93            let get_next_result = super::update_transaction_manager_status(
94                self.conn.raw_connection.get_next_result(),
95                self.conn,
96                &crate::debug_query(&self.query),
97                false,
98            );
99            match get_next_result {
100                Ok(Some(res)) => {
101                    // we try to reuse the existing allocation here
102                    if let Some(old_res) = Rc::get_mut(&mut self.db_result) {
103                        *old_res = res;
104                    } else {
105                        self.db_result = Rc::new(res);
106                    }
107                }
108                Ok(None) => {
109                    return None;
110                }
111                Err(e) => return Some(Err(e)),
112            }
113        }
114        // This contains either 1 (for a row containing data) or 0 (for the last one) rows
115        if self.db_result.num_rows() > 0 {
116            debug_assert_eq!(self.db_result.num_rows(), 1);
117            self.first_row = false;
118            Some(Ok(self.db_result.clone().get_row(0)))
119        } else {
120            None
121        }
122    }
123}
124
125impl Drop for RowByRowCursor<'_, '_> {
126    fn drop(&mut self) {
127        loop {
128            let res = super::update_transaction_manager_status(
129                self.conn.raw_connection.get_next_result(),
130                self.conn,
131                &crate::debug_query(&self.query),
132                false,
133            );
134            if matches!(res, Err(_) | Ok(None)) {
135                // the error case is handled in update_transaction_manager_status
136                if res.is_ok() {
137                    self.conn.instrumentation.on_connection_event(
138                        crate::connection::InstrumentationEvent::FinishQuery {
139                            query: &crate::debug_query(&self.query),
140                            error: None,
141                        },
142                    );
143                }
144                break;
145            }
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use crate::connection::DefaultLoadingMode;
153    use crate::pg::PgRowByRowLoadingMode;
154
155    #[test]
156    fn fun_with_row_iters() {
157        crate::table! {
158            #[allow(unused_parens)]
159            users(id) {
160                id -> Integer,
161                name -> Text,
162            }
163        }
164
165        use crate::connection::LoadConnection;
166        use crate::deserialize::{FromSql, FromSqlRow};
167        use crate::pg::Pg;
168        use crate::prelude::*;
169        use crate::row::{Field, Row};
170        use crate::sql_types;
171
172        let conn = &mut crate::test_helpers::connection();
173
174        crate::sql_query(
175            "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
176        )
177        .execute(conn)
178        .unwrap();
179
180        crate::insert_into(users::table)
181            .values(vec![
182                (users::id.eq(1), users::name.eq("Sean")),
183                (users::id.eq(2), users::name.eq("Tess")),
184            ])
185            .execute(conn)
186            .unwrap();
187
188        let query = users::table.select((users::id, users::name));
189
190        let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
191
192        let row_iter = LoadConnection::<DefaultLoadingMode>::load(conn, query).unwrap();
193        for (row, expected) in row_iter.zip(&expected) {
194            let row = row.unwrap();
195
196            let deserialized = <(i32, String) as FromSqlRow<
197                (sql_types::Integer, sql_types::Text),
198                _,
199            >>::build_from_row(&row)
200            .unwrap();
201
202            assert_eq!(&deserialized, expected);
203        }
204
205        {
206            let collected_rows = LoadConnection::<DefaultLoadingMode>::load(conn, query)
207                .unwrap()
208                .collect::<Vec<_>>();
209
210            for (row, expected) in collected_rows.iter().zip(&expected) {
211                let deserialized = row
212                    .as_ref()
213                    .map(|row| {
214                        <(i32, String) as FromSqlRow<
215                                (sql_types::Integer, sql_types::Text),
216                            _,
217                            >>::build_from_row(row).unwrap()
218                    })
219                    .unwrap();
220
221                assert_eq!(&deserialized, expected);
222            }
223        }
224
225        let mut row_iter = LoadConnection::<DefaultLoadingMode>::load(conn, query).unwrap();
226
227        let first_row = row_iter.next().unwrap().unwrap();
228        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
229        let first_values = (first_fields.0.value(), first_fields.1.value());
230
231        let second_row = row_iter.next().unwrap().unwrap();
232        let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
233        let second_values = (second_fields.0.value(), second_fields.1.value());
234
235        assert!(row_iter.next().is_none());
236
237        assert_eq!(
238            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(first_values.0).unwrap(),
239            expected[0].0
240        );
241        assert_eq!(
242            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(first_values.1).unwrap(),
243            expected[0].1
244        );
245
246        assert_eq!(
247            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(second_values.0).unwrap(),
248            expected[1].0
249        );
250        assert_eq!(
251            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(second_values.1).unwrap(),
252            expected[1].1
253        );
254
255        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
256        let first_values = (first_fields.0.value(), first_fields.1.value());
257
258        assert_eq!(
259            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(first_values.0).unwrap(),
260            expected[0].0
261        );
262        assert_eq!(
263            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(first_values.1).unwrap(),
264            expected[0].1
265        );
266    }
267
268    #[test]
269    fn loading_modes_return_the_same_result() {
270        use crate::prelude::*;
271
272        crate::table! {
273            #[allow(unused_parens)]
274            users(id) {
275                id -> Integer,
276                name -> Text,
277            }
278        }
279
280        let conn = &mut crate::test_helpers::connection();
281
282        crate::sql_query(
283            "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
284        )
285        .execute(conn)
286        .unwrap();
287
288        crate::insert_into(users::table)
289            .values(vec![
290                (users::id.eq(1), users::name.eq("Sean")),
291                (users::id.eq(2), users::name.eq("Tess")),
292            ])
293            .execute(conn)
294            .unwrap();
295
296        let users_by_default_mode = users::table
297            .select(users::name)
298            .load_iter::<String, DefaultLoadingMode>(conn)
299            .unwrap()
300            .collect::<QueryResult<Vec<_>>>()
301            .unwrap();
302        let users_row_by_row = users::table
303            .select(users::name)
304            .load_iter::<String, PgRowByRowLoadingMode>(conn)
305            .unwrap()
306            .collect::<QueryResult<Vec<_>>>()
307            .unwrap();
308        assert_eq!(users_by_default_mode, users_row_by_row);
309        assert_eq!(users_by_default_mode, vec!["Sean", "Tess"]);
310    }
311
312    #[test]
313    fn fun_with_row_iters_row_by_row() {
314        crate::table! {
315            #[allow(unused_parens)]
316            users(id) {
317                id -> Integer,
318                name -> Text,
319            }
320        }
321
322        use crate::connection::LoadConnection;
323        use crate::deserialize::{FromSql, FromSqlRow};
324        use crate::pg::Pg;
325        use crate::prelude::*;
326        use crate::row::{Field, Row};
327        use crate::sql_types;
328
329        let conn = &mut crate::test_helpers::connection();
330
331        crate::sql_query(
332            "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
333        )
334        .execute(conn)
335        .unwrap();
336
337        crate::insert_into(users::table)
338            .values(vec![
339                (users::id.eq(1), users::name.eq("Sean")),
340                (users::id.eq(2), users::name.eq("Tess")),
341            ])
342            .execute(conn)
343            .unwrap();
344
345        let query = users::table.select((users::id, users::name));
346
347        let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
348
349        let row_iter = LoadConnection::<PgRowByRowLoadingMode>::load(conn, query).unwrap();
350        for (row, expected) in row_iter.zip(&expected) {
351            let row = row.unwrap();
352
353            let deserialized = <(i32, String) as FromSqlRow<
354                (sql_types::Integer, sql_types::Text),
355                _,
356            >>::build_from_row(&row)
357            .unwrap();
358
359            assert_eq!(&deserialized, expected);
360        }
361
362        {
363            let collected_rows = LoadConnection::<PgRowByRowLoadingMode>::load(conn, query)
364                .unwrap()
365                .collect::<Vec<_>>();
366
367            for (row, expected) in collected_rows.iter().zip(&expected) {
368                let deserialized = row
369                    .as_ref()
370                    .map(|row| {
371                        <(i32, String) as FromSqlRow<
372                                (sql_types::Integer, sql_types::Text),
373                            _,
374                            >>::build_from_row(row).unwrap()
375                    })
376                    .unwrap();
377
378                assert_eq!(&deserialized, expected);
379            }
380        }
381
382        let mut row_iter = LoadConnection::<PgRowByRowLoadingMode>::load(conn, query).unwrap();
383
384        let first_row = row_iter.next().unwrap().unwrap();
385        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
386        let first_values = (first_fields.0.value(), first_fields.1.value());
387
388        let second_row = row_iter.next().unwrap().unwrap();
389        let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
390        let second_values = (second_fields.0.value(), second_fields.1.value());
391
392        assert!(row_iter.next().is_none());
393
394        assert_eq!(
395            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(first_values.0).unwrap(),
396            expected[0].0
397        );
398        assert_eq!(
399            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(first_values.1).unwrap(),
400            expected[0].1
401        );
402
403        assert_eq!(
404            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(second_values.0).unwrap(),
405            expected[1].0
406        );
407        assert_eq!(
408            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(second_values.1).unwrap(),
409            expected[1].1
410        );
411
412        let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
413        let first_values = (first_fields.0.value(), first_fields.1.value());
414
415        assert_eq!(
416            <i32 as FromSql<sql_types::Integer, Pg>>::from_nullable_sql(first_values.0).unwrap(),
417            expected[0].0
418        );
419        assert_eq!(
420            <String as FromSql<sql_types::Text, Pg>>::from_nullable_sql(first_values.1).unwrap(),
421            expected[0].1
422        );
423    }
424}