1pub use r2d2::*;
174
175pub type PoolError = r2d2::Error;
180
181use std::fmt;
182use std::marker::PhantomData;
183
184use crate::backend::Backend;
185use crate::connection::{
186 ConnectionSealed, LoadConnection, SimpleConnection, TransactionManager,
187 TransactionManagerStatus,
188};
189use crate::expression::QueryMetadata;
190use crate::prelude::*;
191use crate::query_builder::{Query, QueryFragment, QueryId};
192
193#[derive(Clone)]
197pub struct ConnectionManager<T> {
198 database_url: String,
199 _marker: PhantomData<T>,
200}
201
202impl<T> fmt::Debug for ConnectionManager<T> {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 write!(f, "ConnectionManager<{}>", std::any::type_name::<T>())
205 }
206}
207
208#[allow(unsafe_code)] unsafe impl<T: Send + 'static> Sync for ConnectionManager<T> {}
210
211impl<T> ConnectionManager<T> {
212 pub fn new<S: Into<String>>(database_url: S) -> Self {
215 ConnectionManager {
216 database_url: database_url.into(),
217 _marker: PhantomData,
218 }
219 }
220
221 pub fn update_database_url<S: Into<String>>(&mut self, database_url: S) {
226 self.database_url = database_url.into();
227 }
228}
229
230#[derive(Debug)]
232pub enum Error {
233 ConnectionError(ConnectionError),
235
236 QueryError(crate::result::Error),
238}
239
240impl fmt::Display for Error {
241 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
242 match *self {
243 Error::ConnectionError(ref e) => e.fmt(f),
244 Error::QueryError(ref e) => e.fmt(f),
245 }
246 }
247}
248
249impl ::std::error::Error for Error {}
250
251impl From<crate::result::Error> for Error {
252 fn from(other: crate::result::Error) -> Self {
253 Self::QueryError(other)
254 }
255}
256
257impl From<ConnectionError> for Error {
258 fn from(other: ConnectionError) -> Self {
259 Self::ConnectionError(other)
260 }
261}
262
263pub trait R2D2Connection: Connection {
265 fn ping(&mut self) -> QueryResult<()>;
267
268 fn is_broken(&mut self) -> bool {
276 false
277 }
278}
279
280impl<T> ManageConnection for ConnectionManager<T>
281where
282 T: R2D2Connection + Send + 'static,
283{
284 type Connection = T;
285 type Error = Error;
286
287 fn connect(&self) -> Result<T, Error> {
288 T::establish(&self.database_url).map_err(Error::ConnectionError)
289 }
290
291 fn is_valid(&self, conn: &mut T) -> Result<(), Error> {
292 conn.ping().map_err(Error::QueryError)
293 }
294
295 fn has_broken(&self, conn: &mut T) -> bool {
296 std::thread::panicking() || conn.is_broken()
297 }
298}
299
300impl<M> SimpleConnection for PooledConnection<M>
301where
302 M: ManageConnection,
303 M::Connection: R2D2Connection + Send + 'static,
304{
305 fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
306 (**self).batch_execute(query)
307 }
308}
309
310impl<M> ConnectionSealed for PooledConnection<M>
311where
312 M: ManageConnection,
313 M::Connection: ConnectionSealed,
314{
315}
316
317impl<M> Connection for PooledConnection<M>
318where
319 M: ManageConnection,
320 M::Connection: Connection + R2D2Connection + Send + 'static,
321{
322 type Backend = <M::Connection as Connection>::Backend;
323 type TransactionManager =
324 PoolTransactionManager<<M::Connection as Connection>::TransactionManager>;
325
326 fn establish(_: &str) -> ConnectionResult<Self> {
327 Err(ConnectionError::BadConnection(String::from(
328 "Cannot directly establish a pooled connection",
329 )))
330 }
331
332 fn begin_test_transaction(&mut self) -> QueryResult<()> {
333 (**self).begin_test_transaction()
334 }
335
336 fn execute_returning_count<T>(&mut self, source: &T) -> QueryResult<usize>
337 where
338 T: QueryFragment<Self::Backend> + QueryId,
339 {
340 (**self).execute_returning_count(source)
341 }
342
343 fn transaction_state(
344 &mut self,
345 ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
346 (**self).transaction_state()
347 }
348
349 fn instrumentation(&mut self) -> &mut dyn crate::connection::Instrumentation {
350 (**self).instrumentation()
351 }
352
353 fn set_instrumentation(&mut self, instrumentation: impl crate::connection::Instrumentation) {
354 (**self).set_instrumentation(instrumentation)
355 }
356
357 fn set_prepared_statement_cache_size(&mut self, size: crate::connection::CacheSize) {
358 (**self).set_prepared_statement_cache_size(size)
359 }
360}
361
362impl<B, M> LoadConnection<B> for PooledConnection<M>
363where
364 M: ManageConnection,
365 M::Connection: LoadConnection<B> + R2D2Connection,
366{
367 type Cursor<'conn, 'query> = <M::Connection as LoadConnection<B>>::Cursor<'conn, 'query>;
368 type Row<'conn, 'query> = <M::Connection as LoadConnection<B>>::Row<'conn, 'query>;
369
370 fn load<'conn, 'query, T>(
371 &'conn mut self,
372 source: T,
373 ) -> QueryResult<Self::Cursor<'conn, 'query>>
374 where
375 T: Query + QueryFragment<Self::Backend> + QueryId + 'query,
376 Self::Backend: QueryMetadata<T::SqlType>,
377 {
378 (**self).load(source)
379 }
380}
381
382#[doc(hidden)]
383#[allow(missing_debug_implementations)]
384pub struct PoolTransactionManager<T>(std::marker::PhantomData<T>);
385
386impl<M, T> TransactionManager<PooledConnection<M>> for PoolTransactionManager<T>
387where
388 M: ManageConnection,
389 M::Connection: Connection<TransactionManager = T> + R2D2Connection,
390 T: TransactionManager<M::Connection>,
391{
392 type TransactionStateData = T::TransactionStateData;
393
394 fn begin_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
395 T::begin_transaction(&mut **conn)
396 }
397
398 fn rollback_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
399 T::rollback_transaction(&mut **conn)
400 }
401
402 fn commit_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
403 T::commit_transaction(&mut **conn)
404 }
405
406 fn transaction_manager_status_mut(
407 conn: &mut PooledConnection<M>,
408 ) -> &mut TransactionManagerStatus {
409 T::transaction_manager_status_mut(&mut **conn)
410 }
411}
412
413impl<M> crate::migration::MigrationConnection for PooledConnection<M>
414where
415 M: ManageConnection,
416 M::Connection: crate::migration::MigrationConnection,
417 Self: Connection,
418{
419 fn setup(&mut self) -> QueryResult<usize> {
420 (**self).setup()
421 }
422}
423
424impl<Changes, Output, M> crate::query_dsl::UpdateAndFetchResults<Changes, Output>
425 for PooledConnection<M>
426where
427 M: ManageConnection,
428 M::Connection: crate::query_dsl::UpdateAndFetchResults<Changes, Output>,
429 Self: Connection,
430{
431 fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output> {
432 (**self).update_and_fetch(changeset)
433 }
434}
435
436#[derive(QueryId)]
437pub(crate) struct CheckConnectionQuery;
438
439impl<DB> QueryFragment<DB> for CheckConnectionQuery
440where
441 DB: Backend,
442{
443 fn walk_ast<'b>(
444 &'b self,
445 mut pass: crate::query_builder::AstPass<'_, 'b, DB>,
446 ) -> QueryResult<()> {
447 pass.push_sql("SELECT 1");
448 Ok(())
449 }
450}
451
452impl Query for CheckConnectionQuery {
453 type SqlType = crate::sql_types::Integer;
454}
455
456impl<C> RunQueryDsl<C> for CheckConnectionQuery {}
457
458#[derive(Debug, Clone, Copy)]
462pub struct TestCustomizer;
463
464impl<C: Connection> CustomizeConnection<C, crate::r2d2::Error> for TestCustomizer {
465 fn on_acquire(&self, conn: &mut C) -> Result<(), crate::r2d2::Error> {
466 conn.begin_test_transaction()
467 .map_err(crate::r2d2::Error::QueryError)
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use std::sync::mpsc;
474 use std::sync::Arc;
475 use std::thread;
476 use std::time::Duration;
477
478 use crate::r2d2::*;
479 use crate::test_helpers::*;
480
481 #[diesel_test_helper::test]
482 fn establish_basic_connection() {
483 let manager = ConnectionManager::<TestConnection>::new(database_url());
484 let pool = Arc::new(Pool::builder().max_size(2).build(manager).unwrap());
485
486 let (s1, r1) = mpsc::channel();
487 let (s2, r2) = mpsc::channel();
488
489 let pool1 = Arc::clone(&pool);
490 let t1 = thread::spawn(move || {
491 let conn = pool1.get().unwrap();
492 s1.send(()).unwrap();
493 r2.recv().unwrap();
494 drop(conn);
495 });
496
497 let pool2 = Arc::clone(&pool);
498 let t2 = thread::spawn(move || {
499 let conn = pool2.get().unwrap();
500 s2.send(()).unwrap();
501 r1.recv().unwrap();
502 drop(conn);
503 });
504
505 t1.join().unwrap();
506 t2.join().unwrap();
507
508 pool.get().unwrap();
509 }
510
511 #[diesel_test_helper::test]
512 fn is_valid() {
513 let manager = ConnectionManager::<TestConnection>::new(database_url());
514 let pool = Pool::builder()
515 .max_size(1)
516 .test_on_check_out(true)
517 .build(manager)
518 .unwrap();
519
520 pool.get().unwrap();
521 }
522
523 #[diesel_test_helper::test]
524 fn pooled_connection_impls_connection() {
525 use crate::select;
526 use crate::sql_types::Text;
527
528 let manager = ConnectionManager::<TestConnection>::new(database_url());
529 let pool = Pool::builder()
530 .max_size(1)
531 .test_on_check_out(true)
532 .build(manager)
533 .unwrap();
534 let mut conn = pool.get().unwrap();
535
536 let query = select("foo".into_sql::<Text>());
537 assert_eq!("foo", query.get_result::<String>(&mut conn).unwrap());
538 }
539
540 #[diesel_test_helper::test]
541 fn check_pool_does_actually_hold_connections() {
542 use std::sync::atomic::{AtomicU32, Ordering};
543
544 #[derive(Debug)]
545 struct TestEventHandler {
546 acquire_count: Arc<AtomicU32>,
547 release_count: Arc<AtomicU32>,
548 checkin_count: Arc<AtomicU32>,
549 checkout_count: Arc<AtomicU32>,
550 }
551
552 impl r2d2::HandleEvent for TestEventHandler {
553 fn handle_acquire(&self, _event: r2d2::event::AcquireEvent) {
554 self.acquire_count.fetch_add(1, Ordering::Relaxed);
555 }
556 fn handle_release(&self, _event: r2d2::event::ReleaseEvent) {
557 self.release_count.fetch_add(1, Ordering::Relaxed);
558 }
559 fn handle_checkout(&self, _event: r2d2::event::CheckoutEvent) {
560 self.checkout_count.fetch_add(1, Ordering::Relaxed);
561 }
562 fn handle_checkin(&self, _event: r2d2::event::CheckinEvent) {
563 self.checkin_count.fetch_add(1, Ordering::Relaxed);
564 }
565 }
566
567 let acquire_count = Arc::new(AtomicU32::new(0));
568 let release_count = Arc::new(AtomicU32::new(0));
569 let checkin_count = Arc::new(AtomicU32::new(0));
570 let checkout_count = Arc::new(AtomicU32::new(0));
571
572 let handler = Box::new(TestEventHandler {
573 acquire_count: acquire_count.clone(),
574 release_count: release_count.clone(),
575 checkin_count: checkin_count.clone(),
576 checkout_count: checkout_count.clone(),
577 });
578
579 let manager = ConnectionManager::<TestConnection>::new(database_url());
580 let pool = Pool::builder()
581 .max_size(1)
582 .test_on_check_out(true)
583 .event_handler(handler)
584 .build(manager)
585 .unwrap();
586
587 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
588 assert_eq!(release_count.load(Ordering::Relaxed), 0);
589 assert_eq!(checkin_count.load(Ordering::Relaxed), 0);
590 assert_eq!(checkout_count.load(Ordering::Relaxed), 0);
591
592 {
594 let conn = pool.get().unwrap();
595
596 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
597 assert_eq!(release_count.load(Ordering::Relaxed), 0);
598 assert_eq!(checkin_count.load(Ordering::Relaxed), 0);
599 assert_eq!(checkout_count.load(Ordering::Relaxed), 1);
600 std::mem::drop(conn);
601 }
602
603 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
604 assert_eq!(release_count.load(Ordering::Relaxed), 0);
605 assert_eq!(checkin_count.load(Ordering::Relaxed), 1);
606 assert_eq!(checkout_count.load(Ordering::Relaxed), 1);
607
608 {
610 let mut conn = pool.get().unwrap();
611
612 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
613 assert_eq!(release_count.load(Ordering::Relaxed), 0);
614 assert_eq!(checkin_count.load(Ordering::Relaxed), 1);
615 assert_eq!(checkout_count.load(Ordering::Relaxed), 2);
616
617 <TestConnection as Connection>::TransactionManager::begin_transaction(&mut *conn)
618 .unwrap();
619 }
620
621 assert_eq!(release_count.load(Ordering::Relaxed), 1);
627 assert_eq!(checkin_count.load(Ordering::Relaxed), 2);
628 assert_eq!(checkout_count.load(Ordering::Relaxed), 2);
629
630 #[allow(unreachable_code, unused_variables)]
633 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
634 let conn = pool.get();
635 assert_eq!(acquire_count.load(Ordering::Relaxed), 2);
636 assert_eq!(release_count.load(Ordering::Relaxed), 1);
637 assert_eq!(checkin_count.load(Ordering::Relaxed), 2);
638 assert_eq!(checkout_count.load(Ordering::Relaxed), 3);
639 panic!();
640 std::mem::drop(conn);
641 }))
642 .unwrap_err();
643
644 assert_eq!(release_count.load(Ordering::Relaxed), 2);
650 assert_eq!(checkin_count.load(Ordering::Relaxed), 3);
651 assert_eq!(checkout_count.load(Ordering::Relaxed), 3);
652 std::thread::sleep(Duration::from_millis(100));
655 }
656
657 #[cfg(feature = "postgres")]
658 #[diesel_test_helper::test]
659 fn verify_that_begin_test_transaction_works_with_pools() {
660 use crate::prelude::*;
661 use crate::r2d2::*;
662
663 table! {
664 users {
665 id -> Integer,
666 name -> Text,
667 }
668 }
669
670 #[derive(Debug)]
671 struct TestConnectionCustomizer;
672
673 impl<E> CustomizeConnection<PgConnection, E> for TestConnectionCustomizer {
674 fn on_acquire(&self, conn: &mut PgConnection) -> Result<(), E> {
675 conn.begin_test_transaction()
676 .expect("Failed to start test transaction");
677
678 Ok(())
679 }
680 }
681
682 let manager = ConnectionManager::<PgConnection>::new(database_url());
683 let pool = Pool::builder()
684 .max_size(1)
685 .connection_customizer(Box::new(TestConnectionCustomizer))
686 .build(manager)
687 .unwrap();
688
689 let mut conn = pool.get().unwrap();
690
691 crate::sql_query(
692 "CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)",
693 )
694 .execute(&mut conn)
695 .unwrap();
696
697 crate::insert_into(users::table)
698 .values(users::name.eq("John"))
699 .execute(&mut conn)
700 .unwrap();
701
702 std::mem::drop(conn);
703
704 let mut conn2 = pool.get().unwrap();
705
706 let user_count = users::table.count().get_result::<i64>(&mut conn2).unwrap();
707 assert_eq!(user_count, 1);
708 }
709}