From b0b979d58ce25ddbf536dfff836f3561e1d9b621 Mon Sep 17 00:00:00 2001 From: Hiram Chirino Date: Wed, 12 Jun 2024 08:35:43 -0400 Subject: [PATCH 1/2] Support providing an optional id to limits/counters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add a key_for_counters_v2 function that uses the id as the key if set, otherwise uses the previous key encoding strategy. * updated the distributed store to use key_for_counters_v2. Since we can’t decode a partial counter from id based keys, we now also keep in memory the Counter in a counter field of the limits map. * Use the shorter binary encoding for limits/counters when the id is set, continue to use the older text encoding for backward compatibility when the id is not set. Signed-off-by: Hiram Chirino --- .github/workflows/rust.yml | 2 +- .../src/http_api/request_types.rs | 6 + limitador/src/counter.rs | 4 + limitador/src/limit.rs | 25 ++ limitador/src/storage/distributed/grpc/mod.rs | 27 ++- limitador/src/storage/distributed/mod.rs | 184 +++++---------- limitador/src/storage/keys.rs | 218 ++++++++++++++---- limitador/src/storage/redis/redis_async.rs | 10 +- limitador/src/storage/redis/redis_sync.rs | 12 +- limitador/tests/integration_tests.rs | 38 +++ 10 files changed, 332 insertions(+), 194 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index df8931c7..a64392ba 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -22,7 +22,7 @@ jobs: - uses: abelfodil/protoc-action@v1 with: protoc-version: '3.19.4' - - run: cargo check + - run: cargo check --all-features test: name: Test Suite diff --git a/limitador-server/src/http_api/request_types.rs b/limitador-server/src/http_api/request_types.rs index f8d7dc45..9751ee2e 100644 --- a/limitador-server/src/http_api/request_types.rs +++ b/limitador-server/src/http_api/request_types.rs @@ -18,6 +18,7 @@ pub struct CheckAndReportInfo { #[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Apiv2Schema)] pub struct Limit { + id: Option, namespace: String, max_value: u64, seconds: u64, @@ -29,6 +30,7 @@ pub struct Limit { impl From<&LimitadorLimit> for Limit { fn from(ll: &LimitadorLimit) -> Self { Self { + id: ll.id().clone(), namespace: ll.namespace().as_ref().to_string(), max_value: ll.max_value(), seconds: ll.seconds(), @@ -49,6 +51,10 @@ impl From for LimitadorLimit { limit.variables, ); + if let Some(id) = limit.id { + limitador_limit.set_id(id); + } + if let Some(name) = limit.name { limitador_limit.set_name(name) } diff --git a/limitador/src/counter.rs b/limitador/src/counter.rs index 5f5bac49..0b55bd6c 100644 --- a/limitador/src/counter.rs +++ b/limitador/src/counter.rs @@ -72,6 +72,10 @@ impl Counter { Duration::from_secs(self.limit.seconds()) } + pub fn id(&self) -> &Option { + self.limit.id() + } + pub fn namespace(&self) -> &Namespace { self.limit.namespace() } diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index 12adb7ff..fa0e124c 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -51,6 +51,8 @@ impl From for Namespace { #[derive(Eq, Debug, Clone, Serialize, Deserialize)] pub struct Limit { + #[serde(skip_serializing, default)] + id: Option, namespace: Namespace, #[serde(skip_serializing, default)] max_value: u64, @@ -319,6 +321,7 @@ impl Limit { { // the above where-clause is needed in order to call unwrap(). Self { + id: None, namespace: namespace.into(), max_value, seconds, @@ -335,6 +338,14 @@ impl Limit { &self.namespace } + pub fn set_id(&mut self, value: String) { + self.id = Some(value); + } + + pub fn id(&self) -> &Option { + &self.id + } + pub fn max_value(&self) -> u64 { self.max_value } @@ -998,4 +1009,18 @@ mod tests { let result = serde_json::to_string(&condition).expect("Should serialize"); assert_eq!(result, r#""foobar == \"ok\"""#.to_string()); } + + #[test] + fn limit_id() { + let mut limit = Limit::new( + "test_namespace", + 10, + 60, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); + limit.set_id("test_id".to_string()); + + assert_eq!(limit.id().clone(), Some("test_id".to_string())) + } } diff --git a/limitador/src/storage/distributed/grpc/mod.rs b/limitador/src/storage/distributed/grpc/mod.rs index 20548850..79bc79ae 100644 --- a/limitador/src/storage/distributed/grpc/mod.rs +++ b/limitador/src/storage/distributed/grpc/mod.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{error::Error, io::ErrorKind, pin::Pin}; +use crate::counter::Counter; use tokio::sync::mpsc::error::TrySendError; use tokio::sync::mpsc::{Permit, Sender}; use tokio::sync::{broadcast, mpsc, Notify, RwLock}; @@ -156,9 +157,10 @@ impl Session { update = udpates_to_send.recv() => { let update = update.map_err(|_| Status::unknown("broadcast error"))?; // Multiple updates collapse into a single update for the same key - if !tx_updates_by_key.contains_key(&update.key) { - tx_updates_by_key.insert(update.key.clone(), update.value); - tx_updates_order.push(update.key); + let key = &update.key.clone(); + if !tx_updates_by_key.contains_key(key) { + tx_updates_by_key.insert(key.clone(), update); + tx_updates_order.push(key.clone()); notifier.notify_one(); } } @@ -174,7 +176,7 @@ impl Session { let key = tx_updates_order.remove(0); let cr_counter_value = tx_updates_by_key.remove(&key).unwrap().clone(); - let (expiry, values) = (*cr_counter_value).clone().into_inner(); + let (expiry, values) = cr_counter_value.value.clone().into_inner(); // only send the update if it has not expired. if expiry > SystemTime::now() { @@ -437,19 +439,24 @@ type CounterUpdateFn = Pin>; #[derive(Clone, Debug)] pub struct CounterEntry { pub key: Vec, - pub value: Arc>, + pub counter: Counter, + pub value: CrCounterValue, } impl CounterEntry { - pub fn new(key: Vec, value: Arc>) -> Self { - Self { key, value } + pub fn new(key: Vec, counter: Counter, value: CrCounterValue) -> Self { + Self { + key, + counter, + value, + } } } #[derive(Clone)] struct BrokerState { id: String, - publisher: broadcast::Sender, + publisher: broadcast::Sender>, on_counter_update: Arc, on_re_sync: Arc>>>, } @@ -471,7 +478,7 @@ impl Broker { on_re_sync: Sender>>, ) -> Broker { let (tx, _) = broadcast::channel(16); - let publisher: broadcast::Sender = tx; + let publisher: broadcast::Sender> = tx; Broker { listen_address, @@ -489,7 +496,7 @@ impl Broker { } } - pub fn publish(&self, counter_update: CounterEntry) { + pub fn publish(&self, counter_update: Arc) { // ignore the send error, it just means there are no active subscribers _ = self.broker_state.publisher.send(counter_update); } diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 6deda533..020918c5 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -4,22 +4,22 @@ use std::net::ToSocketAddrs; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; use tokio::sync::mpsc::Sender; use tracing::debug; use crate::counter::Counter; -use crate::limit::{Limit, Namespace}; +use crate::limit::Limit; use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::distributed::grpc::v1::CounterUpdate; use crate::storage::distributed::grpc::{Broker, CounterEntry}; +use crate::storage::keys::bin::key_for_counter_v2; use crate::storage::{Authorization, CounterStorage, StorageErr}; mod cr_counter_value; mod grpc; -pub type LimitsMap = HashMap, Arc>>; +pub type LimitsMap = HashMap, Arc>; pub struct CrInMemoryStorage { identifier: String, @@ -35,7 +35,7 @@ impl CounterStorage for CrInMemoryStorage { let mut value = 0; let key = encode_counter_to_key(counter); if let Some(counter_value) = limits.get(&key) { - value = counter_value.read() + value = counter_value.value.read() } Ok(counter.max_value() >= value + delta) } @@ -45,11 +45,15 @@ impl CounterStorage for CrInMemoryStorage { if limit.variables().is_empty() { let mut limits = self.limits.write().unwrap(); let key = encode_limit_to_key(limit); - limits.entry(key).or_insert(Arc::new(CrCounterValue::new( - self.identifier.clone(), - limit.max_value(), - Duration::from_secs(limit.seconds()), - ))); + limits.entry(key.clone()).or_insert(Arc::new(CounterEntry { + key, + counter: Counter::new(limit.clone(), HashMap::default()), + value: CrCounterValue::new( + self.identifier.clone(), + limit.max_value(), + Duration::from_secs(limit.seconds()), + ), + })); } Ok(()) } @@ -63,16 +67,20 @@ impl CounterStorage for CrInMemoryStorage { match limits.entry(key.clone()) { Entry::Vacant(entry) => { let duration = counter.window(); - let store_value = Arc::new(CrCounterValue::new( - self.identifier.clone(), - counter.max_value(), - duration, - )); - self.increment_counter(counter, key, store_value.clone(), delta, now); - entry.insert(store_value); + let value = Arc::new(CounterEntry { + key: key.clone(), + counter: counter.clone(), + value: CrCounterValue::new( + self.identifier.clone(), + counter.max_value(), + duration, + ), + }); + self.increment_counter(value.clone(), delta, now); + entry.insert(value); } Entry::Occupied(entry) => { - self.increment_counter(counter, key, entry.get().clone(), delta, now); + self.increment_counter(entry.get().clone(), delta, now); } }; Ok(()) @@ -86,7 +94,7 @@ impl CounterStorage for CrInMemoryStorage { load_counters: bool, ) -> Result { let mut first_limited = None; - let mut counter_values_to_update: Vec<(Counter, Vec)> = Vec::new(); + let mut counter_values_to_update: Vec> = Vec::new(); let now = SystemTime::now(); let mut process_counter = @@ -120,12 +128,14 @@ impl CounterStorage for CrInMemoryStorage { match limits.get(&key) { None => false, Some(store_value) => { - if let Some(limited) = process_counter(counter, store_value.read(), delta) { + if let Some(limited) = + process_counter(counter, store_value.value.read(), delta) + { if !load_counters { return Ok(limited); } } - counter_values_to_update.push((counter.clone(), key)); + counter_values_to_update.push(key); true } } @@ -135,21 +145,22 @@ impl CounterStorage for CrInMemoryStorage { if !counter_existed { // try again with a write lock to create the counter if it's still missing. let mut limits = self.limits.write().unwrap(); - let store_value = - limits - .entry(key.clone()) - .or_insert(Arc::new(CrCounterValue::new( - self.identifier.clone(), - counter.max_value(), - counter.window(), - ))); - - if let Some(limited) = process_counter(counter, store_value.read(), delta) { + let store_value = limits.entry(key.clone()).or_insert(Arc::new(CounterEntry { + key: key.clone(), + counter: counter.clone(), + value: CrCounterValue::new( + self.identifier.clone(), + counter.max_value(), + counter.window(), + ), + })); + + if let Some(limited) = process_counter(counter, store_value.value.read(), delta) { if !load_counters { return Ok(limited); } } - counter_values_to_update.push((counter.clone(), key)); + counter_values_to_update.push(key); } } @@ -159,12 +170,10 @@ impl CounterStorage for CrInMemoryStorage { // Update counters let limits = self.limits.read().unwrap(); - counter_values_to_update - .into_iter() - .for_each(|(counter, key)| { - let store_value = limits.get(&key).unwrap(); - self.increment_counter(&counter, key, store_value.clone(), delta, now); - }); + counter_values_to_update.into_iter().for_each(|key| { + let store_value = limits.get(&key).unwrap(); + self.increment_counter(store_value.clone(), delta, now); + }); Ok(Authorization::Ok) } @@ -172,25 +181,12 @@ impl CounterStorage for CrInMemoryStorage { #[tracing::instrument(skip_all)] fn get_counters(&self, limits: &HashSet>) -> Result, StorageErr> { let mut res = HashSet::new(); - - let limits: HashSet<_> = limits.iter().map(|l| encode_limit_to_key(l)).collect(); - let limits_map = self.limits.read().unwrap(); - for (key, counter_value) in limits_map.iter() { - let counter_key = decode_counter_key(key).unwrap(); - let limit_key = if !counter_key.vars.is_empty() { - let mut cloned = counter_key.clone(); - cloned.vars = HashMap::default(); - cloned.encode() - } else { - key.clone() - }; - - if limits.contains(&limit_key) { - let counter = (&counter_key, &*counter_value.clone()); - let mut counter: Counter = counter.into(); - counter.set_remaining(counter.max_value() - counter_value.read()); - counter.set_expires_in(counter_value.ttl()); + for (_, counter_entry) in limits_map.iter() { + if limits.contains(counter_entry.counter.limit()) { + let mut counter: Counter = counter_entry.counter.clone(); + counter.set_remaining(counter.max_value() - counter_entry.value.read()); + counter.set_expires_in(counter_entry.value.ttl()); if counter.expires_in().unwrap() > Duration::ZERO { res.insert(counter); } @@ -241,7 +237,9 @@ impl CrInMemoryStorage { ); let limits = limits_clone.read().unwrap(); let value = limits.get(&update.key).unwrap(); - value.merge((UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into()); + value + .value + .merge((UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into()); }), re_sync_queue_tx, ); @@ -282,17 +280,11 @@ impl CrInMemoryStorage { } } - fn increment_counter( - &self, - counter: &Counter, - store_key: Vec, - store_value: Arc>, - delta: u64, - when: SystemTime, - ) { - store_value.inc_at(delta, counter.window(), when); - self.broker - .publish(CounterEntry::new(store_key, store_value)) + fn increment_counter(&self, counter_entry: Arc, delta: u64, when: SystemTime) { + counter_entry + .value + .inc_at(delta, counter_entry.counter.window(), when); + self.broker.publish(counter_entry) } } @@ -308,7 +300,7 @@ async fn process_re_sync(limits: &Arc>, sender: Sender>, sender: Sender, - variables: HashSet, - vars: HashMap, -} - -impl CounterKey { - fn new(limit: &Limit, vars: HashMap) -> Self { - CounterKey { - namespace: limit.namespace().clone(), - seconds: limit.seconds(), - variables: limit.variables().clone(), - conditions: limit.conditions().clone(), - vars, - } - } - - fn encode(&self) -> Vec { - postcard::to_stdvec(self).unwrap() - } -} - -impl From<(&CounterKey, &CrCounterValue)> for Counter { - fn from(value: (&CounterKey, &CrCounterValue)) -> Self { - let (counter_key, store_value) = value; - let max_value = store_value.max_value(); - let mut counter = Self::new( - Limit::new( - counter_key.namespace.clone(), - max_value, - counter_key.seconds, - counter_key.conditions.clone(), - counter_key.vars.keys(), - ), - counter_key.vars.clone(), - ); - counter.set_remaining(max_value - store_value.read()); - counter.set_expires_in(store_value.ttl()); - counter - } -} - fn encode_counter_to_key(counter: &Counter) -> Vec { - let key = CounterKey::new(counter.limit(), counter.set_variables().clone()); - postcard::to_stdvec(&key).unwrap() + key_for_counter_v2(counter) } fn encode_limit_to_key(limit: &Limit) -> Vec { - let key = CounterKey::new(limit, HashMap::default()); - postcard::to_stdvec(&key).unwrap() -} - -fn decode_counter_key(key: &Vec) -> postcard::Result { - postcard::from_bytes(key.as_slice()) + let counter = Counter::new(limit.clone(), HashMap::default()); + key_for_counter_v2(&counter) } diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 81d818c6..46b8292d 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -14,37 +14,56 @@ use crate::counter::Counter; use crate::limit::Limit; +use serde::{Deserialize, Serialize}; use std::sync::Arc; -pub fn key_for_counter(counter: &Counter) -> String { - if counter.remaining().is_some() || counter.expires_in().is_some() { - format!( - "{},counter:{}", - prefix_for_namespace(counter.namespace().as_ref()), - serde_json::to_string(&counter.key()).unwrap() - ) +pub fn key_for_counter(counter: &Counter) -> Vec { + if counter.id().is_none() { + // continue to use the legacy text encoding... + let namespace = counter.namespace().as_ref(); + let key = if counter.remaining().is_some() || counter.expires_in().is_some() { + format!( + "namespace:{{{namespace}}},counter:{}", + serde_json::to_string(&counter.key()).unwrap() + ) + } else { + format!( + "namespace:{{{namespace}}},counter:{}", + serde_json::to_string(counter).unwrap() + ) + }; + key.into_bytes() } else { - format!( - "{},counter:{}", - prefix_for_namespace(counter.namespace().as_ref()), - serde_json::to_string(counter).unwrap() - ) + // if the id is set, use the new binary encoding... + bin::key_for_counter_v2(counter) } } -pub fn key_for_counters_of_limit(limit: &Limit) -> String { - format!( - "namespace:{{{}}},counters_of_limit:{}", - limit.namespace().as_ref(), - serde_json::to_string(limit).unwrap() - ) -} +pub fn key_for_counters_of_limit(limit: &Limit) -> Vec { + if limit.id().is_none() { + let namespace = limit.namespace().as_ref(); + format!( + "namespace:{{{namespace}}},counters_of_limit:{}", + serde_json::to_string(limit).unwrap() + ) + .into_bytes() + } else { + #[derive(PartialEq, Debug, Serialize, Deserialize)] + struct IdLimitKey<'a> { + id: &'a str, + } -pub fn prefix_for_namespace(namespace: &str) -> String { - format!("namespace:{{{namespace}}},") + let id = limit.id().as_ref().unwrap(); + let key = IdLimitKey { id: id.as_ref() }; + + let mut encoded_key = Vec::new(); + encoded_key = postcard::to_extend(&2u8, encoded_key).unwrap(); + encoded_key = postcard::to_extend(&key, encoded_key).unwrap(); + encoded_key + } } -pub fn counter_from_counter_key(key: &str, limit: Arc) -> Counter { +pub fn counter_from_counter_key(key: &Vec, limit: Arc) -> Counter { let mut counter = partial_counter_from_counter_key(key); if !counter.update_to_limit(Arc::clone(&limit)) { // this means some kind of data corruption _or_ most probably @@ -58,33 +77,38 @@ pub fn counter_from_counter_key(key: &str, limit: Arc) -> Counter { counter } -pub fn partial_counter_from_counter_key(key: &str) -> Counter { - let namespace_prefix = "namespace:"; - let counter_prefix = ",counter:"; - - // Find the start position of the counter portion - let start_pos_namespace = key - .find(namespace_prefix) - .expect("Namespace not found in the key"); - let start_pos_counter = key[start_pos_namespace..] - .find(counter_prefix) - .expect("Counter not found in the key") - + start_pos_namespace - + counter_prefix.len(); - - // Extract counter JSON substring and deserialize it - let counter_str = &key[start_pos_counter..]; - let counter: Counter = - serde_json::from_str(counter_str).expect("Failed to deserialize counter JSON"); - counter +pub fn partial_counter_from_counter_key(key: &Vec) -> Counter { + if key.starts_with(b"namespace:") { + let key = String::from_utf8_lossy(key.as_ref()); + + // It's using to the legacy text encoding... + let namespace_prefix = "namespace:"; + let counter_prefix = ",counter:"; + + // Find the start position of the counter portion + let start_pos_namespace = key + .find(namespace_prefix) + .expect("Namespace not found in the key"); + let start_pos_counter = key[start_pos_namespace..] + .find(counter_prefix) + .expect("Counter not found in the key") + + start_pos_namespace + + counter_prefix.len(); + + // Extract counter JSON substring and deserialize it + let counter_str = &key[start_pos_counter..]; + let counter: Counter = + serde_json::from_str(counter_str).expect("Failed to deserialize counter JSON"); + counter + } else { + // It's using to the new binary encoding... + bin::partial_counter_from_counter_key_v2(key) + } } #[cfg(test)] mod tests { - use super::{ - key_for_counter, key_for_counters_of_limit, partial_counter_from_counter_key, - prefix_for_namespace, - }; + use super::{key_for_counter, key_for_counters_of_limit, partial_counter_from_counter_key}; use crate::counter::Counter; use crate::Limit; use std::collections::HashMap; @@ -100,7 +124,7 @@ mod tests { vec!["app_id"], ); assert_eq!( - "namespace:{example.com},counters_of_limit:{\"namespace\":\"example.com\",\"seconds\":60,\"conditions\":[\"req.method == \\\"GET\\\"\"],\"variables\":[\"app_id\"]}", + "namespace:{example.com},counters_of_limit:{\"namespace\":\"example.com\",\"seconds\":60,\"conditions\":[\"req.method == \\\"GET\\\"\"],\"variables\":[\"app_id\"]}".as_bytes(), key_for_counters_of_limit(&limit)) } @@ -111,8 +135,6 @@ mod tests { let counter = Counter::new(limit.clone(), HashMap::default()); let raw = key_for_counter(&counter); assert_eq!(counter, partial_counter_from_counter_key(&raw)); - let prefix = prefix_for_namespace(namespace); - assert_eq!(&raw[0..prefix.len()], &prefix); } #[test] @@ -135,6 +157,21 @@ pub mod bin { use crate::counter::Counter; use crate::limit::Limit; + #[derive(PartialEq, Debug, Serialize, Deserialize)] + struct IdCounterKey<'a> { + id: &'a str, + variables: Vec<(&'a str, &'a str)>, + } + + impl<'a> From<&'a Counter> for IdCounterKey<'a> { + fn from(counter: &'a Counter) -> Self { + IdCounterKey { + id: counter.id().as_ref().unwrap().as_ref(), + variables: counter.variables_for_key(), + } + } + } + #[derive(PartialEq, Debug, Serialize, Deserialize)] struct CounterKey<'a> { ns: &'a str, @@ -187,6 +224,54 @@ pub mod bin { } } + pub fn key_for_counter_v2(counter: &Counter) -> Vec { + let mut encoded_key = Vec::new(); + if counter.id().is_none() { + let key: CounterKey = counter.into(); + encoded_key = postcard::to_extend(&1u8, encoded_key).unwrap(); + encoded_key = postcard::to_extend(&key, encoded_key).unwrap() + } else { + let key: IdCounterKey = counter.into(); + encoded_key = postcard::to_extend(&2u8, encoded_key).unwrap(); + encoded_key = postcard::to_extend(&key, encoded_key).unwrap(); + } + encoded_key + } + + pub fn partial_counter_from_counter_key_v2(key: &[u8]) -> Counter { + let (version, key) = postcard::take_from_bytes::(key).unwrap(); + match version { + 1u8 => { + let CounterKey { + ns, + seconds, + conditions, + variables, + } = postcard::from_bytes(key).unwrap(); + + let map: HashMap = variables + .into_iter() + .map(|(var, value)| (var.to_string(), value.to_string())) + .collect(); + let limit = Limit::new(ns, u64::default(), seconds, conditions, map.keys()); + Counter::new(limit, map) + } + 2u8 => { + let IdCounterKey { id, variables } = postcard::from_bytes(key).unwrap(); + let map: HashMap = variables + .into_iter() + .map(|(var, value)| (var.to_string(), value.to_string())) + .collect(); + + // we are not able to rebuild the full limit since we only have the id and variables. + let mut limit = Limit::new::<&str, &str>("", u64::default(), 0, vec![], map.keys()); + limit.set_id(id.to_string()); + Counter::new(limit, map) + } + _ => panic!("Unknown version: {}", version), + } + } + pub fn key_for_counter(counter: &Counter) -> Vec { let key: CounterKey = counter.into(); postcard::to_stdvec(&key).unwrap() @@ -215,13 +300,15 @@ pub mod bin { #[cfg(test)] mod tests { - use super::{ - key_for_counter, partial_counter_from_counter_key, prefix_for_namespace, CounterKey, - }; use crate::counter::Counter; use crate::Limit; use std::collections::HashMap; + use super::{ + key_for_counter, key_for_counter_v2, partial_counter_from_counter_key, + prefix_for_namespace, CounterKey, + }; + #[test] fn counter_key_serializes_and_back() { let namespace = "ns_counter:"; @@ -266,5 +353,34 @@ pub mod bin { let prefix = prefix_for_namespace(namespace); assert_eq!(&serialized_counter[..prefix.len()], &prefix); } + + #[test] + fn counters_with_id() { + let namespace = "ns_counter:"; + let limit_without_id = + Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]); + let mut limit_with_id = + Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]); + limit_with_id.set_id("id200".to_string()); + + let counter_with_id = Counter::new(limit_with_id, HashMap::default()); + let serialized_with_id_counter = key_for_counter(&counter_with_id); + + let counter_without_id = Counter::new(limit_without_id, HashMap::default()); + let serialized_without_id_counter = key_for_counter(&counter_without_id); + + // the original key_for_counter continues to encode kinda big + assert_eq!(serialized_without_id_counter.len(), 35); + assert_eq!(serialized_with_id_counter.len(), 35); + + // serialized_counter_v2 will only encode the id.... so it will be smaller for + // counters with an id. + let serialized_counter_with_id_v2 = key_for_counter_v2(&counter_with_id); + assert_eq!(serialized_counter_with_id_v2.clone().len(), 8); + + // but continues to be large for counters without an id. + let serialized_counter_without_id_v2 = key_for_counter_v2(&counter_without_id); + assert_eq!(serialized_counter_without_id_v2.clone().len(), 36); + } } } diff --git a/limitador/src/storage/redis/redis_async.rs b/limitador/src/storage/redis/redis_async.rs index 4da003a7..31bf946e 100644 --- a/limitador/src/storage/redis/redis_async.rs +++ b/limitador/src/storage/redis/redis_async.rs @@ -38,7 +38,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { let mut con = self.conn_manager.clone(); match con - .get::>(key_for_counter(counter)) + .get::, Option>(key_for_counter(counter)) .instrument(info_span!("datastore")) .await? { @@ -71,7 +71,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { load_counters: bool, ) -> Result { let mut con = self.conn_manager.clone(); - let counter_keys: Vec = counters.iter().map(key_for_counter).collect(); + let counter_keys: Vec> = counters.iter().map(key_for_counter).collect(); if load_counters { let script = redis::Script::new(VALUES_AND_TTLS); @@ -139,7 +139,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { for limit in limits { let counter_keys = { - con.smembers::>(key_for_counters_of_limit(limit)) + con.smembers::, HashSet>>(key_for_counters_of_limit(limit)) .instrument(info_span!("datastore")) .await? }; @@ -156,7 +156,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { // This does not cause any bugs, but consumes memory // unnecessarily. let option = { - con.get::>(counter_key.clone()) + con.get::, Option>(counter_key.clone()) .instrument(info_span!("datastore")) .await? }; @@ -218,7 +218,7 @@ impl AsyncRedisStorage { let mut con = self.conn_manager.clone(); let counter_keys = { - con.smembers::>(key_for_counters_of_limit(limit)) + con.smembers::, HashSet>>(key_for_counters_of_limit(limit)) .instrument(info_span!("datastore")) .await? }; diff --git a/limitador/src/storage/redis/redis_sync.rs b/limitador/src/storage/redis/redis_sync.rs index 52d66fc3..4c0619ee 100644 --- a/limitador/src/storage/redis/redis_sync.rs +++ b/limitador/src/storage/redis/redis_sync.rs @@ -29,7 +29,7 @@ impl CounterStorage for RedisStorage { fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { let mut con = self.conn_pool.get()?; - match con.get::>(key_for_counter(counter))? { + match con.get::, Option>(key_for_counter(counter))? { Some(val) => Ok(u64::try_from(val).unwrap_or(0) + delta <= counter.max_value()), None => Ok(counter.max_value().checked_sub(delta).is_some()), } @@ -62,7 +62,7 @@ impl CounterStorage for RedisStorage { load_counters: bool, ) -> Result { let mut con = self.conn_pool.get()?; - let counter_keys: Vec = counters.iter().map(key_for_counter).collect(); + let counter_keys: Vec> = counters.iter().map(key_for_counter).collect(); if load_counters { let script = redis::Script::new(VALUES_AND_TTLS); @@ -115,7 +115,7 @@ impl CounterStorage for RedisStorage { for limit in limits { let counter_keys = - con.smembers::>(key_for_counters_of_limit(limit))?; + con.smembers::, HashSet>>(key_for_counters_of_limit(limit))?; for counter_key in counter_keys { let mut counter: Counter = @@ -128,7 +128,7 @@ impl CounterStorage for RedisStorage { // do the "get" + "delete if none" atomically. // This does not cause any bugs, but consumes memory // unnecessarily. - if let Some(val) = con.get::>(counter_key.clone())? { + if let Some(val) = con.get::, Option>(counter_key.clone())? { counter.set_remaining( limit .max_value() @@ -150,8 +150,8 @@ impl CounterStorage for RedisStorage { let mut con = self.conn_pool.get()?; for limit in limits { - let counter_keys = - con.smembers::>(key_for_counters_of_limit(limit.deref()))?; + let counter_keys = con + .smembers::, HashSet>>(key_for_counters_of_limit(limit.deref()))?; for counter_key in counter_keys { con.del::<_, ()>(counter_key)?; diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index da20237a..5ecf661f 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -187,6 +187,7 @@ mod test { test_with_all_storage_impls!(delete_limits_of_a_namespace_also_deletes_counters); test_with_all_storage_impls!(delete_limits_of_an_empty_namespace_does_nothing); test_with_all_storage_impls!(rate_limited); + test_with_all_storage_impls!(rate_limited_id_counter); test_with_all_storage_impls!(multiple_limits_rate_limited); test_with_all_storage_impls!(rate_limited_with_delta_higher_than_one); test_with_all_storage_impls!(rate_limited_with_delta_higher_than_max); @@ -519,6 +520,43 @@ mod test { .unwrap()); } + async fn rate_limited_id_counter(rate_limiter: &mut TestsLimiter) { + let namespace = "test_namespace"; + let max_hits = 3; + let mut limit = Limit::new( + namespace, + max_hits, + 60, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); + limit.set_id("test-rate_limited_id_counter".to_string()); + + rate_limiter.add_limit(&limit).await; + + let mut values: HashMap = HashMap::new(); + values.insert("req.method".to_string(), "GET".to_string()); + values.insert("app_id".to_string(), "test_app_id".to_string()); + + for i in 0..max_hits { + assert!( + !rate_limiter + .is_rate_limited(namespace, &values, 1) + .await + .unwrap(), + "Must not be limited after {i}" + ); + rate_limiter + .update_counters(namespace, &values, 1) + .await + .unwrap(); + } + assert!(rate_limiter + .is_rate_limited(namespace, &values, 1) + .await + .unwrap()); + } + async fn multiple_limits_rate_limited(rate_limiter: &mut TestsLimiter) { let namespace = "test_namespace"; let max_hits = 3; From 8b8d59286924319d93f7db2e6d15a73a3c70eee4 Mon Sep 17 00:00:00 2001 From: Hiram Chirino Date: Mon, 15 Jul 2024 09:40:57 -0400 Subject: [PATCH 2/2] Add test that verifies the encoded bytes for the new method, to make sure they remain stable. Signed-off-by: Hiram Chirino --- limitador/src/storage/keys.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 46b8292d..3fe5ded1 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -128,6 +128,22 @@ mod tests { key_for_counters_of_limit(&limit)) } + #[test] + fn key_for_limit_with_id_format() { + let mut limit = Limit::new( + "example.com", + 10, + 60, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); + limit.set_id("test_id".to_string()); + assert_eq!( + "\u{2}\u{7}test_id".as_bytes(), + key_for_counters_of_limit(&limit) + ) + } + #[test] fn counter_key_and_counter_are_symmetric() { let namespace = "ns_counter:";