diff --git a/redis/src/lib.rs b/redis/src/lib.rs index b4510d2..7872bb7 100644 --- a/redis/src/lib.rs +++ b/redis/src/lib.rs @@ -39,9 +39,44 @@ pub use bb8; pub use redis; use async_trait::async_trait; +use redis::aio::PubSub; use redis::{aio::Connection, ErrorKind}; use redis::{Client, IntoConnectionInfo, RedisError}; +/// A `bb8::ManageConnection` for `redis::aio::PubSub` +#[derive(Clone, Debug)] +pub struct RedisPubSubConnectionManager { + client: Client, +} + +impl RedisPubSubConnectionManager { + /// Create a new `RedisConnectionPubSubManager`. + /// See `redis::Client::open` for a description of the parameter types. + pub fn new(info: T) -> Result { + Ok(Self { + client: Client::open(info.into_connection_info()?)?, + }) + } +} + +#[async_trait] +impl bb8::ManageConnection for RedisPubSubConnectionManager { + type Connection = PubSub; + type Error = RedisError; + + async fn connect(&self) -> Result { + Ok(self.client.get_async_connection().await?.into_pubsub()) + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + conn.punsubscribe("").await + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + false + } +} + /// A `bb8::ManageConnection` for `redis::Client::get_async_connection`. #[derive(Clone, Debug)] pub struct RedisConnectionManager {