diesel/sqlite/connection/
stmt.rs1#![allow(unsafe_code)] use super::bind_collector::{InternalSqliteBindValue, SqliteBindCollector};
3use super::raw::RawConnection;
4use super::sqlite_value::OwnedSqliteValue;
5use crate::connection::statement_cache::{MaybeCached, PrepareForCache};
6use crate::connection::Instrumentation;
7use crate::query_builder::{QueryFragment, QueryId};
8use crate::result::Error::DatabaseError;
9use crate::result::*;
10use crate::sqlite::{Sqlite, SqliteType};
11#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
12use libsqlite3_sys as ffi;
13#[cfg(all(target_family = "wasm", target_os = "unknown"))]
14use sqlite_wasm_rs as ffi;
15use std::cell::OnceCell;
16use std::ffi::{CStr, CString};
17use std::io::{stderr, Write};
18use std::os::raw as libc;
19use std::ptr::{self, NonNull};
20
21pub(super) struct Statement {
22    inner_statement: NonNull<ffi::sqlite3_stmt>,
23}
24
25#[allow(unsafe_code)]
29unsafe impl Send for Statement {}
30
31impl Statement {
32    pub(super) fn prepare(
33        raw_connection: &RawConnection,
34        sql: &str,
35        is_cached: PrepareForCache,
36        _: &[SqliteType],
37    ) -> QueryResult<Self> {
38        let mut stmt = ptr::null_mut();
39        let mut unused_portion = ptr::null();
40        let n_byte = sql
41            .len()
42            .try_into()
43            .map_err(|e| Error::SerializationError(Box::new(e)))?;
44        #[allow(clippy::unnecessary_cast)]
46        let prepare_result = unsafe {
47            ffi::sqlite3_prepare_v3(
48                raw_connection.internal_connection.as_ptr(),
49                CString::new(sql)?.as_ptr(),
50                n_byte,
51                if #[allow(non_exhaustive_omitted_patterns)] match is_cached {
    PrepareForCache::Yes { counter: _ } => true,
    _ => false,
}matches!(is_cached, PrepareForCache::Yes { counter: _ }) {
52                    ffi::SQLITE_PREPARE_PERSISTENT as u32
53                } else {
54                    0
55                },
56                &mut stmt,
57                &mut unused_portion,
58            )
59        };
60
61        ensure_sqlite_ok(prepare_result, raw_connection.internal_connection.as_ptr())?;
62
63        let inner_statement = NonNull::new(stmt).ok_or_else(|| {
66            crate::result::Error::QueryBuilderError(Box::new(crate::result::EmptyQuery))
67        })?;
68        Ok(Statement { inner_statement })
69    }
70
71    unsafe fn bind(
77        &mut self,
78        tpe: SqliteType,
79        value: InternalSqliteBindValue<'_>,
80        bind_index: i32,
81    ) -> QueryResult<Option<NonNull<[u8]>>> {
82        let mut ret_ptr = None;
83        let result = match (tpe, value) {
84            (_, InternalSqliteBindValue::Null) => unsafe {
85                ffi::sqlite3_bind_null(self.inner_statement.as_ptr(), bind_index)
86            },
87            (SqliteType::Binary, InternalSqliteBindValue::BorrowedBinary(bytes)) => {
88                let n = bytes
89                    .len()
90                    .try_into()
91                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
92                unsafe {
93                    ffi::sqlite3_bind_blob(
94                        self.inner_statement.as_ptr(),
95                        bind_index,
96                        bytes.as_ptr() as *const libc::c_void,
97                        n,
98                        ffi::SQLITE_STATIC(),
99                    )
100                }
101            }
102            (SqliteType::Binary, InternalSqliteBindValue::Binary(mut bytes)) => {
103                let len = bytes
104                    .len()
105                    .try_into()
106                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
107                let ptr = bytes.as_mut_ptr();
111                ret_ptr = NonNull::new(Box::into_raw(bytes));
112                unsafe {
113                    ffi::sqlite3_bind_blob(
114                        self.inner_statement.as_ptr(),
115                        bind_index,
116                        ptr as *const libc::c_void,
117                        len,
118                        ffi::SQLITE_STATIC(),
119                    )
120                }
121            }
122            (SqliteType::Text, InternalSqliteBindValue::BorrowedString(bytes)) => {
123                let len = bytes
124                    .len()
125                    .try_into()
126                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
127                unsafe {
128                    ffi::sqlite3_bind_text(
129                        self.inner_statement.as_ptr(),
130                        bind_index,
131                        bytes.as_ptr() as *const libc::c_char,
132                        len,
133                        ffi::SQLITE_STATIC(),
134                    )
135                }
136            }
137            (SqliteType::Text, InternalSqliteBindValue::String(bytes)) => {
138                let mut bytes = Box::<[u8]>::from(bytes);
139                let len = bytes
140                    .len()
141                    .try_into()
142                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
143                let ptr = bytes.as_mut_ptr();
147                ret_ptr = NonNull::new(Box::into_raw(bytes));
148                unsafe {
149                    ffi::sqlite3_bind_text(
150                        self.inner_statement.as_ptr(),
151                        bind_index,
152                        ptr as *const libc::c_char,
153                        len,
154                        ffi::SQLITE_STATIC(),
155                    )
156                }
157            }
158            (SqliteType::Float, InternalSqliteBindValue::F64(value))
159            | (SqliteType::Double, InternalSqliteBindValue::F64(value)) => unsafe {
160                ffi::sqlite3_bind_double(
161                    self.inner_statement.as_ptr(),
162                    bind_index,
163                    value as libc::c_double,
164                )
165            },
166            (SqliteType::SmallInt, InternalSqliteBindValue::I32(value))
167            | (SqliteType::Integer, InternalSqliteBindValue::I32(value)) => unsafe {
168                ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, value)
169            },
170            (SqliteType::Long, InternalSqliteBindValue::I64(value)) => unsafe {
171                ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, value)
172            },
173            (t, b) => {
174                return Err(Error::SerializationError(
175                    ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("Type mismatch: Expected {0:?}, got {1}",
                t, b))
    })format!("Type mismatch: Expected {t:?}, got {b}").into(),
176                ))
177            }
178        };
179        match ensure_sqlite_ok(result, self.raw_connection()) {
180            Ok(()) => Ok(ret_ptr),
181            Err(e) => {
182                if let Some(ptr) = ret_ptr {
183                    std::mem::drop(unsafe { Box::from_raw(ptr.as_ptr()) })
187                }
188                Err(e)
189            }
190        }
191    }
192
193    fn reset(&mut self) {
194        unsafe { ffi::sqlite3_reset(self.inner_statement.as_ptr()) };
195    }
196
197    fn raw_connection(&self) -> *mut ffi::sqlite3 {
198        unsafe { ffi::sqlite3_db_handle(self.inner_statement.as_ptr()) }
199    }
200}
201
202pub(super) fn ensure_sqlite_ok(
203    code: libc::c_int,
204    raw_connection: *mut ffi::sqlite3,
205) -> QueryResult<()> {
206    if code == ffi::SQLITE_OK {
207        Ok(())
208    } else {
209        Err(last_error(raw_connection))
210    }
211}
212
213fn last_error(raw_connection: *mut ffi::sqlite3) -> Error {
214    let error_message = last_error_message(raw_connection);
215    let error_information = Box::new(error_message);
216    let error_kind = match last_error_code(raw_connection) {
217        ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY => {
218            DatabaseErrorKind::UniqueViolation
219        }
220        ffi::SQLITE_CONSTRAINT_FOREIGNKEY => DatabaseErrorKind::ForeignKeyViolation,
221        ffi::SQLITE_CONSTRAINT_NOTNULL => DatabaseErrorKind::NotNullViolation,
222        ffi::SQLITE_CONSTRAINT_CHECK => DatabaseErrorKind::CheckViolation,
223        _ => DatabaseErrorKind::Unknown,
224    };
225    DatabaseError(error_kind, error_information)
226}
227
228fn last_error_message(conn: *mut ffi::sqlite3) -> String {
229    let c_str = unsafe { CStr::from_ptr(ffi::sqlite3_errmsg(conn)) };
230    c_str.to_string_lossy().into_owned()
231}
232
233fn last_error_code(conn: *mut ffi::sqlite3) -> libc::c_int {
234    unsafe { ffi::sqlite3_extended_errcode(conn) }
235}
236
237impl Drop for Statement {
238    fn drop(&mut self) {
239        use std::thread::panicking;
240
241        let raw_connection = self.raw_connection();
242        let finalize_result = unsafe { ffi::sqlite3_finalize(self.inner_statement.as_ptr()) };
243        if let Err(e) = ensure_sqlite_ok(finalize_result, raw_connection) {
244            if panicking() {
245                stderr().write_fmt(format_args!("Error finalizing SQLite prepared statement: {0:?}",
        e))write!(
246                    stderr(),
247                    "Error finalizing SQLite prepared statement: {e:?}"
248                )
249                .expect("Error writing to `stderr`");
250            } else {
251                {
    ::core::panicking::panic_fmt(format_args!("Error finalizing SQLite prepared statement: {0:?}",
            e));
};panic!("Error finalizing SQLite prepared statement: {e:?}");
252            }
253        }
254    }
255}
256
257struct BoundStatement<'stmt, 'query> {
267    statement: MaybeCached<'stmt, Statement>,
268    query: Option<NonNull<dyn QueryFragment<Sqlite> + 'query>>,
274    binds_to_free: Vec<(i32, Option<NonNull<[u8]>>)>,
278    instrumentation: &'stmt mut dyn Instrumentation,
279    has_error: bool,
280}
281
282impl<'stmt, 'query> BoundStatement<'stmt, 'query> {
283    fn bind<T>(
284        statement: MaybeCached<'stmt, Statement>,
285        query: T,
286        instrumentation: &'stmt mut dyn Instrumentation,
287    ) -> QueryResult<BoundStatement<'stmt, 'query>>
288    where
289        T: QueryFragment<Sqlite> + QueryId + 'query,
290    {
291        let query = Box::new(query);
296
297        let mut bind_collector = SqliteBindCollector::new();
298        query.collect_binds(&mut bind_collector, &mut (), &Sqlite)?;
299        let SqliteBindCollector { binds } = bind_collector;
300
301        let mut ret = BoundStatement {
302            statement,
303            query: None,
304            binds_to_free: Vec::new(),
305            instrumentation,
306            has_error: false,
307        };
308
309        ret.bind_buffers(binds)?;
310
311        let query = query as Box<dyn QueryFragment<Sqlite> + 'query>;
312        ret.query = NonNull::new(Box::into_raw(query));
313
314        Ok(ret)
315    }
316
317    fn bind_buffers(
321        &mut self,
322        binds: Vec<(InternalSqliteBindValue<'_>, SqliteType)>,
323    ) -> QueryResult<()> {
324        self.binds_to_free.reserve(
329            binds
330                .iter()
331                .filter(|&(b, _)| {
332                    #[allow(non_exhaustive_omitted_patterns)] match b {
    InternalSqliteBindValue::BorrowedBinary(_) |
        InternalSqliteBindValue::BorrowedString(_) |
        InternalSqliteBindValue::String(_) |
        InternalSqliteBindValue::Binary(_) => true,
    _ => false,
}matches!(
333                        b,
334                        InternalSqliteBindValue::BorrowedBinary(_)
335                            | InternalSqliteBindValue::BorrowedString(_)
336                            | InternalSqliteBindValue::String(_)
337                            | InternalSqliteBindValue::Binary(_)
338                    )
339                })
340                .count(),
341        );
342        for (bind_idx, (bind, tpe)) in (1..).zip(binds) {
343            let is_borrowed_bind = #[allow(non_exhaustive_omitted_patterns)] match bind {
    InternalSqliteBindValue::BorrowedString(_) |
        InternalSqliteBindValue::BorrowedBinary(_) => true,
    _ => false,
}matches!(
344                bind,
345                InternalSqliteBindValue::BorrowedString(_)
346                    | InternalSqliteBindValue::BorrowedBinary(_)
347            );
348
349            let res = unsafe { self.statement.bind(tpe, bind, bind_idx) }?;
354
355            if let Some(ptr) = res {
360                self.binds_to_free.push((bind_idx, Some(ptr)));
363            } else if is_borrowed_bind {
364                self.binds_to_free.push((bind_idx, None));
366            }
367        }
368        Ok(())
369    }
370
371    fn finish_query_with_error(mut self, e: &Error) {
372        self.has_error = true;
373        if let Some(q) = self.query {
374            let q = unsafe { q.as_ref() };
376            self.instrumentation.on_connection_event(
377                crate::connection::InstrumentationEvent::FinishQuery {
378                    query: &crate::debug_query(&q),
379                    error: Some(e),
380                },
381            );
382        }
383    }
384}
385
386impl Drop for BoundStatement<'_, '_> {
387    fn drop(&mut self) {
388        self.statement.reset();
391
392        for (idx, buffer) in std::mem::take(&mut self.binds_to_free) {
393            unsafe {
394                self.statement
396                    .bind(SqliteType::Text, InternalSqliteBindValue::Null, idx)
397                    .expect(
398                        "Binding a null value should never fail. \
399                             If you ever see this error message please open \
400                             an issue at diesels issue tracker containing \
401                             code how to trigger this message.",
402                    );
403            }
404
405            if let Some(buffer) = buffer {
406                unsafe {
407                    std::mem::drop(Box::from_raw(buffer.as_ptr()));
410                }
411            }
412        }
413
414        if let Some(query) = self.query {
415            let query = unsafe {
416                Box::from_raw(query.as_ptr())
419            };
420            if !self.has_error {
421                self.instrumentation.on_connection_event(
422                    crate::connection::InstrumentationEvent::FinishQuery {
423                        query: &crate::debug_query(&query),
424                        error: None,
425                    },
426                );
427            }
428            std::mem::drop(query);
429            self.query = None;
430        }
431    }
432}
433
434#[allow(missing_debug_implementations)]
435pub struct StatementUse<'stmt, 'query> {
436    statement: BoundStatement<'stmt, 'query>,
437    column_names: OnceCell<Vec<*const str>>,
438}
439
440impl<'stmt, 'query> StatementUse<'stmt, 'query> {
441    pub(super) fn bind<T>(
442        statement: MaybeCached<'stmt, Statement>,
443        query: T,
444        instrumentation: &'stmt mut dyn Instrumentation,
445    ) -> QueryResult<StatementUse<'stmt, 'query>>
446    where
447        T: QueryFragment<Sqlite> + QueryId + 'query,
448    {
449        Ok(Self {
450            statement: BoundStatement::bind(statement, query, instrumentation)?,
451            column_names: OnceCell::new(),
452        })
453    }
454
455    pub(super) fn run(mut self) -> QueryResult<()> {
456        let r = unsafe {
457            self.step(true).map(|_| ())
461        };
462        if let Err(ref e) = r {
463            self.statement.finish_query_with_error(e);
464        }
465        r
466    }
467
468    pub(super) unsafe fn step(&mut self, first_step: bool) -> QueryResult<bool> {
475        let step_result =
476            unsafe { ffi::sqlite3_step(self.statement.statement.inner_statement.as_ptr()) };
477        let res = match step_result {
478            ffi::SQLITE_DONE => Ok(false),
479            ffi::SQLITE_ROW => Ok(true),
480            _ => Err(last_error(self.statement.statement.raw_connection())),
481        };
482        if first_step {
483            self.column_names = OnceCell::new();
484        }
485        res
486    }
487
488    unsafe fn column_name(&self, idx: i32) -> *const str {
500        let name = {
501            let column_name = unsafe {
502                ffi::sqlite3_column_name(self.statement.statement.inner_statement.as_ptr(), idx)
503            };
504            if !!column_name.is_null() {
    {
        ::core::panicking::panic_fmt(format_args!("The Sqlite documentation states that it only returns a null pointer here if we are in a OOM condition."));
    }
};assert!(
505                !column_name.is_null(),
506                "The Sqlite documentation states that it only returns a \
507                 null pointer here if we are in a OOM condition."
508            );
509            unsafe { CStr::from_ptr(column_name) }
510        };
511        name.to_str().expect(
512            "The Sqlite documentation states that this is UTF8. \
513             If you see this error message something has gone \
514             horribly wrong. Please open an issue at the \
515             diesel repository.",
516        ) as *const str
517    }
518
519    pub(super) fn column_count(&self) -> i32 {
520        unsafe { ffi::sqlite3_column_count(self.statement.statement.inner_statement.as_ptr()) }
521    }
522
523    pub(super) fn index_for_column_name(&mut self, field_name: &str) -> Option<usize> {
524        (0..self.column_count())
525            .find(|idx| self.field_name(*idx) == Some(field_name))
526            .map(|v| {
527                v.try_into()
528                    .expect("Diesel expects to run at least on a 32 bit platform")
529            })
530    }
531
532    pub(super) fn field_name(&self, idx: i32) -> Option<&str> {
533        let column_names = self.column_names.get_or_init(|| {
534            let count = self.column_count();
535            (0..count)
536                .map(|idx| unsafe {
537                    self.column_name(idx)
540                })
541                .collect()
542        });
543
544        column_names
545            .get(usize::try_from(idx).expect("Diesel expects to run at least on a 32 bit platform"))
546            .and_then(|c| unsafe { c.as_ref() })
547    }
548
549    pub(super) fn copy_value(&self, idx: i32) -> Option<OwnedSqliteValue> {
550        OwnedSqliteValue::copy_from_ptr(self.column_value(idx)?)
551    }
552
553    pub(super) fn column_value(&self, idx: i32) -> Option<NonNull<ffi::sqlite3_value>> {
554        let ptr = unsafe {
555            ffi::sqlite3_column_value(self.statement.statement.inner_statement.as_ptr(), idx)
556        };
557        NonNull::new(ptr)
558    }
559}
560
561#[cfg(test)]
562mod tests {
563    use crate::prelude::*;
564    use crate::sql_types::Text;
565
566    #[diesel_test_helper::test]
569    fn check_out_of_bounds_bind_does_not_panic_on_drop() {
570        let mut conn = SqliteConnection::establish(":memory:").unwrap();
571
572        let e = crate::sql_query("SELECT '?'")
573            .bind::<Text, _>("foo")
574            .execute(&mut conn);
575
576        assert!(e.is_err());
577        let e = e.unwrap_err();
578        if let crate::result::Error::DatabaseError(crate::result::DatabaseErrorKind::Unknown, m) = e
579        {
580            assert_eq!(m.message(), "column index out of range");
581        } else {
582            panic!("Wrong error returned");
583        }
584    }
585}