diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index d51d2f8..2dbf6af 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -1,7 +1,7 @@ use std::cmp::{max, min}; use std::fmt; use std::future::Future; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::Ordering; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; @@ -11,14 +11,13 @@ use tokio::spawn; use tokio::time::{interval_at, sleep, timeout, Interval}; use crate::api::{Builder, ConnectionState, ManageConnection, PooledConnection, RunError, State}; -use crate::internals::{Approval, ApprovalIter, Conn, SharedPool}; +use crate::internals::{Approval, ApprovalIter, AtomicStatistics, Conn, SharedPool}; pub(crate) struct PoolInner where M: ManageConnection + Send, { inner: Arc>, - pool_inner_stats: Arc, } impl PoolInner @@ -27,7 +26,6 @@ where { pub(crate) fn new(builder: Builder, manager: M) -> Self { let inner = Arc::new(SharedPool::new(builder, manager)); - let pool_inner_stats = Arc::new(SharedPoolInnerStatistics::new()); if inner.statics.max_lifetime.is_some() || inner.statics.idle_timeout.is_some() { let start = Instant::now() + inner.statics.reaper_rate; @@ -36,16 +34,12 @@ where Reaper { interval, pool: Arc::downgrade(&inner), - pool_inner_stats: Arc::downgrade(&pool_inner_stats), } .run(), ); } - Self { - inner, - pool_inner_stats, - } + Self { inner } } pub(crate) async fn start_connections(&self) -> Result<(), M::Error> { @@ -92,7 +86,7 @@ where } pub(crate) async fn get(&self) -> Result, RunError> { - let mut with_contention = false; + let mut get_direct = true; let future = async { loop { @@ -105,7 +99,7 @@ where let mut conn = match conn { Some(conn) => PooledConnection::new(self, conn), None => { - with_contention = true; + get_direct = false; self.inner.notify.notified().await; continue; } @@ -126,14 +120,29 @@ where } }; - let result = match timeout(self.inner.statics.connection_timeout, future).await { - Ok(result) => result, - _ => Err(RunError::TimedOut), - }; - - self.pool_inner_stats.record_get(with_contention); - - result + match timeout(self.inner.statics.connection_timeout, future).await { + Ok(result) => { + if get_direct { + self.inner + .statistics + .get_direct + .fetch_add(1, Ordering::SeqCst); + } else { + self.inner + .statistics + .get_waited + .fetch_add(1, Ordering::SeqCst); + } + result + } + _ => { + self.inner + .statistics + .get_timed_out + .fetch_add(1, Ordering::SeqCst); + Err(RunError::TimedOut) + } + } } pub(crate) async fn connect(&self) -> Result { @@ -162,10 +171,7 @@ where /// Returns statistics about the historical usage of the pool. pub(crate) fn statistics(&self) -> Statistics { - let gets = self.pool_inner_stats.gets.load(Ordering::SeqCst); - let gets_waited = self.pool_inner_stats.gets_waited.load(Ordering::SeqCst); - - Statistics { gets, gets_waited } + (&(self.inner.statistics)).into() } /// Returns information about the current state of the pool. @@ -234,7 +240,6 @@ where fn clone(&self) -> Self { PoolInner { inner: self.inner.clone(), - pool_inner_stats: self.pool_inner_stats.clone(), } } } @@ -251,7 +256,6 @@ where struct Reaper { interval: Interval, pool: Weak>, - pool_inner_stats: Weak, } impl Reaper { @@ -259,10 +263,7 @@ impl Reaper { loop { let _ = self.interval.tick().await; let pool = match self.pool.upgrade() { - Some(inner) => PoolInner { - inner, - pool_inner_stats: self.pool_inner_stats.upgrade().unwrap(), - }, + Some(inner) => PoolInner { inner }, None => break, }; @@ -272,38 +273,25 @@ impl Reaper { } } -struct SharedPoolInnerStatistics { - gets: AtomicU64, - gets_waited: AtomicU64, -} - -impl SharedPoolInnerStatistics { - fn new() -> Self { - Self { - gets: AtomicU64::new(0), - gets_waited: AtomicU64::new(0), - } - } - - fn record_get(&self, with_contention: bool) { - self.gets.fetch_add(1, Ordering::SeqCst); - - if with_contention { - self.gets_waited.fetch_add(1, Ordering::SeqCst); - } - } -} - /// Statistics about the historical usage of the `Pool`. #[derive(Debug)] #[non_exhaustive] pub struct Statistics { - /// Information about gets - /// Total gets performed, you should consider that the - /// value can overflow and start from 0 eventually. - pub gets: u64, - /// Total gets performed that had to wait for having a - /// connection available. The value can overflow and - /// start from 0 eventually. - pub gets_waited: u64, + /// Information about gets. + /// Total gets performed that did not have to wait for a connection. + pub get_direct: u64, + /// Total gets performed that had to wait for a connection available. + pub get_waited: u64, + /// Total gets performed that timed out while waiting for a connection. + pub get_timed_out: u64, +} + +impl From<&AtomicStatistics> for Statistics { + fn from(item: &AtomicStatistics) -> Self { + Statistics { + get_direct: item.get_direct.load(Ordering::SeqCst), + get_waited: item.get_waited.load(Ordering::SeqCst), + get_timed_out: item.get_timed_out.load(Ordering::SeqCst), + } + } } diff --git a/bb8/src/internals.rs b/bb8/src/internals.rs index ed153f0..1cd9039 100644 --- a/bb8/src/internals.rs +++ b/bb8/src/internals.rs @@ -1,4 +1,5 @@ use std::cmp::min; +use std::sync::atomic::AtomicU64; use std::sync::Arc; use std::time::Instant; @@ -18,6 +19,7 @@ where pub(crate) manager: M, pub(crate) internals: Mutex>, pub(crate) notify: Arc, + pub(crate) statistics: AtomicStatistics, } impl SharedPool @@ -30,6 +32,7 @@ where manager, internals: Mutex::new(PoolInternals::default()), notify: Arc::new(Notify::new()), + statistics: AtomicStatistics::default(), } } @@ -246,6 +249,13 @@ impl From> for IdleConn { } } +#[derive(Default)] +pub(crate) struct AtomicStatistics { + pub(crate) get_direct: AtomicU64, + pub(crate) get_waited: AtomicU64, + pub(crate) get_timed_out: AtomicU64, +} + /// Information about the state of a `Pool`. #[derive(Debug)] #[non_exhaustive] diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index 82c28f7..661144c 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -313,6 +313,10 @@ async fn test_get_timeout() { tx2.send(()).unwrap(); let r: Result<(), ()> = Ok(()); ready(r).await.unwrap(); + + // check that the timed out was tracked + let statistics = pool.statistics(); + assert_eq!(statistics.get_timed_out, 1); } #[tokio::test] @@ -887,7 +891,7 @@ async fn test_broken_connections_dont_starve_pool() { } #[tokio::test] -async fn test_statistics_get_contention() { +async fn test_statistics_get_waited() { let pool = Pool::builder() .max_size(1) .min_idle(1) @@ -922,6 +926,6 @@ async fn test_statistics_get_contention() { f.await.unwrap(); let statistics = pool.statistics(); - assert_eq!(statistics.gets, 2); - assert_eq!(statistics.gets_waited, 1); + assert_eq!(statistics.get_direct, 1); + assert_eq!(statistics.get_waited, 1); }