diesel/mysql/connection/stmt/
iterator.rs
1#![allow(unsafe_code)] use std::cell::{Ref, RefCell};
3use std::rc::Rc;
4
5use super::{OutputBinds, Statement, StatementMetadata, StatementUse};
6use crate::backend::Backend;
7use crate::connection::statement_cache::MaybeCached;
8use crate::mysql::{Mysql, MysqlType};
9use crate::result::QueryResult;
10use crate::row::*;
11
12#[allow(missing_debug_implementations)]
13pub struct StatementIterator<'a> {
14 stmt: StatementUse<'a>,
15 last_row: Rc<RefCell<PrivateMysqlRow>>,
16 metadata: Rc<StatementMetadata>,
17 len: usize,
18}
19
20impl<'a> StatementIterator<'a> {
21 pub fn from_stmt(
22 stmt: MaybeCached<'a, Statement>,
23 types: &[Option<MysqlType>],
24 ) -> QueryResult<Self> {
25 let metadata = stmt.metadata()?;
26
27 let mut output_binds = OutputBinds::from_output_types(types, &metadata);
28
29 let mut stmt = stmt.execute_statement(&mut output_binds)?;
30 let size = unsafe { stmt.result_size() }?;
31
32 Ok(StatementIterator {
33 metadata: Rc::new(metadata),
34 last_row: Rc::new(RefCell::new(PrivateMysqlRow::Direct(output_binds))),
35 len: size,
36 stmt,
37 })
38 }
39}
40
41impl Iterator for StatementIterator<'_> {
42 type Item = QueryResult<MysqlRow>;
43
44 fn next(&mut self) -> Option<Self::Item> {
45 let res = if let Some(binds) = Rc::get_mut(&mut self.last_row) {
50 if let PrivateMysqlRow::Direct(ref mut binds) = RefCell::get_mut(binds) {
51 self.stmt.populate_row_buffers(binds)
52 } else {
53 unreachable!(
57 "You've reached an impossible internal state. \
58 If you ever see this error message please open \
59 an issue at https://github.com/diesel-rs/diesel \
60 providing example code how to trigger this error."
61 )
62 }
63 } else {
64 let mut last_row = {
69 let mut last_row = match self.last_row.try_borrow_mut() {
70 Ok(o) => o,
71 Err(_e) => {
72 return Some(Err(crate::result::Error::DeserializationError(
73 "Failed to reborrow row. Try to release any `MysqlField` or `MysqlValue` \
74 that exists at this point"
75 .into(),
76 )));
77 }
78 };
79 let last_row = &mut *last_row;
80 let duplicated = last_row.duplicate();
81 std::mem::replace(last_row, duplicated)
82 };
83 let res = if let PrivateMysqlRow::Direct(ref mut binds) = last_row {
84 self.stmt.populate_row_buffers(binds)
85 } else {
86 unreachable!(
90 "You've reached an impossible internal state. \
91 If you ever see this error message please open \
92 an issue at https://github.com/diesel-rs/diesel \
93 providing example code how to trigger this error."
94 )
95 };
96 self.last_row = Rc::new(RefCell::new(last_row));
97 res
98 };
99
100 match res {
101 Ok(Some(())) => {
102 self.len = self.len.saturating_sub(1);
103 Some(Ok(MysqlRow {
104 metadata: self.metadata.clone(),
105 row: self.last_row.clone(),
106 }))
107 }
108 Ok(None) => None,
109 Err(e) => {
110 self.len = self.len.saturating_sub(1);
111 Some(Err(e))
112 }
113 }
114 }
115
116 fn size_hint(&self) -> (usize, Option<usize>) {
117 (self.len(), Some(self.len()))
118 }
119
120 fn count(self) -> usize
121 where
122 Self: Sized,
123 {
124 self.len()
125 }
126}
127
128impl ExactSizeIterator for StatementIterator<'_> {
129 fn len(&self) -> usize {
130 self.len
131 }
132}
133
134#[derive(Clone)]
135#[allow(missing_debug_implementations)]
136pub struct MysqlRow {
137 row: Rc<RefCell<PrivateMysqlRow>>,
138 metadata: Rc<StatementMetadata>,
139}
140
141enum PrivateMysqlRow {
142 Direct(OutputBinds),
143 Copied(OutputBinds),
144}
145
146impl PrivateMysqlRow {
147 fn duplicate(&self) -> Self {
148 match self {
149 Self::Copied(b) | Self::Direct(b) => Self::Copied(b.clone()),
150 }
151 }
152}
153
154impl RowSealed for MysqlRow {}
155
156impl<'a> Row<'a, Mysql> for MysqlRow {
157 type Field<'f>
158 = MysqlField<'f>
159 where
160 'a: 'f,
161 Self: 'f;
162 type InnerPartialRow = Self;
163
164 fn field_count(&self) -> usize {
165 self.metadata.fields().len()
166 }
167
168 fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
169 where
170 'a: 'b,
171 Self: RowIndex<I>,
172 {
173 let idx = self.idx(idx)?;
174 Some(MysqlField {
175 binds: self.row.borrow(),
176 metadata: self.metadata.clone(),
177 idx,
178 })
179 }
180
181 fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
182 PartialRow::new(self, range)
183 }
184}
185
186impl RowIndex<usize> for MysqlRow {
187 fn idx(&self, idx: usize) -> Option<usize> {
188 if idx < self.field_count() {
189 Some(idx)
190 } else {
191 None
192 }
193 }
194}
195
196impl<'a> RowIndex<&'a str> for MysqlRow {
197 fn idx(&self, idx: &'a str) -> Option<usize> {
198 self.metadata
199 .fields()
200 .iter()
201 .enumerate()
202 .find(|(_, field_meta)| field_meta.field_name() == Some(idx))
203 .map(|(idx, _)| idx)
204 }
205}
206
207#[allow(missing_debug_implementations)]
208pub struct MysqlField<'a> {
209 binds: Ref<'a, PrivateMysqlRow>,
210 metadata: Rc<StatementMetadata>,
211 idx: usize,
212}
213
214impl<'a> Field<'a, Mysql> for MysqlField<'a> {
215 fn field_name(&self) -> Option<&str> {
216 self.metadata.fields()[self.idx].field_name()
217 }
218
219 fn is_null(&self) -> bool {
220 match &*self.binds {
221 PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].is_null(),
222 }
223 }
224
225 fn value(&self) -> Option<<Mysql as Backend>::RawValue<'_>> {
226 match &*self.binds {
227 PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].value(),
228 }
229 }
230}
231
232#[test]
233#[allow(clippy::drop_non_drop)] fn fun_with_row_iters() {
235 crate::table! {
236 #[allow(unused_parens)]
237 users(id) {
238 id -> Integer,
239 name -> Text,
240 }
241 }
242
243 use crate::connection::LoadConnection;
244 use crate::deserialize::{FromSql, FromSqlRow};
245 use crate::prelude::*;
246 use crate::row::{Field, Row};
247 use crate::sql_types;
248
249 let conn = &mut crate::test_helpers::connection();
250
251 crate::sql_query(
252 "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
253 )
254 .execute(conn)
255 .unwrap();
256 crate::sql_query("DELETE FROM users;")
257 .execute(conn)
258 .unwrap();
259
260 crate::insert_into(users::table)
261 .values(vec![
262 (users::id.eq(1), users::name.eq("Sean")),
263 (users::id.eq(2), users::name.eq("Tess")),
264 ])
265 .execute(conn)
266 .unwrap();
267
268 let query = users::table.select((users::id, users::name));
269
270 let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
271
272 {
273 let row_iter = conn.load(query).unwrap();
274 for (row, expected) in row_iter.zip(&expected) {
275 let row = row.unwrap();
276
277 let deserialized = <(i32, String) as FromSqlRow<
278 (sql_types::Integer, sql_types::Text),
279 _,
280 >>::build_from_row(&row)
281 .unwrap();
282
283 assert_eq!(&deserialized, expected);
284 }
285 }
286
287 {
288 let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();
289 assert_eq!(collected_rows.len(), 2);
290 for (row, expected) in collected_rows.iter().zip(&expected) {
291 let deserialized = row
292 .as_ref()
293 .map(|row| {
294 <(i32, String) as FromSqlRow<
295 (sql_types::Integer, sql_types::Text),
296 _,
297 >>::build_from_row(row).unwrap()
298 })
299 .unwrap();
300 assert_eq!(&deserialized, expected);
301 }
302 }
303
304 let mut row_iter = conn.load(query).unwrap();
305
306 let first_row = row_iter.next().unwrap().unwrap();
307 let first_fields = (
308 Row::get(&first_row, 0).unwrap(),
309 Row::get(&first_row, 1).unwrap(),
310 );
311 let first_values = (first_fields.0.value(), first_fields.1.value());
312
313 assert!(row_iter.next().unwrap().is_err());
314 std::mem::drop(first_values);
315 assert!(row_iter.next().unwrap().is_err());
316 std::mem::drop(first_fields);
317
318 let second_row = row_iter.next().unwrap().unwrap();
319 let second_fields = (
320 Row::get(&second_row, 0).unwrap(),
321 Row::get(&second_row, 1).unwrap(),
322 );
323 let second_values = (second_fields.0.value(), second_fields.1.value());
324
325 assert!(row_iter.next().unwrap().is_err());
326 std::mem::drop(second_values);
327 assert!(row_iter.next().unwrap().is_err());
328 std::mem::drop(second_fields);
329
330 assert!(row_iter.next().is_none());
331
332 let first_fields = (
333 Row::get(&first_row, 0).unwrap(),
334 Row::get(&first_row, 1).unwrap(),
335 );
336 let second_fields = (
337 Row::get(&second_row, 0).unwrap(),
338 Row::get(&second_row, 1).unwrap(),
339 );
340
341 let first_values = (first_fields.0.value(), first_fields.1.value());
342 let second_values = (second_fields.0.value(), second_fields.1.value());
343
344 assert_eq!(
345 <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(first_values.0).unwrap(),
346 expected[0].0
347 );
348 assert_eq!(
349 <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(first_values.1).unwrap(),
350 expected[0].1
351 );
352
353 assert_eq!(
354 <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(second_values.0).unwrap(),
355 expected[1].0
356 );
357 assert_eq!(
358 <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(second_values.1).unwrap(),
359 expected[1].1
360 );
361
362 let first_fields = (
363 Row::get(&first_row, 0).unwrap(),
364 Row::get(&first_row, 1).unwrap(),
365 );
366 let first_values = (first_fields.0.value(), first_fields.1.value());
367
368 assert_eq!(
369 <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(first_values.0).unwrap(),
370 expected[0].0
371 );
372 assert_eq!(
373 <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(first_values.1).unwrap(),
374 expected[0].1
375 );
376}