diesel/mysql/connection/stmt/
iterator.rs
1#![allow(unsafe_code)] use std::cell::{Ref, RefCell};
3use std::rc::Rc;
4
5use super::{OutputBinds, Statement, StatementMetadata, StatementUse};
6use crate::backend::Backend;
7use crate::connection::statement_cache::MaybeCached;
8use crate::mysql::{Mysql, MysqlType};
9use crate::result::QueryResult;
10use crate::row::*;
11
12#[allow(missing_debug_implementations)]
13pub struct StatementIterator<'a> {
14 stmt: StatementUse<'a>,
15 last_row: Rc<RefCell<PrivateMysqlRow>>,
16 metadata: Rc<StatementMetadata>,
17 len: usize,
18}
19
20impl<'a> StatementIterator<'a> {
21 pub fn from_stmt(
22 stmt: MaybeCached<'a, Statement>,
23 types: &[Option<MysqlType>],
24 ) -> QueryResult<Self> {
25 let metadata = stmt.metadata()?;
26
27 let mut output_binds = OutputBinds::from_output_types(types, &metadata);
28
29 let mut stmt = stmt.execute_statement(&mut output_binds)?;
30 let size = unsafe { stmt.result_size() }?;
31
32 Ok(StatementIterator {
33 metadata: Rc::new(metadata),
34 last_row: Rc::new(RefCell::new(PrivateMysqlRow::Direct(output_binds))),
35 len: size,
36 stmt,
37 })
38 }
39}
40
41impl Iterator for StatementIterator<'_> {
42 type Item = QueryResult<MysqlRow>;
43
44 fn next(&mut self) -> Option<Self::Item> {
45 let res = if let Some(binds) = Rc::get_mut(&mut self.last_row) {
50 if let PrivateMysqlRow::Direct(ref mut binds) = RefCell::get_mut(binds) {
51 self.stmt.populate_row_buffers(binds)
52 } else {
53 unreachable!(
57 "You've reached an impossible internal state. \
58 If you ever see this error message please open \
59 an issue at https://github.com/diesel-rs/diesel \
60 providing example code how to trigger this error."
61 )
62 }
63 } else {
64 let mut last_row = {
69 let mut last_row = match self.last_row.try_borrow_mut() {
70 Ok(o) => o,
71 Err(_e) => {
72 return Some(Err(crate::result::Error::DeserializationError(
73 "Failed to reborrow row. Try to release any `MysqlField` or `MysqlValue` \
74 that exists at this point"
75 .into(),
76 )));
77 }
78 };
79 let last_row = &mut *last_row;
80 let duplicated = last_row.duplicate();
81 std::mem::replace(last_row, duplicated)
82 };
83 let res = if let PrivateMysqlRow::Direct(ref mut binds) = last_row {
84 self.stmt.populate_row_buffers(binds)
85 } else {
86 unreachable!(
90 "You've reached an impossible internal state. \
91 If you ever see this error message please open \
92 an issue at https://github.com/diesel-rs/diesel \
93 providing example code how to trigger this error."
94 )
95 };
96 self.last_row = Rc::new(RefCell::new(last_row));
97 res
98 };
99
100 match res {
101 Ok(Some(())) => {
102 self.len = self.len.saturating_sub(1);
103 Some(Ok(MysqlRow {
104 metadata: self.metadata.clone(),
105 row: self.last_row.clone(),
106 }))
107 }
108 Ok(None) => None,
109 Err(e) => {
110 self.len = self.len.saturating_sub(1);
111 Some(Err(e))
112 }
113 }
114 }
115
116 fn size_hint(&self) -> (usize, Option<usize>) {
117 (self.len(), Some(self.len()))
118 }
119
120 fn count(self) -> usize
121 where
122 Self: Sized,
123 {
124 self.len()
125 }
126}
127
128impl ExactSizeIterator for StatementIterator<'_> {
129 fn len(&self) -> usize {
130 self.len
131 }
132}
133
134#[derive(Clone)]
135#[allow(missing_debug_implementations)]
136pub struct MysqlRow {
137 row: Rc<RefCell<PrivateMysqlRow>>,
138 metadata: Rc<StatementMetadata>,
139}
140
141enum PrivateMysqlRow {
142 Direct(OutputBinds),
143 Copied(OutputBinds),
144}
145
146impl PrivateMysqlRow {
147 fn duplicate(&self) -> Self {
148 match self {
149 Self::Copied(b) | Self::Direct(b) => Self::Copied(b.clone()),
150 }
151 }
152}
153
154impl RowSealed for MysqlRow {}
155
156impl<'a> Row<'a, Mysql> for MysqlRow {
157 type Field<'f>
158 = MysqlField<'f>
159 where
160 'a: 'f,
161 Self: 'f;
162 type InnerPartialRow = Self;
163
164 fn field_count(&self) -> usize {
165 self.metadata.fields().len()
166 }
167
168 fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
169 where
170 'a: 'b,
171 Self: RowIndex<I>,
172 {
173 let idx = self.idx(idx)?;
174 Some(MysqlField {
175 binds: self.row.borrow(),
176 metadata: self.metadata.clone(),
177 idx,
178 })
179 }
180
181 fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
182 PartialRow::new(self, range)
183 }
184}
185
186impl RowIndex<usize> for MysqlRow {
187 fn idx(&self, idx: usize) -> Option<usize> {
188 if idx < self.field_count() {
189 Some(idx)
190 } else {
191 None
192 }
193 }
194}
195
196impl<'a> RowIndex<&'a str> for MysqlRow {
197 fn idx(&self, idx: &'a str) -> Option<usize> {
198 self.metadata
199 .fields()
200 .iter()
201 .enumerate()
202 .find(|(_, field_meta)| field_meta.field_name() == Some(idx))
203 .map(|(idx, _)| idx)
204 }
205}
206
207#[allow(missing_debug_implementations)]
208pub struct MysqlField<'a> {
209 binds: Ref<'a, PrivateMysqlRow>,
210 metadata: Rc<StatementMetadata>,
211 idx: usize,
212}
213
214impl<'a> Field<'a, Mysql> for MysqlField<'a> {
215 fn field_name(&self) -> Option<&str> {
216 self.metadata.fields()[self.idx].field_name()
217 }
218
219 fn is_null(&self) -> bool {
220 match &*self.binds {
221 PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].is_null(),
222 }
223 }
224
225 fn value(&self) -> Option<<Mysql as Backend>::RawValue<'_>> {
226 match &*self.binds {
227 PrivateMysqlRow::Copied(b) | PrivateMysqlRow::Direct(b) => b[self.idx].value(),
228 }
229 }
230}
231
232#[cfg(test)]
233#[diesel_test_helper::test]
234#[allow(clippy::drop_non_drop)] fn fun_with_row_iters() {
236 crate::table! {
237 users(id) {
238 id -> Integer,
239 name -> Text,
240 }
241 }
242
243 use crate::connection::LoadConnection;
244 use crate::deserialize::{FromSql, FromSqlRow};
245 use crate::prelude::*;
246 use crate::row::{Field, Row};
247 use crate::sql_types;
248
249 let conn = &mut crate::test_helpers::connection();
250
251 crate::sql_query(
252 "CREATE TEMPORARY TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);",
253 )
254 .execute(conn)
255 .unwrap();
256
257 crate::insert_into(users::table)
258 .values(vec![
259 (users::id.eq(1), users::name.eq("Sean")),
260 (users::id.eq(2), users::name.eq("Tess")),
261 ])
262 .execute(conn)
263 .unwrap();
264
265 let query = users::table.select((users::id, users::name));
266
267 let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
268
269 {
270 let row_iter = conn.load(query).unwrap();
271 for (row, expected) in row_iter.zip(&expected) {
272 let row = row.unwrap();
273
274 let deserialized = <(i32, String) as FromSqlRow<
275 (sql_types::Integer, sql_types::Text),
276 _,
277 >>::build_from_row(&row)
278 .unwrap();
279
280 assert_eq!(&deserialized, expected);
281 }
282 }
283
284 {
285 let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();
286 assert_eq!(collected_rows.len(), 2);
287 for (row, expected) in collected_rows.iter().zip(&expected) {
288 let deserialized = row
289 .as_ref()
290 .map(|row| {
291 <(i32, String) as FromSqlRow<
292 (sql_types::Integer, sql_types::Text),
293 _,
294 >>::build_from_row(row).unwrap()
295 })
296 .unwrap();
297 assert_eq!(&deserialized, expected);
298 }
299 }
300
301 let mut row_iter = conn.load(query).unwrap();
302
303 let first_row = row_iter.next().unwrap().unwrap();
304 let first_fields = (
305 Row::get(&first_row, 0).unwrap(),
306 Row::get(&first_row, 1).unwrap(),
307 );
308 let first_values = (first_fields.0.value(), first_fields.1.value());
309
310 assert!(row_iter.next().unwrap().is_err());
311 std::mem::drop(first_values);
312 assert!(row_iter.next().unwrap().is_err());
313 std::mem::drop(first_fields);
314
315 let second_row = row_iter.next().unwrap().unwrap();
316 let second_fields = (
317 Row::get(&second_row, 0).unwrap(),
318 Row::get(&second_row, 1).unwrap(),
319 );
320 let second_values = (second_fields.0.value(), second_fields.1.value());
321
322 assert!(row_iter.next().unwrap().is_err());
323 std::mem::drop(second_values);
324 assert!(row_iter.next().unwrap().is_err());
325 std::mem::drop(second_fields);
326
327 assert!(row_iter.next().is_none());
328
329 let first_fields = (
330 Row::get(&first_row, 0).unwrap(),
331 Row::get(&first_row, 1).unwrap(),
332 );
333 let second_fields = (
334 Row::get(&second_row, 0).unwrap(),
335 Row::get(&second_row, 1).unwrap(),
336 );
337
338 let first_values = (first_fields.0.value(), first_fields.1.value());
339 let second_values = (second_fields.0.value(), second_fields.1.value());
340
341 assert_eq!(
342 <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(first_values.0).unwrap(),
343 expected[0].0
344 );
345 assert_eq!(
346 <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(first_values.1).unwrap(),
347 expected[0].1
348 );
349
350 assert_eq!(
351 <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(second_values.0).unwrap(),
352 expected[1].0
353 );
354 assert_eq!(
355 <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(second_values.1).unwrap(),
356 expected[1].1
357 );
358
359 let first_fields = (
360 Row::get(&first_row, 0).unwrap(),
361 Row::get(&first_row, 1).unwrap(),
362 );
363 let first_values = (first_fields.0.value(), first_fields.1.value());
364
365 assert_eq!(
366 <i32 as FromSql<sql_types::Integer, Mysql>>::from_nullable_sql(first_values.0).unwrap(),
367 expected[0].0
368 );
369 assert_eq!(
370 <String as FromSql<sql_types::Text, Mysql>>::from_nullable_sql(first_values.1).unwrap(),
371 expected[0].1
372 );
373}