diesel/mysql/connection/stmt/
iterator.rs

1#![allow(unsafe_code)] // module uses ffi
2use std::cell::{Ref, RefCell};
3use std::rc::Rc;
4
5use super::{OutputBinds, Statement, StatementMetadata, StatementUse};
6use crate::backend::Backend;
7use crate::connection::statement_cache::MaybeCached;
8use crate::mysql::{Mysql, MysqlType};
9use crate::result::QueryResult;
10use crate::row::*;
11
12#[allow(missing_debug_implementations)]
13pub struct StatementIterator<'a> {
14    stmt: StatementUse<'a>,
15    last_row: Rc<RefCell<PrivateMysqlRow>>,
16    metadata: Rc<StatementMetadata>,
17    len: usize,
18}
19
20impl<'a> StatementIterator<'a> {
21    pub fn from_stmt(
22        stmt: MaybeCached<'a, Statement>,
23        types: &[Option<MysqlType>],
24    ) -> QueryResult<Self> {
25        let metadata = stmt.metadata()?;
26
27        let mut output_binds = OutputBinds::from_output_types(types, &metadata);
28
29        let mut stmt = stmt.execute_statement(&mut output_binds)?;
30        let size = unsafe { stmt.result_size() }?;
31
32        Ok(StatementIterator {
33            metadata: Rc::new(metadata),
34            last_row: Rc::new(RefCell::new(PrivateMysqlRow::Direct(output_binds))),
35            len: size,
36            stmt,
37        })
38    }
39}
40
41impl Iterator for StatementIterator<'_> {
42    type Item = QueryResult<MysqlRow>;
43
44    fn next(&mut self) -> Option<Self::Item> {
45        // check if we own the only instance of the bind buffer
46        // if that's the case we can reuse the underlying allocations
47        // if that's not the case, we need to copy the output bind buffers
48        // to somewhere else
49        let res = if let Some(binds) = Rc::get_mut(&mut self.last_row) {
50            if let PrivateMysqlRow::Direct(ref mut binds) = RefCell::get_mut(binds) {
51                self.stmt.populate_row_buffers(binds)
52            } else {
53                // any other state than `PrivateMysqlRow::Direct` is invalid here
54                // and should not happen. If this ever happens this is a logic error
55                // in the code above
56                unreachable!(
57                    "You've reached an impossible internal state. \
58                     If you ever see this error message please open \
59                     an issue at https://github.com/diesel-rs/diesel \
60                     providing example code how to trigger this error."
61                )
62            }
63        } else {
64            // The shared bind buffer is in use by someone else,
65            // this means we copy out the values and replace the used reference
66            // by the copied values. After this we can advance the statement
67            // another step
68            let mut last_row = {
69                let mut last_row = match self.last_row.try_borrow_mut() {
70                    Ok(o) => o,
71                    Err(_e) => {
72                        return Some(Err(crate::result::Error::DeserializationError(
73                            "Failed to reborrow row. Try to release any `MysqlField` or `MysqlValue` \
74                             that exists at this point"
75                                .into(),
76                        )));
77                    }
78                };
79                let last_row = &mut *last_row;
80                let duplicated = last_row.duplicate();
81                std::mem::replace(last_row, duplicated)
82            };
83            let res = if let PrivateMysqlRow::Direct(ref mut binds) = last_row {
84                self.stmt.populate_row_buffers(binds)
85            } else {
86                // any other state than `PrivateMysqlRow::Direct` is invalid here
87                // and should not happen. If this ever happens this is a logic error
88                // in the code above
89                unreachable!(
90                    "You've reached an impossible internal state. \
91                     If you ever see this error message please open \
92                     an issue at https://github.com/diesel-rs/diesel \
93                     providing example code how to trigger this error."
94                )
95            };
96            self.last_row = Rc::new(RefCell::new(last_row));
97            res
98        };
99
100        match res {
101            Ok(Some(())) => {
102                self.len = self.len.saturating_sub(1);
103                Some(Ok(MysqlRow {
104                    metadata: self.metadata.clone(),
105                    row: self.last_row.clone(),
106                }))
107            }
108            Ok(None) => None,
109            Err(e) => {
110                self.len = self.len.saturating_sub(1);
111                Some(Err(e))
112            }
113        }
114    }
115
116    fn size_hint(&self) -> (usize, Option<usize>) {
117        (self.len(), Some(self.len()))
118    }
119
120    fn count(self) -> usize
121    where
122        Self: Sized,
123    {
124        self.len()
125    }
126}
127
128impl ExactSizeIterator for StatementIterator<'_> {
129    fn len(&self) -> usize {
130        self.len
131    }
132}
133
134#[derive(Clone)]
135#[allow(missing_debug_implementations)]
136pub struct MysqlRow {
137    row: Rc<RefCell<PrivateMysqlRow>>,
138    metadata: Rc<StatementMetadata>,
139}
140
141enum PrivateMysqlRow {
142    Direct(OutputBinds),
143    Copied(OutputBinds),
144}
145
146impl PrivateMysqlRow {
147    fn duplicate(&self) -> Self {
148        match self {
149            Self::Copied(b) | Self::Direct(b) => Self::Copied(b.clone()),
150        }
151    }
152}
153
154impl RowSealed for MysqlRow {}
155
156impl<'a> Row<'a, Mysql> for MysqlRow {
157    type Field<'f>
158        = MysqlField<'f>
159    where
160        'a: 'f,
161        Self: 'f;
162    type InnerPartialRow = Self;
163
164    fn field_count(&self) -> usize {
165        self.metadata.fields().len()
166    }
167
168    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
169    where
170        'a: 'b,
171        Self: RowIndex<I>,
172    {
173        let idx = self.idx(idx)?;
174        Some(MysqlField {
175            binds: self.row.borrow(),
176            metadata: self.metadata.clone(),
177            idx,
178        })
179    }
180
181    fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
182        PartialRow::new(self, range)
183    }
184}
185
186impl RowIndex<usize> for MysqlRow {
187    fn idx(&self, idx: usize) -> Option<usize> {
188        if idx < self.field_count() {
189            Some(idx)
190        } else {
191            None
192        }
193    }
194}
195
196impl<'a> RowIndex<&'a str> for MysqlRow {
197    fn idx(&self, idx: &'a str) -> Option<usize> {
198        self.metadata
199            .fields()
200            .iter()
201            .enumerate()
202            .find(|(_, field_meta)| field_meta.field_name() == Some(idx))
203            .map(|(idx, _)| idx)
204    }
205}
206
207#[allow(missing_debug_implementations)]
208pub struct MysqlField<'a> {
209    binds: Ref<'a, PrivateMysqlRow>,
210    metadata: Rc<StatementMetadata>,
211    idx: usize,
212}
213
214impl<'a> Field<'a, Mysql> for MysqlField<'a> {
215    fn field_name(&self) -> Option<&str> {
216        self.metadata.fields()[self.idx].field_name()
217    }
218
219    fn is_null(&self) -> bool {
220        match &*self.binds {
221            PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].is_null(),
222        }
223    }
224
225    fn value(&self) -> Option<<Mysql as Backend>::RawValue<'_>> {
226        match &*self.binds {
227            PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].value(),
228        }
229    }
230}
231
232#[cfg(test)]
233#[diesel_test_helper::test]
234#[allow(clippy::drop_non_drop)] // we want to explicitly extend lifetimes here
235fn fun_with_row_iters() {
236    crate::table! {
237        users(id) {
238            id -> Integer,
239            name -> Text,
240        }
241    }
242
243    use crate::connection::LoadConnection;
244    use crate::deserialize::{FromSql, FromSqlRow};
245    use crate::prelude::*;
246    use crate::row::{Field, Row};
247    use crate::sql_types;
248
249    let conn = &mut crate::test_helpers::connection();
250
251    crate::sql_query(
252        "CREATE TEMPORARY TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
253    )
254    .execute(conn)
255    .unwrap();
256
257    crate::insert_into(users::table)
258        .values(vec![
259            (users::id.eq(1), users::name.eq("Sean")),
260            (users::id.eq(2), users::name.eq("Tess")),
261        ])
262        .execute(conn)
263        .unwrap();
264
265    let query = users::table.select((users::id, users::name));
266
267    let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
268
269    {
270        let row_iter = conn.load(query).unwrap();
271        for (row, expected) in row_iter.zip(&expected) {
272            let row = row.unwrap();
273
274            let deserialized = <(i32, String) as FromSqlRow<
275                (sql_types::Integer, sql_types::Text),
276                _,
277            >>::build_from_row(&row)
278            .unwrap();
279
280            assert_eq!(&deserialized, expected);
281        }
282    }
283
284    {
285        let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();
286        assert_eq!(collected_rows.len(), 2);
287        for (row, expected) in collected_rows.iter().zip(&expected) {
288            let deserialized = row
289                .as_ref()
290                .map(|row| {
291                    <(i32, String) as FromSqlRow<
292                            (sql_types::Integer, sql_types::Text),
293                        _,
294                        >>::build_from_row(row).unwrap()
295                })
296                .unwrap();
297            assert_eq!(&deserialized, expected);
298        }
299    }
300
301    let mut row_iter = conn.load(query).unwrap();
302
303    let first_row = row_iter.next().unwrap().unwrap();
304    let first_fields = (
305        Row::get(&first_row, 0).unwrap(),
306        Row::get(&first_row, 1).unwrap(),
307    );
308    let first_values = (first_fields.0.value(), first_fields.1.value());
309
310    assert!(row_iter.next().unwrap().is_err());
311    std::mem::drop(first_values);
312    assert!(row_iter.next().unwrap().is_err());
313    std::mem::drop(first_fields);
314
315    let second_row = row_iter.next().unwrap().unwrap();
316    let second_fields = (
317        Row::get(&second_row, 0).unwrap(),
318        Row::get(&second_row, 1).unwrap(),
319    );
320    let second_values = (second_fields.0.value(), second_fields.1.value());
321
322    assert!(row_iter.next().unwrap().is_err());
323    std::mem::drop(second_values);
324    assert!(row_iter.next().unwrap().is_err());
325    std::mem::drop(second_fields);
326
327    assert!(row_iter.next().is_none());
328
329    let first_fields = (
330        Row::get(&first_row, 0).unwrap(),
331        Row::get(&first_row, 1).unwrap(),
332    );
333    let second_fields = (
334        Row::get(&second_row, 0).unwrap(),
335        Row::get(&second_row, 1).unwrap(),
336    );
337
338    let first_values = (first_fields.0.value(), first_fields.1.value());
339    let second_values = (second_fields.0.value(), second_fields.1.value());
340
341    assert_eq!(
342        <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(first_values.0).unwrap(),
343        expected[0].0
344    );
345    assert_eq!(
346        <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(first_values.1).unwrap(),
347        expected[0].1
348    );
349
350    assert_eq!(
351        <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(second_values.0).unwrap(),
352        expected[1].0
353    );
354    assert_eq!(
355        <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(second_values.1).unwrap(),
356        expected[1].1
357    );
358
359    let first_fields = (
360        Row::get(&first_row, 0).unwrap(),
361        Row::get(&first_row, 1).unwrap(),
362    );
363    let first_values = (first_fields.0.value(), first_fields.1.value());
364
365    assert_eq!(
366        <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(first_values.0).unwrap(),
367        expected[0].0
368    );
369    assert_eq!(
370        <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(first_values.1).unwrap(),
371        expected[0].1
372    );
373}