1pub use r2d2::*;
96
97pub type PoolError = r2d2::Error;
102
103use std::fmt;
104use std::marker::PhantomData;
105
106use crate::backend::Backend;
107use crate::connection::{
108 ConnectionSealed, LoadConnection, SimpleConnection, TransactionManager,
109 TransactionManagerStatus,
110};
111use crate::expression::QueryMetadata;
112use crate::prelude::*;
113use crate::query_builder::{Query, QueryFragment, QueryId};
114
115#[derive(Clone)]
119pub struct ConnectionManager<T> {
120 database_url: String,
121 _marker: PhantomData<T>,
122}
123
124impl<T> fmt::Debug for ConnectionManager<T> {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 write!(f, "ConnectionManager<{}>", std::any::type_name::<T>())
127 }
128}
129
130#[allow(unsafe_code)] unsafe impl<T: Send + 'static> Sync for ConnectionManager<T> {}
132
133impl<T> ConnectionManager<T> {
134 pub fn new<S: Into<String>>(database_url: S) -> Self {
137 ConnectionManager {
138 database_url: database_url.into(),
139 _marker: PhantomData,
140 }
141 }
142
143 pub fn update_database_url<S: Into<String>>(&mut self, database_url: S) {
148 self.database_url = database_url.into();
149 }
150}
151
152#[derive(Debug)]
154pub enum Error {
155 ConnectionError(ConnectionError),
157
158 QueryError(crate::result::Error),
160}
161
162impl fmt::Display for Error {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 match *self {
165 Error::ConnectionError(ref e) => e.fmt(f),
166 Error::QueryError(ref e) => e.fmt(f),
167 }
168 }
169}
170
171impl ::std::error::Error for Error {}
172
173pub trait R2D2Connection: Connection {
175 fn ping(&mut self) -> QueryResult<()>;
177
178 fn is_broken(&mut self) -> bool {
186 false
187 }
188}
189
190impl<T> ManageConnection for ConnectionManager<T>
191where
192 T: R2D2Connection + Send + 'static,
193{
194 type Connection = T;
195 type Error = Error;
196
197 fn connect(&self) -> Result<T, Error> {
198 T::establish(&self.database_url).map_err(Error::ConnectionError)
199 }
200
201 fn is_valid(&self, conn: &mut T) -> Result<(), Error> {
202 conn.ping().map_err(Error::QueryError)
203 }
204
205 fn has_broken(&self, conn: &mut T) -> bool {
206 std::thread::panicking() || conn.is_broken()
207 }
208}
209
210impl<M> SimpleConnection for PooledConnection<M>
211where
212 M: ManageConnection,
213 M::Connection: R2D2Connection + Send + 'static,
214{
215 fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
216 (**self).batch_execute(query)
217 }
218}
219
220impl<M> ConnectionSealed for PooledConnection<M>
221where
222 M: ManageConnection,
223 M::Connection: ConnectionSealed,
224{
225}
226
227impl<M> Connection for PooledConnection<M>
228where
229 M: ManageConnection,
230 M::Connection: Connection + R2D2Connection + Send + 'static,
231{
232 type Backend = <M::Connection as Connection>::Backend;
233 type TransactionManager =
234 PoolTransactionManager<<M::Connection as Connection>::TransactionManager>;
235
236 fn establish(_: &str) -> ConnectionResult<Self> {
237 Err(ConnectionError::BadConnection(String::from(
238 "Cannot directly establish a pooled connection",
239 )))
240 }
241
242 fn begin_test_transaction(&mut self) -> QueryResult<()> {
243 (**self).begin_test_transaction()
244 }
245
246 fn execute_returning_count<T>(&mut self, source: &T) -> QueryResult<usize>
247 where
248 T: QueryFragment<Self::Backend> + QueryId,
249 {
250 (**self).execute_returning_count(source)
251 }
252
253 fn transaction_state(
254 &mut self,
255 ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
256 (**self).transaction_state()
257 }
258
259 fn instrumentation(&mut self) -> &mut dyn crate::connection::Instrumentation {
260 (**self).instrumentation()
261 }
262
263 fn set_instrumentation(&mut self, instrumentation: impl crate::connection::Instrumentation) {
264 (**self).set_instrumentation(instrumentation)
265 }
266}
267
268impl<B, M> LoadConnection<B> for PooledConnection<M>
269where
270 M: ManageConnection,
271 M::Connection: LoadConnection<B> + R2D2Connection,
272{
273 type Cursor<'conn, 'query> = <M::Connection as LoadConnection<B>>::Cursor<'conn, 'query>;
274 type Row<'conn, 'query> = <M::Connection as LoadConnection<B>>::Row<'conn, 'query>;
275
276 fn load<'conn, 'query, T>(
277 &'conn mut self,
278 source: T,
279 ) -> QueryResult<Self::Cursor<'conn, 'query>>
280 where
281 T: Query + QueryFragment<Self::Backend> + QueryId + 'query,
282 Self::Backend: QueryMetadata<T::SqlType>,
283 {
284 (**self).load(source)
285 }
286}
287
288#[doc(hidden)]
289#[allow(missing_debug_implementations)]
290pub struct PoolTransactionManager<T>(std::marker::PhantomData<T>);
291
292impl<M, T> TransactionManager<PooledConnection<M>> for PoolTransactionManager<T>
293where
294 M: ManageConnection,
295 M::Connection: Connection<TransactionManager = T> + R2D2Connection,
296 T: TransactionManager<M::Connection>,
297{
298 type TransactionStateData = T::TransactionStateData;
299
300 fn begin_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
301 T::begin_transaction(&mut **conn)
302 }
303
304 fn rollback_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
305 T::rollback_transaction(&mut **conn)
306 }
307
308 fn commit_transaction(conn: &mut PooledConnection<M>) -> QueryResult<()> {
309 T::commit_transaction(&mut **conn)
310 }
311
312 fn transaction_manager_status_mut(
313 conn: &mut PooledConnection<M>,
314 ) -> &mut TransactionManagerStatus {
315 T::transaction_manager_status_mut(&mut **conn)
316 }
317}
318
319impl<M> crate::migration::MigrationConnection for PooledConnection<M>
320where
321 M: ManageConnection,
322 M::Connection: crate::migration::MigrationConnection,
323 Self: Connection,
324{
325 fn setup(&mut self) -> QueryResult<usize> {
326 (**self).setup()
327 }
328}
329
330impl<Changes, Output, M> crate::query_dsl::UpdateAndFetchResults<Changes, Output>
331 for PooledConnection<M>
332where
333 M: ManageConnection,
334 M::Connection: crate::query_dsl::UpdateAndFetchResults<Changes, Output>,
335 Self: Connection,
336{
337 fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output> {
338 (**self).update_and_fetch(changeset)
339 }
340}
341
342#[derive(QueryId)]
343pub(crate) struct CheckConnectionQuery;
344
345impl<DB> QueryFragment<DB> for CheckConnectionQuery
346where
347 DB: Backend,
348{
349 fn walk_ast<'b>(
350 &'b self,
351 mut pass: crate::query_builder::AstPass<'_, 'b, DB>,
352 ) -> QueryResult<()> {
353 pass.push_sql("SELECT 1");
354 Ok(())
355 }
356}
357
358impl Query for CheckConnectionQuery {
359 type SqlType = crate::sql_types::Integer;
360}
361
362impl<C> RunQueryDsl<C> for CheckConnectionQuery {}
363
364#[cfg(test)]
365mod tests {
366 use std::sync::mpsc;
367 use std::sync::Arc;
368 use std::thread;
369
370 use crate::r2d2::*;
371 use crate::test_helpers::*;
372
373 #[test]
374 fn establish_basic_connection() {
375 let manager = ConnectionManager::<TestConnection>::new(database_url());
376 let pool = Arc::new(Pool::builder().max_size(2).build(manager).unwrap());
377
378 let (s1, r1) = mpsc::channel();
379 let (s2, r2) = mpsc::channel();
380
381 let pool1 = Arc::clone(&pool);
382 let t1 = thread::spawn(move || {
383 let conn = pool1.get().unwrap();
384 s1.send(()).unwrap();
385 r2.recv().unwrap();
386 drop(conn);
387 });
388
389 let pool2 = Arc::clone(&pool);
390 let t2 = thread::spawn(move || {
391 let conn = pool2.get().unwrap();
392 s2.send(()).unwrap();
393 r1.recv().unwrap();
394 drop(conn);
395 });
396
397 t1.join().unwrap();
398 t2.join().unwrap();
399
400 pool.get().unwrap();
401 }
402
403 #[test]
404 fn is_valid() {
405 let manager = ConnectionManager::<TestConnection>::new(database_url());
406 let pool = Pool::builder()
407 .max_size(1)
408 .test_on_check_out(true)
409 .build(manager)
410 .unwrap();
411
412 pool.get().unwrap();
413 }
414
415 #[test]
416 fn pooled_connection_impls_connection() {
417 use crate::select;
418 use crate::sql_types::Text;
419
420 let manager = ConnectionManager::<TestConnection>::new(database_url());
421 let pool = Pool::builder()
422 .max_size(1)
423 .test_on_check_out(true)
424 .build(manager)
425 .unwrap();
426 let mut conn = pool.get().unwrap();
427
428 let query = select("foo".into_sql::<Text>());
429 assert_eq!("foo", query.get_result::<String>(&mut conn).unwrap());
430 }
431
432 #[test]
433 fn check_pool_does_actually_hold_connections() {
434 use std::sync::atomic::{AtomicU32, Ordering};
435
436 #[derive(Debug)]
437 struct TestEventHandler {
438 acquire_count: Arc<AtomicU32>,
439 release_count: Arc<AtomicU32>,
440 checkin_count: Arc<AtomicU32>,
441 checkout_count: Arc<AtomicU32>,
442 }
443
444 impl r2d2::HandleEvent for TestEventHandler {
445 fn handle_acquire(&self, _event: r2d2::event::AcquireEvent) {
446 self.acquire_count.fetch_add(1, Ordering::Relaxed);
447 }
448 fn handle_release(&self, _event: r2d2::event::ReleaseEvent) {
449 self.release_count.fetch_add(1, Ordering::Relaxed);
450 }
451 fn handle_checkout(&self, _event: r2d2::event::CheckoutEvent) {
452 self.checkout_count.fetch_add(1, Ordering::Relaxed);
453 }
454 fn handle_checkin(&self, _event: r2d2::event::CheckinEvent) {
455 self.checkin_count.fetch_add(1, Ordering::Relaxed);
456 }
457 }
458
459 let acquire_count = Arc::new(AtomicU32::new(0));
460 let release_count = Arc::new(AtomicU32::new(0));
461 let checkin_count = Arc::new(AtomicU32::new(0));
462 let checkout_count = Arc::new(AtomicU32::new(0));
463
464 let handler = Box::new(TestEventHandler {
465 acquire_count: acquire_count.clone(),
466 release_count: release_count.clone(),
467 checkin_count: checkin_count.clone(),
468 checkout_count: checkout_count.clone(),
469 });
470
471 let manager = ConnectionManager::<TestConnection>::new(database_url());
472 let pool = Pool::builder()
473 .max_size(1)
474 .test_on_check_out(true)
475 .event_handler(handler)
476 .build(manager)
477 .unwrap();
478
479 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
480 assert_eq!(release_count.load(Ordering::Relaxed), 0);
481 assert_eq!(checkin_count.load(Ordering::Relaxed), 0);
482 assert_eq!(checkout_count.load(Ordering::Relaxed), 0);
483
484 {
486 let conn = pool.get().unwrap();
487
488 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
489 assert_eq!(release_count.load(Ordering::Relaxed), 0);
490 assert_eq!(checkin_count.load(Ordering::Relaxed), 0);
491 assert_eq!(checkout_count.load(Ordering::Relaxed), 1);
492 std::mem::drop(conn);
493 }
494
495 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
496 assert_eq!(release_count.load(Ordering::Relaxed), 0);
497 assert_eq!(checkin_count.load(Ordering::Relaxed), 1);
498 assert_eq!(checkout_count.load(Ordering::Relaxed), 1);
499
500 {
502 let mut conn = pool.get().unwrap();
503
504 assert_eq!(acquire_count.load(Ordering::Relaxed), 1);
505 assert_eq!(release_count.load(Ordering::Relaxed), 0);
506 assert_eq!(checkin_count.load(Ordering::Relaxed), 1);
507 assert_eq!(checkout_count.load(Ordering::Relaxed), 2);
508
509 <TestConnection as Connection>::TransactionManager::begin_transaction(&mut *conn)
510 .unwrap();
511 }
512
513 assert_eq!(release_count.load(Ordering::Relaxed), 1);
519 assert_eq!(checkin_count.load(Ordering::Relaxed), 2);
520 assert_eq!(checkout_count.load(Ordering::Relaxed), 2);
521
522 #[allow(unreachable_code, unused_variables)]
525 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
526 let conn = pool.get();
527 assert_eq!(acquire_count.load(Ordering::Relaxed), 2);
528 assert_eq!(release_count.load(Ordering::Relaxed), 1);
529 assert_eq!(checkin_count.load(Ordering::Relaxed), 2);
530 assert_eq!(checkout_count.load(Ordering::Relaxed), 3);
531 panic!();
532 std::mem::drop(conn);
533 }))
534 .unwrap_err();
535
536 assert_eq!(release_count.load(Ordering::Relaxed), 2);
542 assert_eq!(checkin_count.load(Ordering::Relaxed), 3);
543 assert_eq!(checkout_count.load(Ordering::Relaxed), 3);
544 }
545
546 #[cfg(feature = "postgres")]
547 #[test]
548 fn verify_that_begin_test_transaction_works_with_pools() {
549 use crate::prelude::*;
550 use crate::r2d2::*;
551
552 table! {
553 users {
554 id -> Integer,
555 name -> Text,
556 }
557 }
558
559 #[derive(Debug)]
560 struct TestConnectionCustomizer;
561
562 impl<E> CustomizeConnection<PgConnection, E> for TestConnectionCustomizer {
563 fn on_acquire(&self, conn: &mut PgConnection) -> Result<(), E> {
564 conn.begin_test_transaction()
565 .expect("Failed to start test transaction");
566
567 Ok(())
568 }
569 }
570
571 let manager = ConnectionManager::<PgConnection>::new(database_url());
572 let pool = Pool::builder()
573 .max_size(1)
574 .connection_customizer(Box::new(TestConnectionCustomizer))
575 .build(manager)
576 .unwrap();
577
578 let mut conn = pool.get().unwrap();
579
580 crate::sql_query(
581 "CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)",
582 )
583 .execute(&mut conn)
584 .unwrap();
585
586 crate::insert_into(users::table)
587 .values(users::name.eq("John"))
588 .execute(&mut conn)
589 .unwrap();
590
591 std::mem::drop(conn);
592
593 let mut conn2 = pool.get().unwrap();
594
595 let user_count = users::table.count().get_result::<i64>(&mut conn2).unwrap();
596 assert_eq!(user_count, 1);
597 }
598}