Skip to main content

diesel/mysql/connection/stmt/
iterator.rs

1#![allow(unsafe_code)] // module uses ffi
2use alloc::rc::Rc;
3use core::cell::{Ref, RefCell};
4
5use super::{OutputBinds, Statement, StatementMetadata, StatementUse};
6use crate::backend::Backend;
7use crate::connection::statement_cache::MaybeCached;
8use crate::mysql::{Mysql, MysqlType};
9use crate::result::QueryResult;
10use crate::row::*;
11
12#[allow(missing_debug_implementations)]
13pub struct StatementIterator<'a> {
14    stmt: StatementUse<'a>,
15    last_row: Rc<RefCell<PrivateMysqlRow>>,
16    metadata: Rc<StatementMetadata>,
17    len: usize,
18}
19
20impl<'a> StatementIterator<'a> {
21    pub fn from_stmt(
22        stmt: MaybeCached<'a, Statement>,
23        types: &[Option<MysqlType>],
24    ) -> QueryResult<Self> {
25        let metadata = stmt.metadata()?;
26
27        let mut output_binds = OutputBinds::from_output_types(types, &metadata)
28            .map_err(crate::result::Error::DeserializationError)?;
29
30        let mut stmt = stmt.execute_statement(&mut output_binds)?;
31        let size = unsafe { stmt.result_size() }?;
32
33        Ok(StatementIterator {
34            metadata: Rc::new(metadata),
35            last_row: Rc::new(RefCell::new(PrivateMysqlRow::Direct(output_binds))),
36            len: size,
37            stmt,
38        })
39    }
40}
41
42impl Iterator for StatementIterator<'_> {
43    type Item = QueryResult<MysqlRow>;
44
45    fn next(&mut self) -> Option<Self::Item> {
46        // check if we own the only instance of the bind buffer
47        // if that's the case we can reuse the underlying allocations
48        // if that's not the case, we need to copy the output bind buffers
49        // to somewhere else
50        let res = if let Some(binds) = Rc::get_mut(&mut self.last_row) {
51            if let PrivateMysqlRow::Direct(binds) = RefCell::get_mut(binds) {
52                self.stmt.populate_row_buffers(binds)
53            } else {
54                // any other state than `PrivateMysqlRow::Direct` is invalid here
55                // and should not happen. If this ever happens this is a logic error
56                // in the code above
57                {
    ::core::panicking::panic_fmt(format_args!("internal error: entered unreachable code: {0}",
            format_args!("You\'ve reached an impossible internal state. If you ever see this error message please open an issue at https://github.com/diesel-rs/diesel providing example code how to trigger this error.")));
}unreachable!(
58                    "You've reached an impossible internal state. \
59                     If you ever see this error message please open \
60                     an issue at https://github.com/diesel-rs/diesel \
61                     providing example code how to trigger this error."
62                )
63            }
64        } else {
65            // The shared bind buffer is in use by someone else,
66            // this means we copy out the values and replace the used reference
67            // by the copied values. After this we can advance the statement
68            // another step
69            let mut last_row = {
70                let mut last_row = match self.last_row.try_borrow_mut() {
71                    Ok(o) => o,
72                    Err(_e) => {
73                        return Some(Err(crate::result::Error::DeserializationError(
74                            "Failed to reborrow row. Try to release any `MysqlField` or `MysqlValue` \
75                             that exists at this point"
76                                .into(),
77                        )));
78                    }
79                };
80                let last_row = &mut *last_row;
81                let duplicated = last_row.duplicate();
82                core::mem::replace(last_row, duplicated)
83            };
84            let res = if let PrivateMysqlRow::Direct(ref mut binds) = last_row {
85                self.stmt.populate_row_buffers(binds)
86            } else {
87                // any other state than `PrivateMysqlRow::Direct` is invalid here
88                // and should not happen. If this ever happens this is a logic error
89                // in the code above
90                {
    ::core::panicking::panic_fmt(format_args!("internal error: entered unreachable code: {0}",
            format_args!("You\'ve reached an impossible internal state. If you ever see this error message please open an issue at https://github.com/diesel-rs/diesel providing example code how to trigger this error.")));
}unreachable!(
91                    "You've reached an impossible internal state. \
92                     If you ever see this error message please open \
93                     an issue at https://github.com/diesel-rs/diesel \
94                     providing example code how to trigger this error."
95                )
96            };
97            self.last_row = Rc::new(RefCell::new(last_row));
98            res
99        };
100
101        match res {
102            Ok(Some(())) => {
103                self.len = self.len.saturating_sub(1);
104                Some(Ok(MysqlRow {
105                    metadata: self.metadata.clone(),
106                    row: self.last_row.clone(),
107                }))
108            }
109            Ok(None) => None,
110            Err(e) => {
111                self.len = self.len.saturating_sub(1);
112                Some(Err(e))
113            }
114        }
115    }
116
117    fn size_hint(&self) -> (usize, Option<usize>) {
118        (self.len(), Some(self.len()))
119    }
120
121    fn count(self) -> usize
122    where
123        Self: Sized,
124    {
125        self.len()
126    }
127}
128
129impl ExactSizeIterator for StatementIterator<'_> {
130    fn len(&self) -> usize {
131        self.len
132    }
133}
134
135#[derive(#[automatically_derived]
#[allow(missing_debug_implementations)]
impl ::core::clone::Clone for MysqlRow {
    #[inline]
    fn clone(&self) -> MysqlRow {
        MysqlRow {
            row: ::core::clone::Clone::clone(&self.row),
            metadata: ::core::clone::Clone::clone(&self.metadata),
        }
    }
}Clone)]
136#[allow(missing_debug_implementations)]
137pub struct MysqlRow {
138    row: Rc<RefCell<PrivateMysqlRow>>,
139    metadata: Rc<StatementMetadata>,
140}
141
142enum PrivateMysqlRow {
143    Direct(OutputBinds),
144    Copied(OutputBinds),
145}
146
147impl PrivateMysqlRow {
148    fn duplicate(&self) -> Self {
149        match self {
150            Self::Copied(b) | Self::Direct(b) => Self::Copied(b.clone()),
151        }
152    }
153}
154
155impl RowSealed for MysqlRow {}
156
157impl<'a> Row<'a, Mysql> for MysqlRow {
158    type Field<'f>
159        = MysqlField<'f>
160    where
161        'a: 'f,
162        Self: 'f;
163    type InnerPartialRow = Self;
164
165    fn field_count(&self) -> usize {
166        self.metadata.fields().len()
167    }
168
169    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
170    where
171        'a: 'b,
172        Self: RowIndex<I>,
173    {
174        let idx = self.idx(idx)?;
175        Some(MysqlField {
176            binds: self.row.borrow(),
177            metadata: self.metadata.clone(),
178            idx,
179        })
180    }
181
182    fn partial_row(&self, range: core::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
183        PartialRow::new(self, range)
184    }
185}
186
187impl RowIndex<usize> for MysqlRow {
188    fn idx(&self, idx: usize) -> Option<usize> {
189        if idx < self.field_count() {
190            Some(idx)
191        } else {
192            None
193        }
194    }
195}
196
197impl<'a> RowIndex<&'a str> for MysqlRow {
198    fn idx(&self, idx: &'a str) -> Option<usize> {
199        self.metadata
200            .fields()
201            .iter()
202            .enumerate()
203            .find(|(_, field_meta)| field_meta.field_name() == Some(idx))
204            .map(|(idx, _)| idx)
205    }
206}
207
208#[allow(missing_debug_implementations)]
209pub struct MysqlField<'a> {
210    binds: Ref<'a, PrivateMysqlRow>,
211    metadata: Rc<StatementMetadata>,
212    idx: usize,
213}
214
215impl<'a> Field<'a, Mysql> for MysqlField<'a> {
216    fn field_name(&self) -> Option<&str> {
217        self.metadata.fields()[self.idx].field_name()
218    }
219
220    fn is_null(&self) -> bool {
221        match &*self.binds {
222            PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].is_null(),
223        }
224    }
225
226    fn value(&self) -> Option<<Mysql as Backend>::RawValue<'_>> {
227        match &*self.binds {
228            PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].value(),
229        }
230    }
231}
232
233#[cfg(test)]
234#[diesel_test_helper::test]
235#[allow(clippy::drop_non_drop)] // we want to explicitly extend lifetimes here
236fn fun_with_row_iters() {
237    crate::table! {
238        users(id) {
239            id -> Integer,
240            name -> Text,
241        }
242    }
243
244    use crate::connection::LoadConnection;
245    use crate::deserialize::{FromSql, FromSqlRow};
246    use crate::prelude::*;
247    use crate::row::{Field, Row};
248    use crate::sql_types;
249
250    let conn = &mut crate::test_helpers::connection();
251
252    crate::sql_query(
253        "CREATE TEMPORARY TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
254    )
255    .execute(conn)
256    .unwrap();
257
258    crate::insert_into(users::table)
259        .values(vec![
260            (users::id.eq(1), users::name.eq("Sean")),
261            (users::id.eq(2), users::name.eq("Tess")),
262        ])
263        .execute(conn)
264        .unwrap();
265
266    let query = users::table.select((users::id, users::name));
267
268    let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
269
270    {
271        let row_iter = conn.load(query).unwrap();
272        for (row, expected) in row_iter.zip(&expected) {
273            let row = row.unwrap();
274
275            let deserialized = <(i32, String) as FromSqlRow<
276                (sql_types::Integer, sql_types::Text),
277                _,
278            >>::build_from_row(&row)
279            .unwrap();
280
281            assert_eq!(&deserialized, expected);
282        }
283    }
284
285    {
286        let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();
287        assert_eq!(collected_rows.len(), 2);
288        for (row, expected) in collected_rows.iter().zip(&expected) {
289            let deserialized = row
290                .as_ref()
291                .map(|row| {
292                    <(i32, String) as FromSqlRow<
293                            (sql_types::Integer, sql_types::Text),
294                        _,
295                        >>::build_from_row(row).unwrap()
296                })
297                .unwrap();
298            assert_eq!(&deserialized, expected);
299        }
300    }
301
302    let mut row_iter = conn.load(query).unwrap();
303
304    let first_row = row_iter.next().unwrap().unwrap();
305    let first_fields = (
306        Row::get(&first_row, 0).unwrap(),
307        Row::get(&first_row, 1).unwrap(),
308    );
309    let first_values = (first_fields.0.value(), first_fields.1.value());
310
311    assert!(row_iter.next().unwrap().is_err());
312    std::mem::drop(first_values);
313    assert!(row_iter.next().unwrap().is_err());
314    std::mem::drop(first_fields);
315
316    let second_row = row_iter.next().unwrap().unwrap();
317    let second_fields = (
318        Row::get(&second_row, 0).unwrap(),
319        Row::get(&second_row, 1).unwrap(),
320    );
321    let second_values = (second_fields.0.value(), second_fields.1.value());
322
323    assert!(row_iter.next().unwrap().is_err());
324    std::mem::drop(second_values);
325    assert!(row_iter.next().unwrap().is_err());
326    std::mem::drop(second_fields);
327
328    assert!(row_iter.next().is_none());
329
330    let first_fields = (
331        Row::get(&first_row, 0).unwrap(),
332        Row::get(&first_row, 1).unwrap(),
333    );
334    let second_fields = (
335        Row::get(&second_row, 0).unwrap(),
336        Row::get(&second_row, 1).unwrap(),
337    );
338
339    let first_values = (first_fields.0.value(), first_fields.1.value());
340    let second_values = (second_fields.0.value(), second_fields.1.value());
341
342    assert_eq!(
343        <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(first_values.0).unwrap(),
344        expected[0].0
345    );
346    assert_eq!(
347        <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(first_values.1).unwrap(),
348        expected[0].1
349    );
350
351    assert_eq!(
352        <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(second_values.0).unwrap(),
353        expected[1].0
354    );
355    assert_eq!(
356        <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(second_values.1).unwrap(),
357        expected[1].1
358    );
359
360    let first_fields = (
361        Row::get(&first_row, 0).unwrap(),
362        Row::get(&first_row, 1).unwrap(),
363    );
364    let first_values = (first_fields.0.value(), first_fields.1.value());
365
366    assert_eq!(
367        <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(first_values.0).unwrap(),
368        expected[0].0
369    );
370    assert_eq!(
371        <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(first_values.1).unwrap(),
372        expected[0].1
373    );
374}