diesel/sqlite/connection/
stmt.rs

1#![allow(unsafe_code)] // fii code
2use super::bind_collector::{InternalSqliteBindValue, SqliteBindCollector};
3use super::raw::RawConnection;
4use super::sqlite_value::OwnedSqliteValue;
5use crate::connection::statement_cache::{MaybeCached, PrepareForCache};
6use crate::connection::Instrumentation;
7use crate::query_builder::{QueryFragment, QueryId};
8use crate::result::Error::DatabaseError;
9use crate::result::*;
10use crate::sqlite::{Sqlite, SqliteType};
11#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
12use libsqlite3_sys as ffi;
13#[cfg(all(target_family = "wasm", target_os = "unknown"))]
14use sqlite_wasm_rs::export as ffi;
15use std::cell::OnceCell;
16use std::ffi::{CStr, CString};
17use std::io::{stderr, Write};
18use std::os::raw as libc;
19use std::ptr::{self, NonNull};
20
21pub(super) struct Statement {
22    inner_statement: NonNull<ffi::sqlite3_stmt>,
23}
24
25// This relies on the invariant that RawConnection or Statement are never
26// leaked. If a reference to one of those was held on a different thread, this
27// would not be thread safe.
28#[allow(unsafe_code)]
29unsafe impl Send for Statement {}
30
31impl Statement {
32    pub(super) fn prepare(
33        raw_connection: &RawConnection,
34        sql: &str,
35        is_cached: PrepareForCache,
36        _: &[SqliteType],
37    ) -> QueryResult<Self> {
38        let mut stmt = ptr::null_mut();
39        let mut unused_portion = ptr::null();
40        let n_byte = sql
41            .len()
42            .try_into()
43            .map_err(|e| Error::SerializationError(Box::new(e)))?;
44        // the cast for `ffi::SQLITE_PREPARE_PERSISTENT` is required for old libsqlite3-sys versions
45        #[allow(clippy::unnecessary_cast)]
46        let prepare_result = unsafe {
47            ffi::sqlite3_prepare_v3(
48                raw_connection.internal_connection.as_ptr(),
49                CString::new(sql)?.as_ptr(),
50                n_byte,
51                if matches!(is_cached, PrepareForCache::Yes { counter: _ }) {
52                    ffi::SQLITE_PREPARE_PERSISTENT as u32
53                } else {
54                    0
55                },
56                &mut stmt,
57                &mut unused_portion,
58            )
59        };
60
61        ensure_sqlite_ok(prepare_result, raw_connection.internal_connection.as_ptr())?;
62
63        // sqlite3_prepare_v3 returns a null pointer for empty statements. This includes
64        // empty or only whitespace strings or any other non-op query string like a comment
65        let inner_statement = NonNull::new(stmt).ok_or_else(|| {
66            crate::result::Error::QueryBuilderError(Box::new(crate::result::EmptyQuery))
67        })?;
68        Ok(Statement { inner_statement })
69    }
70
71    // The caller of this function has to ensure that:
72    // * Any buffer provided as `SqliteBindValue::BorrowedBinary`, `SqliteBindValue::Binary`
73    // `SqliteBindValue::String` or `SqliteBindValue::BorrowedString` is valid
74    // till either a new value is bound to the same parameter or the underlying
75    // prepared statement is dropped.
76    unsafe fn bind(
77        &mut self,
78        tpe: SqliteType,
79        value: InternalSqliteBindValue<'_>,
80        bind_index: i32,
81    ) -> QueryResult<Option<NonNull<[u8]>>> {
82        let mut ret_ptr = None;
83        let result = match (tpe, value) {
84            (_, InternalSqliteBindValue::Null) => {
85                ffi::sqlite3_bind_null(self.inner_statement.as_ptr(), bind_index)
86            }
87            (SqliteType::Binary, InternalSqliteBindValue::BorrowedBinary(bytes)) => {
88                let n = bytes
89                    .len()
90                    .try_into()
91                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
92                ffi::sqlite3_bind_blob(
93                    self.inner_statement.as_ptr(),
94                    bind_index,
95                    bytes.as_ptr() as *const libc::c_void,
96                    n,
97                    ffi::SQLITE_STATIC(),
98                )
99            }
100            (SqliteType::Binary, InternalSqliteBindValue::Binary(mut bytes)) => {
101                let len = bytes
102                    .len()
103                    .try_into()
104                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
105                // We need a separate pointer here to pass it to sqlite
106                // as the returned pointer is a pointer to a dyn sized **slice**
107                // and not the pointer to the first element of the slice
108                let ptr = bytes.as_mut_ptr();
109                ret_ptr = NonNull::new(Box::into_raw(bytes));
110                ffi::sqlite3_bind_blob(
111                    self.inner_statement.as_ptr(),
112                    bind_index,
113                    ptr as *const libc::c_void,
114                    len,
115                    ffi::SQLITE_STATIC(),
116                )
117            }
118            (SqliteType::Text, InternalSqliteBindValue::BorrowedString(bytes)) => {
119                let len = bytes
120                    .len()
121                    .try_into()
122                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
123                ffi::sqlite3_bind_text(
124                    self.inner_statement.as_ptr(),
125                    bind_index,
126                    bytes.as_ptr() as *const libc::c_char,
127                    len,
128                    ffi::SQLITE_STATIC(),
129                )
130            }
131            (SqliteType::Text, InternalSqliteBindValue::String(bytes)) => {
132                let mut bytes = Box::<[u8]>::from(bytes);
133                let len = bytes
134                    .len()
135                    .try_into()
136                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
137                // We need a separate pointer here to pass it to sqlite
138                // as the returned pointer is a pointer to a dyn sized **slice**
139                // and not the pointer to the first element of the slice
140                let ptr = bytes.as_mut_ptr();
141                ret_ptr = NonNull::new(Box::into_raw(bytes));
142                ffi::sqlite3_bind_text(
143                    self.inner_statement.as_ptr(),
144                    bind_index,
145                    ptr as *const libc::c_char,
146                    len,
147                    ffi::SQLITE_STATIC(),
148                )
149            }
150            (SqliteType::Float, InternalSqliteBindValue::F64(value))
151            | (SqliteType::Double, InternalSqliteBindValue::F64(value)) => {
152                ffi::sqlite3_bind_double(
153                    self.inner_statement.as_ptr(),
154                    bind_index,
155                    value as libc::c_double,
156                )
157            }
158            (SqliteType::SmallInt, InternalSqliteBindValue::I32(value))
159            | (SqliteType::Integer, InternalSqliteBindValue::I32(value)) => {
160                ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, value)
161            }
162            (SqliteType::Long, InternalSqliteBindValue::I64(value)) => {
163                ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, value)
164            }
165            (t, b) => {
166                return Err(Error::SerializationError(
167                    format!("Type mismatch: Expected {t:?}, got {b}").into(),
168                ))
169            }
170        };
171        match ensure_sqlite_ok(result, self.raw_connection()) {
172            Ok(()) => Ok(ret_ptr),
173            Err(e) => {
174                if let Some(ptr) = ret_ptr {
175                    // This is a `NonNul` ptr so it cannot be null
176                    // It points to a slice internally as we did not apply
177                    // any cast above.
178                    std::mem::drop(Box::from_raw(ptr.as_ptr()))
179                }
180                Err(e)
181            }
182        }
183    }
184
185    fn reset(&mut self) {
186        unsafe { ffi::sqlite3_reset(self.inner_statement.as_ptr()) };
187    }
188
189    fn raw_connection(&self) -> *mut ffi::sqlite3 {
190        unsafe { ffi::sqlite3_db_handle(self.inner_statement.as_ptr()) }
191    }
192}
193
194pub(super) fn ensure_sqlite_ok(
195    code: libc::c_int,
196    raw_connection: *mut ffi::sqlite3,
197) -> QueryResult<()> {
198    if code == ffi::SQLITE_OK {
199        Ok(())
200    } else {
201        Err(last_error(raw_connection))
202    }
203}
204
205fn last_error(raw_connection: *mut ffi::sqlite3) -> Error {
206    let error_message = last_error_message(raw_connection);
207    let error_information = Box::new(error_message);
208    let error_kind = match last_error_code(raw_connection) {
209        ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY => {
210            DatabaseErrorKind::UniqueViolation
211        }
212        ffi::SQLITE_CONSTRAINT_FOREIGNKEY => DatabaseErrorKind::ForeignKeyViolation,
213        ffi::SQLITE_CONSTRAINT_NOTNULL => DatabaseErrorKind::NotNullViolation,
214        ffi::SQLITE_CONSTRAINT_CHECK => DatabaseErrorKind::CheckViolation,
215        _ => DatabaseErrorKind::Unknown,
216    };
217    DatabaseError(error_kind, error_information)
218}
219
220fn last_error_message(conn: *mut ffi::sqlite3) -> String {
221    let c_str = unsafe { CStr::from_ptr(ffi::sqlite3_errmsg(conn)) };
222    c_str.to_string_lossy().into_owned()
223}
224
225fn last_error_code(conn: *mut ffi::sqlite3) -> libc::c_int {
226    unsafe { ffi::sqlite3_extended_errcode(conn) }
227}
228
229impl Drop for Statement {
230    fn drop(&mut self) {
231        use std::thread::panicking;
232
233        let raw_connection = self.raw_connection();
234        let finalize_result = unsafe { ffi::sqlite3_finalize(self.inner_statement.as_ptr()) };
235        if let Err(e) = ensure_sqlite_ok(finalize_result, raw_connection) {
236            if panicking() {
237                write!(
238                    stderr(),
239                    "Error finalizing SQLite prepared statement: {e:?}"
240                )
241                .expect("Error writing to `stderr`");
242            } else {
243                panic!("Error finalizing SQLite prepared statement: {:?}", e);
244            }
245        }
246    }
247}
248
249// A warning for future editors:
250// Changing this code to something "simpler" may
251// introduce undefined behaviour. Make sure you read
252// the following discussions for details about
253// the current version:
254//
255// * https://github.com/weiznich/diesel/pull/7
256// * https://users.rust-lang.org/t/code-review-for-unsafe-code-in-diesel/66798/
257// * https://github.com/rust-lang/unsafe-code-guidelines/issues/194
258struct BoundStatement<'stmt, 'query> {
259    statement: MaybeCached<'stmt, Statement>,
260    // we need to store the query here to ensure no one does
261    // drop it till the end of the statement
262    // We use a boxed queryfragment here just to erase the
263    // generic type, we use NonNull to communicate
264    // that this is a shared buffer
265    query: Option<NonNull<dyn QueryFragment<Sqlite> + 'query>>,
266    // we need to store any owned bind values separately, as they are not
267    // contained in the query itself. We use NonNull to
268    // communicate that this is a shared buffer
269    binds_to_free: Vec<(i32, Option<NonNull<[u8]>>)>,
270    instrumentation: &'stmt mut dyn Instrumentation,
271    has_error: bool,
272}
273
274impl<'stmt, 'query> BoundStatement<'stmt, 'query> {
275    fn bind<T>(
276        statement: MaybeCached<'stmt, Statement>,
277        query: T,
278        instrumentation: &'stmt mut dyn Instrumentation,
279    ) -> QueryResult<BoundStatement<'stmt, 'query>>
280    where
281        T: QueryFragment<Sqlite> + QueryId + 'query,
282    {
283        // Don't use a trait object here to prevent using a virtual function call
284        // For sqlite this can introduce a measurable overhead
285        // Query is boxed here to make sure it won't move in memory anymore, so any bind
286        // it could output would stay valid.
287        let query = Box::new(query);
288
289        let mut bind_collector = SqliteBindCollector::new();
290        query.collect_binds(&mut bind_collector, &mut (), &Sqlite)?;
291        let SqliteBindCollector { binds } = bind_collector;
292
293        let mut ret = BoundStatement {
294            statement,
295            query: None,
296            binds_to_free: Vec::new(),
297            instrumentation,
298            has_error: false,
299        };
300
301        ret.bind_buffers(binds)?;
302
303        let query = query as Box<dyn QueryFragment<Sqlite> + 'query>;
304        ret.query = NonNull::new(Box::into_raw(query));
305
306        Ok(ret)
307    }
308
309    // This is a separated function so that
310    // not the whole constructor is generic over the query type T.
311    // This hopefully prevents binary bloat.
312    fn bind_buffers(
313        &mut self,
314        binds: Vec<(InternalSqliteBindValue<'_>, SqliteType)>,
315    ) -> QueryResult<()> {
316        // It is useful to preallocate `binds_to_free` because it
317        // - Guarantees that pushing inside it cannot panic, which guarantees the `Drop`
318        //   impl of `BoundStatement` will always re-`bind` as needed
319        // - Avoids reallocations
320        self.binds_to_free.reserve(
321            binds
322                .iter()
323                .filter(|&(b, _)| {
324                    matches!(
325                        b,
326                        InternalSqliteBindValue::BorrowedBinary(_)
327                            | InternalSqliteBindValue::BorrowedString(_)
328                            | InternalSqliteBindValue::String(_)
329                            | InternalSqliteBindValue::Binary(_)
330                    )
331                })
332                .count(),
333        );
334        for (bind_idx, (bind, tpe)) in (1..).zip(binds) {
335            let is_borrowed_bind = matches!(
336                bind,
337                InternalSqliteBindValue::BorrowedString(_)
338                    | InternalSqliteBindValue::BorrowedBinary(_)
339            );
340
341            // It's safe to call bind here as:
342            // * The type and value matches
343            // * We ensure that corresponding buffers lives long enough below
344            // * The statement is not used yet by `step` or anything else
345            let res = unsafe { self.statement.bind(tpe, bind, bind_idx) }?;
346
347            // it's important to push these only after
348            // the call to bind succeeded, otherwise we might attempt to
349            // call bind to an non-existing bind position in
350            // the destructor
351            if let Some(ptr) = res {
352                // Store the id + pointer for a owned bind
353                // as we must unbind and free them on drop
354                self.binds_to_free.push((bind_idx, Some(ptr)));
355            } else if is_borrowed_bind {
356                // Store the id's of borrowed binds to unbind them on drop
357                self.binds_to_free.push((bind_idx, None));
358            }
359        }
360        Ok(())
361    }
362
363    fn finish_query_with_error(mut self, e: &Error) {
364        self.has_error = true;
365        if let Some(q) = self.query {
366            // it's safe to get a reference from this ptr as it's guaranteed to not be null
367            let q = unsafe { q.as_ref() };
368            self.instrumentation.on_connection_event(
369                crate::connection::InstrumentationEvent::FinishQuery {
370                    query: &crate::debug_query(&q),
371                    error: Some(e),
372                },
373            );
374        }
375    }
376}
377
378impl Drop for BoundStatement<'_, '_> {
379    fn drop(&mut self) {
380        // First reset the statement, otherwise the bind calls
381        // below will fails
382        self.statement.reset();
383
384        for (idx, buffer) in std::mem::take(&mut self.binds_to_free) {
385            unsafe {
386                // It's always safe to bind null values, as there is no buffer that needs to outlife something
387                self.statement
388                    .bind(SqliteType::Text, InternalSqliteBindValue::Null, idx)
389                    .expect(
390                        "Binding a null value should never fail. \
391                             If you ever see this error message please open \
392                             an issue at diesels issue tracker containing \
393                             code how to trigger this message.",
394                    );
395            }
396
397            if let Some(buffer) = buffer {
398                unsafe {
399                    // Constructing the `Box` here is safe as we
400                    // got the pointer from a box + it is guaranteed to be not null.
401                    std::mem::drop(Box::from_raw(buffer.as_ptr()));
402                }
403            }
404        }
405
406        if let Some(query) = self.query {
407            let query = unsafe {
408                // Constructing the `Box` here is safe as we
409                // got the pointer from a box + it is guaranteed to be not null.
410                Box::from_raw(query.as_ptr())
411            };
412            if !self.has_error {
413                self.instrumentation.on_connection_event(
414                    crate::connection::InstrumentationEvent::FinishQuery {
415                        query: &crate::debug_query(&query),
416                        error: None,
417                    },
418                );
419            }
420            std::mem::drop(query);
421            self.query = None;
422        }
423    }
424}
425
426#[allow(missing_debug_implementations)]
427pub struct StatementUse<'stmt, 'query> {
428    statement: BoundStatement<'stmt, 'query>,
429    column_names: OnceCell<Vec<*const str>>,
430}
431
432impl<'stmt, 'query> StatementUse<'stmt, 'query> {
433    pub(super) fn bind<T>(
434        statement: MaybeCached<'stmt, Statement>,
435        query: T,
436        instrumentation: &'stmt mut dyn Instrumentation,
437    ) -> QueryResult<StatementUse<'stmt, 'query>>
438    where
439        T: QueryFragment<Sqlite> + QueryId + 'query,
440    {
441        Ok(Self {
442            statement: BoundStatement::bind(statement, query, instrumentation)?,
443            column_names: OnceCell::new(),
444        })
445    }
446
447    pub(super) fn run(mut self) -> QueryResult<()> {
448        let r = unsafe {
449            // This is safe as we pass `first_step = true`
450            // and we consume the statement so nobody could
451            // access the columns later on anyway.
452            self.step(true).map(|_| ())
453        };
454        if let Err(ref e) = r {
455            self.statement.finish_query_with_error(e);
456        }
457        r
458    }
459
460    // This function is marked as unsafe incorrectly passing `false` to `first_step`
461    // for a first call to this function could cause access to freed memory via
462    // the cached column names.
463    //
464    // It's always safe to call this function with `first_step = true` as this removes
465    // the cached column names
466    pub(super) unsafe fn step(&mut self, first_step: bool) -> QueryResult<bool> {
467        let res = match ffi::sqlite3_step(self.statement.statement.inner_statement.as_ptr()) {
468            ffi::SQLITE_DONE => Ok(false),
469            ffi::SQLITE_ROW => Ok(true),
470            _ => Err(last_error(self.statement.statement.raw_connection())),
471        };
472        if first_step {
473            self.column_names = OnceCell::new();
474        }
475        res
476    }
477
478    // The returned string pointer is valid until either the prepared statement is
479    // destroyed by sqlite3_finalize() or until the statement is automatically
480    // reprepared by the first call to sqlite3_step() for a particular run or
481    // until the next call to sqlite3_column_name() or sqlite3_column_name16()
482    // on the same column.
483    //
484    // https://sqlite.org/c3ref/column_name.html
485    //
486    // Note: This function is marked as unsafe, as calling it can invalidate
487    // other existing column name pointers on the same column. To prevent that,
488    // it should maximally be called once per column at all.
489    unsafe fn column_name(&self, idx: i32) -> *const str {
490        let name = {
491            let column_name =
492                ffi::sqlite3_column_name(self.statement.statement.inner_statement.as_ptr(), idx);
493            assert!(
494                !column_name.is_null(),
495                "The Sqlite documentation states that it only returns a \
496                 null pointer here if we are in a OOM condition."
497            );
498            CStr::from_ptr(column_name)
499        };
500        name.to_str().expect(
501            "The Sqlite documentation states that this is UTF8. \
502             If you see this error message something has gone \
503             horribly wrong. Please open an issue at the \
504             diesel repository.",
505        ) as *const str
506    }
507
508    pub(super) fn column_count(&self) -> i32 {
509        unsafe { ffi::sqlite3_column_count(self.statement.statement.inner_statement.as_ptr()) }
510    }
511
512    pub(super) fn index_for_column_name(&mut self, field_name: &str) -> Option<usize> {
513        (0..self.column_count())
514            .find(|idx| self.field_name(*idx) == Some(field_name))
515            .map(|v| {
516                v.try_into()
517                    .expect("Diesel expects to run at least on a 32 bit platform")
518            })
519    }
520
521    pub(super) fn field_name(&self, idx: i32) -> Option<&str> {
522        let column_names = self.column_names.get_or_init(|| {
523            let count = self.column_count();
524            (0..count)
525                .map(|idx| unsafe {
526                    // By initializing the whole vec at once we ensure that
527                    // we really call this only once.
528                    self.column_name(idx)
529                })
530                .collect()
531        });
532
533        column_names
534            .get(usize::try_from(idx).expect("Diesel expects to run at least on a 32 bit platform"))
535            .and_then(|c| unsafe { c.as_ref() })
536    }
537
538    pub(super) fn copy_value(&self, idx: i32) -> Option<OwnedSqliteValue> {
539        OwnedSqliteValue::copy_from_ptr(self.column_value(idx)?)
540    }
541
542    pub(super) fn column_value(&self, idx: i32) -> Option<NonNull<ffi::sqlite3_value>> {
543        let ptr = unsafe {
544            ffi::sqlite3_column_value(self.statement.statement.inner_statement.as_ptr(), idx)
545        };
546        NonNull::new(ptr)
547    }
548}
549
550#[cfg(test)]
551mod tests {
552    use crate::prelude::*;
553    use crate::sql_types::Text;
554
555    // this is a regression test for
556    // https://github.com/diesel-rs/diesel/issues/3558
557    #[diesel_test_helper::test]
558    fn check_out_of_bounds_bind_does_not_panic_on_drop() {
559        let mut conn = SqliteConnection::establish(":memory:").unwrap();
560
561        let e = crate::sql_query("SELECT '?'")
562            .bind::<Text, _>("foo")
563            .execute(&mut conn);
564
565        assert!(e.is_err());
566        let e = e.unwrap_err();
567        if let crate::result::Error::DatabaseError(crate::result::DatabaseErrorKind::Unknown, m) = e
568        {
569            assert_eq!(m.message(), "column index out of range");
570        } else {
571            panic!("Wrong error returned");
572        }
573    }
574}