From c07bac52800b22ba556455d64fac1f044d1e7466 Mon Sep 17 00:00:00 2001 From: Hiram Chirino Date: Fri, 24 May 2024 09:25:54 -0400 Subject: [PATCH] [distributed store] Batch up updates per session. Signed-off-by: Hiram Chirino --- limitador/src/storage/distributed/grpc/mod.rs | 98 +++++++++++++++++-- limitador/src/storage/distributed/mod.rs | 47 ++++----- 2 files changed, 115 insertions(+), 30 deletions(-) diff --git a/limitador/src/storage/distributed/grpc/mod.rs b/limitador/src/storage/distributed/grpc/mod.rs index 79d3d985..671aa3d6 100644 --- a/limitador/src/storage/distributed/grpc/mod.rs +++ b/limitador/src/storage/distributed/grpc/mod.rs @@ -5,14 +5,14 @@ use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{error::Error, io::ErrorKind, pin::Pin}; -use tokio::sync::mpsc::Sender; -use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio::sync::mpsc::{Permit, Sender}; +use tokio::sync::{broadcast, mpsc, Notify, RwLock}; use tokio::time::sleep; - use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt}; use tonic::{Code, Request, Response, Status, Streaming}; use tracing::debug; +use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::distributed::grpc::v1::packet::Message; use crate::storage::distributed::grpc::v1::replication_client::ReplicationClient; use crate::storage::distributed::grpc::v1::replication_server::{Replication, ReplicationServer}; @@ -106,12 +106,46 @@ impl Session { .await?; let mut udpates_to_send = self.broker_state.publisher.subscribe(); + let mut tx_updates_by_key = HashMap::new(); + let mut tx_updates_order = vec![]; + let notifier = Notify::default(); loop { tokio::select! { update = udpates_to_send.recv() => { let update = update.map_err(|_| Status::unknown("broadcast error"))?; - self.send(Message::CounterUpdate(update)).await?; + // Multiple updates collapse into a single update for the same key + if !tx_updates_by_key.contains_key(&update.key) { + tx_updates_by_key.insert(update.key.clone(), update.value); + tx_updates_order.push(update.key); + notifier.notify_one(); + } + } + _ = notifier.notified() => { + // while we have pending updates to send... + while !tx_updates_order.is_empty() { + // and we have space on the transmission channel to send the update... + match self.out_stream.clone().try_reserve() { + Err(_) => { + break + }, + Ok(permit) => { + + let key = tx_updates_order.remove(0); + let cr_counter_value = tx_updates_by_key.remove(&key).unwrap().clone(); + let (expiry, values) = (&*cr_counter_value).clone().into_inner(); + + // only send the update if it has not expired. + if expiry > SystemTime::now() { + permit.send(Message::CounterUpdate(CounterUpdate { + key, + values: values.into_iter().collect(), + expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), + })); + } + } + } + } } result = in_stream.next() => { match result { @@ -315,14 +349,64 @@ impl MessageSender { }, } } + fn try_reserve(self) -> Result, Status> { + match self { + MessageSender::Client(sender) => { + let permit = sender + .clone() + .try_reserve() + .map_err(|_| Status::unknown("send error"))?; + Ok(MessagePermit::Client(permit)) + } + MessageSender::Server(sender) => { + let permit = sender + .clone() + .try_reserve() + .map_err(|_| Status::unknown("send error"))?; + Ok(MessagePermit::Server(permit)) + } + } + } +} + +enum MessagePermit<'a> { + Server(Permit<'a, Result>), + Client(Permit<'a, Packet>), +} +impl<'a> MessagePermit<'a> { + fn send(self, message: Message) { + match self { + MessagePermit::Server(permit) => { + permit.send(Ok(Packet { + message: Some(message), + })); + } + MessagePermit::Client(permit) => { + permit.send(Packet { + message: Some(message), + }); + } + } + } } type CounterUpdateFn = Pin>; +#[derive(Clone, Debug)] +pub struct CounterEntry { + pub key: Vec, + pub value: Arc>, +} + +impl CounterEntry { + pub fn new(key: Vec, value: Arc>) -> Self { + Self { key, value } + } +} #[derive(Clone)] struct BrokerState { id: String, - publisher: broadcast::Sender, + publisher: broadcast::Sender, on_counter_update: Arc, } @@ -342,7 +426,7 @@ impl Broker { on_counter_update: CounterUpdateFn, ) -> Broker { let (tx, _) = broadcast::channel(16); - let publisher: broadcast::Sender = tx; + let publisher: broadcast::Sender = tx; Broker { listen_address, @@ -359,7 +443,7 @@ impl Broker { } } - pub fn publish(&self, counter_update: CounterUpdate) { + pub fn publish(&self, counter_update: CounterEntry) { // ignore the send error, it just means there are no active subscribers _ = self.broker_state.publisher.send(counter_update); } diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 452b3aa2..4993c563 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -10,13 +10,13 @@ use crate::counter::Counter; use crate::limit::{Limit, Namespace}; use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::distributed::grpc::v1::CounterUpdate; -use crate::storage::distributed::grpc::Broker; +use crate::storage::distributed::grpc::{Broker, CounterEntry}; use crate::storage::{Authorization, CounterStorage, StorageErr}; mod cr_counter_value; mod grpc; -pub type LimitsMap = HashMap, CrCounterValue>; +pub type LimitsMap = HashMap, Arc>>; pub struct CrInMemoryStorage { identifier: String, @@ -42,11 +42,11 @@ impl CounterStorage for CrInMemoryStorage { if limit.variables().is_empty() { let mut limits = self.limits.write().unwrap(); let key = encode_limit_to_key(limit); - limits.entry(key).or_insert(CrCounterValue::new( + limits.entry(key).or_insert(Arc::new(CrCounterValue::new( self.identifier.clone(), limit.max_value(), Duration::from_secs(limit.seconds()), - )); + ))); } Ok(()) } @@ -60,13 +60,16 @@ impl CounterStorage for CrInMemoryStorage { match limits.entry(key.clone()) { Entry::Vacant(entry) => { let duration = counter.window(); - let store_value = - CrCounterValue::new(self.identifier.clone(), counter.max_value(), duration); - self.increment_counter(counter, key, &store_value, delta, now); + let store_value = Arc::new(CrCounterValue::new( + self.identifier.clone(), + counter.max_value(), + duration, + )); + self.increment_counter(counter, key, store_value.clone(), delta, now); entry.insert(store_value); } Entry::Occupied(entry) => { - self.increment_counter(counter, key, entry.get(), delta, now); + self.increment_counter(counter, key, entry.get().clone(), delta, now); } }; Ok(()) @@ -129,11 +132,14 @@ impl CounterStorage for CrInMemoryStorage { if !counter_existed { // try again with a write lock to create the counter if it's still missing. let mut limits = self.limits.write().unwrap(); - let store_value = limits.entry(key.clone()).or_insert(CrCounterValue::new( - self.identifier.clone(), - counter.max_value(), - counter.window(), - )); + let store_value = + limits + .entry(key.clone()) + .or_insert(Arc::new(CrCounterValue::new( + self.identifier.clone(), + counter.max_value(), + counter.window(), + ))); if let Some(limited) = process_counter(counter, store_value.read(), delta) { if !load_counters { @@ -154,7 +160,7 @@ impl CounterStorage for CrInMemoryStorage { .into_iter() .for_each(|(counter, key)| { let store_value = limits.get(&key).unwrap(); - self.increment_counter(&counter, key, store_value, delta, now); + self.increment_counter(&counter, key, store_value.clone(), delta, now); }); Ok(Authorization::Ok) @@ -178,7 +184,7 @@ impl CounterStorage for CrInMemoryStorage { }; if limits.contains(&limit_key) { - let counter = (&counter_key, counter_value); + let counter = (&counter_key, &*counter_value.clone()); let mut counter: Counter = counter.into(); counter.set_remaining(counter.max_value() - counter_value.read()); counter.set_expires_in(counter_value.ttl()); @@ -264,18 +270,13 @@ impl CrInMemoryStorage { &self, counter: &Counter, store_key: Vec, - store_value: &CrCounterValue, + store_value: Arc>, delta: u64, when: SystemTime, ) { store_value.inc_at(delta, counter.window(), when); - - let (expiry, values) = store_value.clone().into_inner(); - self.broker.publish(CounterUpdate { - key: store_key, - values: values.into_iter().collect(), - expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), - }) + self.broker + .publish(CounterEntry::new(store_key, store_value)) } }