diesel/mysql/connection/stmt/
iterator.rs

1#![allow(unsafe_code)] // module uses ffi
2use std::cell::{Ref, RefCell};
3use std::rc::Rc;
4
5use super::{OutputBinds, Statement, StatementMetadata, StatementUse};
6use crate::backend::Backend;
7use crate::connection::statement_cache::MaybeCached;
8use crate::mysql::{Mysql, MysqlType};
9use crate::result::QueryResult;
10use crate::row::*;
11
12#[allow(missing_debug_implementations)]
13pub struct StatementIterator<'a> {
14    stmt: StatementUse<'a>,
15    last_row: Rc<RefCell<PrivateMysqlRow>>,
16    metadata: Rc<StatementMetadata>,
17    len: usize,
18}
19
20impl<'a> StatementIterator<'a> {
21    pub fn from_stmt(
22        stmt: MaybeCached<'a, Statement>,
23        types: &[Option<MysqlType>],
24    ) -> QueryResult<Self> {
25        let metadata = stmt.metadata()?;
26
27        let mut output_binds = OutputBinds::from_output_types(types, &metadata);
28
29        let mut stmt = stmt.execute_statement(&mut output_binds)?;
30        let size = unsafe { stmt.result_size() }?;
31
32        Ok(StatementIterator {
33            metadata: Rc::new(metadata),
34            last_row: Rc::new(RefCell::new(PrivateMysqlRow::Direct(output_binds))),
35            len: size,
36            stmt,
37        })
38    }
39}
40
41impl Iterator for StatementIterator<'_> {
42    type Item = QueryResult<MysqlRow>;
43
44    fn next(&mut self) -> Option<Self::Item> {
45        // check if we own the only instance of the bind buffer
46        // if that's the case we can reuse the underlying allocations
47        // if that's not the case, we need to copy the output bind buffers
48        // to somewhere else
49        let res = if let Some(binds) = Rc::get_mut(&mut self.last_row) {
50            if let PrivateMysqlRow::Direct(ref mut binds) = RefCell::get_mut(binds) {
51                self.stmt.populate_row_buffers(binds)
52            } else {
53                // any other state than `PrivateMysqlRow::Direct` is invalid here
54                // and should not happen. If this ever happens this is a logic error
55                // in the code above
56                unreachable!(
57                    "You've reached an impossible internal state. \
58                     If you ever see this error message please open \
59                     an issue at https://github.com/diesel-rs/diesel \
60                     providing example code how to trigger this error."
61                )
62            }
63        } else {
64            // The shared bind buffer is in use by someone else,
65            // this means we copy out the values and replace the used reference
66            // by the copied values. After this we can advance the statement
67            // another step
68            let mut last_row = {
69                let mut last_row = match self.last_row.try_borrow_mut() {
70                    Ok(o) => o,
71                    Err(_e) => {
72                        return Some(Err(crate::result::Error::DeserializationError(
73                            "Failed to reborrow row. Try to release any `MysqlField` or `MysqlValue` \
74                             that exists at this point"
75                                .into(),
76                        )));
77                    }
78                };
79                let last_row = &mut *last_row;
80                let duplicated = last_row.duplicate();
81                std::mem::replace(last_row, duplicated)
82            };
83            let res = if let PrivateMysqlRow::Direct(ref mut binds) = last_row {
84                self.stmt.populate_row_buffers(binds)
85            } else {
86                // any other state than `PrivateMysqlRow::Direct` is invalid here
87                // and should not happen. If this ever happens this is a logic error
88                // in the code above
89                unreachable!(
90                    "You've reached an impossible internal state. \
91                     If you ever see this error message please open \
92                     an issue at https://github.com/diesel-rs/diesel \
93                     providing example code how to trigger this error."
94                )
95            };
96            self.last_row = Rc::new(RefCell::new(last_row));
97            res
98        };
99
100        match res {
101            Ok(Some(())) => {
102                self.len = self.len.saturating_sub(1);
103                Some(Ok(MysqlRow {
104                    metadata: self.metadata.clone(),
105                    row: self.last_row.clone(),
106                }))
107            }
108            Ok(None) => None,
109            Err(e) => {
110                self.len = self.len.saturating_sub(1);
111                Some(Err(e))
112            }
113        }
114    }
115
116    fn size_hint(&self) -> (usize, Option<usize>) {
117        (self.len(), Some(self.len()))
118    }
119
120    fn count(self) -> usize
121    where
122        Self: Sized,
123    {
124        self.len()
125    }
126}
127
128impl ExactSizeIterator for StatementIterator<'_> {
129    fn len(&self) -> usize {
130        self.len
131    }
132}
133
134#[derive(Clone)]
135#[allow(missing_debug_implementations)]
136pub struct MysqlRow {
137    row: Rc<RefCell<PrivateMysqlRow>>,
138    metadata: Rc<StatementMetadata>,
139}
140
141enum PrivateMysqlRow {
142    Direct(OutputBinds),
143    Copied(OutputBinds),
144}
145
146impl PrivateMysqlRow {
147    fn duplicate(&self) -> Self {
148        match self {
149            Self::Copied(b) | Self::Direct(b) => Self::Copied(b.clone()),
150        }
151    }
152}
153
154impl RowSealed for MysqlRow {}
155
156impl<'a> Row<'a, Mysql> for MysqlRow {
157    type Field<'f>
158        = MysqlField<'f>
159    where
160        'a: 'f,
161        Self: 'f;
162    type InnerPartialRow = Self;
163
164    fn field_count(&self) -> usize {
165        self.metadata.fields().len()
166    }
167
168    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
169    where
170        'a: 'b,
171        Self: RowIndex<I>,
172    {
173        let idx = self.idx(idx)?;
174        Some(MysqlField {
175            binds: self.row.borrow(),
176            metadata: self.metadata.clone(),
177            idx,
178        })
179    }
180
181    fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
182        PartialRow::new(self, range)
183    }
184}
185
186impl RowIndex<usize> for MysqlRow {
187    fn idx(&self, idx: usize) -> Option<usize> {
188        if idx < self.field_count() {
189            Some(idx)
190        } else {
191            None
192        }
193    }
194}
195
196impl<'a> RowIndex<&'a str> for MysqlRow {
197    fn idx(&self, idx: &'a str) -> Option<usize> {
198        self.metadata
199            .fields()
200            .iter()
201            .enumerate()
202            .find(|(_, field_meta)| field_meta.field_name() == Some(idx))
203            .map(|(idx, _)| idx)
204    }
205}
206
207#[allow(missing_debug_implementations)]
208pub struct MysqlField<'a> {
209    binds: Ref<'a, PrivateMysqlRow>,
210    metadata: Rc<StatementMetadata>,
211    idx: usize,
212}
213
214impl<'a> Field<'a, Mysql> for MysqlField<'a> {
215    fn field_name(&self) -> Option<&str> {
216        self.metadata.fields()[self.idx].field_name()
217    }
218
219    fn is_null(&self) -> bool {
220        match &*self.binds {
221            PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].is_null(),
222        }
223    }
224
225    fn value(&self) -> Option<<Mysql as Backend>::RawValue<'_>> {
226        match &*self.binds {
227            PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].value(),
228        }
229    }
230}
231
232#[test]
233#[allow(clippy::drop_non_drop)] // we want to explicitly extend lifetimes here
234fn fun_with_row_iters() {
235    crate::table! {
236        #[allow(unused_parens)]
237        users(id) {
238            id -> Integer,
239            name -> Text,
240        }
241    }
242
243    use crate::connection::LoadConnection;
244    use crate::deserialize::{FromSql, FromSqlRow};
245    use crate::prelude::*;
246    use crate::row::{Field, Row};
247    use crate::sql_types;
248
249    let conn = &mut crate::test_helpers::connection();
250
251    crate::sql_query(
252        "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
253    )
254    .execute(conn)
255    .unwrap();
256    crate::sql_query("DELETE FROM users;")
257        .execute(conn)
258        .unwrap();
259
260    crate::insert_into(users::table)
261        .values(vec![
262            (users::id.eq(1), users::name.eq("Sean")),
263            (users::id.eq(2), users::name.eq("Tess")),
264        ])
265        .execute(conn)
266        .unwrap();
267
268    let query = users::table.select((users::id, users::name));
269
270    let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
271
272    {
273        let row_iter = conn.load(query).unwrap();
274        for (row, expected) in row_iter.zip(&expected) {
275            let row = row.unwrap();
276
277            let deserialized = <(i32, String) as FromSqlRow<
278                (sql_types::Integer, sql_types::Text),
279                _,
280            >>::build_from_row(&row)
281            .unwrap();
282
283            assert_eq!(&deserialized, expected);
284        }
285    }
286
287    {
288        let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();
289        assert_eq!(collected_rows.len(), 2);
290        for (row, expected) in collected_rows.iter().zip(&expected) {
291            let deserialized = row
292                .as_ref()
293                .map(|row| {
294                    <(i32, String) as FromSqlRow<
295                            (sql_types::Integer, sql_types::Text),
296                        _,
297                        >>::build_from_row(row).unwrap()
298                })
299                .unwrap();
300            assert_eq!(&deserialized, expected);
301        }
302    }
303
304    let mut row_iter = conn.load(query).unwrap();
305
306    let first_row = row_iter.next().unwrap().unwrap();
307    let first_fields = (
308        Row::get(&first_row, 0).unwrap(),
309        Row::get(&first_row, 1).unwrap(),
310    );
311    let first_values = (first_fields.0.value(), first_fields.1.value());
312
313    assert!(row_iter.next().unwrap().is_err());
314    std::mem::drop(first_values);
315    assert!(row_iter.next().unwrap().is_err());
316    std::mem::drop(first_fields);
317
318    let second_row = row_iter.next().unwrap().unwrap();
319    let second_fields = (
320        Row::get(&second_row, 0).unwrap(),
321        Row::get(&second_row, 1).unwrap(),
322    );
323    let second_values = (second_fields.0.value(), second_fields.1.value());
324
325    assert!(row_iter.next().unwrap().is_err());
326    std::mem::drop(second_values);
327    assert!(row_iter.next().unwrap().is_err());
328    std::mem::drop(second_fields);
329
330    assert!(row_iter.next().is_none());
331
332    let first_fields = (
333        Row::get(&first_row, 0).unwrap(),
334        Row::get(&first_row, 1).unwrap(),
335    );
336    let second_fields = (
337        Row::get(&second_row, 0).unwrap(),
338        Row::get(&second_row, 1).unwrap(),
339    );
340
341    let first_values = (first_fields.0.value(), first_fields.1.value());
342    let second_values = (second_fields.0.value(), second_fields.1.value());
343
344    assert_eq!(
345        <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(first_values.0).unwrap(),
346        expected[0].0
347    );
348    assert_eq!(
349        <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(first_values.1).unwrap(),
350        expected[0].1
351    );
352
353    assert_eq!(
354        <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(second_values.0).unwrap(),
355        expected[1].0
356    );
357    assert_eq!(
358        <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(second_values.1).unwrap(),
359        expected[1].1
360    );
361
362    let first_fields = (
363        Row::get(&first_row, 0).unwrap(),
364        Row::get(&first_row, 1).unwrap(),
365    );
366    let first_values = (first_fields.0.value(), first_fields.1.value());
367
368    assert_eq!(
369        <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(first_values.0).unwrap(),
370        expected[0].0
371    );
372    assert_eq!(
373        <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(first_values.1).unwrap(),
374        expected[0].1
375    );
376}