Skip to main content

diesel/sqlite/connection/
hooks.rs

1use super::SqliteConnection;
2use core::num::NonZeroU32;
3
4pub(super) use super::{CommitDecision, ProgressDecision};
5
6impl SqliteConnection {
7    /// Registers a callback invoked when a transaction is about to be
8    /// committed.
9    ///
10    /// The callback returns a [`CommitDecision`]: `Proceed` lets the commit
11    /// complete, `Rollback` converts it into a rollback.
12    ///
13    /// Only one commit hook can be active at a time per connection.
14    /// Registering a new one replaces the previous.
15    ///
16    /// The callback runs synchronously as part of the committing
17    /// `sqlite3_step()` call, on the thread performing the commit, so it is
18    /// never invoked concurrently. Per SQLite, the callback must not use the
19    /// connection that triggered it (running any SQL, including a `SELECT`,
20    /// counts as use) and is not reentrant. A panic in the callback aborts the
21    /// process.
22    ///
23    /// See: [`sqlite3_commit_hook`](https://www.sqlite.org/c3ref/commit_hook.html)
24    ///
25    /// # Example
26    ///
27    /// ```rust
28    /// use diesel::prelude::*;
29    /// use diesel::sqlite::{SqliteConnection, CommitDecision};
30    /// use std::sync::{Arc, Mutex};
31    ///
32    /// diesel::table! {
33    ///     users (id) {
34    ///         id -> Integer,
35    ///         name -> Text,
36    ///     }
37    /// }
38    ///
39    /// # use diesel::connection::SimpleConnection;
40    /// # let conn = &mut SqliteConnection::establish(":memory:").unwrap();
41    /// # conn.batch_execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL)").unwrap();
42    /// let commits = Arc::new(Mutex::new(0u32));
43    /// let commits2 = commits.clone();
44    ///
45    /// conn.on_commit(move || {
46    ///     *commits2.lock().unwrap() += 1;
47    ///     CommitDecision::Proceed
48    /// });
49    ///
50    /// conn.immediate_transaction(|conn| {
51    ///     diesel::insert_into(users::table)
52    ///         .values(users::name.eq("Alice"))
53    ///         .execute(conn)?;
54    ///     Ok::<_, diesel::result::Error>(())
55    /// }).unwrap();
56    ///
57    /// assert_eq!(*commits.lock().unwrap(), 1);
58    /// ```
59    pub fn on_commit<F>(&mut self, hook: F)
60    where
61        F: FnMut() -> CommitDecision + Send + 'static,
62    {
63        self.raw_connection.set_commit_hook(hook);
64    }
65
66    /// Removes the commit hook. Subsequent commits will not invoke any
67    /// callback.
68    ///
69    /// See [`on_commit`](Self::on_commit) for usage example.
70    pub fn remove_commit_hook(&mut self) {
71        self.raw_connection.remove_commit_hook();
72    }
73
74    /// Registers a callback invoked after a transaction is rolled back.
75    ///
76    /// This is **not** invoked for the implicit rollback that occurs when
77    /// the connection is closed. It **is** invoked when a commit hook forces
78    /// a rollback by returning [`CommitDecision::Rollback`].
79    ///
80    /// Only one rollback hook can be active at a time per connection.
81    /// Registering a new one replaces the previous.
82    ///
83    /// The callback must not use the database connection. It is invoked
84    /// synchronously on the thread driving the connection, so it is never
85    /// called concurrently, and like the commit hook it is not reentrant.
86    /// Panics in the callback abort the process.
87    ///
88    /// See: [`sqlite3_rollback_hook`](https://www.sqlite.org/c3ref/commit_hook.html)
89    ///
90    /// # Example
91    ///
92    /// ```rust
93    /// use diesel::prelude::*;
94    /// use diesel::sqlite::SqliteConnection;
95    /// use std::sync::Arc;
96    /// use std::sync::atomic::{AtomicU32, Ordering};
97    ///
98    /// # use diesel::connection::SimpleConnection;
99    /// # let conn = &mut SqliteConnection::establish(":memory:").unwrap();
100    /// # conn.batch_execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL)").unwrap();
101    /// let rollbacks = Arc::new(AtomicU32::new(0));
102    /// let rb2 = rollbacks.clone();
103    ///
104    /// conn.on_rollback(move || {
105    ///     rb2.fetch_add(1, Ordering::Relaxed);
106    /// });
107    ///
108    /// // Force a rollback by returning an error.
109    /// let _ = conn.immediate_transaction(|_conn| {
110    ///     Err::<(), _>(diesel::result::Error::RollbackTransaction)
111    /// });
112    ///
113    /// assert_eq!(rollbacks.load(Ordering::Relaxed), 1);
114    /// ```
115    pub fn on_rollback<F>(&mut self, hook: F)
116    where
117        F: FnMut() + Send + 'static,
118    {
119        self.raw_connection.set_rollback_hook(hook);
120    }
121
122    /// Removes the rollback hook. Subsequent rollbacks will not invoke any
123    /// callback.
124    ///
125    /// See [`on_rollback`](Self::on_rollback) for usage example.
126    pub fn remove_rollback_hook(&mut self) {
127        self.raw_connection.remove_rollback_hook();
128    }
129
130    /// Registers a progress handler that can interrupt long-running queries.
131    ///
132    /// The callback is invoked periodically while a query runs. `n` is the
133    /// approximate number of virtual-machine instructions between callbacks. It
134    /// is a [`NonZeroU32`] so the handler cannot be disabled implicitly by
135    /// passing zero. Use [`remove_progress_handler`](Self::remove_progress_handler)
136    /// to disable it. Since SQLite 3.41.0 the callback may also fire during
137    /// statement preparation.
138    ///
139    /// The callback returns a [`ProgressDecision`]: `Continue` lets the query
140    /// keep executing, `Interrupt` aborts it (causes `SQLITE_INTERRUPT`).
141    ///
142    /// Only one progress handler can be active at a time per connection.
143    /// Registering a new one replaces the previous.
144    ///
145    /// The callback must not use the database connection. It is invoked
146    /// synchronously on the thread driving the connection, so it is never
147    /// called concurrently. Panics in the callback abort the process.
148    ///
149    /// See: [`sqlite3_progress_handler`](https://www.sqlite.org/c3ref/progress_handler.html)
150    ///
151    /// # Example
152    ///
153    /// ```rust
154    /// use diesel::prelude::*;
155    /// use diesel::sqlite::{SqliteConnection, ProgressDecision};
156    /// use std::num::NonZeroU32;
157    /// use std::sync::Arc;
158    /// use std::sync::atomic::{AtomicBool, Ordering};
159    ///
160    /// # let conn = &mut SqliteConnection::establish(":memory:").unwrap();
161    /// let cancelled = Arc::new(AtomicBool::new(false));
162    /// let cancelled2 = cancelled.clone();
163    ///
164    /// conn.on_progress(NonZeroU32::new(1000).unwrap(), move || {
165    ///     if cancelled2.load(Ordering::Relaxed) {
166    ///         ProgressDecision::Interrupt
167    ///     } else {
168    ///         ProgressDecision::Continue
169    ///     }
170    /// });
171    ///
172    /// // Later: remove the handler.
173    /// conn.remove_progress_handler();
174    /// ```
175    pub fn on_progress<F>(&mut self, n: NonZeroU32, hook: F)
176    where
177        F: FnMut() -> ProgressDecision + Send + 'static,
178    {
179        self.raw_connection.set_progress_handler(n, hook);
180    }
181
182    /// Removes the progress handler. Subsequent queries will not invoke any
183    /// callback.
184    ///
185    /// See [`on_progress`](Self::on_progress) for usage example.
186    pub fn remove_progress_handler(&mut self) {
187        self.raw_connection.remove_progress_handler();
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::connection::Connection;
195    use crate::query_dsl::RunQueryDsl;
196    use std::sync::Arc;
197    use std::sync::atomic::{AtomicU32, Ordering};
198
199    fn connection() -> SqliteConnection {
200        SqliteConnection::establish(":memory:").unwrap()
201    }
202
203    #[derive(crate::QueryableByName)]
204    struct CountResult {
205        #[diesel(sql_type = crate::sql_types::BigInt)]
206        c: i64,
207    }
208
209    #[diesel_test_helper::test]
210    fn on_commit_fires_on_commit() {
211        let conn = &mut connection();
212
213        let count = Arc::new(AtomicU32::new(0));
214        let c2 = count.clone();
215
216        conn.on_commit(move || {
217            c2.fetch_add(1, Ordering::Relaxed);
218            CommitDecision::Proceed
219        });
220
221        conn.immediate_transaction(|conn| {
222            crate::sql_query("CREATE TABLE t1 (id INTEGER PRIMARY KEY)")
223                .execute(conn)
224                .unwrap();
225            Ok::<_, crate::result::Error>(())
226        })
227        .unwrap();
228
229        assert_eq!(count.load(Ordering::Relaxed), 1);
230    }
231
232    #[diesel_test_helper::test]
233    fn on_commit_returning_true_forces_rollback() {
234        let conn = &mut connection();
235
236        crate::sql_query("CREATE TABLE t_commit (id INTEGER PRIMARY KEY)")
237            .execute(conn)
238            .unwrap();
239
240        conn.on_commit(|| CommitDecision::Rollback);
241
242        // The transaction will attempt to commit, but the hook will convert
243        // it to a rollback. diesel's AnsiTransactionManager will see the
244        // failure from the COMMIT statement (sqlite returns an error when
245        // the commit hook returns non-zero and the commit is aborted).
246        let result = conn.immediate_transaction(|conn| {
247            crate::sql_query("INSERT INTO t_commit (id) VALUES (1)")
248                .execute(conn)
249                .unwrap();
250            Ok::<_, crate::result::Error>(())
251        });
252
253        // The transaction should have been rolled back.
254        assert!(result.is_err());
255
256        // Remove the hook so subsequent queries don't fail.
257        conn.remove_commit_hook();
258
259        // Verify the row was not persisted.
260        let cnt: i64 = crate::sql_query("SELECT COUNT(*) as c FROM t_commit")
261            .get_result::<CountResult>(conn)
262            .unwrap()
263            .c;
264        assert_eq!(cnt, 0);
265    }
266
267    #[diesel_test_helper::test]
268    fn replacing_commit_hook_drops_old() {
269        let conn = &mut connection();
270
271        let old_count = Arc::new(AtomicU32::new(0));
272        let new_count = Arc::new(AtomicU32::new(0));
273        let oc = old_count.clone();
274        let nc = new_count.clone();
275
276        conn.on_commit(move || {
277            oc.fetch_add(1, Ordering::Relaxed);
278            CommitDecision::Proceed
279        });
280
281        // Replace with a new hook.
282        conn.on_commit(move || {
283            nc.fetch_add(1, Ordering::Relaxed);
284            CommitDecision::Proceed
285        });
286
287        conn.immediate_transaction(|conn| {
288            crate::sql_query("CREATE TABLE t_replace (id INTEGER PRIMARY KEY)")
289                .execute(conn)
290                .unwrap();
291            Ok::<_, crate::result::Error>(())
292        })
293        .unwrap();
294
295        assert_eq!(old_count.load(Ordering::Relaxed), 0);
296        assert_eq!(new_count.load(Ordering::Relaxed), 1);
297    }
298
299    #[diesel_test_helper::test]
300    fn remove_commit_hook_disables_callback() {
301        let conn = &mut connection();
302
303        let count = Arc::new(AtomicU32::new(0));
304        let c2 = count.clone();
305
306        conn.on_commit(move || {
307            c2.fetch_add(1, Ordering::Relaxed);
308            CommitDecision::Proceed
309        });
310
311        conn.remove_commit_hook();
312
313        conn.immediate_transaction(|conn| {
314            crate::sql_query("CREATE TABLE t_rem (id INTEGER PRIMARY KEY)")
315                .execute(conn)
316                .unwrap();
317            Ok::<_, crate::result::Error>(())
318        })
319        .unwrap();
320
321        assert_eq!(count.load(Ordering::Relaxed), 0);
322    }
323
324    #[diesel_test_helper::test]
325    fn on_rollback_fires_on_explicit_rollback() {
326        let conn = &mut connection();
327
328        crate::sql_query("CREATE TABLE t_rb (id INTEGER PRIMARY KEY)")
329            .execute(conn)
330            .unwrap();
331
332        let count = Arc::new(AtomicU32::new(0));
333        let c2 = count.clone();
334
335        conn.on_rollback(move || {
336            c2.fetch_add(1, Ordering::Relaxed);
337        });
338
339        // Force a rollback by returning Err from the transaction closure.
340        let _ = conn.immediate_transaction(|conn| {
341            crate::sql_query("INSERT INTO t_rb (id) VALUES (1)")
342                .execute(conn)
343                .unwrap();
344            Err::<(), _>(crate::result::Error::RollbackTransaction)
345        });
346
347        assert_eq!(count.load(Ordering::Relaxed), 1);
348    }
349
350    #[diesel_test_helper::test]
351    fn on_rollback_fires_when_commit_hook_forces_rollback() {
352        let conn = &mut connection();
353
354        crate::sql_query("CREATE TABLE t_rb2 (id INTEGER PRIMARY KEY)")
355            .execute(conn)
356            .unwrap();
357
358        let rb_count = Arc::new(AtomicU32::new(0));
359        let rb2 = rb_count.clone();
360
361        conn.on_commit(|| CommitDecision::Rollback);
362        conn.on_rollback(move || {
363            rb2.fetch_add(1, Ordering::Relaxed);
364        });
365
366        let _ = conn.immediate_transaction(|conn| {
367            crate::sql_query("INSERT INTO t_rb2 (id) VALUES (1)")
368                .execute(conn)
369                .unwrap();
370            Ok::<_, crate::result::Error>(())
371        });
372
373        // Rollback hook should have fired.
374        assert_eq!(rb_count.load(Ordering::Relaxed), 1);
375
376        conn.remove_commit_hook();
377        conn.remove_rollback_hook();
378
379        // Verify the row was not persisted.
380        let cnt: i64 = crate::sql_query("SELECT COUNT(*) as c FROM t_rb2")
381            .get_result::<CountResult>(conn)
382            .unwrap()
383            .c;
384        assert_eq!(cnt, 0);
385    }
386
387    #[diesel_test_helper::test]
388    fn on_rollback_does_not_fire_on_connection_close() {
389        let count = Arc::new(AtomicU32::new(0));
390        let c2 = count.clone();
391
392        {
393            let conn = &mut connection();
394            conn.on_rollback(move || {
395                c2.fetch_add(1, Ordering::Relaxed);
396            });
397            // conn is dropped here: implicit close, not a rollback.
398        }
399
400        assert_eq!(count.load(Ordering::Relaxed), 0);
401    }
402
403    #[diesel_test_helper::test]
404    fn remove_rollback_hook_disables_callback() {
405        let conn = &mut connection();
406
407        crate::sql_query("CREATE TABLE t_rem_rb (id INTEGER PRIMARY KEY)")
408            .execute(conn)
409            .unwrap();
410
411        let count = Arc::new(AtomicU32::new(0));
412        let c2 = count.clone();
413
414        conn.on_rollback(move || {
415            c2.fetch_add(1, Ordering::Relaxed);
416        });
417
418        conn.remove_rollback_hook();
419
420        let _ = conn.immediate_transaction(|conn| {
421            crate::sql_query("INSERT INTO t_rem_rb (id) VALUES (1)")
422                .execute(conn)
423                .unwrap();
424            Err::<(), _>(crate::result::Error::RollbackTransaction)
425        });
426
427        assert_eq!(count.load(Ordering::Relaxed), 0);
428    }
429
430    // A recursive CTE heavy enough that the progress handler fires while it runs.
431    const HEAVY_QUERY: &str = "WITH RECURSIVE c(x) AS \
432        (SELECT 1 UNION ALL SELECT x + 1 FROM c WHERE x < 100000) SELECT count(*) FROM c";
433
434    #[diesel_test_helper::test]
435    fn on_progress_interrupts_query() {
436        let conn = &mut connection();
437
438        conn.on_progress(NonZeroU32::new(1).unwrap(), || ProgressDecision::Interrupt);
439
440        let result = crate::sql_query(HEAVY_QUERY).execute(conn);
441        assert!(
442            result.is_err(),
443            "the query should be interrupted by the progress handler"
444        );
445    }
446
447    #[diesel_test_helper::test]
448    fn remove_progress_handler_stops_interruption() {
449        let conn = &mut connection();
450
451        conn.on_progress(NonZeroU32::new(1).unwrap(), || ProgressDecision::Interrupt);
452        conn.remove_progress_handler();
453
454        // With the handler removed the same query runs to completion.
455        let result = crate::sql_query(HEAVY_QUERY).execute(conn);
456        assert!(
457            result.is_ok(),
458            "the query should complete after the handler is removed"
459        );
460    }
461}