1#![allow(unsafe_code)] use super::bind_collector::{InternalSqliteBindValue, SqliteBindCollector};
3use super::raw::RawConnection;
4use super::sqlite_value::OwnedSqliteValue;
5use crate::connection::statement_cache::{MaybeCached, PrepareForCache};
6use crate::connection::Instrumentation;
7use crate::query_builder::{QueryFragment, QueryId};
8use crate::result::Error::DatabaseError;
9use crate::result::*;
10use crate::sqlite::{Sqlite, SqliteType};
11use libsqlite3_sys as ffi;
12use std::cell::OnceCell;
13use std::ffi::{CStr, CString};
14use std::io::{stderr, Write};
15use std::os::raw as libc;
16use std::ptr::{self, NonNull};
17
18pub(super) struct Statement {
19 inner_statement: NonNull<ffi::sqlite3_stmt>,
20}
21
22impl Statement {
23 pub(super) fn prepare(
24 raw_connection: &RawConnection,
25 sql: &str,
26 is_cached: PrepareForCache,
27 ) -> QueryResult<Self> {
28 let mut stmt = ptr::null_mut();
29 let mut unused_portion = ptr::null();
30 let n_byte = sql
31 .len()
32 .try_into()
33 .map_err(|e| Error::SerializationError(Box::new(e)))?;
34 #[allow(clippy::unnecessary_cast)]
36 let prepare_result = unsafe {
37 ffi::sqlite3_prepare_v3(
38 raw_connection.internal_connection.as_ptr(),
39 CString::new(sql)?.as_ptr(),
40 n_byte,
41 if matches!(is_cached, PrepareForCache::Yes) {
42 ffi::SQLITE_PREPARE_PERSISTENT as u32
43 } else {
44 0
45 },
46 &mut stmt,
47 &mut unused_portion,
48 )
49 };
50
51 ensure_sqlite_ok(prepare_result, raw_connection.internal_connection.as_ptr())?;
52
53 let inner_statement = NonNull::new(stmt).ok_or_else(|| {
56 crate::result::Error::QueryBuilderError(Box::new(crate::result::EmptyQuery))
57 })?;
58 Ok(Statement { inner_statement })
59 }
60
61 unsafe fn bind(
67 &mut self,
68 tpe: SqliteType,
69 value: InternalSqliteBindValue<'_>,
70 bind_index: i32,
71 ) -> QueryResult<Option<NonNull<[u8]>>> {
72 let mut ret_ptr = None;
73 let result = match (tpe, value) {
74 (_, InternalSqliteBindValue::Null) => {
75 ffi::sqlite3_bind_null(self.inner_statement.as_ptr(), bind_index)
76 }
77 (SqliteType::Binary, InternalSqliteBindValue::BorrowedBinary(bytes)) => {
78 let n = bytes
79 .len()
80 .try_into()
81 .map_err(|e| Error::SerializationError(Box::new(e)))?;
82 ffi::sqlite3_bind_blob(
83 self.inner_statement.as_ptr(),
84 bind_index,
85 bytes.as_ptr() as *const libc::c_void,
86 n,
87 ffi::SQLITE_STATIC(),
88 )
89 }
90 (SqliteType::Binary, InternalSqliteBindValue::Binary(mut bytes)) => {
91 let len = bytes
92 .len()
93 .try_into()
94 .map_err(|e| Error::SerializationError(Box::new(e)))?;
95 let ptr = bytes.as_mut_ptr();
99 ret_ptr = NonNull::new(Box::into_raw(bytes));
100 ffi::sqlite3_bind_blob(
101 self.inner_statement.as_ptr(),
102 bind_index,
103 ptr as *const libc::c_void,
104 len,
105 ffi::SQLITE_STATIC(),
106 )
107 }
108 (SqliteType::Text, InternalSqliteBindValue::BorrowedString(bytes)) => {
109 let len = bytes
110 .len()
111 .try_into()
112 .map_err(|e| Error::SerializationError(Box::new(e)))?;
113 ffi::sqlite3_bind_text(
114 self.inner_statement.as_ptr(),
115 bind_index,
116 bytes.as_ptr() as *const libc::c_char,
117 len,
118 ffi::SQLITE_STATIC(),
119 )
120 }
121 (SqliteType::Text, InternalSqliteBindValue::String(bytes)) => {
122 let mut bytes = Box::<[u8]>::from(bytes);
123 let len = bytes
124 .len()
125 .try_into()
126 .map_err(|e| Error::SerializationError(Box::new(e)))?;
127 let ptr = bytes.as_mut_ptr();
131 ret_ptr = NonNull::new(Box::into_raw(bytes));
132 ffi::sqlite3_bind_text(
133 self.inner_statement.as_ptr(),
134 bind_index,
135 ptr as *const libc::c_char,
136 len,
137 ffi::SQLITE_STATIC(),
138 )
139 }
140 (SqliteType::Float, InternalSqliteBindValue::F64(value))
141 | (SqliteType::Double, InternalSqliteBindValue::F64(value)) => {
142 ffi::sqlite3_bind_double(
143 self.inner_statement.as_ptr(),
144 bind_index,
145 value as libc::c_double,
146 )
147 }
148 (SqliteType::SmallInt, InternalSqliteBindValue::I32(value))
149 | (SqliteType::Integer, InternalSqliteBindValue::I32(value)) => {
150 ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, value)
151 }
152 (SqliteType::Long, InternalSqliteBindValue::I64(value)) => {
153 ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, value)
154 }
155 (t, b) => {
156 return Err(Error::SerializationError(
157 format!("Type mismatch: Expected {t:?}, got {b}").into(),
158 ))
159 }
160 };
161 match ensure_sqlite_ok(result, self.raw_connection()) {
162 Ok(()) => Ok(ret_ptr),
163 Err(e) => {
164 if let Some(ptr) = ret_ptr {
165 std::mem::drop(Box::from_raw(ptr.as_ptr()))
169 }
170 Err(e)
171 }
172 }
173 }
174
175 fn reset(&mut self) {
176 unsafe { ffi::sqlite3_reset(self.inner_statement.as_ptr()) };
177 }
178
179 fn raw_connection(&self) -> *mut ffi::sqlite3 {
180 unsafe { ffi::sqlite3_db_handle(self.inner_statement.as_ptr()) }
181 }
182}
183
184pub(super) fn ensure_sqlite_ok(
185 code: libc::c_int,
186 raw_connection: *mut ffi::sqlite3,
187) -> QueryResult<()> {
188 if code == ffi::SQLITE_OK {
189 Ok(())
190 } else {
191 Err(last_error(raw_connection))
192 }
193}
194
195fn last_error(raw_connection: *mut ffi::sqlite3) -> Error {
196 let error_message = last_error_message(raw_connection);
197 let error_information = Box::new(error_message);
198 let error_kind = match last_error_code(raw_connection) {
199 ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY => {
200 DatabaseErrorKind::UniqueViolation
201 }
202 ffi::SQLITE_CONSTRAINT_FOREIGNKEY => DatabaseErrorKind::ForeignKeyViolation,
203 ffi::SQLITE_CONSTRAINT_NOTNULL => DatabaseErrorKind::NotNullViolation,
204 ffi::SQLITE_CONSTRAINT_CHECK => DatabaseErrorKind::CheckViolation,
205 _ => DatabaseErrorKind::Unknown,
206 };
207 DatabaseError(error_kind, error_information)
208}
209
210fn last_error_message(conn: *mut ffi::sqlite3) -> String {
211 let c_str = unsafe { CStr::from_ptr(ffi::sqlite3_errmsg(conn)) };
212 c_str.to_string_lossy().into_owned()
213}
214
215fn last_error_code(conn: *mut ffi::sqlite3) -> libc::c_int {
216 unsafe { ffi::sqlite3_extended_errcode(conn) }
217}
218
219impl Drop for Statement {
220 fn drop(&mut self) {
221 use std::thread::panicking;
222
223 let raw_connection = self.raw_connection();
224 let finalize_result = unsafe { ffi::sqlite3_finalize(self.inner_statement.as_ptr()) };
225 if let Err(e) = ensure_sqlite_ok(finalize_result, raw_connection) {
226 if panicking() {
227 write!(
228 stderr(),
229 "Error finalizing SQLite prepared statement: {e:?}"
230 )
231 .expect("Error writing to `stderr`");
232 } else {
233 panic!("Error finalizing SQLite prepared statement: {:?}", e);
234 }
235 }
236 }
237}
238
239struct BoundStatement<'stmt, 'query> {
249 statement: MaybeCached<'stmt, Statement>,
250 query: Option<NonNull<dyn QueryFragment<Sqlite> + 'query>>,
256 binds_to_free: Vec<(i32, Option<NonNull<[u8]>>)>,
260 instrumentation: &'stmt mut dyn Instrumentation,
261 has_error: bool,
262}
263
264impl<'stmt, 'query> BoundStatement<'stmt, 'query> {
265 fn bind<T>(
266 statement: MaybeCached<'stmt, Statement>,
267 query: T,
268 instrumentation: &'stmt mut dyn Instrumentation,
269 ) -> QueryResult<BoundStatement<'stmt, 'query>>
270 where
271 T: QueryFragment<Sqlite> + QueryId + 'query,
272 {
273 let query = Box::new(query);
278
279 let mut bind_collector = SqliteBindCollector::new();
280 query.collect_binds(&mut bind_collector, &mut (), &Sqlite)?;
281 let SqliteBindCollector { binds } = bind_collector;
282
283 let mut ret = BoundStatement {
284 statement,
285 query: None,
286 binds_to_free: Vec::new(),
287 instrumentation,
288 has_error: false,
289 };
290
291 ret.bind_buffers(binds)?;
292
293 let query = query as Box<dyn QueryFragment<Sqlite> + 'query>;
294 ret.query = NonNull::new(Box::into_raw(query));
295
296 Ok(ret)
297 }
298
299 fn bind_buffers(
303 &mut self,
304 binds: Vec<(InternalSqliteBindValue<'_>, SqliteType)>,
305 ) -> QueryResult<()> {
306 self.binds_to_free.reserve(
311 binds
312 .iter()
313 .filter(|&(b, _)| {
314 matches!(
315 b,
316 InternalSqliteBindValue::BorrowedBinary(_)
317 | InternalSqliteBindValue::BorrowedString(_)
318 | InternalSqliteBindValue::String(_)
319 | InternalSqliteBindValue::Binary(_)
320 )
321 })
322 .count(),
323 );
324 for (bind_idx, (bind, tpe)) in (1..).zip(binds) {
325 let is_borrowed_bind = matches!(
326 bind,
327 InternalSqliteBindValue::BorrowedString(_)
328 | InternalSqliteBindValue::BorrowedBinary(_)
329 );
330
331 let res = unsafe { self.statement.bind(tpe, bind, bind_idx) }?;
336
337 if let Some(ptr) = res {
342 self.binds_to_free.push((bind_idx, Some(ptr)));
345 } else if is_borrowed_bind {
346 self.binds_to_free.push((bind_idx, None));
348 }
349 }
350 Ok(())
351 }
352
353 fn finish_query_with_error(mut self, e: &Error) {
354 self.has_error = true;
355 if let Some(q) = self.query {
356 let q = unsafe { q.as_ref() };
358 self.instrumentation.on_connection_event(
359 crate::connection::InstrumentationEvent::FinishQuery {
360 query: &crate::debug_query(&q),
361 error: Some(e),
362 },
363 );
364 }
365 }
366}
367
368impl Drop for BoundStatement<'_, '_> {
369 fn drop(&mut self) {
370 self.statement.reset();
373
374 for (idx, buffer) in std::mem::take(&mut self.binds_to_free) {
375 unsafe {
376 self.statement
378 .bind(SqliteType::Text, InternalSqliteBindValue::Null, idx)
379 .expect(
380 "Binding a null value should never fail. \
381 If you ever see this error message please open \
382 an issue at diesels issue tracker containing \
383 code how to trigger this message.",
384 );
385 }
386
387 if let Some(buffer) = buffer {
388 unsafe {
389 std::mem::drop(Box::from_raw(buffer.as_ptr()));
392 }
393 }
394 }
395
396 if let Some(query) = self.query {
397 let query = unsafe {
398 Box::from_raw(query.as_ptr())
401 };
402 if !self.has_error {
403 self.instrumentation.on_connection_event(
404 crate::connection::InstrumentationEvent::FinishQuery {
405 query: &crate::debug_query(&query),
406 error: None,
407 },
408 );
409 }
410 std::mem::drop(query);
411 self.query = None;
412 }
413 }
414}
415
416#[allow(missing_debug_implementations)]
417pub struct StatementUse<'stmt, 'query> {
418 statement: BoundStatement<'stmt, 'query>,
419 column_names: OnceCell<Vec<*const str>>,
420}
421
422impl<'stmt, 'query> StatementUse<'stmt, 'query> {
423 pub(super) fn bind<T>(
424 statement: MaybeCached<'stmt, Statement>,
425 query: T,
426 instrumentation: &'stmt mut dyn Instrumentation,
427 ) -> QueryResult<StatementUse<'stmt, 'query>>
428 where
429 T: QueryFragment<Sqlite> + QueryId + 'query,
430 {
431 Ok(Self {
432 statement: BoundStatement::bind(statement, query, instrumentation)?,
433 column_names: OnceCell::new(),
434 })
435 }
436
437 pub(super) fn run(mut self) -> QueryResult<()> {
438 let r = unsafe {
439 self.step(true).map(|_| ())
443 };
444 if let Err(ref e) = r {
445 self.statement.finish_query_with_error(e);
446 }
447 r
448 }
449
450 pub(super) unsafe fn step(&mut self, first_step: bool) -> QueryResult<bool> {
457 let res = match ffi::sqlite3_step(self.statement.statement.inner_statement.as_ptr()) {
458 ffi::SQLITE_DONE => Ok(false),
459 ffi::SQLITE_ROW => Ok(true),
460 _ => Err(last_error(self.statement.statement.raw_connection())),
461 };
462 if first_step {
463 self.column_names = OnceCell::new();
464 }
465 res
466 }
467
468 unsafe fn column_name(&self, idx: i32) -> *const str {
480 let name = {
481 let column_name =
482 ffi::sqlite3_column_name(self.statement.statement.inner_statement.as_ptr(), idx);
483 assert!(
484 !column_name.is_null(),
485 "The Sqlite documentation states that it only returns a \
486 null pointer here if we are in a OOM condition."
487 );
488 CStr::from_ptr(column_name)
489 };
490 name.to_str().expect(
491 "The Sqlite documentation states that this is UTF8. \
492 If you see this error message something has gone \
493 horribly wrong. Please open an issue at the \
494 diesel repository.",
495 ) as *const str
496 }
497
498 pub(super) fn column_count(&self) -> i32 {
499 unsafe { ffi::sqlite3_column_count(self.statement.statement.inner_statement.as_ptr()) }
500 }
501
502 pub(super) fn index_for_column_name(&mut self, field_name: &str) -> Option<usize> {
503 (0..self.column_count())
504 .find(|idx| self.field_name(*idx) == Some(field_name))
505 .map(|v| {
506 v.try_into()
507 .expect("Diesel expects to run at least on a 32 bit platform")
508 })
509 }
510
511 pub(super) fn field_name(&self, idx: i32) -> Option<&str> {
512 let column_names = self.column_names.get_or_init(|| {
513 let count = self.column_count();
514 (0..count)
515 .map(|idx| unsafe {
516 self.column_name(idx)
519 })
520 .collect()
521 });
522
523 column_names
524 .get(usize::try_from(idx).expect("Diesel expects to run at least on a 32 bit platform"))
525 .and_then(|c| unsafe { c.as_ref() })
526 }
527
528 pub(super) fn copy_value(&self, idx: i32) -> Option<OwnedSqliteValue> {
529 OwnedSqliteValue::copy_from_ptr(self.column_value(idx)?)
530 }
531
532 pub(super) fn column_value(&self, idx: i32) -> Option<NonNull<ffi::sqlite3_value>> {
533 let ptr = unsafe {
534 ffi::sqlite3_column_value(self.statement.statement.inner_statement.as_ptr(), idx)
535 };
536 NonNull::new(ptr)
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use crate::prelude::*;
543 use crate::sql_types::Text;
544
545 #[test]
548 fn check_out_of_bounds_bind_does_not_panic_on_drop() {
549 let mut conn = SqliteConnection::establish(":memory:").unwrap();
550
551 let e = crate::sql_query("SELECT '?'")
552 .bind::<Text, _>("foo")
553 .execute(&mut conn);
554
555 assert!(e.is_err());
556 let e = e.unwrap_err();
557 if let crate::result::Error::DatabaseError(crate::result::DatabaseErrorKind::Unknown, m) = e
558 {
559 assert_eq!(m.message(), "column index out of range");
560 } else {
561 panic!("Wrong error returned");
562 }
563 }
564}