diesel/sqlite/connection/
stmt.rs1#![allow(unsafe_code)] use super::bind_collector::{InternalSqliteBindValue, SqliteBindCollector};
3use super::raw::RawConnection;
4use super::sqlite_value::OwnedSqliteValue;
5use crate::connection::statement_cache::{MaybeCached, PrepareForCache};
6use crate::connection::Instrumentation;
7use crate::query_builder::{QueryFragment, QueryId};
8use crate::result::Error::DatabaseError;
9use crate::result::*;
10use crate::sqlite::{Sqlite, SqliteType};
11#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
12use libsqlite3_sys as ffi;
13#[cfg(all(target_family = "wasm", target_os = "unknown"))]
14use sqlite_wasm_rs as ffi;
15use std::cell::OnceCell;
16use std::ffi::{CStr, CString};
17use std::io::{stderr, Write};
18use std::os::raw as libc;
19use std::ptr::{self, NonNull};
20
21pub(super) struct Statement {
22 inner_statement: NonNull<ffi::sqlite3_stmt>,
23}
24
25#[allow(unsafe_code)]
29unsafe impl Send for Statement {}
30
31impl Statement {
32 pub(super) fn prepare(
33 raw_connection: &RawConnection,
34 sql: &str,
35 is_cached: PrepareForCache,
36 _: &[SqliteType],
37 ) -> QueryResult<Self> {
38 let mut stmt = ptr::null_mut();
39 let mut unused_portion = ptr::null();
40 let n_byte = sql
41 .len()
42 .try_into()
43 .map_err(|e| Error::SerializationError(Box::new(e)))?;
44 #[allow(clippy::unnecessary_cast)]
46 let prepare_result = unsafe {
47 ffi::sqlite3_prepare_v3(
48 raw_connection.internal_connection.as_ptr(),
49 CString::new(sql)?.as_ptr(),
50 n_byte,
51 if matches!(is_cached, PrepareForCache::Yes { counter: _ }) {
52 ffi::SQLITE_PREPARE_PERSISTENT as u32
53 } else {
54 0
55 },
56 &mut stmt,
57 &mut unused_portion,
58 )
59 };
60
61 ensure_sqlite_ok(prepare_result, raw_connection.internal_connection.as_ptr())?;
62
63 let inner_statement = NonNull::new(stmt).ok_or_else(|| {
66 crate::result::Error::QueryBuilderError(Box::new(crate::result::EmptyQuery))
67 })?;
68 Ok(Statement { inner_statement })
69 }
70
71 unsafe fn bind(
77 &mut self,
78 tpe: SqliteType,
79 value: InternalSqliteBindValue<'_>,
80 bind_index: i32,
81 ) -> QueryResult<Option<NonNull<[u8]>>> {
82 let mut ret_ptr = None;
83 let result = match (tpe, value) {
84 (_, InternalSqliteBindValue::Null) => unsafe {
85 ffi::sqlite3_bind_null(self.inner_statement.as_ptr(), bind_index)
86 },
87 (SqliteType::Binary, InternalSqliteBindValue::BorrowedBinary(bytes)) => {
88 let n = bytes
89 .len()
90 .try_into()
91 .map_err(|e| Error::SerializationError(Box::new(e)))?;
92 unsafe {
93 ffi::sqlite3_bind_blob(
94 self.inner_statement.as_ptr(),
95 bind_index,
96 bytes.as_ptr() as *const libc::c_void,
97 n,
98 ffi::SQLITE_STATIC(),
99 )
100 }
101 }
102 (SqliteType::Binary, InternalSqliteBindValue::Binary(mut bytes)) => {
103 let len = bytes
104 .len()
105 .try_into()
106 .map_err(|e| Error::SerializationError(Box::new(e)))?;
107 let ptr = bytes.as_mut_ptr();
111 ret_ptr = NonNull::new(Box::into_raw(bytes));
112 unsafe {
113 ffi::sqlite3_bind_blob(
114 self.inner_statement.as_ptr(),
115 bind_index,
116 ptr as *const libc::c_void,
117 len,
118 ffi::SQLITE_STATIC(),
119 )
120 }
121 }
122 (SqliteType::Text, InternalSqliteBindValue::BorrowedString(bytes)) => {
123 let len = bytes
124 .len()
125 .try_into()
126 .map_err(|e| Error::SerializationError(Box::new(e)))?;
127 unsafe {
128 ffi::sqlite3_bind_text(
129 self.inner_statement.as_ptr(),
130 bind_index,
131 bytes.as_ptr() as *const libc::c_char,
132 len,
133 ffi::SQLITE_STATIC(),
134 )
135 }
136 }
137 (SqliteType::Text, InternalSqliteBindValue::String(bytes)) => {
138 let mut bytes = Box::<[u8]>::from(bytes);
139 let len = bytes
140 .len()
141 .try_into()
142 .map_err(|e| Error::SerializationError(Box::new(e)))?;
143 let ptr = bytes.as_mut_ptr();
147 ret_ptr = NonNull::new(Box::into_raw(bytes));
148 unsafe {
149 ffi::sqlite3_bind_text(
150 self.inner_statement.as_ptr(),
151 bind_index,
152 ptr as *const libc::c_char,
153 len,
154 ffi::SQLITE_STATIC(),
155 )
156 }
157 }
158 (SqliteType::Float, InternalSqliteBindValue::F64(value))
159 | (SqliteType::Double, InternalSqliteBindValue::F64(value)) => unsafe {
160 ffi::sqlite3_bind_double(
161 self.inner_statement.as_ptr(),
162 bind_index,
163 value as libc::c_double,
164 )
165 },
166 (SqliteType::SmallInt, InternalSqliteBindValue::I32(value))
167 | (SqliteType::Integer, InternalSqliteBindValue::I32(value)) => unsafe {
168 ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, value)
169 },
170 (SqliteType::Long, InternalSqliteBindValue::I64(value)) => unsafe {
171 ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, value)
172 },
173 (t, b) => {
174 return Err(Error::SerializationError(
175 format!("Type mismatch: Expected {t:?}, got {b}").into(),
176 ))
177 }
178 };
179 match ensure_sqlite_ok(result, self.raw_connection()) {
180 Ok(()) => Ok(ret_ptr),
181 Err(e) => {
182 if let Some(ptr) = ret_ptr {
183 std::mem::drop(unsafe { Box::from_raw(ptr.as_ptr()) })
187 }
188 Err(e)
189 }
190 }
191 }
192
193 fn reset(&mut self) {
194 unsafe { ffi::sqlite3_reset(self.inner_statement.as_ptr()) };
195 }
196
197 fn raw_connection(&self) -> *mut ffi::sqlite3 {
198 unsafe { ffi::sqlite3_db_handle(self.inner_statement.as_ptr()) }
199 }
200}
201
202pub(super) fn ensure_sqlite_ok(
203 code: libc::c_int,
204 raw_connection: *mut ffi::sqlite3,
205) -> QueryResult<()> {
206 if code == ffi::SQLITE_OK {
207 Ok(())
208 } else {
209 Err(last_error(raw_connection))
210 }
211}
212
213fn last_error(raw_connection: *mut ffi::sqlite3) -> Error {
214 let error_message = last_error_message(raw_connection);
215 let error_information = Box::new(error_message);
216 let error_kind = match last_error_code(raw_connection) {
217 ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY => {
218 DatabaseErrorKind::UniqueViolation
219 }
220 ffi::SQLITE_CONSTRAINT_FOREIGNKEY => DatabaseErrorKind::ForeignKeyViolation,
221 ffi::SQLITE_CONSTRAINT_NOTNULL => DatabaseErrorKind::NotNullViolation,
222 ffi::SQLITE_CONSTRAINT_CHECK => DatabaseErrorKind::CheckViolation,
223 _ => DatabaseErrorKind::Unknown,
224 };
225 DatabaseError(error_kind, error_information)
226}
227
228fn last_error_message(conn: *mut ffi::sqlite3) -> String {
229 let c_str = unsafe { CStr::from_ptr(ffi::sqlite3_errmsg(conn)) };
230 c_str.to_string_lossy().into_owned()
231}
232
233fn last_error_code(conn: *mut ffi::sqlite3) -> libc::c_int {
234 unsafe { ffi::sqlite3_extended_errcode(conn) }
235}
236
237impl Drop for Statement {
238 fn drop(&mut self) {
239 use std::thread::panicking;
240
241 let raw_connection = self.raw_connection();
242 let finalize_result = unsafe { ffi::sqlite3_finalize(self.inner_statement.as_ptr()) };
243 if let Err(e) = ensure_sqlite_ok(finalize_result, raw_connection) {
244 if panicking() {
245 write!(
246 stderr(),
247 "Error finalizing SQLite prepared statement: {e:?}"
248 )
249 .expect("Error writing to `stderr`");
250 } else {
251 panic!("Error finalizing SQLite prepared statement: {e:?}");
252 }
253 }
254 }
255}
256
257struct BoundStatement<'stmt, 'query> {
267 statement: MaybeCached<'stmt, Statement>,
268 query: Option<NonNull<dyn QueryFragment<Sqlite> + 'query>>,
274 binds_to_free: Vec<(i32, Option<NonNull<[u8]>>)>,
278 instrumentation: &'stmt mut dyn Instrumentation,
279 has_error: bool,
280}
281
282impl<'stmt, 'query> BoundStatement<'stmt, 'query> {
283 fn bind<T>(
284 statement: MaybeCached<'stmt, Statement>,
285 query: T,
286 instrumentation: &'stmt mut dyn Instrumentation,
287 ) -> QueryResult<BoundStatement<'stmt, 'query>>
288 where
289 T: QueryFragment<Sqlite> + QueryId + 'query,
290 {
291 let query = Box::new(query);
296
297 let mut bind_collector = SqliteBindCollector::new();
298 query.collect_binds(&mut bind_collector, &mut (), &Sqlite)?;
299 let SqliteBindCollector { binds } = bind_collector;
300
301 let mut ret = BoundStatement {
302 statement,
303 query: None,
304 binds_to_free: Vec::new(),
305 instrumentation,
306 has_error: false,
307 };
308
309 ret.bind_buffers(binds)?;
310
311 let query = query as Box<dyn QueryFragment<Sqlite> + 'query>;
312 ret.query = NonNull::new(Box::into_raw(query));
313
314 Ok(ret)
315 }
316
317 fn bind_buffers(
321 &mut self,
322 binds: Vec<(InternalSqliteBindValue<'_>, SqliteType)>,
323 ) -> QueryResult<()> {
324 self.binds_to_free.reserve(
329 binds
330 .iter()
331 .filter(|&(b, _)| {
332 matches!(
333 b,
334 InternalSqliteBindValue::BorrowedBinary(_)
335 | InternalSqliteBindValue::BorrowedString(_)
336 | InternalSqliteBindValue::String(_)
337 | InternalSqliteBindValue::Binary(_)
338 )
339 })
340 .count(),
341 );
342 for (bind_idx, (bind, tpe)) in (1..).zip(binds) {
343 let is_borrowed_bind = matches!(
344 bind,
345 InternalSqliteBindValue::BorrowedString(_)
346 | InternalSqliteBindValue::BorrowedBinary(_)
347 );
348
349 let res = unsafe { self.statement.bind(tpe, bind, bind_idx) }?;
354
355 if let Some(ptr) = res {
360 self.binds_to_free.push((bind_idx, Some(ptr)));
363 } else if is_borrowed_bind {
364 self.binds_to_free.push((bind_idx, None));
366 }
367 }
368 Ok(())
369 }
370
371 fn finish_query_with_error(mut self, e: &Error) {
372 self.has_error = true;
373 if let Some(q) = self.query {
374 let q = unsafe { q.as_ref() };
376 self.instrumentation.on_connection_event(
377 crate::connection::InstrumentationEvent::FinishQuery {
378 query: &crate::debug_query(&q),
379 error: Some(e),
380 },
381 );
382 }
383 }
384}
385
386impl Drop for BoundStatement<'_, '_> {
387 fn drop(&mut self) {
388 self.statement.reset();
391
392 for (idx, buffer) in std::mem::take(&mut self.binds_to_free) {
393 unsafe {
394 self.statement
396 .bind(SqliteType::Text, InternalSqliteBindValue::Null, idx)
397 .expect(
398 "Binding a null value should never fail. \
399 If you ever see this error message please open \
400 an issue at diesels issue tracker containing \
401 code how to trigger this message.",
402 );
403 }
404
405 if let Some(buffer) = buffer {
406 unsafe {
407 std::mem::drop(Box::from_raw(buffer.as_ptr()));
410 }
411 }
412 }
413
414 if let Some(query) = self.query {
415 let query = unsafe {
416 Box::from_raw(query.as_ptr())
419 };
420 if !self.has_error {
421 self.instrumentation.on_connection_event(
422 crate::connection::InstrumentationEvent::FinishQuery {
423 query: &crate::debug_query(&query),
424 error: None,
425 },
426 );
427 }
428 std::mem::drop(query);
429 self.query = None;
430 }
431 }
432}
433
434#[allow(missing_debug_implementations)]
435pub struct StatementUse<'stmt, 'query> {
436 statement: BoundStatement<'stmt, 'query>,
437 column_names: OnceCell<Vec<*const str>>,
438}
439
440impl<'stmt, 'query> StatementUse<'stmt, 'query> {
441 pub(super) fn bind<T>(
442 statement: MaybeCached<'stmt, Statement>,
443 query: T,
444 instrumentation: &'stmt mut dyn Instrumentation,
445 ) -> QueryResult<StatementUse<'stmt, 'query>>
446 where
447 T: QueryFragment<Sqlite> + QueryId + 'query,
448 {
449 Ok(Self {
450 statement: BoundStatement::bind(statement, query, instrumentation)?,
451 column_names: OnceCell::new(),
452 })
453 }
454
455 pub(super) fn run(mut self) -> QueryResult<()> {
456 let r = unsafe {
457 self.step(true).map(|_| ())
461 };
462 if let Err(ref e) = r {
463 self.statement.finish_query_with_error(e);
464 }
465 r
466 }
467
468 pub(super) unsafe fn step(&mut self, first_step: bool) -> QueryResult<bool> {
475 let step_result =
476 unsafe { ffi::sqlite3_step(self.statement.statement.inner_statement.as_ptr()) };
477 let res = match step_result {
478 ffi::SQLITE_DONE => Ok(false),
479 ffi::SQLITE_ROW => Ok(true),
480 _ => Err(last_error(self.statement.statement.raw_connection())),
481 };
482 if first_step {
483 self.column_names = OnceCell::new();
484 }
485 res
486 }
487
488 unsafe fn column_name(&self, idx: i32) -> *const str {
500 let name = {
501 let column_name = unsafe {
502 ffi::sqlite3_column_name(self.statement.statement.inner_statement.as_ptr(), idx)
503 };
504 assert!(
505 !column_name.is_null(),
506 "The Sqlite documentation states that it only returns a \
507 null pointer here if we are in a OOM condition."
508 );
509 unsafe { CStr::from_ptr(column_name) }
510 };
511 name.to_str().expect(
512 "The Sqlite documentation states that this is UTF8. \
513 If you see this error message something has gone \
514 horribly wrong. Please open an issue at the \
515 diesel repository.",
516 ) as *const str
517 }
518
519 pub(super) fn column_count(&self) -> i32 {
520 unsafe { ffi::sqlite3_column_count(self.statement.statement.inner_statement.as_ptr()) }
521 }
522
523 pub(super) fn index_for_column_name(&mut self, field_name: &str) -> Option<usize> {
524 (0..self.column_count())
525 .find(|idx| self.field_name(*idx) == Some(field_name))
526 .map(|v| {
527 v.try_into()
528 .expect("Diesel expects to run at least on a 32 bit platform")
529 })
530 }
531
532 pub(super) fn field_name(&self, idx: i32) -> Option<&str> {
533 let column_names = self.column_names.get_or_init(|| {
534 let count = self.column_count();
535 (0..count)
536 .map(|idx| unsafe {
537 self.column_name(idx)
540 })
541 .collect()
542 });
543
544 column_names
545 .get(usize::try_from(idx).expect("Diesel expects to run at least on a 32 bit platform"))
546 .and_then(|c| unsafe { c.as_ref() })
547 }
548
549 pub(super) fn copy_value(&self, idx: i32) -> Option<OwnedSqliteValue> {
550 OwnedSqliteValue::copy_from_ptr(self.column_value(idx)?)
551 }
552
553 pub(super) fn column_value(&self, idx: i32) -> Option<NonNull<ffi::sqlite3_value>> {
554 let ptr = unsafe {
555 ffi::sqlite3_column_value(self.statement.statement.inner_statement.as_ptr(), idx)
556 };
557 NonNull::new(ptr)
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use crate::prelude::*;
564 use crate::sql_types::Text;
565
566 #[diesel_test_helper::test]
569 fn check_out_of_bounds_bind_does_not_panic_on_drop() {
570 let mut conn = SqliteConnection::establish(":memory:").unwrap();
571
572 let e = crate::sql_query("SELECT '?'")
573 .bind::<Text, _>("foo")
574 .execute(&mut conn);
575
576 assert!(e.is_err());
577 let e = e.unwrap_err();
578 if let crate::result::Error::DatabaseError(crate::result::DatabaseErrorKind::Unknown, m) = e
579 {
580 assert_eq!(m.message(), "column index out of range");
581 } else {
582 panic!("Wrong error returned");
583 }
584 }
585}