diesel/sqlite/connection/
stmt.rs

1#![allow(unsafe_code)] // fii code
2use super::bind_collector::{InternalSqliteBindValue, SqliteBindCollector};
3use super::raw::RawConnection;
4use super::sqlite_value::OwnedSqliteValue;
5use crate::connection::statement_cache::{MaybeCached, PrepareForCache};
6use crate::connection::Instrumentation;
7use crate::query_builder::{QueryFragment, QueryId};
8use crate::result::Error::DatabaseError;
9use crate::result::*;
10use crate::sqlite::{Sqlite, SqliteType};
11use libsqlite3_sys as ffi;
12use std::cell::OnceCell;
13use std::ffi::{CStr, CString};
14use std::io::{stderr, Write};
15use std::os::raw as libc;
16use std::ptr::{self, NonNull};
17
18pub(super) struct Statement {
19    inner_statement: NonNull<ffi::sqlite3_stmt>,
20}
21
22impl Statement {
23    pub(super) fn prepare(
24        raw_connection: &RawConnection,
25        sql: &str,
26        is_cached: PrepareForCache,
27    ) -> QueryResult<Self> {
28        let mut stmt = ptr::null_mut();
29        let mut unused_portion = ptr::null();
30        let n_byte = sql
31            .len()
32            .try_into()
33            .map_err(|e| Error::SerializationError(Box::new(e)))?;
34        // the cast for `ffi::SQLITE_PREPARE_PERSISTENT` is required for old libsqlite3-sys versions
35        #[allow(clippy::unnecessary_cast)]
36        let prepare_result = unsafe {
37            ffi::sqlite3_prepare_v3(
38                raw_connection.internal_connection.as_ptr(),
39                CString::new(sql)?.as_ptr(),
40                n_byte,
41                if matches!(is_cached, PrepareForCache::Yes) {
42                    ffi::SQLITE_PREPARE_PERSISTENT as u32
43                } else {
44                    0
45                },
46                &mut stmt,
47                &mut unused_portion,
48            )
49        };
50
51        ensure_sqlite_ok(prepare_result, raw_connection.internal_connection.as_ptr())?;
52
53        // sqlite3_prepare_v3 returns a null pointer for empty statements. This includes
54        // empty or only whitespace strings or any other non-op query string like a comment
55        let inner_statement = NonNull::new(stmt).ok_or_else(|| {
56            crate::result::Error::QueryBuilderError(Box::new(crate::result::EmptyQuery))
57        })?;
58        Ok(Statement { inner_statement })
59    }
60
61    // The caller of this function has to ensure that:
62    // * Any buffer provided as `SqliteBindValue::BorrowedBinary`, `SqliteBindValue::Binary`
63    // `SqliteBindValue::String` or `SqliteBindValue::BorrowedString` is valid
64    // till either a new value is bound to the same parameter or the underlying
65    // prepared statement is dropped.
66    unsafe fn bind(
67        &mut self,
68        tpe: SqliteType,
69        value: InternalSqliteBindValue<'_>,
70        bind_index: i32,
71    ) -> QueryResult<Option<NonNull<[u8]>>> {
72        let mut ret_ptr = None;
73        let result = match (tpe, value) {
74            (_, InternalSqliteBindValue::Null) => {
75                ffi::sqlite3_bind_null(self.inner_statement.as_ptr(), bind_index)
76            }
77            (SqliteType::Binary, InternalSqliteBindValue::BorrowedBinary(bytes)) => {
78                let n = bytes
79                    .len()
80                    .try_into()
81                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
82                ffi::sqlite3_bind_blob(
83                    self.inner_statement.as_ptr(),
84                    bind_index,
85                    bytes.as_ptr() as *const libc::c_void,
86                    n,
87                    ffi::SQLITE_STATIC(),
88                )
89            }
90            (SqliteType::Binary, InternalSqliteBindValue::Binary(mut bytes)) => {
91                let len = bytes
92                    .len()
93                    .try_into()
94                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
95                // We need a separate pointer here to pass it to sqlite
96                // as the returned pointer is a pointer to a dyn sized **slice**
97                // and not the pointer to the first element of the slice
98                let ptr = bytes.as_mut_ptr();
99                ret_ptr = NonNull::new(Box::into_raw(bytes));
100                ffi::sqlite3_bind_blob(
101                    self.inner_statement.as_ptr(),
102                    bind_index,
103                    ptr as *const libc::c_void,
104                    len,
105                    ffi::SQLITE_STATIC(),
106                )
107            }
108            (SqliteType::Text, InternalSqliteBindValue::BorrowedString(bytes)) => {
109                let len = bytes
110                    .len()
111                    .try_into()
112                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
113                ffi::sqlite3_bind_text(
114                    self.inner_statement.as_ptr(),
115                    bind_index,
116                    bytes.as_ptr() as *const libc::c_char,
117                    len,
118                    ffi::SQLITE_STATIC(),
119                )
120            }
121            (SqliteType::Text, InternalSqliteBindValue::String(bytes)) => {
122                let mut bytes = Box::<[u8]>::from(bytes);
123                let len = bytes
124                    .len()
125                    .try_into()
126                    .map_err(|e| Error::SerializationError(Box::new(e)))?;
127                // We need a separate pointer here to pass it to sqlite
128                // as the returned pointer is a pointer to a dyn sized **slice**
129                // and not the pointer to the first element of the slice
130                let ptr = bytes.as_mut_ptr();
131                ret_ptr = NonNull::new(Box::into_raw(bytes));
132                ffi::sqlite3_bind_text(
133                    self.inner_statement.as_ptr(),
134                    bind_index,
135                    ptr as *const libc::c_char,
136                    len,
137                    ffi::SQLITE_STATIC(),
138                )
139            }
140            (SqliteType::Float, InternalSqliteBindValue::F64(value))
141            | (SqliteType::Double, InternalSqliteBindValue::F64(value)) => {
142                ffi::sqlite3_bind_double(
143                    self.inner_statement.as_ptr(),
144                    bind_index,
145                    value as libc::c_double,
146                )
147            }
148            (SqliteType::SmallInt, InternalSqliteBindValue::I32(value))
149            | (SqliteType::Integer, InternalSqliteBindValue::I32(value)) => {
150                ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, value)
151            }
152            (SqliteType::Long, InternalSqliteBindValue::I64(value)) => {
153                ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, value)
154            }
155            (t, b) => {
156                return Err(Error::SerializationError(
157                    format!("Type mismatch: Expected {t:?}, got {b}").into(),
158                ))
159            }
160        };
161        match ensure_sqlite_ok(result, self.raw_connection()) {
162            Ok(()) => Ok(ret_ptr),
163            Err(e) => {
164                if let Some(ptr) = ret_ptr {
165                    // This is a `NonNul` ptr so it cannot be null
166                    // It points to a slice internally as we did not apply
167                    // any cast above.
168                    std::mem::drop(Box::from_raw(ptr.as_ptr()))
169                }
170                Err(e)
171            }
172        }
173    }
174
175    fn reset(&mut self) {
176        unsafe { ffi::sqlite3_reset(self.inner_statement.as_ptr()) };
177    }
178
179    fn raw_connection(&self) -> *mut ffi::sqlite3 {
180        unsafe { ffi::sqlite3_db_handle(self.inner_statement.as_ptr()) }
181    }
182}
183
184pub(super) fn ensure_sqlite_ok(
185    code: libc::c_int,
186    raw_connection: *mut ffi::sqlite3,
187) -> QueryResult<()> {
188    if code == ffi::SQLITE_OK {
189        Ok(())
190    } else {
191        Err(last_error(raw_connection))
192    }
193}
194
195fn last_error(raw_connection: *mut ffi::sqlite3) -> Error {
196    let error_message = last_error_message(raw_connection);
197    let error_information = Box::new(error_message);
198    let error_kind = match last_error_code(raw_connection) {
199        ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY => {
200            DatabaseErrorKind::UniqueViolation
201        }
202        ffi::SQLITE_CONSTRAINT_FOREIGNKEY => DatabaseErrorKind::ForeignKeyViolation,
203        ffi::SQLITE_CONSTRAINT_NOTNULL => DatabaseErrorKind::NotNullViolation,
204        ffi::SQLITE_CONSTRAINT_CHECK => DatabaseErrorKind::CheckViolation,
205        _ => DatabaseErrorKind::Unknown,
206    };
207    DatabaseError(error_kind, error_information)
208}
209
210fn last_error_message(conn: *mut ffi::sqlite3) -> String {
211    let c_str = unsafe { CStr::from_ptr(ffi::sqlite3_errmsg(conn)) };
212    c_str.to_string_lossy().into_owned()
213}
214
215fn last_error_code(conn: *mut ffi::sqlite3) -> libc::c_int {
216    unsafe { ffi::sqlite3_extended_errcode(conn) }
217}
218
219impl Drop for Statement {
220    fn drop(&mut self) {
221        use std::thread::panicking;
222
223        let raw_connection = self.raw_connection();
224        let finalize_result = unsafe { ffi::sqlite3_finalize(self.inner_statement.as_ptr()) };
225        if let Err(e) = ensure_sqlite_ok(finalize_result, raw_connection) {
226            if panicking() {
227                write!(
228                    stderr(),
229                    "Error finalizing SQLite prepared statement: {e:?}"
230                )
231                .expect("Error writing to `stderr`");
232            } else {
233                panic!("Error finalizing SQLite prepared statement: {:?}", e);
234            }
235        }
236    }
237}
238
239// A warning for future editors:
240// Changing this code to something "simpler" may
241// introduce undefined behaviour. Make sure you read
242// the following discussions for details about
243// the current version:
244//
245// * https://github.com/weiznich/diesel/pull/7
246// * https://users.rust-lang.org/t/code-review-for-unsafe-code-in-diesel/66798/
247// * https://github.com/rust-lang/unsafe-code-guidelines/issues/194
248struct BoundStatement<'stmt, 'query> {
249    statement: MaybeCached<'stmt, Statement>,
250    // we need to store the query here to ensure no one does
251    // drop it till the end of the statement
252    // We use a boxed queryfragment here just to erase the
253    // generic type, we use NonNull to communicate
254    // that this is a shared buffer
255    query: Option<NonNull<dyn QueryFragment<Sqlite> + 'query>>,
256    // we need to store any owned bind values separately, as they are not
257    // contained in the query itself. We use NonNull to
258    // communicate that this is a shared buffer
259    binds_to_free: Vec<(i32, Option<NonNull<[u8]>>)>,
260    instrumentation: &'stmt mut dyn Instrumentation,
261    has_error: bool,
262}
263
264impl<'stmt, 'query> BoundStatement<'stmt, 'query> {
265    fn bind<T>(
266        statement: MaybeCached<'stmt, Statement>,
267        query: T,
268        instrumentation: &'stmt mut dyn Instrumentation,
269    ) -> QueryResult<BoundStatement<'stmt, 'query>>
270    where
271        T: QueryFragment<Sqlite> + QueryId + 'query,
272    {
273        // Don't use a trait object here to prevent using a virtual function call
274        // For sqlite this can introduce a measurable overhead
275        // Query is boxed here to make sure it won't move in memory anymore, so any bind
276        // it could output would stay valid.
277        let query = Box::new(query);
278
279        let mut bind_collector = SqliteBindCollector::new();
280        query.collect_binds(&mut bind_collector, &mut (), &Sqlite)?;
281        let SqliteBindCollector { binds } = bind_collector;
282
283        let mut ret = BoundStatement {
284            statement,
285            query: None,
286            binds_to_free: Vec::new(),
287            instrumentation,
288            has_error: false,
289        };
290
291        ret.bind_buffers(binds)?;
292
293        let query = query as Box<dyn QueryFragment<Sqlite> + 'query>;
294        ret.query = NonNull::new(Box::into_raw(query));
295
296        Ok(ret)
297    }
298
299    // This is a separated function so that
300    // not the whole constructor is generic over the query type T.
301    // This hopefully prevents binary bloat.
302    fn bind_buffers(
303        &mut self,
304        binds: Vec<(InternalSqliteBindValue<'_>, SqliteType)>,
305    ) -> QueryResult<()> {
306        // It is useful to preallocate `binds_to_free` because it
307        // - Guarantees that pushing inside it cannot panic, which guarantees the `Drop`
308        //   impl of `BoundStatement` will always re-`bind` as needed
309        // - Avoids reallocations
310        self.binds_to_free.reserve(
311            binds
312                .iter()
313                .filter(|&(b, _)| {
314                    matches!(
315                        b,
316                        InternalSqliteBindValue::BorrowedBinary(_)
317                            | InternalSqliteBindValue::BorrowedString(_)
318                            | InternalSqliteBindValue::String(_)
319                            | InternalSqliteBindValue::Binary(_)
320                    )
321                })
322                .count(),
323        );
324        for (bind_idx, (bind, tpe)) in (1..).zip(binds) {
325            let is_borrowed_bind = matches!(
326                bind,
327                InternalSqliteBindValue::BorrowedString(_)
328                    | InternalSqliteBindValue::BorrowedBinary(_)
329            );
330
331            // It's safe to call bind here as:
332            // * The type and value matches
333            // * We ensure that corresponding buffers lives long enough below
334            // * The statement is not used yet by `step` or anything else
335            let res = unsafe { self.statement.bind(tpe, bind, bind_idx) }?;
336
337            // it's important to push these only after
338            // the call to bind succeeded, otherwise we might attempt to
339            // call bind to an non-existing bind position in
340            // the destructor
341            if let Some(ptr) = res {
342                // Store the id + pointer for a owned bind
343                // as we must unbind and free them on drop
344                self.binds_to_free.push((bind_idx, Some(ptr)));
345            } else if is_borrowed_bind {
346                // Store the id's of borrowed binds to unbind them on drop
347                self.binds_to_free.push((bind_idx, None));
348            }
349        }
350        Ok(())
351    }
352
353    fn finish_query_with_error(mut self, e: &Error) {
354        self.has_error = true;
355        if let Some(q) = self.query {
356            // it's safe to get a reference from this ptr as it's guaranteed to not be null
357            let q = unsafe { q.as_ref() };
358            self.instrumentation.on_connection_event(
359                crate::connection::InstrumentationEvent::FinishQuery {
360                    query: &crate::debug_query(&q),
361                    error: Some(e),
362                },
363            );
364        }
365    }
366}
367
368impl Drop for BoundStatement<'_, '_> {
369    fn drop(&mut self) {
370        // First reset the statement, otherwise the bind calls
371        // below will fails
372        self.statement.reset();
373
374        for (idx, buffer) in std::mem::take(&mut self.binds_to_free) {
375            unsafe {
376                // It's always safe to bind null values, as there is no buffer that needs to outlife something
377                self.statement
378                    .bind(SqliteType::Text, InternalSqliteBindValue::Null, idx)
379                    .expect(
380                        "Binding a null value should never fail. \
381                             If you ever see this error message please open \
382                             an issue at diesels issue tracker containing \
383                             code how to trigger this message.",
384                    );
385            }
386
387            if let Some(buffer) = buffer {
388                unsafe {
389                    // Constructing the `Box` here is safe as we
390                    // got the pointer from a box + it is guaranteed to be not null.
391                    std::mem::drop(Box::from_raw(buffer.as_ptr()));
392                }
393            }
394        }
395
396        if let Some(query) = self.query {
397            let query = unsafe {
398                // Constructing the `Box` here is safe as we
399                // got the pointer from a box + it is guaranteed to be not null.
400                Box::from_raw(query.as_ptr())
401            };
402            if !self.has_error {
403                self.instrumentation.on_connection_event(
404                    crate::connection::InstrumentationEvent::FinishQuery {
405                        query: &crate::debug_query(&query),
406                        error: None,
407                    },
408                );
409            }
410            std::mem::drop(query);
411            self.query = None;
412        }
413    }
414}
415
416#[allow(missing_debug_implementations)]
417pub struct StatementUse<'stmt, 'query> {
418    statement: BoundStatement<'stmt, 'query>,
419    column_names: OnceCell<Vec<*const str>>,
420}
421
422impl<'stmt, 'query> StatementUse<'stmt, 'query> {
423    pub(super) fn bind<T>(
424        statement: MaybeCached<'stmt, Statement>,
425        query: T,
426        instrumentation: &'stmt mut dyn Instrumentation,
427    ) -> QueryResult<StatementUse<'stmt, 'query>>
428    where
429        T: QueryFragment<Sqlite> + QueryId + 'query,
430    {
431        Ok(Self {
432            statement: BoundStatement::bind(statement, query, instrumentation)?,
433            column_names: OnceCell::new(),
434        })
435    }
436
437    pub(super) fn run(mut self) -> QueryResult<()> {
438        let r = unsafe {
439            // This is safe as we pass `first_step = true`
440            // and we consume the statement so nobody could
441            // access the columns later on anyway.
442            self.step(true).map(|_| ())
443        };
444        if let Err(ref e) = r {
445            self.statement.finish_query_with_error(e);
446        }
447        r
448    }
449
450    // This function is marked as unsafe incorrectly passing `false` to `first_step`
451    // for a first call to this function could cause access to freed memory via
452    // the cached column names.
453    //
454    // It's always safe to call this function with `first_step = true` as this removes
455    // the cached column names
456    pub(super) unsafe fn step(&mut self, first_step: bool) -> QueryResult<bool> {
457        let res = match ffi::sqlite3_step(self.statement.statement.inner_statement.as_ptr()) {
458            ffi::SQLITE_DONE => Ok(false),
459            ffi::SQLITE_ROW => Ok(true),
460            _ => Err(last_error(self.statement.statement.raw_connection())),
461        };
462        if first_step {
463            self.column_names = OnceCell::new();
464        }
465        res
466    }
467
468    // The returned string pointer is valid until either the prepared statement is
469    // destroyed by sqlite3_finalize() or until the statement is automatically
470    // reprepared by the first call to sqlite3_step() for a particular run or
471    // until the next call to sqlite3_column_name() or sqlite3_column_name16()
472    // on the same column.
473    //
474    // https://sqlite.org/c3ref/column_name.html
475    //
476    // Note: This function is marked as unsafe, as calling it can invalidate
477    // other existing column name pointers on the same column. To prevent that,
478    // it should maximally be called once per column at all.
479    unsafe fn column_name(&self, idx: i32) -> *const str {
480        let name = {
481            let column_name =
482                ffi::sqlite3_column_name(self.statement.statement.inner_statement.as_ptr(), idx);
483            assert!(
484                !column_name.is_null(),
485                "The Sqlite documentation states that it only returns a \
486                 null pointer here if we are in a OOM condition."
487            );
488            CStr::from_ptr(column_name)
489        };
490        name.to_str().expect(
491            "The Sqlite documentation states that this is UTF8. \
492             If you see this error message something has gone \
493             horribly wrong. Please open an issue at the \
494             diesel repository.",
495        ) as *const str
496    }
497
498    pub(super) fn column_count(&self) -> i32 {
499        unsafe { ffi::sqlite3_column_count(self.statement.statement.inner_statement.as_ptr()) }
500    }
501
502    pub(super) fn index_for_column_name(&mut self, field_name: &str) -> Option<usize> {
503        (0..self.column_count())
504            .find(|idx| self.field_name(*idx) == Some(field_name))
505            .map(|v| {
506                v.try_into()
507                    .expect("Diesel expects to run at least on a 32 bit platform")
508            })
509    }
510
511    pub(super) fn field_name(&self, idx: i32) -> Option<&str> {
512        let column_names = self.column_names.get_or_init(|| {
513            let count = self.column_count();
514            (0..count)
515                .map(|idx| unsafe {
516                    // By initializing the whole vec at once we ensure that
517                    // we really call this only once.
518                    self.column_name(idx)
519                })
520                .collect()
521        });
522
523        column_names
524            .get(usize::try_from(idx).expect("Diesel expects to run at least on a 32 bit platform"))
525            .and_then(|c| unsafe { c.as_ref() })
526    }
527
528    pub(super) fn copy_value(&self, idx: i32) -> Option<OwnedSqliteValue> {
529        OwnedSqliteValue::copy_from_ptr(self.column_value(idx)?)
530    }
531
532    pub(super) fn column_value(&self, idx: i32) -> Option<NonNull<ffi::sqlite3_value>> {
533        let ptr = unsafe {
534            ffi::sqlite3_column_value(self.statement.statement.inner_statement.as_ptr(), idx)
535        };
536        NonNull::new(ptr)
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use crate::prelude::*;
543    use crate::sql_types::Text;
544
545    // this is a regression test for
546    // https://github.com/diesel-rs/diesel/issues/3558
547    #[test]
548    fn check_out_of_bounds_bind_does_not_panic_on_drop() {
549        let mut conn = SqliteConnection::establish(":memory:").unwrap();
550
551        let e = crate::sql_query("SELECT '?'")
552            .bind::<Text, _>("foo")
553            .execute(&mut conn);
554
555        assert!(e.is_err());
556        let e = e.unwrap_err();
557        if let crate::result::Error::DatabaseError(crate::result::DatabaseErrorKind::Unknown, m) = e
558        {
559            assert_eq!(m.message(), "column index out of range");
560        } else {
561            panic!("Wrong error returned");
562        }
563    }
564}