1pub use r2d2::*;
174
175pub type PoolError = r2d2::Error;
180
181use alloc::fmt;
182use core::marker::PhantomData;
183
184use crate::backend::Backend;
185use crate::connection::{
186 ConnectionSealed, LoadConnection, SimpleConnection, TransactionManager,
187 TransactionManagerStatus,
188};
189use crate::expression::QueryMetadata;
190use crate::prelude::*;
191use crate::query_builder::{Query, QueryFragment, QueryId};
192use crate::query_dsl::RunQueryDslSupport;
193
194#[derive(#[automatically_derived]
impl<T: ::core::clone::Clone> ::core::clone::Clone for ConnectionManager<T> {
#[inline]
fn clone(&self) -> ConnectionManager<T> {
ConnectionManager {
database_url: ::core::clone::Clone::clone(&self.database_url),
_marker: ::core::clone::Clone::clone(&self._marker),
}
}
}Clone)]
198pub struct ConnectionManager<T> {
199 database_url: String,
200 _marker: PhantomData<T>,
201}
202
203impl<T> fmt::Debug for ConnectionManager<T> {
204 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205 f.write_fmt(format_args!("ConnectionManager<{0}>",
core::any::type_name::<T>()))write!(f, "ConnectionManager<{}>", core::any::type_name::<T>())
206 }
207}
208
209#[allow(unsafe_code)] unsafe impl<T: Send + 'static> Sync for ConnectionManager<T> {}
211
212impl<T> ConnectionManager<T> {
213 pub fn new<S: Into<String>>(database_url: S) -> Self {
216 ConnectionManager {
217 database_url: database_url.into(),
218 _marker: PhantomData,
219 }
220 }
221
222 pub fn update_database_url<S: Into<String>>(&mut self, database_url: S) {
227 self.database_url = database_url.into();
228 }
229}
230
231#[derive(#[automatically_derived]
impl ::core::fmt::Debug for Error {
#[inline]
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
match self {
Error::ConnectionError(__self_0) =>
::core::fmt::Formatter::debug_tuple_field1_finish(f,
"ConnectionError", &__self_0),
Error::QueryError(__self_0) =>
::core::fmt::Formatter::debug_tuple_field1_finish(f,
"QueryError", &__self_0),
}
}
}Debug)]
233pub enum Error {
234 ConnectionError(ConnectionError),
236
237 QueryError(crate::result::Error),
239}
240
241impl fmt::Display for Error {
242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243 match *self {
244 Error::ConnectionError(ref e) => e.fmt(f),
245 Error::QueryError(ref e) => e.fmt(f),
246 }
247 }
248}
249
250impl ::core::error::Error for Error {}
251
252impl From<crate::result::Error> for Error {
253 fn from(other: crate::result::Error) -> Self {
254 Self::QueryError(other)
255 }
256}
257
258impl From<ConnectionError> for Error {
259 fn from(other: ConnectionError) -> Self {
260 Self::ConnectionError(other)
261 }
262}
263
264pub trait R2D2Connection: Connection {
266 fn ping(&mut self) -> QueryResult<()>;
268
269 fn is_broken(&mut self) -> bool {
277 false
278 }
279}
280
281impl<T> ManageConnection for ConnectionManager<T>
282where
283 T: R2D2Connection + Send + 'static,
284{
285 type Connection = T;
286 type Error = Error;
287
288 fn connect(&self) -> Result<T, Error> {
289 T::establish(&self.database_url).map_err(Error::ConnectionError)
290 }
291
292 fn is_valid(&self, conn: &mut T) -> Result<(), Error> {
293 conn.ping().map_err(Error::QueryError)
294 }
295
296 fn has_broken(&self, conn: &mut T) -> bool {
297 std::thread::panicking() || conn.is_broken()
298 }
299}
300
301impl<M> SimpleConnection for PooledConnection<M>
302where
303 M: ManageConnection,
304 M::Connection: R2D2Connection + Send + 'static,
305{
306 fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
307 (**self).batch_execute(query)
308 }
309}
310
311impl<M> ConnectionSealed for PooledConnection<M>
312where
313 M: ManageConnection,
314 M::Connection: ConnectionSealed,
315{
316}
317
318impl<M> Connection for PooledConnection<M>
319where
320 M: ManageConnection,
321 M::Connection: Connection + R2D2Connection + Send + 'static,
322{
323 type Backend = <M::Connection as Connection>::Backend;
324 type TransactionManager =
325 PoolTransactionManager<<M::Connection as Connection>::TransactionManager>;
326
327 fn establish(_: &str) -> ConnectionResult<Self> {
328 Err(ConnectionError::BadConnection(String::from(
329 "Cannot directly establish a pooled connection",
330 )))
331 }
332
333 fn begin_test_transaction(&mut self) -> QueryResult<()> {
334 (**self).begin_test_transaction()
335 }
336
337 fn execute_returning_count<T>(&mut self, source: &T) -> QueryResult<usize>
338 where
339 T: QueryFragment<Self::Backend> + QueryId,
340 {
341 (**self).execute_returning_count(source)
342 }
343
344 fn transaction_state(
345 &mut self,
346 ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
347 (**self).transaction_state()
348 }
349
350 fn instrumentation(&mut self) -> &mut dyn crate::connection::Instrumentation {
351 (**self).instrumentation()
352 }
353
354 fn set_instrumentation(&mut self, instrumentation: impl crate::connection::Instrumentation) {
355 (**self).set_instrumentation(instrumentation)
356 }
357
358 fn set_prepared_statement_cache_size(&mut self, size: crate::connection::CacheSize) {
359 (**self).set_prepared_statement_cache_size(size)
360 }
361}
362
363impl<B, M> LoadConnection<B> for PooledConnection<M>
364where
365 M: ManageConnection,
366 M::Connection: LoadConnection<B> + R2D2Connection,
367{
368 type Cursor<'conn, 'query> = <M::Connection as LoadConnection<B>>::Cursor<'conn, 'query>;
369 type Row<'conn, 'query> = <M::Connection as LoadConnection<B>>::Row<'conn, 'query>;
370
371 fn load<'conn, 'query, T>(
372 &'conn mut self,
373 source: T,
374 ) -> QueryResult<Self::Cursor<'conn, 'query>>
375 where
376 T: Query + QueryFragment<Self::Backend> + QueryId + 'query,
377 Self::Backend: QueryMetadata<T::SqlType>,
378 {
379 (**self).load(source)
380 }
381}
382
383#[doc(hidden)]
384#[allow(missing_debug_implementations)]
385pub struct PoolTransactionManager<T>(core::marker::PhantomData<T>);
386
387impl<M, T> TransactionManager<PooledConnection<M>> for PoolTransactionManager<T>
388where
389 M: ManageConnection,
390 M::Connection: Connection<TransactionManager = T> + R2D2Connection,
391 T: TransactionManager<M::Connection>,
392{
393 type TransactionStateData = T::TransactionStateData;
394
395 fn begin_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
396 T::begin_transaction(&mut **conn)
397 }
398
399 fn rollback_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
400 T::rollback_transaction(&mut **conn)
401 }
402
403 fn commit_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
404 T::commit_transaction(&mut **conn)
405 }
406
407 fn transaction_manager_status_mut(
408 conn: &mut PooledConnection<M>,
409 ) -> &mut TransactionManagerStatus {
410 T::transaction_manager_status_mut(&mut **conn)
411 }
412}
413
414impl<M> crate::migration::MigrationConnection for PooledConnection<M>
415where
416 M: ManageConnection,
417 M::Connection: crate::migration::MigrationConnection,
418 Self: Connection,
419{
420 fn setup(&mut self) -> QueryResult<usize> {
421 (**self).setup()
422 }
423}
424
425impl<Changes, Output, M> crate::query_dsl::UpdateAndFetchResults<Changes, Output>
426 for PooledConnection<M>
427where
428 M: ManageConnection,
429 M::Connection: crate::query_dsl::UpdateAndFetchResults<Changes, Output>,
430 Self: Connection,
431{
432 fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output> {
433 (**self).update_and_fetch(changeset)
434 }
435}
436
437#[derive(const _: () =
{
use diesel;
#[allow(non_camel_case_types)]
impl diesel::query_builder::QueryId for CheckConnectionQuery {
type QueryId = CheckConnectionQuery<>;
const HAS_STATIC_QUERY_ID: bool = true;
const IS_WINDOW_FUNCTION: bool = false;
}
};QueryId)]
438pub(crate) struct CheckConnectionQuery;
439
440impl<DB> QueryFragment<DB> for CheckConnectionQuery
441where
442 DB: Backend,
443{
444 fn walk_ast<'b>(
445 &'b self,
446 mut pass: crate::query_builder::AstPass<'_, 'b, DB>,
447 ) -> QueryResult<()> {
448 pass.push_sql("SELECT 1");
449 Ok(())
450 }
451}
452
453impl Query for CheckConnectionQuery {
454 type SqlType = crate::sql_types::Integer;
455}
456
457impl RunQueryDslSupport for CheckConnectionQuery {}
458
459#[derive(#[automatically_derived]
impl ::core::fmt::Debug for TestCustomizer {
#[inline]
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
::core::fmt::Formatter::write_str(f, "TestCustomizer")
}
}Debug, #[automatically_derived]
impl ::core::clone::Clone for TestCustomizer {
#[inline]
fn clone(&self) -> TestCustomizer { *self }
}Clone, #[automatically_derived]
impl ::core::marker::Copy for TestCustomizer { }Copy)]
463pub struct TestCustomizer;
464
465impl<C: Connection> CustomizeConnection<C, crate::r2d2::Error> for TestCustomizer {
466 fn on_acquire(&self, conn: &mut C) -> Result<(), crate::r2d2::Error> {
467 conn.begin_test_transaction()
468 .map_err(crate::r2d2::Error::QueryError)
469 }
470}
471
472#[cfg(all(test, not(all(target_family = "wasm", target_os = "unknown"))))]
473mod tests {
474 use std::sync::Arc;
475 use std::sync::mpsc;
476 use std::thread;
477 use std::time::Duration;
478
479 use crate::r2d2::*;
480 use crate::test_helpers::*;
481
482 #[diesel_test_helper::test]
483 fn establish_basic_connection() {
484 let manager = ConnectionManager::<TestConnection>::new(database_url());
485 let pool = Arc::new(Pool::builder().max_size(2).build(manager).unwrap());
486
487 let (s1, r1) = mpsc::channel();
488 let (s2, r2) = mpsc::channel();
489
490 let pool1 = Arc::clone(&pool);
491 let t1 = thread::spawn(move || {
492 let conn = pool1.get().unwrap();
493 s1.send(()).unwrap();
494 r2.recv().unwrap();
495 drop(conn);
496 });
497
498 let pool2 = Arc::clone(&pool);
499 let t2 = thread::spawn(move || {
500 let conn = pool2.get().unwrap();
501 s2.send(()).unwrap();
502 r1.recv().unwrap();
503 drop(conn);
504 });
505
506 t1.join().unwrap();
507 t2.join().unwrap();
508
509 pool.get().unwrap();
510 }
511
512 #[diesel_test_helper::test]
513 fn is_valid() {
514 let manager = ConnectionManager::<TestConnection>::new(database_url());
515 let pool = Pool::builder()
516 .max_size(1)
517 .test_on_check_out(true)
518 .build(manager)
519 .unwrap();
520
521 pool.get().unwrap();
522 }
523
524 #[diesel_test_helper::test]
525 fn pooled_connection_impls_connection() {
526 use crate::select;
527 use crate::sql_types::Text;
528
529 let manager = ConnectionManager::<TestConnection>::new(database_url());
530 let pool = Pool::builder()
531 .max_size(1)
532 .test_on_check_out(true)
533 .build(manager)
534 .unwrap();
535 let mut conn = pool.get().unwrap();
536
537 let query = select("foo".into_sql::<Text>());
538 assert_eq!("foo", query.get_result::<String>(&mut conn).unwrap());
539 }
540
541 #[diesel_test_helper::test]
542 fn check_pool_does_actually_hold_connections() {
543 use std::sync::atomic::{AtomicU32, Ordering};
544
545 #[derive(Debug)]
546 struct TestEventHandler {
547 acquire_count: Arc<AtomicU32>,
548 release_count: Arc<AtomicU32>,
549 checkin_count: Arc<AtomicU32>,
550 checkout_count: Arc<AtomicU32>,
551 }
552
553 impl r2d2::HandleEvent for TestEventHandler {
554 fn handle_acquire(&self, _event: r2d2::event::AcquireEvent) {
555 self.acquire_count.fetch_add(1, Ordering::Relaxed);
556 }
557 fn handle_release(&self, _event: r2d2::event::ReleaseEvent) {
558 self.release_count.fetch_add(1, Ordering::Relaxed);
559 }
560 fn handle_checkout(&self, _event: r2d2::event::CheckoutEvent) {
561 self.checkout_count.fetch_add(1, Ordering::Relaxed);
562 }
563 fn handle_checkin(&self, _event: r2d2::event::CheckinEvent) {
564 self.checkin_count.fetch_add(1, Ordering::Relaxed);
565 }
566 }
567
568 let acquire_count = Arc::new(AtomicU32::new(0));
569 let release_count = Arc::new(AtomicU32::new(0));
570 let checkin_count = Arc::new(AtomicU32::new(0));
571 let checkout_count = Arc::new(AtomicU32::new(0));
572
573 let handler = Box::new(TestEventHandler {
574 acquire_count: acquire_count.clone(),
575 release_count: release_count.clone(),
576 checkin_count: checkin_count.clone(),
577 checkout_count: checkout_count.clone(),
578 });
579
580 let manager = ConnectionManager::<TestConnection>::new(database_url());
581 let pool = Pool::builder()
582 .max_size(1)
583 .test_on_check_out(true)
584 .event_handler(handler)
585 .build(manager)
586 .unwrap();
587
588 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
589 assert_eq!(release_count.load(Ordering::Relaxed), 0);
590 assert_eq!(checkin_count.load(Ordering::Relaxed), 0);
591 assert_eq!(checkout_count.load(Ordering::Relaxed), 0);
592
593 {
595 let conn = pool.get().unwrap();
596
597 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
598 assert_eq!(release_count.load(Ordering::Relaxed), 0);
599 assert_eq!(checkin_count.load(Ordering::Relaxed), 0);
600 assert_eq!(checkout_count.load(Ordering::Relaxed), 1);
601 std::mem::drop(conn);
602 }
603
604 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
605 assert_eq!(release_count.load(Ordering::Relaxed), 0);
606 assert_eq!(checkin_count.load(Ordering::Relaxed), 1);
607 assert_eq!(checkout_count.load(Ordering::Relaxed), 1);
608
609 {
611 let mut conn = pool.get().unwrap();
612
613 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
614 assert_eq!(release_count.load(Ordering::Relaxed), 0);
615 assert_eq!(checkin_count.load(Ordering::Relaxed), 1);
616 assert_eq!(checkout_count.load(Ordering::Relaxed), 2);
617
618 <TestConnection as Connection>::TransactionManager::begin_transaction(&mut *conn)
619 .unwrap();
620 }
621
622 assert_eq!(release_count.load(Ordering::Relaxed), 1);
628 assert_eq!(checkin_count.load(Ordering::Relaxed), 2);
629 assert_eq!(checkout_count.load(Ordering::Relaxed), 2);
630
631 #[allow(unreachable_code, unused_variables)]
634 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
635 let conn = pool.get();
636 assert_eq!(acquire_count.load(Ordering::Relaxed), 2);
637 assert_eq!(release_count.load(Ordering::Relaxed), 1);
638 assert_eq!(checkin_count.load(Ordering::Relaxed), 2);
639 assert_eq!(checkout_count.load(Ordering::Relaxed), 3);
640 panic!();
641 std::mem::drop(conn);
642 }))
643 .unwrap_err();
644
645 assert_eq!(release_count.load(Ordering::Relaxed), 2);
651 assert_eq!(checkin_count.load(Ordering::Relaxed), 3);
652 assert_eq!(checkout_count.load(Ordering::Relaxed), 3);
653 std::thread::sleep(Duration::from_millis(100));
656 }
657
658 #[cfg(feature = "postgres")]
659 #[diesel_test_helper::test]
660 fn verify_that_begin_test_transaction_works_with_pools() {
661 use crate::prelude::*;
662 use crate::r2d2::*;
663
664 table! {
665 users {
666 id -> Integer,
667 name -> Text,
668 }
669 }
670
671 #[derive(Debug)]
672 struct TestConnectionCustomizer;
673
674 impl<E> CustomizeConnection<PgConnection, E> for TestConnectionCustomizer {
675 fn on_acquire(&self, conn: &mut PgConnection) -> Result<(), E> {
676 conn.begin_test_transaction()
677 .expect("Failed to start test transaction");
678
679 Ok(())
680 }
681 }
682
683 let manager = ConnectionManager::<PgConnection>::new(database_url());
684 let pool = Pool::builder()
685 .max_size(1)
686 .connection_customizer(Box::new(TestConnectionCustomizer))
687 .build(manager)
688 .unwrap();
689
690 let mut conn = pool.get().unwrap();
691
692 crate::sql_query(
693 "CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)",
694 )
695 .execute(&mut conn)
696 .unwrap();
697
698 crate::insert_into(users::table)
699 .values(users::name.eq("John"))
700 .execute(&mut conn)
701 .unwrap();
702
703 std::mem::drop(conn);
704
705 let mut conn2 = pool.get().unwrap();
706
707 let user_count = users::table.count().get_result::<i64>(&mut conn2).unwrap();
708 assert_eq!(user_count, 1);
709 }
710}