1#![allow(unsafe_code)] use super::bind_collector::{InternalSqliteBindValue, SqliteBindCollector};
3use super::raw::RawConnection;
4use super::sqlite_value::OwnedSqliteValue;
5use crate::connection::statement_cache::{MaybeCached, PrepareForCache};
6use crate::connection::Instrumentation;
7use crate::query_builder::{QueryFragment, QueryId};
8use crate::result::Error::DatabaseError;
9use crate::result::*;
10use crate::sqlite::{Sqlite, SqliteType};
11#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
12use libsqlite3_sys as ffi;
13#[cfg(all(target_family = "wasm", target_os = "unknown"))]
14use sqlite_wasm_rs::export as ffi;
15use std::cell::OnceCell;
16use std::ffi::{CStr, CString};
17use std::io::{stderr, Write};
18use std::os::raw as libc;
19use std::ptr::{self, NonNull};
20
21pub(super) struct Statement {
22 inner_statement: NonNull<ffi::sqlite3_stmt>,
23}
24
25#[allow(unsafe_code)]
29unsafe impl Send for Statement {}
30
31impl Statement {
32 pub(super) fn prepare(
33 raw_connection: &RawConnection,
34 sql: &str,
35 is_cached: PrepareForCache,
36 _: &[SqliteType],
37 ) -> QueryResult<Self> {
38 let mut stmt = ptr::null_mut();
39 let mut unused_portion = ptr::null();
40 let n_byte = sql
41 .len()
42 .try_into()
43 .map_err(|e| Error::SerializationError(Box::new(e)))?;
44 #[allow(clippy::unnecessary_cast)]
46 let prepare_result = unsafe {
47 ffi::sqlite3_prepare_v3(
48 raw_connection.internal_connection.as_ptr(),
49 CString::new(sql)?.as_ptr(),
50 n_byte,
51 if matches!(is_cached, PrepareForCache::Yes { counter: _ }) {
52 ffi::SQLITE_PREPARE_PERSISTENT as u32
53 } else {
54 0
55 },
56 &mut stmt,
57 &mut unused_portion,
58 )
59 };
60
61 ensure_sqlite_ok(prepare_result, raw_connection.internal_connection.as_ptr())?;
62
63 let inner_statement = NonNull::new(stmt).ok_or_else(|| {
66 crate::result::Error::QueryBuilderError(Box::new(crate::result::EmptyQuery))
67 })?;
68 Ok(Statement { inner_statement })
69 }
70
71 unsafe fn bind(
77 &mut self,
78 tpe: SqliteType,
79 value: InternalSqliteBindValue<'_>,
80 bind_index: i32,
81 ) -> QueryResult<Option<NonNull<[u8]>>> {
82 let mut ret_ptr = None;
83 let result = match (tpe, value) {
84 (_, InternalSqliteBindValue::Null) => {
85 ffi::sqlite3_bind_null(self.inner_statement.as_ptr(), bind_index)
86 }
87 (SqliteType::Binary, InternalSqliteBindValue::BorrowedBinary(bytes)) => {
88 let n = bytes
89 .len()
90 .try_into()
91 .map_err(|e| Error::SerializationError(Box::new(e)))?;
92 ffi::sqlite3_bind_blob(
93 self.inner_statement.as_ptr(),
94 bind_index,
95 bytes.as_ptr() as *const libc::c_void,
96 n,
97 ffi::SQLITE_STATIC(),
98 )
99 }
100 (SqliteType::Binary, InternalSqliteBindValue::Binary(mut bytes)) => {
101 let len = bytes
102 .len()
103 .try_into()
104 .map_err(|e| Error::SerializationError(Box::new(e)))?;
105 let ptr = bytes.as_mut_ptr();
109 ret_ptr = NonNull::new(Box::into_raw(bytes));
110 ffi::sqlite3_bind_blob(
111 self.inner_statement.as_ptr(),
112 bind_index,
113 ptr as *const libc::c_void,
114 len,
115 ffi::SQLITE_STATIC(),
116 )
117 }
118 (SqliteType::Text, InternalSqliteBindValue::BorrowedString(bytes)) => {
119 let len = bytes
120 .len()
121 .try_into()
122 .map_err(|e| Error::SerializationError(Box::new(e)))?;
123 ffi::sqlite3_bind_text(
124 self.inner_statement.as_ptr(),
125 bind_index,
126 bytes.as_ptr() as *const libc::c_char,
127 len,
128 ffi::SQLITE_STATIC(),
129 )
130 }
131 (SqliteType::Text, InternalSqliteBindValue::String(bytes)) => {
132 let mut bytes = Box::<[u8]>::from(bytes);
133 let len = bytes
134 .len()
135 .try_into()
136 .map_err(|e| Error::SerializationError(Box::new(e)))?;
137 let ptr = bytes.as_mut_ptr();
141 ret_ptr = NonNull::new(Box::into_raw(bytes));
142 ffi::sqlite3_bind_text(
143 self.inner_statement.as_ptr(),
144 bind_index,
145 ptr as *const libc::c_char,
146 len,
147 ffi::SQLITE_STATIC(),
148 )
149 }
150 (SqliteType::Float, InternalSqliteBindValue::F64(value))
151 | (SqliteType::Double, InternalSqliteBindValue::F64(value)) => {
152 ffi::sqlite3_bind_double(
153 self.inner_statement.as_ptr(),
154 bind_index,
155 value as libc::c_double,
156 )
157 }
158 (SqliteType::SmallInt, InternalSqliteBindValue::I32(value))
159 | (SqliteType::Integer, InternalSqliteBindValue::I32(value)) => {
160 ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, value)
161 }
162 (SqliteType::Long, InternalSqliteBindValue::I64(value)) => {
163 ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, value)
164 }
165 (t, b) => {
166 return Err(Error::SerializationError(
167 format!("Type mismatch: Expected {t:?}, got {b}").into(),
168 ))
169 }
170 };
171 match ensure_sqlite_ok(result, self.raw_connection()) {
172 Ok(()) => Ok(ret_ptr),
173 Err(e) => {
174 if let Some(ptr) = ret_ptr {
175 std::mem::drop(Box::from_raw(ptr.as_ptr()))
179 }
180 Err(e)
181 }
182 }
183 }
184
185 fn reset(&mut self) {
186 unsafe { ffi::sqlite3_reset(self.inner_statement.as_ptr()) };
187 }
188
189 fn raw_connection(&self) -> *mut ffi::sqlite3 {
190 unsafe { ffi::sqlite3_db_handle(self.inner_statement.as_ptr()) }
191 }
192}
193
194pub(super) fn ensure_sqlite_ok(
195 code: libc::c_int,
196 raw_connection: *mut ffi::sqlite3,
197) -> QueryResult<()> {
198 if code == ffi::SQLITE_OK {
199 Ok(())
200 } else {
201 Err(last_error(raw_connection))
202 }
203}
204
205fn last_error(raw_connection: *mut ffi::sqlite3) -> Error {
206 let error_message = last_error_message(raw_connection);
207 let error_information = Box::new(error_message);
208 let error_kind = match last_error_code(raw_connection) {
209 ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY => {
210 DatabaseErrorKind::UniqueViolation
211 }
212 ffi::SQLITE_CONSTRAINT_FOREIGNKEY => DatabaseErrorKind::ForeignKeyViolation,
213 ffi::SQLITE_CONSTRAINT_NOTNULL => DatabaseErrorKind::NotNullViolation,
214 ffi::SQLITE_CONSTRAINT_CHECK => DatabaseErrorKind::CheckViolation,
215 _ => DatabaseErrorKind::Unknown,
216 };
217 DatabaseError(error_kind, error_information)
218}
219
220fn last_error_message(conn: *mut ffi::sqlite3) -> String {
221 let c_str = unsafe { CStr::from_ptr(ffi::sqlite3_errmsg(conn)) };
222 c_str.to_string_lossy().into_owned()
223}
224
225fn last_error_code(conn: *mut ffi::sqlite3) -> libc::c_int {
226 unsafe { ffi::sqlite3_extended_errcode(conn) }
227}
228
229impl Drop for Statement {
230 fn drop(&mut self) {
231 use std::thread::panicking;
232
233 let raw_connection = self.raw_connection();
234 let finalize_result = unsafe { ffi::sqlite3_finalize(self.inner_statement.as_ptr()) };
235 if let Err(e) = ensure_sqlite_ok(finalize_result, raw_connection) {
236 if panicking() {
237 write!(
238 stderr(),
239 "Error finalizing SQLite prepared statement: {e:?}"
240 )
241 .expect("Error writing to `stderr`");
242 } else {
243 panic!("Error finalizing SQLite prepared statement: {:?}", e);
244 }
245 }
246 }
247}
248
249struct BoundStatement<'stmt, 'query> {
259 statement: MaybeCached<'stmt, Statement>,
260 query: Option<NonNull<dyn QueryFragment<Sqlite> + 'query>>,
266 binds_to_free: Vec<(i32, Option<NonNull<[u8]>>)>,
270 instrumentation: &'stmt mut dyn Instrumentation,
271 has_error: bool,
272}
273
274impl<'stmt, 'query> BoundStatement<'stmt, 'query> {
275 fn bind<T>(
276 statement: MaybeCached<'stmt, Statement>,
277 query: T,
278 instrumentation: &'stmt mut dyn Instrumentation,
279 ) -> QueryResult<BoundStatement<'stmt, 'query>>
280 where
281 T: QueryFragment<Sqlite> + QueryId + 'query,
282 {
283 let query = Box::new(query);
288
289 let mut bind_collector = SqliteBindCollector::new();
290 query.collect_binds(&mut bind_collector, &mut (), &Sqlite)?;
291 let SqliteBindCollector { binds } = bind_collector;
292
293 let mut ret = BoundStatement {
294 statement,
295 query: None,
296 binds_to_free: Vec::new(),
297 instrumentation,
298 has_error: false,
299 };
300
301 ret.bind_buffers(binds)?;
302
303 let query = query as Box<dyn QueryFragment<Sqlite> + 'query>;
304 ret.query = NonNull::new(Box::into_raw(query));
305
306 Ok(ret)
307 }
308
309 fn bind_buffers(
313 &mut self,
314 binds: Vec<(InternalSqliteBindValue<'_>, SqliteType)>,
315 ) -> QueryResult<()> {
316 self.binds_to_free.reserve(
321 binds
322 .iter()
323 .filter(|&(b, _)| {
324 matches!(
325 b,
326 InternalSqliteBindValue::BorrowedBinary(_)
327 | InternalSqliteBindValue::BorrowedString(_)
328 | InternalSqliteBindValue::String(_)
329 | InternalSqliteBindValue::Binary(_)
330 )
331 })
332 .count(),
333 );
334 for (bind_idx, (bind, tpe)) in (1..).zip(binds) {
335 let is_borrowed_bind = matches!(
336 bind,
337 InternalSqliteBindValue::BorrowedString(_)
338 | InternalSqliteBindValue::BorrowedBinary(_)
339 );
340
341 let res = unsafe { self.statement.bind(tpe, bind, bind_idx) }?;
346
347 if let Some(ptr) = res {
352 self.binds_to_free.push((bind_idx, Some(ptr)));
355 } else if is_borrowed_bind {
356 self.binds_to_free.push((bind_idx, None));
358 }
359 }
360 Ok(())
361 }
362
363 fn finish_query_with_error(mut self, e: &Error) {
364 self.has_error = true;
365 if let Some(q) = self.query {
366 let q = unsafe { q.as_ref() };
368 self.instrumentation.on_connection_event(
369 crate::connection::InstrumentationEvent::FinishQuery {
370 query: &crate::debug_query(&q),
371 error: Some(e),
372 },
373 );
374 }
375 }
376}
377
378impl Drop for BoundStatement<'_, '_> {
379 fn drop(&mut self) {
380 self.statement.reset();
383
384 for (idx, buffer) in std::mem::take(&mut self.binds_to_free) {
385 unsafe {
386 self.statement
388 .bind(SqliteType::Text, InternalSqliteBindValue::Null, idx)
389 .expect(
390 "Binding a null value should never fail. \
391 If you ever see this error message please open \
392 an issue at diesels issue tracker containing \
393 code how to trigger this message.",
394 );
395 }
396
397 if let Some(buffer) = buffer {
398 unsafe {
399 std::mem::drop(Box::from_raw(buffer.as_ptr()));
402 }
403 }
404 }
405
406 if let Some(query) = self.query {
407 let query = unsafe {
408 Box::from_raw(query.as_ptr())
411 };
412 if !self.has_error {
413 self.instrumentation.on_connection_event(
414 crate::connection::InstrumentationEvent::FinishQuery {
415 query: &crate::debug_query(&query),
416 error: None,
417 },
418 );
419 }
420 std::mem::drop(query);
421 self.query = None;
422 }
423 }
424}
425
426#[allow(missing_debug_implementations)]
427pub struct StatementUse<'stmt, 'query> {
428 statement: BoundStatement<'stmt, 'query>,
429 column_names: OnceCell<Vec<*const str>>,
430}
431
432impl<'stmt, 'query> StatementUse<'stmt, 'query> {
433 pub(super) fn bind<T>(
434 statement: MaybeCached<'stmt, Statement>,
435 query: T,
436 instrumentation: &'stmt mut dyn Instrumentation,
437 ) -> QueryResult<StatementUse<'stmt, 'query>>
438 where
439 T: QueryFragment<Sqlite> + QueryId + 'query,
440 {
441 Ok(Self {
442 statement: BoundStatement::bind(statement, query, instrumentation)?,
443 column_names: OnceCell::new(),
444 })
445 }
446
447 pub(super) fn run(mut self) -> QueryResult<()> {
448 let r = unsafe {
449 self.step(true).map(|_| ())
453 };
454 if let Err(ref e) = r {
455 self.statement.finish_query_with_error(e);
456 }
457 r
458 }
459
460 pub(super) unsafe fn step(&mut self, first_step: bool) -> QueryResult<bool> {
467 let res = match ffi::sqlite3_step(self.statement.statement.inner_statement.as_ptr()) {
468 ffi::SQLITE_DONE => Ok(false),
469 ffi::SQLITE_ROW => Ok(true),
470 _ => Err(last_error(self.statement.statement.raw_connection())),
471 };
472 if first_step {
473 self.column_names = OnceCell::new();
474 }
475 res
476 }
477
478 unsafe fn column_name(&self, idx: i32) -> *const str {
490 let name = {
491 let column_name =
492 ffi::sqlite3_column_name(self.statement.statement.inner_statement.as_ptr(), idx);
493 assert!(
494 !column_name.is_null(),
495 "The Sqlite documentation states that it only returns a \
496 null pointer here if we are in a OOM condition."
497 );
498 CStr::from_ptr(column_name)
499 };
500 name.to_str().expect(
501 "The Sqlite documentation states that this is UTF8. \
502 If you see this error message something has gone \
503 horribly wrong. Please open an issue at the \
504 diesel repository.",
505 ) as *const str
506 }
507
508 pub(super) fn column_count(&self) -> i32 {
509 unsafe { ffi::sqlite3_column_count(self.statement.statement.inner_statement.as_ptr()) }
510 }
511
512 pub(super) fn index_for_column_name(&mut self, field_name: &str) -> Option<usize> {
513 (0..self.column_count())
514 .find(|idx| self.field_name(*idx) == Some(field_name))
515 .map(|v| {
516 v.try_into()
517 .expect("Diesel expects to run at least on a 32 bit platform")
518 })
519 }
520
521 pub(super) fn field_name(&self, idx: i32) -> Option<&str> {
522 let column_names = self.column_names.get_or_init(|| {
523 let count = self.column_count();
524 (0..count)
525 .map(|idx| unsafe {
526 self.column_name(idx)
529 })
530 .collect()
531 });
532
533 column_names
534 .get(usize::try_from(idx).expect("Diesel expects to run at least on a 32 bit platform"))
535 .and_then(|c| unsafe { c.as_ref() })
536 }
537
538 pub(super) fn copy_value(&self, idx: i32) -> Option<OwnedSqliteValue> {
539 OwnedSqliteValue::copy_from_ptr(self.column_value(idx)?)
540 }
541
542 pub(super) fn column_value(&self, idx: i32) -> Option<NonNull<ffi::sqlite3_value>> {
543 let ptr = unsafe {
544 ffi::sqlite3_column_value(self.statement.statement.inner_statement.as_ptr(), idx)
545 };
546 NonNull::new(ptr)
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use crate::prelude::*;
553 use crate::sql_types::Text;
554
555 #[diesel_test_helper::test]
558 fn check_out_of_bounds_bind_does_not_panic_on_drop() {
559 let mut conn = SqliteConnection::establish(":memory:").unwrap();
560
561 let e = crate::sql_query("SELECT '?'")
562 .bind::<Text, _>("foo")
563 .execute(&mut conn);
564
565 assert!(e.is_err());
566 let e = e.unwrap_err();
567 if let crate::result::Error::DatabaseError(crate::result::DatabaseErrorKind::Unknown, m) = e
568 {
569 assert_eq!(m.message(), "column index out of range");
570 } else {
571 panic!("Wrong error returned");
572 }
573 }
574}