diff --git a/bb8/src/api.rs b/bb8/src/api.rs index 1f2d598..857f592 100644 --- a/bb8/src/api.rs +++ b/bb8/src/api.rs @@ -75,6 +75,14 @@ impl Pool { pub fn state(&self) -> State { self.inner.state() } + + /// Adds a connection to the pool. + /// + /// If the connection is broken, or the pool is at capacity, the + /// connection is not added and instead returned to the caller in Err. + pub fn add(&self, conn: M::Connection) -> Result<(), AddError> { + self.inner.try_put(conn) + } } /// Information about the state of a `Pool`. @@ -526,6 +534,33 @@ where } } +/// Error type returned by `Pool::add(conn)` +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AddError { + /// The connection was broken before it could be added. + Broken(C), + /// Unable to add the connection to the pool due to insufficient capacity. + NoCapacity(C), +} + +impl fmt::Display for AddError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + AddError::Broken(_) => write!(f, "The connection was broken before it could be added"), + AddError::NoCapacity(_) => write!( + f, + "Unable to add the connection to the pool due to insufficient capacity" + ), + } + } +} + +impl error::Error for AddError { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + None + } +} + /// A trait to receive errors generated by connection management that aren't /// tied to any particular caller. pub trait ErrorSink: fmt::Debug + Send + Sync + 'static { diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index 209d69f..be2ef52 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -9,7 +9,9 @@ use futures_util::TryFutureExt; use tokio::spawn; use tokio::time::{interval_at, sleep, timeout, Interval}; -use crate::api::{Builder, ConnectionState, ManageConnection, PooledConnection, RunError, State}; +use crate::api::{ + AddError, Builder, ConnectionState, ManageConnection, PooledConnection, RunError, State, +}; use crate::internals::{Approval, ApprovalIter, Conn, SharedPool, StatsGetKind, StatsKind}; pub(crate) struct PoolInner @@ -161,6 +163,15 @@ where } } + /// Adds an external connection to the pool if there is capacity for it. + pub(crate) fn try_put(&self, mut conn: M::Connection) -> Result<(), AddError> { + if self.inner.manager.has_broken(&mut conn) { + Err(AddError::Broken(conn)) + } else { + self.inner.try_put(conn).map_err(AddError::NoCapacity) + } + } + /// Returns information about the current state of the pool. pub(crate) fn state(&self) -> State { self.inner diff --git a/bb8/src/internals.rs b/bb8/src/internals.rs index 81fefab..155e21a 100644 --- a/bb8/src/internals.rs +++ b/bb8/src/internals.rs @@ -47,6 +47,17 @@ where (conn, approvals) } + pub(crate) fn try_put(self: &Arc, conn: M::Connection) -> Result<(), M::Connection> { + let mut locked = self.internals.lock(); + let mut approvals = locked.approvals(&self.statics, 1); + let Some(approval) = approvals.next() else { + return Err(conn); + }; + let conn = Conn::new(conn); + locked.put(conn, Some(approval), self.clone()); + Ok(()) + } + pub(crate) fn reap(&self) -> ApprovalIter { let mut locked = self.internals.lock(); let (iter, closed_idle_timeout, closed_max_lifetime) = locked.reap(&self.statics); diff --git a/bb8/src/lib.rs b/bb8/src/lib.rs index 5642cba..df3de74 100644 --- a/bb8/src/lib.rs +++ b/bb8/src/lib.rs @@ -35,7 +35,7 @@ mod api; pub use api::{ - Builder, CustomizeConnection, ErrorSink, ManageConnection, NopErrorSink, Pool, + AddError, Builder, CustomizeConnection, ErrorSink, ManageConnection, NopErrorSink, Pool, PooledConnection, QueueStrategy, RunError, State, Statistics, }; diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index 178e8d3..0e1225e 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -1020,3 +1020,51 @@ async fn test_statistics_connections_created() { assert_eq!(pool.state().statistics.connections_created, 1); } + +#[tokio::test] +async fn test_can_use_added_connections() { + let pool = Pool::builder() + .connection_timeout(Duration::from_millis(1)) + .build_unchecked(NthConnectionFailManager::::new(0)); + + // Assert pool can't replenish connections on its own + let res = pool.get().await; + assert_eq!(res.unwrap_err(), RunError::TimedOut); + + pool.add(FakeConnection).unwrap(); + let res = pool.get().await; + assert!(res.is_ok()); +} + +#[tokio::test] +async fn test_add_ok_until_max_size() { + let pool = Pool::builder() + .min_idle(1) + .max_size(3) + .build(OkManager::::new()) + .await + .unwrap(); + + for _ in 0..2 { + let conn = pool.dedicated_connection().await.unwrap(); + pool.add(conn).unwrap(); + } + + let conn = pool.dedicated_connection().await.unwrap(); + let res = pool.add(conn); + assert!(matches!(res, Err(AddError::NoCapacity(_)))); +} + +#[tokio::test] +async fn test_add_checks_broken_connections() { + let pool = Pool::builder() + .min_idle(1) + .max_size(3) + .build(BrokenConnectionManager::::default()) + .await + .unwrap(); + + let conn = pool.dedicated_connection().await.unwrap(); + let res = pool.add(conn); + assert!(matches!(res, Err(AddError::Broken(_)))); +}