diesel/connection/
transaction_manager.rs

1use crate::connection::Connection;
2use crate::result::{Error, QueryResult};
3use std::borrow::Cow;
4use std::num::NonZeroU32;
5
6/// Manages the internal transaction state for a connection.
7///
8/// You will not need to interact with this trait, unless you are writing an
9/// implementation of [`Connection`].
10pub trait TransactionManager<Conn: Connection> {
11    /// Data stored as part of the connection implementation
12    /// to track the current transaction state of a connection
13    type TransactionStateData;
14
15    /// Begin a new transaction or savepoint
16    ///
17    /// If the transaction depth is greater than 0,
18    /// this should create a savepoint instead.
19    /// This function is expected to increment the transaction depth by 1.
20    fn begin_transaction(conn: &mut Conn) -> QueryResult<()>;
21
22    /// Rollback the inner-most transaction or savepoint
23    ///
24    /// If the transaction depth is greater than 1,
25    /// this should rollback to the most recent savepoint.
26    /// This function is expected to decrement the transaction depth by 1.
27    fn rollback_transaction(conn: &mut Conn) -> QueryResult<()>;
28
29    /// Commit the inner-most transaction or savepoint
30    ///
31    /// If the transaction depth is greater than 1,
32    /// this should release the most recent savepoint.
33    /// This function is expected to decrement the transaction depth by 1.
34    fn commit_transaction(conn: &mut Conn) -> QueryResult<()>;
35
36    /// Fetch the current transaction status as mutable
37    ///
38    /// Used to ensure that `begin_test_transaction` is not called when already
39    /// inside of a transaction, and that operations are not run in a `InError`
40    /// transaction manager.
41    #[diesel_derives::__diesel_public_if(
42        feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"
43    )]
44    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
45
46    /// Executes the given function inside of a database transaction
47    ///
48    /// Each implementation of this function needs to fulfill the documented
49    /// behaviour of [`Connection::transaction`]
50    fn transaction<F, R, E>(conn: &mut Conn, callback: F) -> Result<R, E>
51    where
52        F: FnOnce(&mut Conn) -> Result<R, E>,
53        E: From<Error>,
54    {
55        Self::begin_transaction(conn)?;
56        match callback(&mut *conn) {
57            Ok(value) => {
58                Self::commit_transaction(conn)?;
59                Ok(value)
60            }
61            Err(user_error) => match Self::rollback_transaction(conn) {
62                Ok(()) => Err(user_error),
63                Err(Error::BrokenTransactionManager) => {
64                    // In this case we are probably more interested by the
65                    // original error, which likely caused this
66                    Err(user_error)
67                }
68                Err(rollback_error) => Err(rollback_error.into()),
69            },
70        }
71    }
72
73    /// This methods checks if the connection manager is considered to be broken
74    /// by connection pool implementations
75    ///
76    /// A connection manager is considered to be broken by default if it either
77    /// contains an open transaction (because you don't want to have connections
78    /// with open transactions in your pool) or when the transaction manager is
79    /// in an error state.
80    #[diesel_derives::__diesel_public_if(
81        feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"
82    )]
83    fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
84        match Self::transaction_manager_status_mut(conn).transaction_state() {
85            // all transactions are closed
86            // so we don't consider this connection broken
87            Ok(ValidTransactionManagerStatus {
88                in_transaction: None,
89            }) => false,
90            // The transaction manager is in an error state
91            // Therefore we consider this connection broken
92            Err(_) => true,
93            // The transaction manager contains a open transaction
94            // we do consider this connection broken
95            // if that transaction was not opened by `begin_test_transaction`
96            Ok(ValidTransactionManagerStatus {
97                in_transaction: Some(s),
98            }) => !s.test_transaction,
99        }
100    }
101}
102
103/// An implementation of `TransactionManager` which can be used for backends
104/// which use ANSI standard syntax for savepoints such as SQLite and PostgreSQL.
105#[derive(Default, Debug)]
106pub struct AnsiTransactionManager {
107    pub(crate) status: TransactionManagerStatus,
108}
109
110/// Status of the transaction manager
111#[diesel_derives::__diesel_public_if(
112    feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"
113)]
114#[derive(Debug)]
115pub enum TransactionManagerStatus {
116    /// Valid status, the manager can run operations
117    Valid(ValidTransactionManagerStatus),
118    /// Error status, probably following a broken connection. The manager will no longer run operations
119    InError,
120}
121
122impl Default for TransactionManagerStatus {
123    fn default() -> Self {
124        TransactionManagerStatus::Valid(ValidTransactionManagerStatus::default())
125    }
126}
127
128impl TransactionManagerStatus {
129    /// Returns the transaction depth if the transaction manager's status is valid, or returns
130    /// [`Error::BrokenTransactionManager`] if the transaction manager is in error.
131    pub fn transaction_depth(&self) -> QueryResult<Option<NonZeroU32>> {
132        match self {
133            TransactionManagerStatus::Valid(valid_status) => Ok(valid_status.transaction_depth()),
134            TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
135        }
136    }
137
138    #[cfg(any(
139        feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes",
140        feature = "postgres",
141        feature = "mysql",
142        test
143    ))]
144    #[diesel_derives::__diesel_public_if(
145        feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"
146    )]
147    /// If in transaction and transaction manager is not broken, registers that it's possible that
148    /// the connection can not be used anymore until top-level transaction is rolled back.
149    ///
150    /// If that is registered, savepoints rollbacks will still be attempted, but failure to do so
151    /// will not result in an error. (Some may succeed, some may not.)
152    pub(crate) fn set_requires_rollback_maybe_up_to_top_level(&mut self, to: bool) {
153        if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
154            in_transaction:
155                Some(InTransactionStatus {
156                    requires_rollback_maybe_up_to_top_level,
157                    ..
158                }),
159        }) = self
160        {
161            *requires_rollback_maybe_up_to_top_level = to;
162        }
163    }
164
165    /// Sets the transaction manager status to InError
166    ///
167    /// Subsequent attempts to use transaction-related features will result in a
168    /// [`Error::BrokenTransactionManager`] error
169    pub fn set_in_error(&mut self) {
170        *self = TransactionManagerStatus::InError
171    }
172
173    /// Expose access to the inner transaction state
174    ///
175    /// This function returns an error if the Transaction manager is in a broken
176    /// state
177    #[diesel_derives::__diesel_public_if(
178        feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"
179    )]
180    pub(self) fn transaction_state(&mut self) -> QueryResult<&mut ValidTransactionManagerStatus> {
181        match self {
182            TransactionManagerStatus::Valid(valid_status) => Ok(valid_status),
183            TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
184        }
185    }
186
187    /// This function allows to flag a transaction manager
188    /// in such a way that it contains a test transaction.
189    ///
190    /// This will disable some checks in regards to open transactions
191    /// to allow `Connection::begin_test_transaction` to work with
192    /// pooled connections as well
193    #[diesel_derives::__diesel_public_if(
194        feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"
195    )]
196    pub(crate) fn set_test_transaction_flag(&mut self) {
197        if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
198            in_transaction: Some(s),
199        }) = self
200        {
201            s.test_transaction = true;
202        }
203    }
204}
205
206/// Valid transaction status for the manager. Can return the current transaction depth
207#[allow(missing_copy_implementations)]
208#[derive(Debug, Default)]
209#[diesel_derives::__diesel_public_if(
210    feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes",
211    public_fields(in_transaction)
212)]
213pub struct ValidTransactionManagerStatus {
214    /// Inner status, or `None` if no transaction is running
215    in_transaction: Option<InTransactionStatus>,
216}
217
218/// Various status fields to track the status of
219/// a transaction manager with a started transaction
220#[allow(missing_copy_implementations)]
221#[derive(Debug)]
222#[diesel_derives::__diesel_public_if(
223    feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes",
224    public_fields(
225        test_transaction,
226        transaction_depth,
227        requires_rollback_maybe_up_to_top_level
228    )
229)]
230pub struct InTransactionStatus {
231    /// The current depth of nested transactions
232    transaction_depth: NonZeroU32,
233    /// If that is registered, savepoints rollbacks will still be attempted, but failure to do so
234    /// will not result in an error. (Some may succeed, some may not.)
235    requires_rollback_maybe_up_to_top_level: bool,
236    /// Is this transaction manager status marked as test-transaction?
237    test_transaction: bool,
238}
239
240impl ValidTransactionManagerStatus {
241    /// Return the current transaction depth
242    ///
243    /// This value is `None` if no current transaction is running
244    /// otherwise the number of nested transactions is returned.
245    pub fn transaction_depth(&self) -> Option<NonZeroU32> {
246        self.in_transaction.as_ref().map(|it| it.transaction_depth)
247    }
248
249    /// Update the transaction depth by adding the value of the `transaction_depth_change` parameter if the `query` is
250    /// `Ok(())`
251    pub fn change_transaction_depth(
252        &mut self,
253        transaction_depth_change: TransactionDepthChange,
254    ) -> QueryResult<()> {
255        match (&mut self.in_transaction, transaction_depth_change) {
256            (Some(in_transaction), TransactionDepthChange::IncreaseDepth) => {
257                // Can be replaced with saturating_add directly on NonZeroU32 once
258                // <https://github.com/rust-lang/rust/issues/84186> is stable
259                in_transaction.transaction_depth =
260                    NonZeroU32::new(in_transaction.transaction_depth.get().saturating_add(1))
261                        .expect("nz + nz is always non-zero");
262                Ok(())
263            }
264            (Some(in_transaction), TransactionDepthChange::DecreaseDepth) => {
265                // This sets `transaction_depth` to `None` as soon as we reach zero
266                match NonZeroU32::new(in_transaction.transaction_depth.get() - 1) {
267                    Some(depth) => in_transaction.transaction_depth = depth,
268                    None => self.in_transaction = None,
269                }
270                Ok(())
271            }
272            (None, TransactionDepthChange::IncreaseDepth) => {
273                self.in_transaction = Some(InTransactionStatus {
274                    transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"),
275                    requires_rollback_maybe_up_to_top_level: false,
276                    test_transaction: false,
277                });
278                Ok(())
279            }
280            (None, TransactionDepthChange::DecreaseDepth) => {
281                // We screwed up something somewhere
282                // we cannot decrease the transaction count if
283                // we are not inside a transaction
284                Err(Error::NotInTransaction)
285            }
286        }
287    }
288}
289
290/// Represents a change to apply to the depth of a transaction
291#[derive(Debug, Clone, Copy)]
292pub enum TransactionDepthChange {
293    /// Increase the depth of the transaction (corresponds to `BEGIN` or `SAVEPOINT`)
294    IncreaseDepth,
295    /// Decreases the depth of the transaction (corresponds to `COMMIT`/`RELEASE SAVEPOINT` or `ROLLBACK`)
296    DecreaseDepth,
297}
298
299impl AnsiTransactionManager {
300    fn get_transaction_state<Conn>(
301        conn: &mut Conn,
302    ) -> QueryResult<&mut ValidTransactionManagerStatus>
303    where
304        Conn: Connection<TransactionManager = Self>,
305    {
306        conn.transaction_state().status.transaction_state()
307    }
308
309    /// Begin a transaction with custom SQL
310    ///
311    /// This is used by connections to implement more complex transaction APIs
312    /// to set things such as isolation levels.
313    /// Returns an error if already inside of a transaction.
314    pub fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
315    where
316        Conn: Connection<TransactionManager = Self>,
317    {
318        let state = Self::get_transaction_state(conn)?;
319        if let Some(_depth) = state.transaction_depth() {
320            return Err(Error::AlreadyInTransaction);
321        }
322        let instrumentation_depth = NonZeroU32::new(1);
323        // Keep remainder of this method in sync with `begin_transaction()`.
324
325        conn.instrumentation().on_connection_event(
326            super::instrumentation::InstrumentationEvent::BeginTransaction {
327                depth: instrumentation_depth.expect("We know that 1 is not zero"),
328            },
329        );
330        conn.batch_execute(sql)?;
331        Self::get_transaction_state(conn)?
332            .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
333
334        Ok(())
335    }
336}
337
338impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
339where
340    Conn: Connection<TransactionManager = Self>,
341{
342    type TransactionStateData = Self;
343
344    fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
345        let transaction_state = Self::get_transaction_state(conn)?;
346        let transaction_depth = transaction_state.transaction_depth();
347        let start_transaction_sql = match transaction_depth {
348            None => Cow::from("BEGIN"),
349            Some(transaction_depth) => {
350                Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
351            }
352        };
353        let instrumentation_depth =
354            NonZeroU32::new(transaction_depth.map_or(0, NonZeroU32::get).wrapping_add(1));
355        let sql = &start_transaction_sql;
356        // Keep remainder of this method in sync with `begin_transaction_sql()`.
357
358        conn.instrumentation().on_connection_event(
359            super::instrumentation::InstrumentationEvent::BeginTransaction {
360                depth: instrumentation_depth.expect("Transaction depth is too large"),
361            },
362        );
363        conn.batch_execute(sql)?;
364        Self::get_transaction_state(conn)?
365            .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
366
367        Ok(())
368    }
369
370    fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
371        let transaction_state = Self::get_transaction_state(conn)?;
372
373        let (
374            (rollback_sql, rolling_back_top_level),
375            requires_rollback_maybe_up_to_top_level_before_execute,
376        ) = match transaction_state.in_transaction {
377            Some(ref in_transaction) => (
378                match in_transaction.transaction_depth.get() {
379                    1 => (Cow::Borrowed("ROLLBACK"), true),
380                    depth_gt1 => (
381                        Cow::Owned(format!(
382                            "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
383                            depth_gt1 - 1
384                        )),
385                        false,
386                    ),
387                },
388                in_transaction.requires_rollback_maybe_up_to_top_level,
389            ),
390            None => return Err(Error::NotInTransaction),
391        };
392        let depth = transaction_state
393            .transaction_depth()
394            .expect("We know that we are in a transaction here");
395        conn.instrumentation().on_connection_event(
396            super::instrumentation::InstrumentationEvent::RollbackTransaction { depth },
397        );
398
399        match conn.batch_execute(&rollback_sql) {
400            Ok(()) => {
401                match Self::get_transaction_state(conn)?
402                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
403                {
404                    Ok(()) => {}
405                    Err(Error::NotInTransaction) if rolling_back_top_level => {
406                        // Transaction exit may have already been detected by connection
407                        // implementation. It's fine.
408                    }
409                    Err(e) => return Err(e),
410                }
411                Ok(())
412            }
413            Err(rollback_error) => {
414                let tm_status = Self::transaction_manager_status_mut(conn);
415                match tm_status {
416                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
417                        in_transaction:
418                            Some(InTransactionStatus {
419                                transaction_depth,
420                                requires_rollback_maybe_up_to_top_level,
421                                ..
422                            }),
423                    }) if transaction_depth.get() > 1 => {
424                        // A savepoint failed to rollback - we may still attempt to repair
425                        // the connection by rolling back higher levels.
426
427                        // To make it easier on the user (that they don't have to really
428                        // look at actual transaction depth and can just rely on the number
429                        // of times they have called begin/commit/rollback) we still
430                        // decrement here:
431                        *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
432                            .expect("Depth was checked to be > 1");
433                        *requires_rollback_maybe_up_to_top_level = true;
434                        if requires_rollback_maybe_up_to_top_level_before_execute {
435                            // In that case, we tolerate that savepoint releases fail
436                            // -> we should ignore errors
437                            return Ok(());
438                        }
439                    }
440                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
441                        in_transaction: None,
442                    }) => {
443                        // we would have returned `NotInTransaction` if that was already the state
444                        // before we made our call
445                        // => Transaction manager status has been fixed by the underlying connection
446                        // so we don't need to set_in_error
447                    }
448                    _ => tm_status.set_in_error(),
449                }
450                Err(rollback_error)
451            }
452        }
453    }
454
455    /// If the transaction fails to commit due to a `SerializationFailure` or a
456    /// `ReadOnlyTransaction` a rollback will be attempted. If the rollback succeeds,
457    /// the original error will be returned, otherwise the error generated by the rollback
458    /// will be returned. In the second case the connection will be considered broken
459    /// as it contains a uncommitted unabortable open transaction.
460    fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
461        let transaction_state = Self::get_transaction_state(conn)?;
462        let transaction_depth = transaction_state.transaction_depth();
463        let (commit_sql, committing_top_level) = match transaction_depth {
464            None => return Err(Error::NotInTransaction),
465            Some(transaction_depth) if transaction_depth.get() == 1 => {
466                (Cow::Borrowed("COMMIT"), true)
467            }
468            Some(transaction_depth) => (
469                Cow::Owned(format!(
470                    "RELEASE SAVEPOINT diesel_savepoint_{}",
471                    transaction_depth.get() - 1
472                )),
473                false,
474            ),
475        };
476        let depth = transaction_state
477            .transaction_depth()
478            .expect("We know that we are in a transaction here");
479        conn.instrumentation().on_connection_event(
480            super::instrumentation::InstrumentationEvent::CommitTransaction { depth },
481        );
482        match conn.batch_execute(&commit_sql) {
483            Ok(()) => {
484                match Self::get_transaction_state(conn)?
485                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
486                {
487                    Ok(()) => {}
488                    Err(Error::NotInTransaction) if committing_top_level => {
489                        // Transaction exit may have already been detected by connection.
490                        // It's fine
491                    }
492                    Err(e) => return Err(e),
493                }
494                Ok(())
495            }
496            Err(commit_error) => {
497                if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
498                    in_transaction:
499                        Some(InTransactionStatus {
500                            requires_rollback_maybe_up_to_top_level: true,
501                            ..
502                        }),
503                }) = conn.transaction_state().status
504                {
505                    match Self::rollback_transaction(conn) {
506                        Ok(()) => {}
507                        Err(rollback_error) => {
508                            conn.transaction_state().status.set_in_error();
509                            return Err(Error::RollbackErrorOnCommit {
510                                rollback_error: Box::new(rollback_error),
511                                commit_error: Box::new(commit_error),
512                            });
513                        }
514                    }
515                }
516                Err(commit_error)
517            }
518        }
519    }
520
521    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
522        &mut conn.transaction_state().status
523    }
524}
525
526#[cfg(test)]
527// that's a false positive for `panic!`/`assert!` on rust 2018
528#[allow(clippy::uninlined_format_args)]
529mod test {
530    // Mock connection.
531    mod mock {
532        use crate::connection::transaction_manager::AnsiTransactionManager;
533        use crate::connection::Instrumentation;
534        use crate::connection::{
535            Connection, ConnectionSealed, SimpleConnection, TransactionManager,
536        };
537        use crate::result::QueryResult;
538        use crate::test_helpers::TestConnection;
539        use std::collections::VecDeque;
540
541        pub(crate) struct MockConnection {
542            pub(crate) next_results: VecDeque<QueryResult<usize>>,
543            pub(crate) next_batch_execute_results: VecDeque<QueryResult<()>>,
544            pub(crate) top_level_requires_rollback_after_next_batch_execute: bool,
545            transaction_state: AnsiTransactionManager,
546            instrumentation: Option<Box<dyn Instrumentation>>,
547        }
548
549        impl SimpleConnection for MockConnection {
550            fn batch_execute(&mut self, _query: &str) -> QueryResult<()> {
551                let res = self
552                    .next_batch_execute_results
553                    .pop_front()
554                    .expect("No next result");
555                if self.top_level_requires_rollback_after_next_batch_execute {
556                    self.transaction_state
557                        .status
558                        .set_requires_rollback_maybe_up_to_top_level(true);
559                }
560                res
561            }
562        }
563
564        impl ConnectionSealed for MockConnection {}
565
566        impl Connection for MockConnection {
567            type Backend = <TestConnection as Connection>::Backend;
568
569            type TransactionManager = AnsiTransactionManager;
570
571            fn establish(_database_url: &str) -> crate::ConnectionResult<Self> {
572                Ok(Self {
573                    next_results: VecDeque::new(),
574                    next_batch_execute_results: VecDeque::new(),
575                    top_level_requires_rollback_after_next_batch_execute: false,
576                    transaction_state: AnsiTransactionManager::default(),
577                    instrumentation: None,
578                })
579            }
580
581            fn execute_returning_count<T>(&mut self, _source: &T) -> QueryResult<usize>
582            where
583                T: crate::query_builder::QueryFragment<Self::Backend>
584                    + crate::query_builder::QueryId,
585            {
586                self.next_results.pop_front().expect("No next result")
587            }
588
589            fn transaction_state(
590                &mut self,
591            ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData
592            {
593                &mut self.transaction_state
594            }
595
596            fn instrumentation(&mut self) -> &mut dyn crate::connection::Instrumentation {
597                &mut self.instrumentation
598            }
599
600            fn set_instrumentation(
601                &mut self,
602                instrumentation: impl crate::connection::Instrumentation,
603            ) {
604                self.instrumentation = Some(Box::new(instrumentation));
605            }
606        }
607    }
608
609    #[test]
610    #[cfg(feature = "postgres")]
611    fn transaction_manager_returns_an_error_when_attempting_to_commit_outside_of_a_transaction() {
612        use crate::connection::transaction_manager::AnsiTransactionManager;
613        use crate::connection::transaction_manager::TransactionManager;
614        use crate::result::Error;
615        use crate::PgConnection;
616
617        let conn = &mut crate::test_helpers::pg_connection_no_transaction();
618        assert_eq!(
619            None,
620            <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
621                conn
622            ).transaction_depth().expect("Transaction depth")
623        );
624        let result = AnsiTransactionManager::commit_transaction(conn);
625        assert!(matches!(result, Err(Error::NotInTransaction)))
626    }
627
628    #[test]
629    #[cfg(feature = "postgres")]
630    fn transaction_manager_returns_an_error_when_attempting_to_rollback_outside_of_a_transaction() {
631        use crate::connection::transaction_manager::AnsiTransactionManager;
632        use crate::connection::transaction_manager::TransactionManager;
633        use crate::result::Error;
634        use crate::PgConnection;
635
636        let conn = &mut crate::test_helpers::pg_connection_no_transaction();
637        assert_eq!(
638            None,
639            <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
640                conn
641            ).transaction_depth().expect("Transaction depth")
642        );
643        let result = AnsiTransactionManager::rollback_transaction(conn);
644        assert!(matches!(result, Err(Error::NotInTransaction)))
645    }
646
647    #[test]
648    fn transaction_manager_enters_broken_state_when_connection_is_broken() {
649        use crate::connection::transaction_manager::AnsiTransactionManager;
650        use crate::connection::transaction_manager::TransactionManager;
651        use crate::connection::TransactionManagerStatus;
652        use crate::result::{DatabaseErrorKind, Error};
653        use crate::*;
654
655        let mut conn = mock::MockConnection::establish("mock").expect("Mock connection");
656
657        // Set result for BEGIN
658        conn.next_batch_execute_results.push_back(Ok(()));
659        let result = conn.transaction(|conn| {
660            conn.next_results.push_back(Ok(1));
661            let query_result = sql_query("SELECT 1").execute(conn);
662            assert!(query_result.is_ok());
663            // Set result for COMMIT attempt
664            conn.next_batch_execute_results
665                .push_back(Err(Error::DatabaseError(
666                    DatabaseErrorKind::Unknown,
667                    Box::new("commit fails".to_string()),
668                )));
669            conn.top_level_requires_rollback_after_next_batch_execute = true;
670            conn.next_batch_execute_results
671                .push_back(Err(Error::DatabaseError(
672                    DatabaseErrorKind::Unknown,
673                    Box::new("rollback also fails".to_string()),
674                )));
675            Ok(())
676        });
677        assert!(
678            matches!(
679                &result,
680                Err(Error::RollbackErrorOnCommit {
681                    rollback_error,
682                    commit_error
683                }) if matches!(**commit_error, Error::DatabaseError(DatabaseErrorKind::Unknown, _))
684                    && matches!(&**rollback_error,
685                        Error::DatabaseError(DatabaseErrorKind::Unknown, msg)
686                            if msg.message() == "rollback also fails"
687                    )
688            ),
689            "Got {:?}",
690            result
691        );
692        assert!(matches!(
693            *AnsiTransactionManager::transaction_manager_status_mut(&mut conn),
694            TransactionManagerStatus::InError
695        ));
696        // Ensure the transaction manager is unusable
697        let result = conn.transaction(|_conn| Ok(()));
698        assert!(matches!(result, Err(Error::BrokenTransactionManager)))
699    }
700
701    #[test]
702    #[cfg(feature = "mysql")]
703    fn mysql_transaction_is_rolled_back_upon_syntax_error() {
704        use crate::connection::transaction_manager::AnsiTransactionManager;
705        use crate::connection::transaction_manager::TransactionManager;
706        use crate::*;
707        use std::num::NonZeroU32;
708
709        let conn = &mut crate::test_helpers::connection_no_transaction();
710        assert_eq!(
711            None,
712            <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(
713                conn
714            ).transaction_depth().expect("Transaction depth")
715        );
716        let _result = conn.transaction(|conn| {
717            assert_eq!(
718                NonZeroU32::new(1),
719                <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(
720                    conn
721            ).transaction_depth().expect("Transaction depth")
722            );
723            // In MySQL, a syntax error does not break the transaction block
724            let query_result = sql_query("SELECT_SYNTAX_ERROR 1").execute(conn);
725            assert!(query_result.is_err());
726            query_result
727        });
728        assert_eq!(
729            None,
730            <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(
731                conn
732            ).transaction_depth().expect("Transaction depth")
733        );
734    }
735
736    #[test]
737    #[cfg(feature = "sqlite")]
738    fn sqlite_transaction_is_rolled_back_upon_syntax_error() {
739        use crate::connection::transaction_manager::AnsiTransactionManager;
740        use crate::connection::transaction_manager::TransactionManager;
741        use crate::*;
742        use std::num::NonZeroU32;
743
744        let conn = &mut crate::test_helpers::connection();
745        assert_eq!(
746            None,
747            <AnsiTransactionManager as TransactionManager<SqliteConnection>>::transaction_manager_status_mut(
748                conn
749            ).transaction_depth().expect("Transaction depth")
750        );
751        let _result = conn.transaction(|conn| {
752            assert_eq!(
753                NonZeroU32::new(1),
754                <AnsiTransactionManager as TransactionManager<SqliteConnection>>::transaction_manager_status_mut(
755                    conn
756            ).transaction_depth().expect("Transaction depth")
757            );
758            // In Sqlite, a syntax error does not break the transaction block
759            let query_result = sql_query("SELECT_SYNTAX_ERROR 1").execute(conn);
760            assert!(query_result.is_err());
761            query_result
762        });
763        assert_eq!(
764            None,
765            <AnsiTransactionManager as TransactionManager<SqliteConnection>>::transaction_manager_status_mut(
766                conn
767            ).transaction_depth().expect("Transaction depth")
768        );
769    }
770
771    #[test]
772    #[cfg(feature = "mysql")]
773    fn nested_mysql_transaction_is_rolled_back_upon_syntax_error() {
774        use crate::connection::transaction_manager::AnsiTransactionManager;
775        use crate::connection::transaction_manager::TransactionManager;
776        use crate::*;
777        use std::num::NonZeroU32;
778
779        let conn = &mut crate::test_helpers::connection_no_transaction();
780        assert_eq!(
781            None,
782            <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(
783                conn
784            ).transaction_depth().expect("Transaction depth")
785        );
786        let result = conn.transaction(|conn| {
787            assert_eq!(
788                NonZeroU32::new(1),
789                <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(
790                    conn
791            ).transaction_depth().expect("Transaction depth")
792            );
793            let result = conn.transaction(|conn| {
794                assert_eq!(
795                    NonZeroU32::new(2),
796                    <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(
797                        conn
798            ).transaction_depth().expect("Transaction depth")
799                );
800                // In MySQL, a syntax error does not break the transaction block
801                sql_query("SELECT_SYNTAX_ERROR 1").execute(conn)
802            });
803            assert!(result.is_err());
804            assert_eq!(
805                NonZeroU32::new(1),
806                <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(
807                    conn
808            ).transaction_depth().expect("Transaction depth")
809            );
810            let query_result = sql_query("SELECT 1").execute(conn);
811            assert!(query_result.is_ok());
812            query_result
813        });
814        assert!(result.is_ok());
815        assert_eq!(
816            None,
817            <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(
818                conn
819            ).transaction_depth().expect("Transaction depth")
820        );
821    }
822
823    #[test]
824    #[cfg(feature = "mysql")]
825    // This function uses a collect with side effects (spawning threads)
826    // so clippy is wrong here
827    #[allow(clippy::needless_collect)]
828    fn mysql_transaction_depth_commits_tracked_properly_on_serialization_failure() {
829        use crate::result::DatabaseErrorKind::SerializationFailure;
830        use crate::result::Error::DatabaseError;
831        use crate::*;
832        use std::num::NonZeroU32;
833        use std::sync::{Arc, Barrier};
834        use std::thread;
835
836        table! {
837            #[sql_name = "mysql_transaction_depth_is_tracked_properly_on_commit_failure"]
838            serialization_example {
839                id -> Integer,
840                class -> Integer,
841            }
842        }
843
844        let conn = &mut crate::test_helpers::connection_no_transaction();
845
846        sql_query(
847            "DROP TABLE IF EXISTS mysql_transaction_depth_is_tracked_properly_on_commit_failure;",
848        )
849        .execute(conn)
850        .unwrap();
851        sql_query(
852            r#"
853            CREATE TABLE mysql_transaction_depth_is_tracked_properly_on_commit_failure (
854                id INT AUTO_INCREMENT PRIMARY KEY,
855                class INTEGER NOT NULL
856            )
857        "#,
858        )
859        .execute(conn)
860        .unwrap();
861
862        insert_into(serialization_example::table)
863            .values(&vec![
864                serialization_example::class.eq(1),
865                serialization_example::class.eq(2),
866            ])
867            .execute(conn)
868            .unwrap();
869
870        let before_barrier = Arc::new(Barrier::new(2));
871        let after_barrier = Arc::new(Barrier::new(2));
872
873        let threads = (1..3)
874            .map(|i| {
875                let before_barrier = before_barrier.clone();
876                let after_barrier = after_barrier.clone();
877                thread::spawn(move || {
878                    use crate::connection::transaction_manager::AnsiTransactionManager;
879                    use crate::connection::transaction_manager::TransactionManager;
880                    let conn = &mut crate::test_helpers::connection_no_transaction();
881                    assert_eq!(None, <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(conn).transaction_depth().expect("Transaction depth"));
882                    crate::sql_query("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE").execute(conn)?;
883
884                    let result =
885                    conn.transaction(|conn| {
886                        assert_eq!(NonZeroU32::new(1), <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(conn).transaction_depth().expect("Transaction depth"));
887                        let _ = serialization_example::table
888                            .filter(serialization_example::class.eq(i))
889                            .count()
890                            .execute(conn)?;
891
892                        let other_i = if i == 1 { 2 } else { 1 };
893                        let q = insert_into(serialization_example::table)
894                            .values(serialization_example::class.eq(other_i));
895                        before_barrier.wait();
896
897                        let r = q.execute(conn);
898                        after_barrier.wait();
899                        r
900                    });
901
902                    assert_eq!(None, <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(conn).transaction_depth().expect("Transaction depth"));
903
904                    let second_trans_result = conn.transaction(|conn| crate::sql_query("SELECT 1").execute(conn));
905                    assert!(second_trans_result.is_ok(), "Expected the thread connections to have been rolled back or committed, but second transaction exited with {:?}", second_trans_result);
906                    result
907                })
908            })
909            .collect::<Vec<_>>();
910        let second_trans_result =
911            conn.transaction(|conn| crate::sql_query("SELECT 1").execute(conn));
912        assert!(second_trans_result.is_ok(), "Expected the main connection to have been rolled back or committed, but second transaction exited with {:?}", second_trans_result);
913
914        let mut results = threads
915            .into_iter()
916            .map(|t| t.join().unwrap())
917            .collect::<Vec<_>>();
918
919        results.sort_by_key(|r| r.is_err());
920        assert!(results[0].is_ok(), "Got {:?} instead", results);
921        // Note that contrary to Postgres, this is not a commit failure
922        assert!(
923            matches!(&results[1], Err(DatabaseError(SerializationFailure, _))),
924            "Got {:?} instead",
925            results
926        );
927    }
928
929    #[test]
930    #[cfg(feature = "mysql")]
931    // This function uses a collect with side effects (spawning threads)
932    // so clippy is wrong here
933    #[allow(clippy::needless_collect)]
934    fn mysql_nested_transaction_depth_commits_tracked_properly_on_serialization_failure() {
935        use crate::result::DatabaseErrorKind::SerializationFailure;
936        use crate::result::Error::DatabaseError;
937        use crate::*;
938        use std::num::NonZeroU32;
939        use std::sync::{Arc, Barrier};
940        use std::thread;
941
942        table! {
943            #[sql_name = "mysql_nested_trans_depth_is_tracked_properly_on_commit_failure"]
944            serialization_example {
945                id -> Integer,
946                class -> Integer,
947            }
948        }
949
950        let conn = &mut crate::test_helpers::connection_no_transaction();
951
952        sql_query(
953            "DROP TABLE IF EXISTS mysql_nested_trans_depth_is_tracked_properly_on_commit_failure;",
954        )
955        .execute(conn)
956        .unwrap();
957        sql_query(
958            r#"
959            CREATE TABLE mysql_nested_trans_depth_is_tracked_properly_on_commit_failure (
960                id INT AUTO_INCREMENT PRIMARY KEY,
961                class INTEGER NOT NULL
962            )
963        "#,
964        )
965        .execute(conn)
966        .unwrap();
967
968        insert_into(serialization_example::table)
969            .values(&vec![
970                serialization_example::class.eq(1),
971                serialization_example::class.eq(2),
972            ])
973            .execute(conn)
974            .unwrap();
975
976        let before_barrier = Arc::new(Barrier::new(2));
977        let after_barrier = Arc::new(Barrier::new(2));
978
979        let threads = (1..3)
980            .map(|i| {
981                let before_barrier = before_barrier.clone();
982                let after_barrier = after_barrier.clone();
983                thread::spawn(move || {
984                    use crate::connection::transaction_manager::AnsiTransactionManager;
985                    use crate::connection::transaction_manager::TransactionManager;
986                    let conn = &mut crate::test_helpers::connection_no_transaction();
987                    assert_eq!(None, <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(conn).transaction_depth().expect("Transaction depth"));
988                    crate::sql_query("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE").execute(conn)?;
989
990                    let result =
991                    conn.transaction(|conn| {
992                        assert_eq!(NonZeroU32::new(1), <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(conn).transaction_depth().expect("Transaction depth"));
993                       conn.transaction(|conn| {
994                            assert_eq!(NonZeroU32::new(2), <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(conn).transaction_depth().expect("Transaction depth"));
995                            let _ = serialization_example::table
996                                .filter(serialization_example::class.eq(i))
997                                .count()
998                                .execute(conn)?;
999
1000                            let other_i = if i == 1 { 2 } else { 1 };
1001                            let q = insert_into(serialization_example::table)
1002                                .values(serialization_example::class.eq(other_i));
1003                            before_barrier.wait();
1004
1005                            let r = q.execute(conn);
1006                            after_barrier.wait();
1007                            r
1008                        })
1009                    });
1010
1011                    assert_eq!(None, <AnsiTransactionManager as TransactionManager<MysqlConnection>>::transaction_manager_status_mut(conn).transaction_depth().expect("Transaction depth"));
1012
1013                    let second_trans_result = conn.transaction(|conn| crate::sql_query("SELECT 1").execute(conn));
1014                    assert!(second_trans_result.is_ok(), "Expected the thread connections to have been rolled back or committed, but second transaction exited with {:?}", second_trans_result);
1015                    result
1016                })
1017            })
1018            .collect::<Vec<_>>();
1019        let second_trans_result =
1020            conn.transaction(|conn| crate::sql_query("SELECT 1").execute(conn));
1021        assert!(second_trans_result.is_ok(), "Expected the main connection to have been rolled back or committed, but second transaction exited with {:?}", second_trans_result);
1022
1023        let mut results = threads
1024            .into_iter()
1025            .map(|t| t.join().unwrap())
1026            .collect::<Vec<_>>();
1027
1028        results.sort_by_key(|r| r.is_err());
1029        assert!(results[0].is_ok(), "Got {:?} instead", results);
1030        assert!(
1031            matches!(&results[1], Err(DatabaseError(SerializationFailure, _))),
1032            "Got {:?} instead",
1033            results
1034        );
1035    }
1036
1037    #[test]
1038    #[cfg(feature = "sqlite")]
1039    fn sqlite_transaction_is_rolled_back_upon_deferred_constraint_failure() {
1040        use crate::connection::transaction_manager::AnsiTransactionManager;
1041        use crate::connection::transaction_manager::TransactionManager;
1042        use crate::result::Error;
1043        use crate::*;
1044        use std::num::NonZeroU32;
1045
1046        let conn = &mut crate::test_helpers::connection();
1047        assert_eq!(
1048            None,
1049            <AnsiTransactionManager as TransactionManager<SqliteConnection>>::transaction_manager_status_mut(
1050                conn
1051            ).transaction_depth().expect("Transaction depth")
1052        );
1053        let result: Result<_, Error> = conn.transaction(|conn| {
1054            assert_eq!(
1055                NonZeroU32::new(1),
1056                <AnsiTransactionManager as TransactionManager<SqliteConnection>>::transaction_manager_status_mut(
1057                    conn
1058            ).transaction_depth().expect("Transaction depth")
1059            );
1060            sql_query("DROP TABLE IF EXISTS deferred_commit").execute(conn)?;
1061            sql_query("CREATE TABLE deferred_commit(id INT UNIQUE INITIALLY DEFERRED)").execute(conn)?;
1062            sql_query("INSERT INTO deferred_commit VALUES(1)").execute(conn)?;
1063            let result = sql_query("INSERT INTO deferred_commit VALUES(1)").execute(conn);
1064            assert!(result.is_ok());
1065            Ok(())
1066        });
1067        assert!(result.is_err());
1068        assert_eq!(
1069            None,
1070            <AnsiTransactionManager as TransactionManager<SqliteConnection>>::transaction_manager_status_mut(
1071                conn
1072            ).transaction_depth().expect("Transaction depth")
1073        );
1074    }
1075
1076    // regression test for #3470
1077    // crates.io depends on this behaviour
1078    #[test]
1079    #[cfg(feature = "postgres")]
1080    fn some_libpq_failures_are_recoverable_by_rolling_back_the_savepoint_only() {
1081        use crate::connection::{AnsiTransactionManager, TransactionManager};
1082        use crate::prelude::*;
1083        use crate::sql_query;
1084
1085        crate::table! {
1086            rollback_test (id) {
1087                id -> Int4,
1088                value -> Int4,
1089            }
1090        }
1091
1092        let conn = &mut crate::test_helpers::pg_connection_no_transaction();
1093        assert_eq!(
1094            None,
1095            <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
1096                conn
1097            ).transaction_depth().expect("Transaction depth")
1098        );
1099
1100        let res = conn.transaction(|conn| {
1101            sql_query(
1102                "CREATE TABLE IF NOT EXISTS rollback_test (id INT PRIMARY KEY, value INT NOT NULL)",
1103            )
1104            .execute(conn)?;
1105            conn.transaction(|conn| {
1106                sql_query("SET TRANSACTION READ ONLY").execute(conn)?;
1107                crate::update(rollback_test::table)
1108                    .set(rollback_test::value.eq(0))
1109                    .execute(conn)
1110            })
1111            .map(|_| {
1112                panic!("Should use the `or_else` branch");
1113            })
1114            .or_else(|_| sql_query("SELECT 1").execute(conn))
1115            .map(|_| ())
1116        });
1117        assert!(res.is_ok());
1118
1119        assert_eq!(
1120            None,
1121            <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
1122                conn
1123            ).transaction_depth().expect("Transaction depth")
1124        );
1125    }
1126
1127    #[test]
1128    #[cfg(feature = "postgres")]
1129    fn other_libpq_failures_are_not_recoverable_by_rolling_back_the_savepoint_only() {
1130        use crate::connection::{AnsiTransactionManager, TransactionManager};
1131        use crate::prelude::*;
1132        use crate::sql_query;
1133        use std::num::NonZeroU32;
1134        use std::sync::{Arc, Barrier};
1135
1136        crate::table! {
1137            rollback_test2 (id) {
1138                id -> Int4,
1139                value -> Int4,
1140            }
1141        }
1142        let conn = &mut crate::test_helpers::pg_connection_no_transaction();
1143
1144        sql_query(
1145            "CREATE TABLE IF NOT EXISTS rollback_test2 (id INT PRIMARY KEY, value INT NOT NULL)",
1146        )
1147        .execute(conn)
1148        .unwrap();
1149
1150        let start_barrier = Arc::new(Barrier::new(2));
1151        let commit_barrier = Arc::new(Barrier::new(2));
1152
1153        let other_start_barrier = start_barrier.clone();
1154        let other_commit_barrier = commit_barrier.clone();
1155
1156        let t1 = std::thread::spawn(move || {
1157            let conn = &mut crate::test_helpers::pg_connection_no_transaction();
1158            assert_eq!(
1159                None,
1160                <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
1161                    conn
1162                ).transaction_depth().expect("Transaction depth")
1163            );
1164            let r = conn.build_transaction().serializable().run::<_, crate::result::Error, _>(|conn| {
1165                assert_eq!(
1166                    NonZeroU32::new(1),
1167                    <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
1168                        conn
1169                    ).transaction_depth().expect("Transaction depth")
1170                );
1171                rollback_test2::table.load::<(i32, i32)>(conn)?;
1172                crate::insert_into(rollback_test2::table)
1173                    .values((rollback_test2::id.eq(1), rollback_test2::value.eq(42)))
1174                    .execute(conn)?;
1175                let r = conn.transaction(|conn| {
1176                    assert_eq!(
1177                        NonZeroU32::new(2),
1178                        <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
1179                            conn
1180                        ).transaction_depth().expect("Transaction depth")
1181                    );
1182                    start_barrier.wait();
1183                    commit_barrier.wait();
1184                    let r = rollback_test2::table.load::<(i32, i32)>(conn);
1185                    assert!(r.is_err());
1186                    Err::<(), _>(crate::result::Error::RollbackTransaction)
1187                });
1188                assert_eq!(
1189                    NonZeroU32::new(1),
1190                    <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
1191                        conn
1192                    ).transaction_depth().expect("Transaction depth")
1193                );
1194                assert!(
1195                    matches!(r, Err(crate::result::Error::RollbackTransaction)),
1196                    "rollback failed (such errors should be ignored by transaction manager): {}",
1197                    r.unwrap_err()
1198                );
1199                let r = rollback_test2::table.load::<(i32, i32)>(conn);
1200                assert!(r.is_err());
1201                // fun fact: if hitting "commit" after receiving a serialization failure, PG
1202                // returns that the commit has succeeded, but in fact it was actually rolled back.
1203                // soo.. one should avoid doing that
1204                r
1205            });
1206            assert!(r.is_err());
1207            assert_eq!(
1208                None,
1209                <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
1210                    conn
1211                ).transaction_depth().expect("Transaction depth")
1212            );
1213        });
1214
1215        let t2 = std::thread::spawn(move || {
1216            other_start_barrier.wait();
1217            let conn = &mut crate::test_helpers::pg_connection_no_transaction();
1218            assert_eq!(
1219                None,
1220                <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
1221                    conn
1222                ).transaction_depth().expect("Transaction depth")
1223            );
1224            let r = conn.build_transaction().serializable().run::<_, crate::result::Error, _>(|conn| {
1225                assert_eq!(
1226                    NonZeroU32::new(1),
1227                    <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
1228                        conn
1229                    ).transaction_depth().expect("Transaction depth")
1230                );
1231                let _ = rollback_test2::table.load::<(i32, i32)>(conn)?;
1232                crate::insert_into(rollback_test2::table)
1233                    .values((rollback_test2::id.eq(23), rollback_test2::value.eq(42)))
1234                    .execute(conn)?;
1235                Ok(())
1236            });
1237            other_commit_barrier.wait();
1238            assert!(r.is_ok(), "{:?}", r.unwrap_err());
1239            assert_eq!(
1240                None,
1241                <AnsiTransactionManager as TransactionManager<PgConnection>>::transaction_manager_status_mut(
1242                    conn
1243                ).transaction_depth().expect("Transaction depth")
1244            );
1245        });
1246        crate::sql_query("DELETE FROM rollback_test2")
1247            .execute(conn)
1248            .unwrap();
1249        t1.join().unwrap();
1250        t2.join().unwrap();
1251    }
1252}