Skip to main content

diesel/sqlite/connection/
stmt.rs

1#![allow(unsafe_code)] // fii code
2use super::bind_collector::{InternalSqliteBindValue, SqliteBindCollector};
3use super::raw::RawConnection;
4use super::sqlite_value::OwnedSqliteValue;
5use crate::connection::Instrumentation;
6use crate::connection::statement_cache::{MaybeCached, PrepareForCache};
7use crate::query_builder::{QueryFragment, QueryId};
8use crate::result::Error::DatabaseError;
9use crate::result::*;
10use crate::sqlite::{Sqlite, SqliteType};
11use alloc::boxed::Box;
12use alloc::ffi::CString;
13use alloc::string::String;
14use alloc::vec::Vec;
15use core::cell::OnceCell;
16use core::ffi as libc;
17use core::ffi::CStr;
18use core::ptr::{self, NonNull};
19#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
20use libsqlite3_sys as ffi;
21#[cfg(all(target_family = "wasm", target_os = "unknown"))]
22use sqlite_wasm_rs as ffi;
23
24pub(super) struct Statement {
25    inner_statement: NonNull<ffi::sqlite3_stmt>,
26}
27
28// This relies on the invariant that RawConnection or Statement are never
29// leaked. If a reference to one of those was held on a different thread, this
30// would not be thread safe.
31#[allow(unsafe_code)]
32unsafe impl Send for Statement {}
33
34impl Statement {
35    pub(super) fn prepare(
36        raw_connection: &RawConnection,
37        sql: &str,
38        is_cached: PrepareForCache,
39        _: &[SqliteType],
40    ) -> QueryResult<Self> {
41        let mut stmt = ptr::null_mut();
42        let mut unused_portion = ptr::null();
43        let n_byte = sql
44            .len()
45            .try_into()
46            .map_err(|e| Error::SerializationError(Box::new(e)))?;
47        // the cast for `ffi::SQLITE_PREPARE_PERSISTENT` is required for old libsqlite3-sys versions
48        #[allow(clippy::unnecessary_cast)]
49        let prepare_result = unsafe {
50            ffi::sqlite3_prepare_v3(
51                raw_connection.internal_connection.as_ptr(),
52                CString::new(sql)?.as_ptr(),
53                n_byte,
54                if #[allow(non_exhaustive_omitted_patterns)] match is_cached {
    PrepareForCache::Yes { counter: _ } => true,
    _ => false,
}matches!(is_cached, PrepareForCache::Yes { counter: _ }) {
55                    ffi::SQLITE_PREPARE_PERSISTENT as u32
56                } else {
57                    0
58                },
59                &mut stmt,
60                &mut unused_portion,
61            )
62        };
63
64        ensure_sqlite_ok(prepare_result, raw_connection.internal_connection.as_ptr())?;
65
66        // sqlite3_prepare_v3 returns a null pointer for empty statements. This includes
67        // empty or only whitespace strings or any other non-op query string like a comment
68        let inner_statement = NonNull::new(stmt).ok_or_else(|| {
69            crate::result::Error::QueryBuilderError(Box::new(crate::result::EmptyQuery))
70        })?;
71        Ok(Statement { inner_statement })
72    }
73
74    // The caller of this function has to ensure that:
75    // * Any buffer provided as `SqliteBindValue::BorrowedBinary`, `SqliteBindValue::Binary`
76    // `SqliteBindValue::String` or `SqliteBindValue::BorrowedString` is valid
77    // till either a new value is bound to the same parameter or the underlying
78    // prepared statement is dropped.
79    unsafe fn bind(
80        &mut self,
81        tpe: SqliteType,
82        value: InternalSqliteBindValue<'_>,
83        bind_index: i32,
84    ) -> QueryResult<Option<NonNull<[u8]>>> {
85        let mut ret_ptr = None;
86        let result = match (tpe, value) {
87            (_, InternalSqliteBindValue::Null) => unsafe {
88                ffi::sqlite3_bind_null(self.inner_statement.as_ptr(), bind_index)
89            },
90            (SqliteType::Binary, InternalSqliteBindValue::BorrowedBinary(bytes)) => {
91                let n = bytes
92                    .len()
93                    .try_into()
94                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
95                unsafe {
96                    ffi::sqlite3_bind_blob(
97                        self.inner_statement.as_ptr(),
98                        bind_index,
99                        bytes.as_ptr() as *const libc::c_void,
100                        n,
101                        ffi::SQLITE_STATIC(),
102                    )
103                }
104            }
105            (SqliteType::Binary, InternalSqliteBindValue::Binary(mut bytes)) => {
106                let len = bytes
107                    .len()
108                    .try_into()
109                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
110                // We need a separate pointer here to pass it to sqlite
111                // as the returned pointer is a pointer to a dyn sized **slice**
112                // and not the pointer to the first element of the slice
113                let ptr = bytes.as_mut_ptr();
114                ret_ptr = NonNull::new(Box::into_raw(bytes));
115                unsafe {
116                    ffi::sqlite3_bind_blob(
117                        self.inner_statement.as_ptr(),
118                        bind_index,
119                        ptr as *const libc::c_void,
120                        len,
121                        ffi::SQLITE_STATIC(),
122                    )
123                }
124            }
125            (SqliteType::Text, InternalSqliteBindValue::BorrowedString(bytes)) => {
126                let len = bytes
127                    .len()
128                    .try_into()
129                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
130                unsafe {
131                    ffi::sqlite3_bind_text(
132                        self.inner_statement.as_ptr(),
133                        bind_index,
134                        bytes.as_ptr() as *const libc::c_char,
135                        len,
136                        ffi::SQLITE_STATIC(),
137                    )
138                }
139            }
140            (SqliteType::Text, InternalSqliteBindValue::String(bytes)) => {
141                let mut bytes = Box::<[u8]>::from(bytes);
142                let len = bytes
143                    .len()
144                    .try_into()
145                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
146                // We need a separate pointer here to pass it to sqlite
147                // as the returned pointer is a pointer to a dyn sized **slice**
148                // and not the pointer to the first element of the slice
149                let ptr = bytes.as_mut_ptr();
150                ret_ptr = NonNull::new(Box::into_raw(bytes));
151                unsafe {
152                    ffi::sqlite3_bind_text(
153                        self.inner_statement.as_ptr(),
154                        bind_index,
155                        ptr as *const libc::c_char,
156                        len,
157                        ffi::SQLITE_STATIC(),
158                    )
159                }
160            }
161            (SqliteType::Float, InternalSqliteBindValue::F64(value))
162            | (SqliteType::Double, InternalSqliteBindValue::F64(value)) => unsafe {
163                ffi::sqlite3_bind_double(
164                    self.inner_statement.as_ptr(),
165                    bind_index,
166                    value as libc::c_double,
167                )
168            },
169            (SqliteType::SmallInt, InternalSqliteBindValue::I32(value))
170            | (SqliteType::Integer, InternalSqliteBindValue::I32(value)) => unsafe {
171                ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, value)
172            },
173            (SqliteType::Long, InternalSqliteBindValue::I64(value)) => unsafe {
174                ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, value)
175            },
176            (t, b) => {
177                return Err(Error::SerializationError(
178                    ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("Type mismatch: Expected {0:?}, got {1}",
                t, b))
    })alloc::format!("Type mismatch: Expected {t:?}, got {b}").into(),
179                ));
180            }
181        };
182        match ensure_sqlite_ok(result, self.raw_connection()) {
183            Ok(()) => Ok(ret_ptr),
184            Err(e) => {
185                if let Some(ptr) = ret_ptr {
186                    // This is a `NonNul` ptr so it cannot be null
187                    // It points to a slice internally as we did not apply
188                    // any cast above.
189                    core::mem::drop(unsafe { Box::from_raw(ptr.as_ptr()) })
190                }
191                Err(e)
192            }
193        }
194    }
195
196    fn reset(&mut self) {
197        unsafe { ffi::sqlite3_reset(self.inner_statement.as_ptr()) };
198    }
199
200    fn raw_connection(&self) -> *mut ffi::sqlite3 {
201        unsafe { ffi::sqlite3_db_handle(self.inner_statement.as_ptr()) }
202    }
203}
204
205pub(super) fn ensure_sqlite_ok(
206    code: libc::c_int,
207    raw_connection: *mut ffi::sqlite3,
208) -> QueryResult<()> {
209    if code == ffi::SQLITE_OK {
210        Ok(())
211    } else {
212        Err(last_error(raw_connection))
213    }
214}
215
216fn last_error(raw_connection: *mut ffi::sqlite3) -> Error {
217    let error_message = last_error_message(raw_connection);
218    let error_code = last_error_code(raw_connection);
219    let error_kind = match error_code {
220        ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY => {
221            DatabaseErrorKind::UniqueViolation
222        }
223        ffi::SQLITE_CONSTRAINT_FOREIGNKEY => DatabaseErrorKind::ForeignKeyViolation,
224        // SQLITE_CONSTRAINT_TRIGGER is returned for ON DELETE RESTRICT violations,
225        // which are actually foreign key violations. We check the error message
226        // to distinguish from user-defined trigger failures.
227        ffi::SQLITE_CONSTRAINT_TRIGGER
228            if error_message.contains("FOREIGN KEY constraint failed") =>
229        {
230            DatabaseErrorKind::ForeignKeyViolation
231        }
232        ffi::SQLITE_CONSTRAINT_NOTNULL => DatabaseErrorKind::NotNullViolation,
233        ffi::SQLITE_CONSTRAINT_CHECK => DatabaseErrorKind::CheckViolation,
234        _ => DatabaseErrorKind::Unknown,
235    };
236    let error_information = Box::new(error_message);
237    DatabaseError(error_kind, error_information)
238}
239
240fn last_error_message(conn: *mut ffi::sqlite3) -> String {
241    let c_str = unsafe { CStr::from_ptr(ffi::sqlite3_errmsg(conn)) };
242    c_str.to_string_lossy().into_owned()
243}
244
245fn last_error_code(conn: *mut ffi::sqlite3) -> libc::c_int {
246    unsafe { ffi::sqlite3_extended_errcode(conn) }
247}
248
249impl Drop for Statement {
250    fn drop(&mut self) {
251        use crate::util::std_compat::panicking;
252
253        let raw_connection = self.raw_connection();
254        let finalize_result = unsafe { ffi::sqlite3_finalize(self.inner_statement.as_ptr()) };
255        if let Err(e) = ensure_sqlite_ok(finalize_result, raw_connection) {
256            if panicking() {
257                #[cfg(feature = "std")]
258                {
    ::std::io::_eprint(format_args!("Error finalizing SQLite prepared statement: {0:?}\n",
            e));
};eprintln!("Error finalizing SQLite prepared statement: {e:?}");
259            } else {
260                {
    ::core::panicking::panic_fmt(format_args!("Error finalizing SQLite prepared statement: {0:?}",
            e));
};panic!("Error finalizing SQLite prepared statement: {e:?}");
261            }
262        }
263    }
264}
265
266// A warning for future editors:
267// Changing this code to something "simpler" may
268// introduce undefined behaviour. Make sure you read
269// the following discussions for details about
270// the current version:
271//
272// * https://github.com/weiznich/diesel/pull/7
273// * https://users.rust-lang.org/t/code-review-for-unsafe-code-in-diesel/66798/
274// * https://github.com/rust-lang/unsafe-code-guidelines/issues/194
275struct BoundStatement<'stmt, 'query> {
276    statement: MaybeCached<'stmt, Statement>,
277    // we need to store the query here to ensure no one does
278    // drop it till the end of the statement
279    // We use a boxed queryfragment here just to erase the
280    // generic type, we use NonNull to communicate
281    // that this is a shared buffer
282    query: Option<NonNull<dyn QueryFragment<Sqlite> + 'query>>,
283    // we need to store any owned bind values separately, as they are not
284    // contained in the query itself. We use NonNull to
285    // communicate that this is a shared buffer
286    binds_to_free: Vec<(i32, Option<NonNull<[u8]>>)>,
287    instrumentation: &'stmt mut dyn Instrumentation,
288    has_error: bool,
289}
290
291impl<'stmt, 'query> BoundStatement<'stmt, 'query> {
292    fn bind<T>(
293        statement: MaybeCached<'stmt, Statement>,
294        query: T,
295        instrumentation: &'stmt mut dyn Instrumentation,
296    ) -> QueryResult<BoundStatement<'stmt, 'query>>
297    where
298        T: QueryFragment<Sqlite> + QueryId + 'query,
299    {
300        // Don't use a trait object here to prevent using a virtual function call
301        // For sqlite this can introduce a measurable overhead
302        // Query is boxed here to make sure it won't move in memory anymore, so any bind
303        // it could output would stay valid.
304        let query = Box::new(query);
305
306        let mut bind_collector = SqliteBindCollector::new();
307        query.collect_binds(&mut bind_collector, &mut (), &Sqlite)?;
308        let SqliteBindCollector { binds } = bind_collector;
309
310        let mut ret = BoundStatement {
311            statement,
312            query: None,
313            binds_to_free: Vec::new(),
314            instrumentation,
315            has_error: false,
316        };
317
318        ret.bind_buffers(binds)?;
319
320        let query = query as Box<dyn QueryFragment<Sqlite> + 'query>;
321        ret.query = NonNull::new(Box::into_raw(query));
322
323        Ok(ret)
324    }
325
326    // This is a separated function so that
327    // not the whole constructor is generic over the query type T.
328    // This hopefully prevents binary bloat.
329    fn bind_buffers(
330        &mut self,
331        binds: Vec<(InternalSqliteBindValue<'_>, SqliteType)>,
332    ) -> QueryResult<()> {
333        // It is useful to preallocate `binds_to_free` because it
334        // - Guarantees that pushing inside it cannot panic, which guarantees the `Drop`
335        //   impl of `BoundStatement` will always re-`bind` as needed
336        // - Avoids reallocations
337        self.binds_to_free.reserve(
338            binds
339                .iter()
340                .filter(|&(b, _)| {
341                    #[allow(non_exhaustive_omitted_patterns)] match b {
    InternalSqliteBindValue::BorrowedBinary(_) |
        InternalSqliteBindValue::BorrowedString(_) |
        InternalSqliteBindValue::String(_) |
        InternalSqliteBindValue::Binary(_) => true,
    _ => false,
}matches!(
342                        b,
343                        InternalSqliteBindValue::BorrowedBinary(_)
344                            | InternalSqliteBindValue::BorrowedString(_)
345                            | InternalSqliteBindValue::String(_)
346                            | InternalSqliteBindValue::Binary(_)
347                    )
348                })
349                .count(),
350        );
351        for (bind_idx, (bind, tpe)) in (1..).zip(binds) {
352            let is_borrowed_bind = #[allow(non_exhaustive_omitted_patterns)] match bind {
    InternalSqliteBindValue::BorrowedString(_) |
        InternalSqliteBindValue::BorrowedBinary(_) => true,
    _ => false,
}matches!(
353                bind,
354                InternalSqliteBindValue::BorrowedString(_)
355                    | InternalSqliteBindValue::BorrowedBinary(_)
356            );
357
358            // It's safe to call bind here as:
359            // * The type and value matches
360            // * We ensure that corresponding buffers lives long enough below
361            // * The statement is not used yet by `step` or anything else
362            let res = unsafe { self.statement.bind(tpe, bind, bind_idx) }?;
363
364            // it's important to push these only after
365            // the call to bind succeeded, otherwise we might attempt to
366            // call bind to an non-existing bind position in
367            // the destructor
368            if let Some(ptr) = res {
369                // Store the id + pointer for a owned bind
370                // as we must unbind and free them on drop
371                self.binds_to_free.push((bind_idx, Some(ptr)));
372            } else if is_borrowed_bind {
373                // Store the id's of borrowed binds to unbind them on drop
374                self.binds_to_free.push((bind_idx, None));
375            }
376        }
377        Ok(())
378    }
379
380    fn finish_query_with_error(mut self, e: &Error) {
381        self.has_error = true;
382        if let Some(q) = self.query {
383            // it's safe to get a reference from this ptr as it's guaranteed to not be null
384            let q = unsafe { q.as_ref() };
385            self.instrumentation.on_connection_event(
386                crate::connection::InstrumentationEvent::FinishQuery {
387                    query: &crate::debug_query(&q),
388                    error: Some(e),
389                },
390            );
391        }
392    }
393}
394
395impl Drop for BoundStatement<'_, '_> {
396    fn drop(&mut self) {
397        // First reset the statement, otherwise the bind calls
398        // below will fails
399        self.statement.reset();
400
401        for (idx, buffer) in core::mem::take(&mut self.binds_to_free) {
402            unsafe {
403                // It's always safe to bind null values, as there is no buffer that needs to outlife something
404                self.statement
405                    .bind(SqliteType::Text, InternalSqliteBindValue::Null, idx)
406                    .expect(
407                        "Binding a null value should never fail. \
408                             If you ever see this error message please open \
409                             an issue at diesels issue tracker containing \
410                             code how to trigger this message.",
411                    );
412            }
413
414            if let Some(buffer) = buffer {
415                unsafe {
416                    // Constructing the `Box` here is safe as we
417                    // got the pointer from a box + it is guaranteed to be not null.
418                    core::mem::drop(Box::from_raw(buffer.as_ptr()));
419                }
420            }
421        }
422
423        if let Some(query) = self.query {
424            let query = unsafe {
425                // Constructing the `Box` here is safe as we
426                // got the pointer from a box + it is guaranteed to be not null.
427                Box::from_raw(query.as_ptr())
428            };
429            if !self.has_error {
430                self.instrumentation.on_connection_event(
431                    crate::connection::InstrumentationEvent::FinishQuery {
432                        query: &crate::debug_query(&query),
433                        error: None,
434                    },
435                );
436            }
437            core::mem::drop(query);
438            self.query = None;
439        }
440    }
441}
442
443#[allow(missing_debug_implementations)]
444pub struct StatementUse<'stmt, 'query> {
445    statement: BoundStatement<'stmt, 'query>,
446    column_names: OnceCell<Vec<*const str>>,
447}
448
449impl<'stmt, 'query> StatementUse<'stmt, 'query> {
450    pub(super) fn bind<T>(
451        statement: MaybeCached<'stmt, Statement>,
452        query: T,
453        instrumentation: &'stmt mut dyn Instrumentation,
454    ) -> QueryResult<StatementUse<'stmt, 'query>>
455    where
456        T: QueryFragment<Sqlite> + QueryId + 'query,
457    {
458        Ok(Self {
459            statement: BoundStatement::bind(statement, query, instrumentation)?,
460            column_names: OnceCell::new(),
461        })
462    }
463
464    pub(super) fn run(mut self) -> QueryResult<()> {
465        let r = unsafe {
466            // This is safe as we pass `first_step = true`
467            // and we consume the statement so nobody could
468            // access the columns later on anyway.
469            self.step(true).map(|_| ())
470        };
471        if let Err(ref e) = r {
472            self.statement.finish_query_with_error(e);
473        }
474        r
475    }
476
477    // This function is marked as unsafe incorrectly passing `false` to `first_step`
478    // for a first call to this function could cause access to freed memory via
479    // the cached column names.
480    //
481    // It's always safe to call this function with `first_step = true` as this removes
482    // the cached column names
483    pub(super) unsafe fn step(&mut self, first_step: bool) -> QueryResult<bool> {
484        let step_result =
485            unsafe { ffi::sqlite3_step(self.statement.statement.inner_statement.as_ptr()) };
486        let res = match step_result {
487            ffi::SQLITE_DONE => Ok(false),
488            ffi::SQLITE_ROW => Ok(true),
489            _ => Err(last_error(self.statement.statement.raw_connection())),
490        };
491        if first_step {
492            self.column_names = OnceCell::new();
493        }
494        res
495    }
496
497    // The returned string pointer is valid until either the prepared statement is
498    // destroyed by sqlite3_finalize() or until the statement is automatically
499    // reprepared by the first call to sqlite3_step() for a particular run or
500    // until the next call to sqlite3_column_name() or sqlite3_column_name16()
501    // on the same column.
502    //
503    // https://sqlite.org/c3ref/column_name.html
504    //
505    // Note: This function is marked as unsafe, as calling it can invalidate
506    // other existing column name pointers on the same column. To prevent that,
507    // it should maximally be called once per column at all.
508    unsafe fn column_name(&self, idx: i32) -> *const str {
509        let name = {
510            let column_name = unsafe {
511                ffi::sqlite3_column_name(self.statement.statement.inner_statement.as_ptr(), idx)
512            };
513            if !!column_name.is_null() {
    {
        ::core::panicking::panic_fmt(format_args!("The Sqlite documentation states that it only returns a null pointer here if we are in a OOM condition."));
    }
};assert!(
514                !column_name.is_null(),
515                "The Sqlite documentation states that it only returns a \
516                 null pointer here if we are in a OOM condition."
517            );
518            unsafe { CStr::from_ptr(column_name) }
519        };
520        name.to_str().expect(
521            "The Sqlite documentation states that this is UTF8. \
522             If you see this error message something has gone \
523             horribly wrong. Please open an issue at the \
524             diesel repository.",
525        ) as *const str
526    }
527
528    pub(super) fn column_count(&self) -> i32 {
529        unsafe { ffi::sqlite3_column_count(self.statement.statement.inner_statement.as_ptr()) }
530    }
531
532    pub(super) fn index_for_column_name(&mut self, field_name: &str) -> Option<usize> {
533        (0..self.column_count())
534            .find(|idx| self.field_name(*idx) == Some(field_name))
535            .map(|v| {
536                v.try_into()
537                    .expect("Diesel expects to run at least on a 32 bit platform")
538            })
539    }
540
541    pub(super) fn field_name(&self, idx: i32) -> Option<&str> {
542        let column_names = self.column_names.get_or_init(|| {
543            let count = self.column_count();
544            (0..count)
545                .map(|idx| unsafe {
546                    // By initializing the whole vec at once we ensure that
547                    // we really call this only once.
548                    self.column_name(idx)
549                })
550                .collect()
551        });
552
553        column_names
554            .get(usize::try_from(idx).expect("Diesel expects to run at least on a 32 bit platform"))
555            .and_then(|c| unsafe { c.as_ref() })
556    }
557
558    pub(super) fn copy_value(&self, idx: i32) -> Option<OwnedSqliteValue> {
559        OwnedSqliteValue::copy_from_ptr(self.column_value(idx)?)
560    }
561
562    pub(super) fn column_value(&self, idx: i32) -> Option<NonNull<ffi::sqlite3_value>> {
563        let ptr = unsafe {
564            ffi::sqlite3_column_value(self.statement.statement.inner_statement.as_ptr(), idx)
565        };
566        NonNull::new(ptr)
567    }
568}
569
570#[cfg(test)]
571mod tests {
572    use crate::prelude::*;
573    use crate::sql_types::Text;
574
575    // this is a regression test for
576    // https://github.com/diesel-rs/diesel/issues/3558
577    #[diesel_test_helper::test]
578    fn check_out_of_bounds_bind_does_not_panic_on_drop() {
579        let mut conn = SqliteConnection::establish(":memory:").unwrap();
580
581        let e = crate::sql_query("SELECT '?'")
582            .bind::<Text, _>("foo")
583            .execute(&mut conn);
584
585        assert!(e.is_err());
586        let e = e.unwrap_err();
587        if let crate::result::Error::DatabaseError(crate::result::DatabaseErrorKind::Unknown, m) = e
588        {
589            assert_eq!(m.message(), "column index out of range");
590        } else {
591            panic!("Wrong error returned");
592        }
593    }
594}