Skip to main content

diesel/
r2d2.rs

1//! Connection pooling via r2d2.
2//!
3//! Note: This module requires enabling the `r2d2` feature
4//!
5//! # Example
6//!
7//! The below snippet is a contrived example emulating a web application,
8//! where one would first initialize the pool in the `main()` function
9//! (at the start of a long-running process). One would then pass this
10//! pool struct around as shared state, which, here, we've emulated using
11//! threads instead of routes.
12//!
13//! ```rust
14//! # include!("doctest_setup.rs");
15//! use diesel::prelude::*;
16//! use diesel::r2d2::ConnectionManager;
17//! # use diesel::r2d2::CustomizeConnection;
18//! # use diesel::r2d2::Error as R2D2Error;
19//! use diesel::r2d2::Pool;
20//! use diesel::result::Error;
21//! use std::thread;
22//!
23//! # #[derive(Copy, Clone, Debug)]
24//! # pub struct SetupUserTableCustomizer;
25//! #
26//! # impl CustomizeConnection<DbConnection, R2D2Error> for SetupUserTableCustomizer
27//! # {
28//! #     fn on_acquire(&self, conn: &mut DbConnection) -> Result<(), R2D2Error> {
29//! #         setup_database(conn);
30//! #         Ok(())
31//! #     }
32//! # }
33//!
34//! pub fn get_connection_pool() -> Pool<ConnectionManager<DbConnection>> {
35//!     let url = database_url_for_env();
36//!     let manager = ConnectionManager::<DbConnection>::new(url);
37//!     // Refer to the `r2d2` documentation for more methods to use
38//!     // when building a connection pool
39//!     Pool::builder()
40//! #         .max_size(1)
41//!         .test_on_check_out(true)
42//! #         .connection_customizer(Box::new(SetupUserTableCustomizer))
43//!         .build(manager)
44//!         .expect("Could not build connection pool")
45//! }
46//!
47//! pub fn create_user(conn: &mut DbConnection, user_name: &str) -> Result<usize, Error> {
48//!     use schema::users::dsl::*;
49//!
50//!     diesel::insert_into(users)
51//!         .values(name.eq(user_name))
52//!         .execute(conn)
53//! }
54//!
55//! fn main() {
56//!     let pool = get_connection_pool();
57//!     let mut threads = vec![];
58//!     let max_users_to_create = 1;
59//!
60//!     for i in 0..max_users_to_create {
61//!         let pool = pool.clone();
62//!         threads.push(thread::spawn({
63//!             move || {
64//!                 let conn = &mut pool.get().unwrap();
65//!                 let name = format!("Person {}", i);
66//!                 create_user(conn, &name).unwrap();
67//!             }
68//!         }))
69//!     }
70//!
71//!     for handle in threads {
72//!         handle.join().unwrap();
73//!     }
74//! }
75//! ```
76//!
77//! # A note on error handling
78//!
79//! When used inside a pool, if an individual connection becomes
80//! broken (as determined by the [R2D2Connection::is_broken] method)
81//! then, when the connection goes out of scope, `r2d2` will close
82//! and return the connection to the DB.
83//!
84//! `diesel` determines broken connections by whether or not the current
85//! thread is panicking or if individual `Connection` structs are
86//! broken (determined by the `is_broken()` method). Generically, these
87//! are left to individual backends to implement themselves.
88//!
89//! For SQLite, PG, and MySQL backends `is_broken()` is determined
90//! by whether or not the `TransactionManagerStatus` (as a part
91//! of the `AnsiTransactionManager` struct) is in an `InError` state
92//! or contains an open transaction when the connection goes out of scope.
93//!
94//!
95//! # Testing with connections pools
96//!
97//! When testing with connection pools, it is recommended to set the pool size to 1,
98//! and use a customizer to ensure that the transactions are never committed.
99//! The tests using a pool prepared this way can be run in parallel, because
100//! the changes are never committed to the database and are local to each test.
101//!
102//! # Example
103//!
104//! ```rust
105//! # include!("doctest_setup.rs");
106//! use diesel::prelude::*;
107//! use diesel::r2d2::ConnectionManager;
108//! use diesel::r2d2::CustomizeConnection;
109//! use diesel::r2d2::TestCustomizer;
110//! # use diesel::r2d2::Error as R2D2Error;
111//! use diesel::r2d2::Pool;
112//! use diesel::result::Error;
113//! use std::thread;
114//!
115//! # fn main() {}
116//!
117//! pub fn get_testing_pool() -> Pool<ConnectionManager<DbConnection>> {
118//!     let url = database_url_for_env();
119//!     let manager = ConnectionManager::<DbConnection>::new(url);
120//!
121//!     Pool::builder()
122//!         .test_on_check_out(true)
123//!         .max_size(1) // Max pool size set to 1
124//!         .connection_customizer(Box::new(TestCustomizer)) // Test customizer
125//!         .build(manager)
126//!         .expect("Could not build connection pool")
127//! }
128//!
129//! table! {
130//!     users {
131//!         id -> Integer,
132//!         name -> Text,
133//!     }
134//! }
135//!
136//! #[cfg(test)]
137//! mod tests {
138//!     use super::*;
139//!
140//!     #[diesel_test_helper::test]
141//!     fn test_1() {
142//!         let pool = get_testing_pool();
143//!         let mut conn = pool.get().unwrap();
144//!
145//!         crate::sql_query(
146//!             "CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)",
147//!         )
148//!         .execute(&mut conn)
149//!         .unwrap();
150//!
151//!         crate::insert_into(users::table)
152//!             .values(users::name.eq("John"))
153//!             .execute(&mut conn)
154//!             .unwrap();
155//!     }
156//!
157//!     #[diesel_test_helper::test]
158//!     fn test_2() {
159//!         let pool = get_testing_pool();
160//!         let mut conn = pool.get().unwrap();
161//!
162//!         crate::sql_query(
163//!             "CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)",
164//!         )
165//!         .execute(&mut conn)
166//!         .unwrap();
167//!
168//!         let user_count = users::table.count().get_result::<i64>(&mut conn).unwrap();
169//!         assert_eq!(user_count, 0); // Because the transaction from test_1 was never committed
170//!     }
171//! }
172//! ```
173pub use r2d2::*;
174
175/// A re-export of [`r2d2::Error`], which is only used by methods on [`r2d2::Pool`].
176///
177/// [`r2d2::Error`]: r2d2::Error
178/// [`r2d2::Pool`]: r2d2::Pool
179pub type PoolError = r2d2::Error;
180
181use alloc::fmt;
182use core::marker::PhantomData;
183
184use crate::backend::Backend;
185use crate::connection::{
186    ConnectionSealed, LoadConnection, SimpleConnection, TransactionManager,
187    TransactionManagerStatus,
188};
189use crate::expression::QueryMetadata;
190use crate::prelude::*;
191use crate::query_builder::{Query, QueryFragment, QueryId};
192use crate::query_dsl::RunQueryDslSupport;
193
194/// An r2d2 connection manager for use with Diesel.
195///
196/// See the [r2d2 documentation](https://docs.rs/r2d2/latest/r2d2/) for usage examples.
197#[derive(#[automatically_derived]
impl<T: ::core::clone::Clone> ::core::clone::Clone for ConnectionManager<T> {
    #[inline]
    fn clone(&self) -> ConnectionManager<T> {
        ConnectionManager {
            database_url: ::core::clone::Clone::clone(&self.database_url),
            _marker: ::core::clone::Clone::clone(&self._marker),
        }
    }
}Clone)]
198pub struct ConnectionManager<T> {
199    database_url: String,
200    _marker: PhantomData<T>,
201}
202
203impl<T> fmt::Debug for ConnectionManager<T> {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        f.write_fmt(format_args!("ConnectionManager<{0}>",
        core::any::type_name::<T>()))write!(f, "ConnectionManager<{}>", core::any::type_name::<T>())
206    }
207}
208
209#[allow(unsafe_code)] // we do not actually hold a reference to `T`
210unsafe impl<T: Send + 'static> Sync for ConnectionManager<T> {}
211
212impl<T> ConnectionManager<T> {
213    /// Returns a new connection manager,
214    /// which establishes connections to the given database URL.
215    pub fn new<S: Into<String>>(database_url: S) -> Self {
216        ConnectionManager {
217            database_url: database_url.into(),
218            _marker: PhantomData,
219        }
220    }
221
222    /// Modifies the URL which was supplied at initialization.
223    ///
224    /// This does not update any state for existing connections,
225    /// but this new URL is used for new connections that are created.
226    pub fn update_database_url<S: Into<String>>(&mut self, database_url: S) {
227        self.database_url = database_url.into();
228    }
229}
230
231/// The error used when managing connections with `r2d2`.
232#[derive(#[automatically_derived]
impl ::core::fmt::Debug for Error {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        match self {
            Error::ConnectionError(__self_0) =>
                ::core::fmt::Formatter::debug_tuple_field1_finish(f,
                    "ConnectionError", &__self_0),
            Error::QueryError(__self_0) =>
                ::core::fmt::Formatter::debug_tuple_field1_finish(f,
                    "QueryError", &__self_0),
        }
    }
}Debug)]
233pub enum Error {
234    /// An error occurred establishing the connection
235    ConnectionError(ConnectionError),
236
237    /// An error occurred pinging the database
238    QueryError(crate::result::Error),
239}
240
241impl fmt::Display for Error {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        match *self {
244            Error::ConnectionError(ref e) => e.fmt(f),
245            Error::QueryError(ref e) => e.fmt(f),
246        }
247    }
248}
249
250impl ::core::error::Error for Error {}
251
252impl From<crate::result::Error> for Error {
253    fn from(other: crate::result::Error) -> Self {
254        Self::QueryError(other)
255    }
256}
257
258impl From<ConnectionError> for Error {
259    fn from(other: ConnectionError) -> Self {
260        Self::ConnectionError(other)
261    }
262}
263
264/// A trait indicating a connection could be used inside a r2d2 pool
265pub trait R2D2Connection: Connection {
266    /// Check if a connection is still valid
267    fn ping(&mut self) -> QueryResult<()>;
268
269    /// Checks if the connection is broken and should not be reused
270    ///
271    /// This method should return only contain a fast non-blocking check
272    /// if the connection is considered to be broken or not. See
273    /// [ManageConnection::has_broken] for details.
274    ///
275    /// The default implementation does not consider any connection as broken
276    fn is_broken(&mut self) -> bool {
277        false
278    }
279}
280
281impl<T> ManageConnection for ConnectionManager<T>
282where
283    T: R2D2Connection + Send + 'static,
284{
285    type Connection = T;
286    type Error = Error;
287
288    fn connect(&self) -> Result<T, Error> {
289        T::establish(&self.database_url).map_err(Error::ConnectionError)
290    }
291
292    fn is_valid(&self, conn: &mut T) -> Result<(), Error> {
293        conn.ping().map_err(Error::QueryError)
294    }
295
296    fn has_broken(&self, conn: &mut T) -> bool {
297        std::thread::panicking() || conn.is_broken()
298    }
299}
300
301impl<M> SimpleConnection for PooledConnection<M>
302where
303    M: ManageConnection,
304    M::Connection: R2D2Connection + Send + 'static,
305{
306    fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
307        (**self).batch_execute(query)
308    }
309}
310
311impl<M> ConnectionSealed for PooledConnection<M>
312where
313    M: ManageConnection,
314    M::Connection: ConnectionSealed,
315{
316}
317
318impl<M> Connection for PooledConnection<M>
319where
320    M: ManageConnection,
321    M::Connection: Connection + R2D2Connection + Send + 'static,
322{
323    type Backend = <M::Connection as Connection>::Backend;
324    type TransactionManager =
325        PoolTransactionManager<<M::Connection as Connection>::TransactionManager>;
326
327    fn establish(_: &str) -> ConnectionResult<Self> {
328        Err(ConnectionError::BadConnection(String::from(
329            "Cannot directly establish a pooled connection",
330        )))
331    }
332
333    fn begin_test_transaction(&mut self) -> QueryResult<()> {
334        (**self).begin_test_transaction()
335    }
336
337    fn execute_returning_count<T>(&mut self, source: &T) -> QueryResult<usize>
338    where
339        T: QueryFragment<Self::Backend> + QueryId,
340    {
341        (**self).execute_returning_count(source)
342    }
343
344    fn transaction_state(
345        &mut self,
346    ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
347        (**self).transaction_state()
348    }
349
350    fn instrumentation(&mut self) -> &mut dyn crate::connection::Instrumentation {
351        (**self).instrumentation()
352    }
353
354    fn set_instrumentation(&mut self, instrumentation: impl crate::connection::Instrumentation) {
355        (**self).set_instrumentation(instrumentation)
356    }
357
358    fn set_prepared_statement_cache_size(&mut self, size: crate::connection::CacheSize) {
359        (**self).set_prepared_statement_cache_size(size)
360    }
361}
362
363impl<B, M> LoadConnection<B> for PooledConnection<M>
364where
365    M: ManageConnection,
366    M::Connection: LoadConnection<B> + R2D2Connection,
367{
368    type Cursor<'conn, 'query> = <M::Connection as LoadConnection<B>>::Cursor<'conn, 'query>;
369    type Row<'conn, 'query> = <M::Connection as LoadConnection<B>>::Row<'conn, 'query>;
370
371    fn load<'conn, 'query, T>(
372        &'conn mut self,
373        source: T,
374    ) -> QueryResult<Self::Cursor<'conn, 'query>>
375    where
376        T: Query + QueryFragment<Self::Backend> + QueryId + 'query,
377        Self::Backend: QueryMetadata<T::SqlType>,
378    {
379        (**self).load(source)
380    }
381}
382
383#[doc(hidden)]
384#[allow(missing_debug_implementations)]
385pub struct PoolTransactionManager<T>(core::marker::PhantomData<T>);
386
387impl<M, T> TransactionManager<PooledConnection<M>> for PoolTransactionManager<T>
388where
389    M: ManageConnection,
390    M::Connection: Connection<TransactionManager = T> + R2D2Connection,
391    T: TransactionManager<M::Connection>,
392{
393    type TransactionStateData = T::TransactionStateData;
394
395    fn begin_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
396        T::begin_transaction(&mut **conn)
397    }
398
399    fn rollback_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
400        T::rollback_transaction(&mut **conn)
401    }
402
403    fn commit_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
404        T::commit_transaction(&mut **conn)
405    }
406
407    fn transaction_manager_status_mut(
408        conn: &mut PooledConnection<M>,
409    ) -> &mut TransactionManagerStatus {
410        T::transaction_manager_status_mut(&mut **conn)
411    }
412}
413
414impl<M> crate::migration::MigrationConnection for PooledConnection<M>
415where
416    M: ManageConnection,
417    M::Connection: crate::migration::MigrationConnection,
418    Self: Connection,
419{
420    fn setup(&mut self) -> QueryResult<usize> {
421        (**self).setup()
422    }
423}
424
425impl<Changes, Output, M> crate::query_dsl::UpdateAndFetchResults<Changes, Output>
426    for PooledConnection<M>
427where
428    M: ManageConnection,
429    M::Connection: crate::query_dsl::UpdateAndFetchResults<Changes, Output>,
430    Self: Connection,
431{
432    fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output> {
433        (**self).update_and_fetch(changeset)
434    }
435}
436
437#[derive(const _: () =
    {
        use diesel;
        #[allow(non_camel_case_types)]
        impl diesel::query_builder::QueryId for CheckConnectionQuery {
            type QueryId = CheckConnectionQuery<>;
            const HAS_STATIC_QUERY_ID: bool = true;
            const IS_WINDOW_FUNCTION: bool = false;
        }
    };QueryId)]
438pub(crate) struct CheckConnectionQuery;
439
440impl<DB> QueryFragment<DB> for CheckConnectionQuery
441where
442    DB: Backend,
443{
444    fn walk_ast<'b>(
445        &'b self,
446        mut pass: crate::query_builder::AstPass<'_, 'b, DB>,
447    ) -> QueryResult<()> {
448        pass.push_sql("SELECT 1");
449        Ok(())
450    }
451}
452
453impl Query for CheckConnectionQuery {
454    type SqlType = crate::sql_types::Integer;
455}
456
457impl RunQueryDslSupport for CheckConnectionQuery {}
458
459/// A connection customizer designed for use in tests. Implements
460/// [CustomizeConnection] in a way that ensures transactions
461/// in a pool customized by it are never committed.
462#[derive(#[automatically_derived]
impl ::core::fmt::Debug for TestCustomizer {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::write_str(f, "TestCustomizer")
    }
}Debug, #[automatically_derived]
impl ::core::clone::Clone for TestCustomizer {
    #[inline]
    fn clone(&self) -> TestCustomizer { *self }
}Clone, #[automatically_derived]
impl ::core::marker::Copy for TestCustomizer { }Copy)]
463pub struct TestCustomizer;
464
465impl<C: Connection> CustomizeConnection<C, crate::r2d2::Error> for TestCustomizer {
466    fn on_acquire(&self, conn: &mut C) -> Result<(), crate::r2d2::Error> {
467        conn.begin_test_transaction()
468            .map_err(crate::r2d2::Error::QueryError)
469    }
470}
471
472#[cfg(all(test, not(all(target_family = "wasm", target_os = "unknown"))))]
473mod tests {
474    use std::sync::Arc;
475    use std::sync::mpsc;
476    use std::thread;
477    use std::time::Duration;
478
479    use crate::r2d2::*;
480    use crate::test_helpers::*;
481
482    #[diesel_test_helper::test]
483    fn establish_basic_connection() {
484        let manager = ConnectionManager::<TestConnection>::new(database_url());
485        let pool = Arc::new(Pool::builder().max_size(2).build(manager).unwrap());
486
487        let (s1, r1) = mpsc::channel();
488        let (s2, r2) = mpsc::channel();
489
490        let pool1 = Arc::clone(&pool);
491        let t1 = thread::spawn(move || {
492            let conn = pool1.get().unwrap();
493            s1.send(()).unwrap();
494            r2.recv().unwrap();
495            drop(conn);
496        });
497
498        let pool2 = Arc::clone(&pool);
499        let t2 = thread::spawn(move || {
500            let conn = pool2.get().unwrap();
501            s2.send(()).unwrap();
502            r1.recv().unwrap();
503            drop(conn);
504        });
505
506        t1.join().unwrap();
507        t2.join().unwrap();
508
509        pool.get().unwrap();
510    }
511
512    #[diesel_test_helper::test]
513    fn is_valid() {
514        let manager = ConnectionManager::<TestConnection>::new(database_url());
515        let pool = Pool::builder()
516            .max_size(1)
517            .test_on_check_out(true)
518            .build(manager)
519            .unwrap();
520
521        pool.get().unwrap();
522    }
523
524    #[diesel_test_helper::test]
525    fn pooled_connection_impls_connection() {
526        use crate::select;
527        use crate::sql_types::Text;
528
529        let manager = ConnectionManager::<TestConnection>::new(database_url());
530        let pool = Pool::builder()
531            .max_size(1)
532            .test_on_check_out(true)
533            .build(manager)
534            .unwrap();
535        let mut conn = pool.get().unwrap();
536
537        let query = select("foo".into_sql::<Text>());
538        assert_eq!("foo", query.get_result::<String>(&mut conn).unwrap());
539    }
540
541    #[diesel_test_helper::test]
542    fn check_pool_does_actually_hold_connections() {
543        use std::sync::atomic::{AtomicU32, Ordering};
544
545        #[derive(Debug)]
546        struct TestEventHandler {
547            acquire_count: Arc<AtomicU32>,
548            release_count: Arc<AtomicU32>,
549            checkin_count: Arc<AtomicU32>,
550            checkout_count: Arc<AtomicU32>,
551        }
552
553        impl r2d2::HandleEvent for TestEventHandler {
554            fn handle_acquire(&self, _event: r2d2::event::AcquireEvent) {
555                self.acquire_count.fetch_add(1, Ordering::Relaxed);
556            }
557            fn handle_release(&self, _event: r2d2::event::ReleaseEvent) {
558                self.release_count.fetch_add(1, Ordering::Relaxed);
559            }
560            fn handle_checkout(&self, _event: r2d2::event::CheckoutEvent) {
561                self.checkout_count.fetch_add(1, Ordering::Relaxed);
562            }
563            fn handle_checkin(&self, _event: r2d2::event::CheckinEvent) {
564                self.checkin_count.fetch_add(1, Ordering::Relaxed);
565            }
566        }
567
568        let acquire_count = Arc::new(AtomicU32::new(0));
569        let release_count = Arc::new(AtomicU32::new(0));
570        let checkin_count = Arc::new(AtomicU32::new(0));
571        let checkout_count = Arc::new(AtomicU32::new(0));
572
573        let handler = Box::new(TestEventHandler {
574            acquire_count: acquire_count.clone(),
575            release_count: release_count.clone(),
576            checkin_count: checkin_count.clone(),
577            checkout_count: checkout_count.clone(),
578        });
579
580        let manager = ConnectionManager::<TestConnection>::new(database_url());
581        let pool = Pool::builder()
582            .max_size(1)
583            .test_on_check_out(true)
584            .event_handler(handler)
585            .build(manager)
586            .unwrap();
587
588        assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
589        assert_eq!(release_count.load(Ordering::Relaxed), 0);
590        assert_eq!(checkin_count.load(Ordering::Relaxed), 0);
591        assert_eq!(checkout_count.load(Ordering::Relaxed), 0);
592
593        // check that we reuse connections with the pool
594        {
595            let conn = pool.get().unwrap();
596
597            assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
598            assert_eq!(release_count.load(Ordering::Relaxed), 0);
599            assert_eq!(checkin_count.load(Ordering::Relaxed), 0);
600            assert_eq!(checkout_count.load(Ordering::Relaxed), 1);
601            std::mem::drop(conn);
602        }
603
604        assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
605        assert_eq!(release_count.load(Ordering::Relaxed), 0);
606        assert_eq!(checkin_count.load(Ordering::Relaxed), 1);
607        assert_eq!(checkout_count.load(Ordering::Relaxed), 1);
608
609        // check that we remove a connection with open transactions from the pool
610        {
611            let mut conn = pool.get().unwrap();
612
613            assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
614            assert_eq!(release_count.load(Ordering::Relaxed), 0);
615            assert_eq!(checkin_count.load(Ordering::Relaxed), 1);
616            assert_eq!(checkout_count.load(Ordering::Relaxed), 2);
617
618            <TestConnection as Connection>::TransactionManager::begin_transaction(&mut *conn)
619                .unwrap();
620        }
621
622        // we are not interested in the acquire count here
623        // as the pool opens a new connection in the background
624        // that could lead to this test failing if that happens to fast
625        // (which is sometimes the case for sqlite)
626        //assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
627        assert_eq!(release_count.load(Ordering::Relaxed), 1);
628        assert_eq!(checkin_count.load(Ordering::Relaxed), 2);
629        assert_eq!(checkout_count.load(Ordering::Relaxed), 2);
630
631        // check that we remove a connection from the pool that was
632        // open during panicking
633        #[allow(unreachable_code, unused_variables)]
634        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
635            let conn = pool.get();
636            assert_eq!(acquire_count.load(Ordering::Relaxed), 2);
637            assert_eq!(release_count.load(Ordering::Relaxed), 1);
638            assert_eq!(checkin_count.load(Ordering::Relaxed), 2);
639            assert_eq!(checkout_count.load(Ordering::Relaxed), 3);
640            panic!();
641            std::mem::drop(conn);
642        }))
643        .unwrap_err();
644
645        // we are not interested in the acquire count here
646        // as the pool opens a new connection in the background
647        // that could lead to this test failing if that happens to fast
648        // (which is sometimes the case for sqlite)
649        //assert_eq!(acquire_count.load(Ordering::Relaxed), 2);
650        assert_eq!(release_count.load(Ordering::Relaxed), 2);
651        assert_eq!(checkin_count.load(Ordering::Relaxed), 3);
652        assert_eq!(checkout_count.load(Ordering::Relaxed), 3);
653        // this is required to workaround a segfault while shutting down
654        // the pool
655        std::thread::sleep(Duration::from_millis(100));
656    }
657
658    #[cfg(feature = "postgres")]
659    #[diesel_test_helper::test]
660    fn verify_that_begin_test_transaction_works_with_pools() {
661        use crate::prelude::*;
662        use crate::r2d2::*;
663
664        table! {
665            users {
666                id -> Integer,
667                name -> Text,
668            }
669        }
670
671        #[derive(Debug)]
672        struct TestConnectionCustomizer;
673
674        impl<E> CustomizeConnection<PgConnection, E> for TestConnectionCustomizer {
675            fn on_acquire(&self, conn: &mut PgConnection) -> Result<(), E> {
676                conn.begin_test_transaction()
677                    .expect("Failed to start test transaction");
678
679                Ok(())
680            }
681        }
682
683        let manager = ConnectionManager::<PgConnection>::new(database_url());
684        let pool = Pool::builder()
685            .max_size(1)
686            .connection_customizer(Box::new(TestConnectionCustomizer))
687            .build(manager)
688            .unwrap();
689
690        let mut conn = pool.get().unwrap();
691
692        crate::sql_query(
693            "CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)",
694        )
695        .execute(&mut conn)
696        .unwrap();
697
698        crate::insert_into(users::table)
699            .values(users::name.eq("John"))
700            .execute(&mut conn)
701            .unwrap();
702
703        std::mem::drop(conn);
704
705        let mut conn2 = pool.get().unwrap();
706
707        let user_count = users::table.count().get_result::<i64>(&mut conn2).unwrap();
708        assert_eq!(user_count, 1);
709    }
710}