diesel/
r2d2.rs

1//! Connection pooling via r2d2.
2//!
3//! Note: This module requires enabling the `r2d2` feature
4//!
5//! # Example
6//!
7//! The below snippet is a contrived example emulating a web application,
8//! where one would first initialize the pool in the `main()` function
9//! (at the start of a long-running process). One would then pass this
10//! pool struct around as shared state, which, here, we've emulated using
11//! threads instead of routes.
12//!
13//! ```rust
14//! # include!("doctest_setup.rs");
15//! use diesel::prelude::*;
16//! use diesel::r2d2::ConnectionManager;
17//! # use diesel::r2d2::CustomizeConnection;
18//! # use diesel::r2d2::Error as R2D2Error;
19//! use diesel::r2d2::Pool;
20//! use diesel::result::Error;
21//! use std::thread;
22//!
23//! # #[derive(Copy, Clone, Debug)]
24//! # pub struct SetupUserTableCustomizer;
25//! #
26//! # impl CustomizeConnection<DbConnection, R2D2Error> for SetupUserTableCustomizer
27//! # {
28//! #     fn on_acquire(&self, conn: &mut DbConnection) -> Result<(), R2D2Error> {
29//! #         setup_database(conn);
30//! #         Ok(())
31//! #     }
32//! # }
33//!
34//! pub fn get_connection_pool() -> Pool<ConnectionManager<DbConnection>> {
35//!     let url = database_url_for_env();
36//!     let manager = ConnectionManager::<DbConnection>::new(url);
37//!     // Refer to the `r2d2` documentation for more methods to use
38//!     // when building a connection pool
39//!     Pool::builder()
40//! #         .max_size(1)
41//!         .test_on_check_out(true)
42//! #         .connection_customizer(Box::new(SetupUserTableCustomizer))
43//!         .build(manager)
44//!         .expect("Could not build connection pool")
45//! }
46//!
47//! pub fn create_user(conn: &mut DbConnection, user_name: &str) -> Result<usize, Error> {
48//!     use schema::users::dsl::*;
49//!
50//!     diesel::insert_into(users)
51//!         .values(name.eq(user_name))
52//!         .execute(conn)
53//! }
54//!
55//! fn main() {
56//!     let pool = get_connection_pool();
57//!     let mut threads = vec![];
58//!     let max_users_to_create = 1;
59//!
60//!     for i in 0..max_users_to_create {
61//!         let pool = pool.clone();
62//!         threads.push(thread::spawn({
63//!             move || {
64//!                 let conn = &mut pool.get().unwrap();
65//!                 let name = format!("Person {}", i);
66//!                 create_user(conn, &name).unwrap();
67//!             }
68//!         }))
69//!     }
70//!
71//!     for handle in threads {
72//!         handle.join().unwrap();
73//!     }
74//! }
75//! ```
76//!
77//! # A note on error handling
78//!
79//! When used inside a pool, if an individual connection becomes
80//! broken (as determined by the [R2D2Connection::is_broken] method)
81//! then, when the connection goes out of scope, `r2d2` will close
82//! and return the connection to the DB.
83//!
84//! `diesel` determines broken connections by whether or not the current
85//! thread is panicking or if individual `Connection` structs are
86//! broken (determined by the `is_broken()` method). Generically, these
87//! are left to individual backends to implement themselves.
88//!
89//! For SQLite, PG, and MySQL backends `is_broken()` is determined
90//! by whether or not the `TransactionManagerStatus` (as a part
91//! of the `AnsiTransactionManager` struct) is in an `InError` state
92//! or contains an open transaction when the connection goes out of scope.
93//!
94
95pub use r2d2::*;
96
97/// A re-export of [`r2d2::Error`], which is only used by methods on [`r2d2::Pool`].
98///
99/// [`r2d2::Error`]: r2d2::Error
100/// [`r2d2::Pool`]: r2d2::Pool
101pub type PoolError = r2d2::Error;
102
103use std::fmt;
104use std::marker::PhantomData;
105
106use crate::backend::Backend;
107use crate::connection::{
108    ConnectionSealed, LoadConnection, SimpleConnection, TransactionManager,
109    TransactionManagerStatus,
110};
111use crate::expression::QueryMetadata;
112use crate::prelude::*;
113use crate::query_builder::{Query, QueryFragment, QueryId};
114
115/// An r2d2 connection manager for use with Diesel.
116///
117/// See the [r2d2 documentation](https://docs.rs/r2d2/latest/r2d2/) for usage examples.
118#[derive(Clone)]
119pub struct ConnectionManager<T> {
120    database_url: String,
121    _marker: PhantomData<T>,
122}
123
124impl<T> fmt::Debug for ConnectionManager<T> {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        write!(f, "ConnectionManager<{}>", std::any::type_name::<T>())
127    }
128}
129
130#[allow(unsafe_code)] // we do not actually hold a reference to `T`
131unsafe impl<T: Send + 'static> Sync for ConnectionManager<T> {}
132
133impl<T> ConnectionManager<T> {
134    /// Returns a new connection manager,
135    /// which establishes connections to the given database URL.
136    pub fn new<S: Into<String>>(database_url: S) -> Self {
137        ConnectionManager {
138            database_url: database_url.into(),
139            _marker: PhantomData,
140        }
141    }
142
143    /// Modifies the URL which was supplied at initialization.
144    ///
145    /// This does not update any state for existing connections,
146    /// but this new URL is used for new connections that are created.
147    pub fn update_database_url<S: Into<String>>(&mut self, database_url: S) {
148        self.database_url = database_url.into();
149    }
150}
151
152/// The error used when managing connections with `r2d2`.
153#[derive(Debug)]
154pub enum Error {
155    /// An error occurred establishing the connection
156    ConnectionError(ConnectionError),
157
158    /// An error occurred pinging the database
159    QueryError(crate::result::Error),
160}
161
162impl fmt::Display for Error {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        match *self {
165            Error::ConnectionError(ref e) => e.fmt(f),
166            Error::QueryError(ref e) => e.fmt(f),
167        }
168    }
169}
170
171impl ::std::error::Error for Error {}
172
173/// A trait indicating a connection could be used inside a r2d2 pool
174pub trait R2D2Connection: Connection {
175    /// Check if a connection is still valid
176    fn ping(&mut self) -> QueryResult<()>;
177
178    /// Checks if the connection is broken and should not be reused
179    ///
180    /// This method should return only contain a fast non-blocking check
181    /// if the connection is considered to be broken or not. See
182    /// [ManageConnection::has_broken] for details.
183    ///
184    /// The default implementation does not consider any connection as broken
185    fn is_broken(&mut self) -> bool {
186        false
187    }
188}
189
190impl<T> ManageConnection for ConnectionManager<T>
191where
192    T: R2D2Connection + Send + 'static,
193{
194    type Connection = T;
195    type Error = Error;
196
197    fn connect(&self) -> Result<T, Error> {
198        T::establish(&self.database_url).map_err(Error::ConnectionError)
199    }
200
201    fn is_valid(&self, conn: &mut T) -> Result<(), Error> {
202        conn.ping().map_err(Error::QueryError)
203    }
204
205    fn has_broken(&self, conn: &mut T) -> bool {
206        std::thread::panicking() || conn.is_broken()
207    }
208}
209
210impl<M> SimpleConnection for PooledConnection<M>
211where
212    M: ManageConnection,
213    M::Connection: R2D2Connection + Send + 'static,
214{
215    fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
216        (**self).batch_execute(query)
217    }
218}
219
220impl<M> ConnectionSealed for PooledConnection<M>
221where
222    M: ManageConnection,
223    M::Connection: ConnectionSealed,
224{
225}
226
227impl<M> Connection for PooledConnection<M>
228where
229    M: ManageConnection,
230    M::Connection: Connection + R2D2Connection + Send + 'static,
231{
232    type Backend = <M::Connection as Connection>::Backend;
233    type TransactionManager =
234        PoolTransactionManager<<M::Connection as Connection>::TransactionManager>;
235
236    fn establish(_: &str) -> ConnectionResult<Self> {
237        Err(ConnectionError::BadConnection(String::from(
238            "Cannot directly establish a pooled connection",
239        )))
240    }
241
242    fn begin_test_transaction(&mut self) -> QueryResult<()> {
243        (**self).begin_test_transaction()
244    }
245
246    fn execute_returning_count<T>(&mut self, source: &T) -> QueryResult<usize>
247    where
248        T: QueryFragment<Self::Backend> + QueryId,
249    {
250        (**self).execute_returning_count(source)
251    }
252
253    fn transaction_state(
254        &mut self,
255    ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
256        (**self).transaction_state()
257    }
258
259    fn instrumentation(&mut self) -> &mut dyn crate::connection::Instrumentation {
260        (**self).instrumentation()
261    }
262
263    fn set_instrumentation(&mut self, instrumentation: impl crate::connection::Instrumentation) {
264        (**self).set_instrumentation(instrumentation)
265    }
266}
267
268impl<B, M> LoadConnection<B> for PooledConnection<M>
269where
270    M: ManageConnection,
271    M::Connection: LoadConnection<B> + R2D2Connection,
272{
273    type Cursor<'conn, 'query> = <M::Connection as LoadConnection<B>>::Cursor<'conn, 'query>;
274    type Row<'conn, 'query> = <M::Connection as LoadConnection<B>>::Row<'conn, 'query>;
275
276    fn load<'conn, 'query, T>(
277        &'conn mut self,
278        source: T,
279    ) -> QueryResult<Self::Cursor<'conn, 'query>>
280    where
281        T: Query + QueryFragment<Self::Backend> + QueryId + 'query,
282        Self::Backend: QueryMetadata<T::SqlType>,
283    {
284        (**self).load(source)
285    }
286}
287
288#[doc(hidden)]
289#[allow(missing_debug_implementations)]
290pub struct PoolTransactionManager<T>(std::marker::PhantomData<T>);
291
292impl<M, T> TransactionManager<PooledConnection<M>> for PoolTransactionManager<T>
293where
294    M: ManageConnection,
295    M::Connection: Connection<TransactionManager = T> + R2D2Connection,
296    T: TransactionManager<M::Connection>,
297{
298    type TransactionStateData = T::TransactionStateData;
299
300    fn begin_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
301        T::begin_transaction(&mut **conn)
302    }
303
304    fn rollback_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
305        T::rollback_transaction(&mut **conn)
306    }
307
308    fn commit_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
309        T::commit_transaction(&mut **conn)
310    }
311
312    fn transaction_manager_status_mut(
313        conn: &mut PooledConnection<M>,
314    ) -> &mut TransactionManagerStatus {
315        T::transaction_manager_status_mut(&mut **conn)
316    }
317}
318
319impl<M> crate::migration::MigrationConnection for PooledConnection<M>
320where
321    M: ManageConnection,
322    M::Connection: crate::migration::MigrationConnection,
323    Self: Connection,
324{
325    fn setup(&mut self) -> QueryResult<usize> {
326        (**self).setup()
327    }
328}
329
330impl<Changes, Output, M> crate::query_dsl::UpdateAndFetchResults<Changes, Output>
331    for PooledConnection<M>
332where
333    M: ManageConnection,
334    M::Connection: crate::query_dsl::UpdateAndFetchResults<Changes, Output>,
335    Self: Connection,
336{
337    fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output> {
338        (**self).update_and_fetch(changeset)
339    }
340}
341
342#[derive(QueryId)]
343pub(crate) struct CheckConnectionQuery;
344
345impl<DB> QueryFragment<DB> for CheckConnectionQuery
346where
347    DB: Backend,
348{
349    fn walk_ast<'b>(
350        &'b self,
351        mut pass: crate::query_builder::AstPass<'_, 'b, DB>,
352    ) -> QueryResult<()> {
353        pass.push_sql("SELECT 1");
354        Ok(())
355    }
356}
357
358impl Query for CheckConnectionQuery {
359    type SqlType = crate::sql_types::Integer;
360}
361
362impl<C> RunQueryDsl<C> for CheckConnectionQuery {}
363
364#[cfg(test)]
365mod tests {
366    use std::sync::mpsc;
367    use std::sync::Arc;
368    use std::thread;
369
370    use crate::r2d2::*;
371    use crate::test_helpers::*;
372
373    #[test]
374    fn establish_basic_connection() {
375        let manager = ConnectionManager::<TestConnection>::new(database_url());
376        let pool = Arc::new(Pool::builder().max_size(2).build(manager).unwrap());
377
378        let (s1, r1) = mpsc::channel();
379        let (s2, r2) = mpsc::channel();
380
381        let pool1 = Arc::clone(&pool);
382        let t1 = thread::spawn(move || {
383            let conn = pool1.get().unwrap();
384            s1.send(()).unwrap();
385            r2.recv().unwrap();
386            drop(conn);
387        });
388
389        let pool2 = Arc::clone(&pool);
390        let t2 = thread::spawn(move || {
391            let conn = pool2.get().unwrap();
392            s2.send(()).unwrap();
393            r1.recv().unwrap();
394            drop(conn);
395        });
396
397        t1.join().unwrap();
398        t2.join().unwrap();
399
400        pool.get().unwrap();
401    }
402
403    #[test]
404    fn is_valid() {
405        let manager = ConnectionManager::<TestConnection>::new(database_url());
406        let pool = Pool::builder()
407            .max_size(1)
408            .test_on_check_out(true)
409            .build(manager)
410            .unwrap();
411
412        pool.get().unwrap();
413    }
414
415    #[test]
416    fn pooled_connection_impls_connection() {
417        use crate::select;
418        use crate::sql_types::Text;
419
420        let manager = ConnectionManager::<TestConnection>::new(database_url());
421        let pool = Pool::builder()
422            .max_size(1)
423            .test_on_check_out(true)
424            .build(manager)
425            .unwrap();
426        let mut conn = pool.get().unwrap();
427
428        let query = select("foo".into_sql::<Text>());
429        assert_eq!("foo", query.get_result::<String>(&mut conn).unwrap());
430    }
431
432    #[test]
433    fn check_pool_does_actually_hold_connections() {
434        use std::sync::atomic::{AtomicU32, Ordering};
435
436        #[derive(Debug)]
437        struct TestEventHandler {
438            acquire_count: Arc<AtomicU32>,
439            release_count: Arc<AtomicU32>,
440            checkin_count: Arc<AtomicU32>,
441            checkout_count: Arc<AtomicU32>,
442        }
443
444        impl r2d2::HandleEvent for TestEventHandler {
445            fn handle_acquire(&self, _event: r2d2::event::AcquireEvent) {
446                self.acquire_count.fetch_add(1, Ordering::Relaxed);
447            }
448            fn handle_release(&self, _event: r2d2::event::ReleaseEvent) {
449                self.release_count.fetch_add(1, Ordering::Relaxed);
450            }
451            fn handle_checkout(&self, _event: r2d2::event::CheckoutEvent) {
452                self.checkout_count.fetch_add(1, Ordering::Relaxed);
453            }
454            fn handle_checkin(&self, _event: r2d2::event::CheckinEvent) {
455                self.checkin_count.fetch_add(1, Ordering::Relaxed);
456            }
457        }
458
459        let acquire_count = Arc::new(AtomicU32::new(0));
460        let release_count = Arc::new(AtomicU32::new(0));
461        let checkin_count = Arc::new(AtomicU32::new(0));
462        let checkout_count = Arc::new(AtomicU32::new(0));
463
464        let handler = Box::new(TestEventHandler {
465            acquire_count: acquire_count.clone(),
466            release_count: release_count.clone(),
467            checkin_count: checkin_count.clone(),
468            checkout_count: checkout_count.clone(),
469        });
470
471        let manager = ConnectionManager::<TestConnection>::new(database_url());
472        let pool = Pool::builder()
473            .max_size(1)
474            .test_on_check_out(true)
475            .event_handler(handler)
476            .build(manager)
477            .unwrap();
478
479        assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
480        assert_eq!(release_count.load(Ordering::Relaxed), 0);
481        assert_eq!(checkin_count.load(Ordering::Relaxed), 0);
482        assert_eq!(checkout_count.load(Ordering::Relaxed), 0);
483
484        // check that we reuse connections with the pool
485        {
486            let conn = pool.get().unwrap();
487
488            assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
489            assert_eq!(release_count.load(Ordering::Relaxed), 0);
490            assert_eq!(checkin_count.load(Ordering::Relaxed), 0);
491            assert_eq!(checkout_count.load(Ordering::Relaxed), 1);
492            std::mem::drop(conn);
493        }
494
495        assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
496        assert_eq!(release_count.load(Ordering::Relaxed), 0);
497        assert_eq!(checkin_count.load(Ordering::Relaxed), 1);
498        assert_eq!(checkout_count.load(Ordering::Relaxed), 1);
499
500        // check that we remove a connection with open transactions from the pool
501        {
502            let mut conn = pool.get().unwrap();
503
504            assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
505            assert_eq!(release_count.load(Ordering::Relaxed), 0);
506            assert_eq!(checkin_count.load(Ordering::Relaxed), 1);
507            assert_eq!(checkout_count.load(Ordering::Relaxed), 2);
508
509            <TestConnection as Connection>::TransactionManager::begin_transaction(&mut *conn)
510                .unwrap();
511        }
512
513        // we are not interested in the acquire count here
514        // as the pool opens a new connection in the background
515        // that could lead to this test failing if that happens to fast
516        // (which is sometimes the case for sqlite)
517        //assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
518        assert_eq!(release_count.load(Ordering::Relaxed), 1);
519        assert_eq!(checkin_count.load(Ordering::Relaxed), 2);
520        assert_eq!(checkout_count.load(Ordering::Relaxed), 2);
521
522        // check that we remove a connection from the pool that was
523        // open during panicking
524        #[allow(unreachable_code, unused_variables)]
525        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
526            let conn = pool.get();
527            assert_eq!(acquire_count.load(Ordering::Relaxed), 2);
528            assert_eq!(release_count.load(Ordering::Relaxed), 1);
529            assert_eq!(checkin_count.load(Ordering::Relaxed), 2);
530            assert_eq!(checkout_count.load(Ordering::Relaxed), 3);
531            panic!();
532            std::mem::drop(conn);
533        }))
534        .unwrap_err();
535
536        // we are not interested in the acquire count here
537        // as the pool opens a new connection in the background
538        // that could lead to this test failing if that happens to fast
539        // (which is sometimes the case for sqlite)
540        //assert_eq!(acquire_count.load(Ordering::Relaxed), 2);
541        assert_eq!(release_count.load(Ordering::Relaxed), 2);
542        assert_eq!(checkin_count.load(Ordering::Relaxed), 3);
543        assert_eq!(checkout_count.load(Ordering::Relaxed), 3);
544    }
545
546    #[cfg(feature = "postgres")]
547    #[test]
548    fn verify_that_begin_test_transaction_works_with_pools() {
549        use crate::prelude::*;
550        use crate::r2d2::*;
551
552        table! {
553            users {
554                id -> Integer,
555                name -> Text,
556            }
557        }
558
559        #[derive(Debug)]
560        struct TestConnectionCustomizer;
561
562        impl<E> CustomizeConnection<PgConnection, E> for TestConnectionCustomizer {
563            fn on_acquire(&self, conn: &mut PgConnection) -> Result<(), E> {
564                conn.begin_test_transaction()
565                    .expect("Failed to start test transaction");
566
567                Ok(())
568            }
569        }
570
571        let manager = ConnectionManager::<PgConnection>::new(database_url());
572        let pool = Pool::builder()
573            .max_size(1)
574            .connection_customizer(Box::new(TestConnectionCustomizer))
575            .build(manager)
576            .unwrap();
577
578        let mut conn = pool.get().unwrap();
579
580        crate::sql_query(
581            "CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)",
582        )
583        .execute(&mut conn)
584        .unwrap();
585
586        crate::insert_into(users::table)
587            .values(users::name.eq("John"))
588            .execute(&mut conn)
589            .unwrap();
590
591        std::mem::drop(conn);
592
593        let mut conn2 = pool.get().unwrap();
594
595        let user_count = users::table.count().get_result::<i64>(&mut conn2).unwrap();
596        assert_eq!(user_count, 1);
597    }
598}