From 312369be34e3c9a7ff99bd104d67873918e84df5 Mon Sep 17 00:00:00 2001 From: Hiram Chirino Date: Wed, 12 Jun 2024 08:35:43 -0400 Subject: [PATCH] Support providing an optional id to limits/counters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add a key_for_counters_v2 function that uses the id as the key if set, otherwise uses the previous key encoding strategy. * updated the distributed store to use key_for_counters_v2. Since we can’t decode a partial counter from id based keys, we now also keep in memory the Counter in a counter field of the limits map. Signed-off-by: Hiram Chirino --- .../src/http_api/request_types.rs | 3 + limitador/Cargo.toml | 2 +- limitador/benches/bench.rs | 1 + limitador/src/counter.rs | 5 + limitador/src/limit.rs | 92 +++++++-- limitador/src/storage/disk/rocksdb_storage.rs | 9 +- limitador/src/storage/distributed/grpc/mod.rs | 27 ++- limitador/src/storage/distributed/mod.rs | 184 ++++++------------ limitador/src/storage/in_memory.rs | 22 ++- limitador/src/storage/keys.rs | 111 +++++++++-- limitador/src/storage/redis/counters_cache.rs | 1 + limitador/src/storage/redis/redis_cached.rs | 3 + limitador/tests/integration_tests.rs | 49 ++++- 13 files changed, 343 insertions(+), 166 deletions(-) diff --git a/limitador-server/src/http_api/request_types.rs b/limitador-server/src/http_api/request_types.rs index f8d7dc45..8db6b079 100644 --- a/limitador-server/src/http_api/request_types.rs +++ b/limitador-server/src/http_api/request_types.rs @@ -18,6 +18,7 @@ pub struct CheckAndReportInfo { #[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Apiv2Schema)] pub struct Limit { + id: Option, namespace: String, max_value: u64, seconds: u64, @@ -29,6 +30,7 @@ pub struct Limit { impl From<&LimitadorLimit> for Limit { fn from(ll: &LimitadorLimit) -> Self { Self { + id: ll.id().clone(), namespace: ll.namespace().as_ref().to_string(), max_value: ll.max_value(), seconds: ll.seconds(), @@ -42,6 +44,7 @@ impl From<&LimitadorLimit> for Limit { impl From for LimitadorLimit { fn from(limit: Limit) -> Self { let mut limitador_limit = Self::new( + limit.id, limit.namespace, limit.max_value, limit.seconds, diff --git a/limitador/Cargo.toml b/limitador/Cargo.toml index 0bb26894..c04faeca 100644 --- a/limitador/Cargo.toml +++ b/limitador/Cargo.toml @@ -23,7 +23,7 @@ lenient_conditions = [] moka = { version = "0.12", features = ["sync"] } dashmap = "5.5.3" getrandom = { version = "0.2", features = ["js"] } -serde = { version = "1", features = ["derive"] } +serde = { version = "1", features = ["derive", "rc"] } postcard = { version = "1.0.4", features = ["use-std"] } serde_json = "1" rmp-serde = "1.1.0" diff --git a/limitador/benches/bench.rs b/limitador/benches/bench.rs index 601fa923..e23198d8 100644 --- a/limitador/benches/bench.rs +++ b/limitador/benches/bench.rs @@ -548,6 +548,7 @@ fn generate_test_limits(scenario: &TestScenario) -> (Vec, Vec, remaining: Option, + expires_in: Option, } @@ -72,6 +73,10 @@ impl Counter { Duration::from_secs(self.limit.seconds()) } + pub fn id(&self) -> &Option { + self.limit.id() + } + pub fn namespace(&self) -> &Namespace { self.limit.namespace() } diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index 12adb7ff..86f1d55e 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -1,11 +1,16 @@ -use crate::limit::conditions::{ErrorType, Literal, SyntaxError, Token, TokenType}; -use serde::{Deserialize, Serialize, Serializer}; use std::cmp::Ordering; use std::collections::{BTreeSet, HashMap, HashSet}; use std::error::Error; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; +use serde::{Deserialize, Serialize, Serializer}; + +#[cfg(feature = "lenient_conditions")] +pub use deprecated::check_deprecated_syntax_usages_and_reset; + +use crate::limit::conditions::{ErrorType, Literal, SyntaxError, Token, TokenType}; + #[cfg(feature = "lenient_conditions")] mod deprecated { use std::sync::atomic::{AtomicBool, Ordering}; @@ -25,9 +30,6 @@ mod deprecated { } } -#[cfg(feature = "lenient_conditions")] -pub use deprecated::check_deprecated_syntax_usages_and_reset; - #[derive(Debug, Hash, Eq, PartialEq, Clone, Serialize, Deserialize)] pub struct Namespace(String); @@ -51,6 +53,8 @@ impl From for Namespace { #[derive(Eq, Debug, Clone, Serialize, Deserialize)] pub struct Limit { + #[serde(skip_serializing, default)] + id: Option, namespace: Namespace, #[serde(skip_serializing, default)] max_value: u64, @@ -307,6 +311,7 @@ where impl Limit { pub fn new, T: TryInto>( + id: Option, namespace: N, max_value: u64, seconds: u64, @@ -319,6 +324,7 @@ impl Limit { { // the above where-clause is needed in order to call unwrap(). Self { + id: id, namespace: namespace.into(), max_value, seconds, @@ -335,6 +341,10 @@ impl Limit { &self.namespace } + pub fn id(&self) -> &Option { + &self.id + } + pub fn max_value(&self) -> u64 { self.max_value } @@ -821,7 +831,14 @@ mod tests { #[test] fn limit_can_have_an_optional_name() { - let mut limit = Limit::new("test_namespace", 10, 60, vec!["x == \"5\""], vec!["y"]); + let mut limit = Limit::new( + None, + "test_namespace", + 10, + 60, + vec!["x == \"5\""], + vec!["y"], + ); assert!(limit.name.is_none()); let name = "Test Limit"; @@ -831,7 +848,14 @@ mod tests { #[test] fn limit_applies() { - let limit = Limit::new("test_namespace", 10, 60, vec!["x == \"5\""], vec!["y"]); + let limit = Limit::new( + None, + "test_namespace", + 10, + 60, + vec!["x == \"5\""], + vec!["y"], + ); let mut values: HashMap = HashMap::new(); values.insert("x".into(), "5".into()); @@ -842,7 +866,14 @@ mod tests { #[test] fn limit_does_not_apply_when_cond_is_false() { - let limit = Limit::new("test_namespace", 10, 60, vec!["x == \"5\""], vec!["y"]); + let limit = Limit::new( + None, + "test_namespace", + 10, + 60, + vec!["x == \"5\""], + vec!["y"], + ); let mut values: HashMap = HashMap::new(); values.insert("x".into(), "1".into()); @@ -854,7 +885,7 @@ mod tests { #[test] #[cfg(feature = "lenient_conditions")] fn limit_does_not_apply_when_cond_is_false_deprecated_style() { - let limit = Limit::new("test_namespace", 10, 60, vec!["x == 5"], vec!["y"]); + let limit = Limit::new(None, "test_namespace", 10, 60, vec!["x == 5"], vec!["y"]); let mut values: HashMap = HashMap::new(); values.insert("x".into(), "1".into()); @@ -864,7 +895,14 @@ mod tests { assert!(check_deprecated_syntax_usages_and_reset()); assert!(!check_deprecated_syntax_usages_and_reset()); - let limit = Limit::new("test_namespace", 10, 60, vec!["x == foobar"], vec!["y"]); + let limit = Limit::new( + None, + "test_namespace", + 10, + 60, + vec!["x == foobar"], + vec!["y"], + ); let mut values: HashMap = HashMap::new(); values.insert("x".into(), "foobar".into()); @@ -877,7 +915,14 @@ mod tests { #[test] fn limit_does_not_apply_when_cond_var_is_not_set() { - let limit = Limit::new("test_namespace", 10, 60, vec!["x == \"5\""], vec!["y"]); + let limit = Limit::new( + None, + "test_namespace", + 10, + 60, + vec!["x == \"5\""], + vec!["y"], + ); // Notice that "x" is not set let mut values: HashMap = HashMap::new(); @@ -889,7 +934,14 @@ mod tests { #[test] fn limit_does_not_apply_when_var_not_set() { - let limit = Limit::new("test_namespace", 10, 60, vec!["x == \"5\""], vec!["y"]); + let limit = Limit::new( + None, + "test_namespace", + 10, + 60, + vec!["x == \"5\""], + vec!["y"], + ); // Notice that "y" is not set let mut values: HashMap = HashMap::new(); @@ -901,6 +953,7 @@ mod tests { #[test] fn limit_applies_when_all_its_conditions_apply() { let limit = Limit::new( + None, "test_namespace", 10, 60, @@ -919,6 +972,7 @@ mod tests { #[test] fn limit_does_not_apply_if_one_cond_doesnt() { let limit = Limit::new( + None, "test_namespace", 10, 60, @@ -998,4 +1052,18 @@ mod tests { let result = serde_json::to_string(&condition).expect("Should serialize"); assert_eq!(result, r#""foobar == \"ok\"""#.to_string()); } + + #[test] + fn limit_id() { + let limit = Limit::new( + Some("test_id".to_string()), + "test_namespace", + 10, + 60, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); + + assert_eq!(limit.id().clone(), Some("test_id".to_string())) + } } diff --git a/limitador/src/storage/disk/rocksdb_storage.rs b/limitador/src/storage/disk/rocksdb_storage.rs index 148af984..8ce80364 100644 --- a/limitador/src/storage/disk/rocksdb_storage.rs +++ b/limitador/src/storage/disk/rocksdb_storage.rs @@ -242,7 +242,14 @@ mod tests { #[test] fn opens_db_on_disk() { let namespace = "test_namespace"; - let limit = Limit::new(namespace, 1, 2, vec!["req.method == 'GET'"], vec!["app_id"]); + let limit = Limit::new( + None, + namespace, + 1, + 2, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); let counter = Counter::new(limit, HashMap::default()); let tmp = TempDir::new().expect("We should have a dir!"); diff --git a/limitador/src/storage/distributed/grpc/mod.rs b/limitador/src/storage/distributed/grpc/mod.rs index 20548850..d5fbc8b8 100644 --- a/limitador/src/storage/distributed/grpc/mod.rs +++ b/limitador/src/storage/distributed/grpc/mod.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{error::Error, io::ErrorKind, pin::Pin}; +use crate::counter::Counter; use tokio::sync::mpsc::error::TrySendError; use tokio::sync::mpsc::{Permit, Sender}; use tokio::sync::{broadcast, mpsc, Notify, RwLock}; @@ -156,9 +157,10 @@ impl Session { update = udpates_to_send.recv() => { let update = update.map_err(|_| Status::unknown("broadcast error"))?; // Multiple updates collapse into a single update for the same key - if !tx_updates_by_key.contains_key(&update.key) { - tx_updates_by_key.insert(update.key.clone(), update.value); - tx_updates_order.push(update.key); + let key = &update.key.clone(); + if !tx_updates_by_key.contains_key(key) { + tx_updates_by_key.insert(key.clone(), update); + tx_updates_order.push(key.clone()); notifier.notify_one(); } } @@ -174,7 +176,7 @@ impl Session { let key = tx_updates_order.remove(0); let cr_counter_value = tx_updates_by_key.remove(&key).unwrap().clone(); - let (expiry, values) = (*cr_counter_value).clone().into_inner(); + let (expiry, values) = (*cr_counter_value).value.clone().into_inner(); // only send the update if it has not expired. if expiry > SystemTime::now() { @@ -437,19 +439,24 @@ type CounterUpdateFn = Pin>; #[derive(Clone, Debug)] pub struct CounterEntry { pub key: Vec, - pub value: Arc>, + pub counter: Counter, + pub value: CrCounterValue, } impl CounterEntry { - pub fn new(key: Vec, value: Arc>) -> Self { - Self { key, value } + pub fn new(key: Vec, counter: Counter, value: CrCounterValue) -> Self { + Self { + key, + counter, + value, + } } } #[derive(Clone)] struct BrokerState { id: String, - publisher: broadcast::Sender, + publisher: broadcast::Sender>, on_counter_update: Arc, on_re_sync: Arc>>>, } @@ -471,7 +478,7 @@ impl Broker { on_re_sync: Sender>>, ) -> Broker { let (tx, _) = broadcast::channel(16); - let publisher: broadcast::Sender = tx; + let publisher: broadcast::Sender> = tx; Broker { listen_address, @@ -489,7 +496,7 @@ impl Broker { } } - pub fn publish(&self, counter_update: CounterEntry) { + pub fn publish(&self, counter_update: Arc) { // ignore the send error, it just means there are no active subscribers _ = self.broker_state.publisher.send(counter_update); } diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 6deda533..020918c5 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -4,22 +4,22 @@ use std::net::ToSocketAddrs; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; use tokio::sync::mpsc::Sender; use tracing::debug; use crate::counter::Counter; -use crate::limit::{Limit, Namespace}; +use crate::limit::Limit; use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::distributed::grpc::v1::CounterUpdate; use crate::storage::distributed::grpc::{Broker, CounterEntry}; +use crate::storage::keys::bin::key_for_counter_v2; use crate::storage::{Authorization, CounterStorage, StorageErr}; mod cr_counter_value; mod grpc; -pub type LimitsMap = HashMap, Arc>>; +pub type LimitsMap = HashMap, Arc>; pub struct CrInMemoryStorage { identifier: String, @@ -35,7 +35,7 @@ impl CounterStorage for CrInMemoryStorage { let mut value = 0; let key = encode_counter_to_key(counter); if let Some(counter_value) = limits.get(&key) { - value = counter_value.read() + value = counter_value.value.read() } Ok(counter.max_value() >= value + delta) } @@ -45,11 +45,15 @@ impl CounterStorage for CrInMemoryStorage { if limit.variables().is_empty() { let mut limits = self.limits.write().unwrap(); let key = encode_limit_to_key(limit); - limits.entry(key).or_insert(Arc::new(CrCounterValue::new( - self.identifier.clone(), - limit.max_value(), - Duration::from_secs(limit.seconds()), - ))); + limits.entry(key.clone()).or_insert(Arc::new(CounterEntry { + key, + counter: Counter::new(limit.clone(), HashMap::default()), + value: CrCounterValue::new( + self.identifier.clone(), + limit.max_value(), + Duration::from_secs(limit.seconds()), + ), + })); } Ok(()) } @@ -63,16 +67,20 @@ impl CounterStorage for CrInMemoryStorage { match limits.entry(key.clone()) { Entry::Vacant(entry) => { let duration = counter.window(); - let store_value = Arc::new(CrCounterValue::new( - self.identifier.clone(), - counter.max_value(), - duration, - )); - self.increment_counter(counter, key, store_value.clone(), delta, now); - entry.insert(store_value); + let value = Arc::new(CounterEntry { + key: key.clone(), + counter: counter.clone(), + value: CrCounterValue::new( + self.identifier.clone(), + counter.max_value(), + duration, + ), + }); + self.increment_counter(value.clone(), delta, now); + entry.insert(value); } Entry::Occupied(entry) => { - self.increment_counter(counter, key, entry.get().clone(), delta, now); + self.increment_counter(entry.get().clone(), delta, now); } }; Ok(()) @@ -86,7 +94,7 @@ impl CounterStorage for CrInMemoryStorage { load_counters: bool, ) -> Result { let mut first_limited = None; - let mut counter_values_to_update: Vec<(Counter, Vec)> = Vec::new(); + let mut counter_values_to_update: Vec> = Vec::new(); let now = SystemTime::now(); let mut process_counter = @@ -120,12 +128,14 @@ impl CounterStorage for CrInMemoryStorage { match limits.get(&key) { None => false, Some(store_value) => { - if let Some(limited) = process_counter(counter, store_value.read(), delta) { + if let Some(limited) = + process_counter(counter, store_value.value.read(), delta) + { if !load_counters { return Ok(limited); } } - counter_values_to_update.push((counter.clone(), key)); + counter_values_to_update.push(key); true } } @@ -135,21 +145,22 @@ impl CounterStorage for CrInMemoryStorage { if !counter_existed { // try again with a write lock to create the counter if it's still missing. let mut limits = self.limits.write().unwrap(); - let store_value = - limits - .entry(key.clone()) - .or_insert(Arc::new(CrCounterValue::new( - self.identifier.clone(), - counter.max_value(), - counter.window(), - ))); - - if let Some(limited) = process_counter(counter, store_value.read(), delta) { + let store_value = limits.entry(key.clone()).or_insert(Arc::new(CounterEntry { + key: key.clone(), + counter: counter.clone(), + value: CrCounterValue::new( + self.identifier.clone(), + counter.max_value(), + counter.window(), + ), + })); + + if let Some(limited) = process_counter(counter, store_value.value.read(), delta) { if !load_counters { return Ok(limited); } } - counter_values_to_update.push((counter.clone(), key)); + counter_values_to_update.push(key); } } @@ -159,12 +170,10 @@ impl CounterStorage for CrInMemoryStorage { // Update counters let limits = self.limits.read().unwrap(); - counter_values_to_update - .into_iter() - .for_each(|(counter, key)| { - let store_value = limits.get(&key).unwrap(); - self.increment_counter(&counter, key, store_value.clone(), delta, now); - }); + counter_values_to_update.into_iter().for_each(|key| { + let store_value = limits.get(&key).unwrap(); + self.increment_counter(store_value.clone(), delta, now); + }); Ok(Authorization::Ok) } @@ -172,25 +181,12 @@ impl CounterStorage for CrInMemoryStorage { #[tracing::instrument(skip_all)] fn get_counters(&self, limits: &HashSet>) -> Result, StorageErr> { let mut res = HashSet::new(); - - let limits: HashSet<_> = limits.iter().map(|l| encode_limit_to_key(l)).collect(); - let limits_map = self.limits.read().unwrap(); - for (key, counter_value) in limits_map.iter() { - let counter_key = decode_counter_key(key).unwrap(); - let limit_key = if !counter_key.vars.is_empty() { - let mut cloned = counter_key.clone(); - cloned.vars = HashMap::default(); - cloned.encode() - } else { - key.clone() - }; - - if limits.contains(&limit_key) { - let counter = (&counter_key, &*counter_value.clone()); - let mut counter: Counter = counter.into(); - counter.set_remaining(counter.max_value() - counter_value.read()); - counter.set_expires_in(counter_value.ttl()); + for (_, counter_entry) in limits_map.iter() { + if limits.contains(counter_entry.counter.limit()) { + let mut counter: Counter = counter_entry.counter.clone(); + counter.set_remaining(counter.max_value() - counter_entry.value.read()); + counter.set_expires_in(counter_entry.value.ttl()); if counter.expires_in().unwrap() > Duration::ZERO { res.insert(counter); } @@ -241,7 +237,9 @@ impl CrInMemoryStorage { ); let limits = limits_clone.read().unwrap(); let value = limits.get(&update.key).unwrap(); - value.merge((UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into()); + value + .value + .merge((UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into()); }), re_sync_queue_tx, ); @@ -282,17 +280,11 @@ impl CrInMemoryStorage { } } - fn increment_counter( - &self, - counter: &Counter, - store_key: Vec, - store_value: Arc>, - delta: u64, - when: SystemTime, - ) { - store_value.inc_at(delta, counter.window(), when); - self.broker - .publish(CounterEntry::new(store_key, store_value)) + fn increment_counter(&self, counter_entry: Arc, delta: u64, when: SystemTime) { + counter_entry + .value + .inc_at(delta, counter_entry.counter.window(), when); + self.broker.publish(counter_entry) } } @@ -308,7 +300,7 @@ async fn process_re_sync(limits: &Arc>, sender: Sender>, sender: Sender, - variables: HashSet, - vars: HashMap, -} - -impl CounterKey { - fn new(limit: &Limit, vars: HashMap) -> Self { - CounterKey { - namespace: limit.namespace().clone(), - seconds: limit.seconds(), - variables: limit.variables().clone(), - conditions: limit.conditions().clone(), - vars, - } - } - - fn encode(&self) -> Vec { - postcard::to_stdvec(self).unwrap() - } -} - -impl From<(&CounterKey, &CrCounterValue)> for Counter { - fn from(value: (&CounterKey, &CrCounterValue)) -> Self { - let (counter_key, store_value) = value; - let max_value = store_value.max_value(); - let mut counter = Self::new( - Limit::new( - counter_key.namespace.clone(), - max_value, - counter_key.seconds, - counter_key.conditions.clone(), - counter_key.vars.keys(), - ), - counter_key.vars.clone(), - ); - counter.set_remaining(max_value - store_value.read()); - counter.set_expires_in(store_value.ttl()); - counter - } -} - fn encode_counter_to_key(counter: &Counter) -> Vec { - let key = CounterKey::new(counter.limit(), counter.set_variables().clone()); - postcard::to_stdvec(&key).unwrap() + key_for_counter_v2(counter) } fn encode_limit_to_key(limit: &Limit) -> Vec { - let key = CounterKey::new(limit, HashMap::default()); - postcard::to_stdvec(&key).unwrap() -} - -fn decode_counter_key(key: &Vec) -> postcard::Result { - postcard::from_bytes(key.as_slice()) + let counter = Counter::new(limit.clone(), HashMap::default()); + key_for_counter_v2(&counter) } diff --git a/limitador/src/storage/in_memory.rs b/limitador/src/storage/in_memory.rs index 2fedb7e8..c1cfc688 100644 --- a/limitador/src/storage/in_memory.rs +++ b/limitador/src/storage/in_memory.rs @@ -1,14 +1,16 @@ -use crate::counter::Counter; -use crate::limit::{Limit, Namespace}; -use crate::storage::atomic_expiring_value::AtomicExpiringValue; -use crate::storage::{Authorization, CounterStorage, StorageErr}; -use moka::sync::Cache; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::ops::Deref; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; +use moka::sync::Cache; + +use crate::counter::Counter; +use crate::limit::{Limit, Namespace}; +use crate::storage::atomic_expiring_value::AtomicExpiringValue; +use crate::storage::{Authorization, CounterStorage, StorageErr}; + type NamespacedLimitCounters = HashMap>; pub struct InMemoryStorage { @@ -288,8 +290,16 @@ mod tests { fn counters_for_multiple_limit_per_ns() { let storage = InMemoryStorage::default(); let namespace = "test_namespace"; - let limit_1 = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]); + let limit_1 = Limit::new( + None, + namespace, + 1, + 1, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); let limit_2 = Limit::new( + None, namespace, 1, 10, diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 81d818c6..f39c10f8 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -12,9 +12,10 @@ // reusing this module for other storage implementations make sure that using // "{}" for sharding applies. +use std::sync::Arc; + use crate::counter::Counter; use crate::limit::Limit; -use std::sync::Arc; pub fn key_for_counter(counter: &Counter) -> String { if counter.remaining().is_some() || counter.expires_in().is_some() { @@ -81,18 +82,21 @@ pub fn partial_counter_from_counter_key(key: &str) -> Counter { #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::time::Duration; + + use crate::counter::Counter; + use crate::Limit; + use super::{ key_for_counter, key_for_counters_of_limit, partial_counter_from_counter_key, prefix_for_namespace, }; - use crate::counter::Counter; - use crate::Limit; - use std::collections::HashMap; - use std::time::Duration; #[test] fn key_for_limit_format() { let limit = Limit::new( + None, "example.com", 10, 60, @@ -107,7 +111,14 @@ mod tests { #[test] fn counter_key_and_counter_are_symmetric() { let namespace = "ns_counter:"; - let limit = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]); + let limit = Limit::new( + None, + namespace, + 1, + 1, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); let counter = Counter::new(limit.clone(), HashMap::default()); let raw = key_for_counter(&counter); assert_eq!(counter, partial_counter_from_counter_key(&raw)); @@ -118,7 +129,14 @@ mod tests { #[test] fn counter_key_does_not_include_transient_state() { let namespace = "ns_counter:"; - let limit = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]); + let limit = Limit::new( + None, + namespace, + 1, + 1, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); let counter = Counter::new(limit.clone(), HashMap::default()); let mut other = counter.clone(); other.set_remaining(123); @@ -129,12 +147,28 @@ mod tests { #[cfg(feature = "disk_storage")] pub mod bin { - use serde::{Deserialize, Serialize}; use std::collections::HashMap; + use serde::{Deserialize, Serialize}; + use crate::counter::Counter; use crate::limit::Limit; + #[derive(PartialEq, Debug, Serialize, Deserialize)] + struct IdCounterKey<'a> { + id: &'a str, + variables: Vec<(&'a str, &'a str)>, + } + + impl<'a> From<&'a Counter> for IdCounterKey<'a> { + fn from(counter: &'a Counter) -> Self { + IdCounterKey { + id: counter.id().as_ref().unwrap().as_ref(), + variables: counter.variables_for_key(), + } + } + } + #[derive(PartialEq, Debug, Serialize, Deserialize)] struct CounterKey<'a> { ns: &'a str, @@ -187,6 +221,20 @@ pub mod bin { } } + pub fn key_for_counter_v2(counter: &Counter) -> Vec { + let mut encoded_key = Vec::new(); + if counter.id().is_some() { + let key: IdCounterKey = counter.into(); + encoded_key = postcard::to_extend(&1u8, encoded_key).unwrap(); + encoded_key = postcard::to_extend(&key, encoded_key).unwrap(); + } else { + let key: CounterKey = counter.into(); + encoded_key = postcard::to_extend(&2u8, encoded_key).unwrap(); + encoded_key = postcard::to_extend(&key, encoded_key).unwrap() + } + encoded_key + } + pub fn key_for_counter(counter: &Counter) -> Vec { let key: CounterKey = counter.into(); postcard::to_stdvec(&key).unwrap() @@ -209,23 +257,26 @@ pub mod bin { .into_iter() .map(|(var, value)| (var.to_string(), value.to_string())) .collect(); - let limit = Limit::new(ns, u64::default(), seconds, conditions, map.keys()); + let limit = Limit::new(None, ns, u64::default(), seconds, conditions, map.keys()); Counter::new(limit, map) } #[cfg(test)] mod tests { + use std::collections::HashMap; + + use crate::counter::Counter; + use crate::Limit; + use super::{ key_for_counter, partial_counter_from_counter_key, prefix_for_namespace, CounterKey, }; - use crate::counter::Counter; - use crate::Limit; - use std::collections::HashMap; #[test] fn counter_key_serializes_and_back() { let namespace = "ns_counter:"; let limit = Limit::new( + None, namespace, 1, 2, @@ -248,7 +299,14 @@ pub mod bin { #[test] fn counter_key_and_counter_are_symmetric() { let namespace = "ns_counter:"; - let limit = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]); + let limit = Limit::new( + None, + namespace, + 1, + 1, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); let mut variables = HashMap::default(); variables.insert("app_id".to_string(), "123".to_string()); let counter = Counter::new(limit.clone(), variables); @@ -259,7 +317,32 @@ pub mod bin { #[test] fn counter_key_starts_with_namespace_prefix() { let namespace = "ns_counter:"; - let limit = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]); + let limit = Limit::new( + None, + namespace, + 1, + 1, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); + let counter = Counter::new(limit, HashMap::default()); + let serialized_counter = key_for_counter(&counter); + + let prefix = prefix_for_namespace(namespace); + assert_eq!(&serialized_counter[..prefix.len()], &prefix); + } + + #[test] + fn counters_with_id() { + let namespace = "ns_counter:"; + let limit = Limit::new( + Some("id200".to_string()), + namespace, + 1, + 1, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); let counter = Counter::new(limit, HashMap::default()); let serialized_counter = key_for_counter(&counter); diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 35152537..dabe7004 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -677,6 +677,7 @@ mod tests { } Counter::new( Limit::new( + None, "test_namespace", max_val, 60, diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index 9a3ae681..a824846f 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -437,6 +437,7 @@ mod tests { let mut counters_and_deltas = HashMap::new(); let counter = Counter::new( Limit::new( + None, "test_namespace", 10, 60, @@ -499,6 +500,7 @@ mod tests { async fn flush_batcher_and_update_counters_test() { let counter = Counter::new( Limit::new( + None, "test_namespace", 10, 60, @@ -558,6 +560,7 @@ mod tests { async fn flush_batcher_reverts_on_err() { let counter = Counter::new( Limit::new( + None, "test_namespace", 10, 60, diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index da20237a..241c54df 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -216,6 +216,7 @@ mod test { async fn get_namespaces(rate_limiter: &mut TestsLimiter) { let limits = vec![ Limit::new( + None, "first_namespace", 10, 60, @@ -223,6 +224,7 @@ mod test { vec!["app_id"], ), Limit::new( + None, "second_namespace", 20, 60, @@ -248,6 +250,7 @@ mod test { rate_limiter: &mut TestsLimiter, ) { let lim1 = Limit::new( + None, "first_namespace", 10, 60, @@ -256,6 +259,7 @@ mod test { ); let lim2 = Limit::new( + None, "second_namespace", 20, 60, @@ -280,6 +284,7 @@ mod test { async fn add_a_limit(rate_limiter: &mut TestsLimiter) { let limit = Limit::new( + None, "test_namespace", 10, 60, @@ -300,6 +305,7 @@ mod test { async fn add_limit_without_vars(rate_limiter: &mut TestsLimiter) { let limit = Limit::new( + None, "test_namespace", 10, 60, @@ -322,6 +328,7 @@ mod test { let namespace = "test_namespace"; let limit_1 = Limit::new( + None, namespace, 10, 60, @@ -330,6 +337,7 @@ mod test { ); let limit_2 = Limit::new( + None, namespace, 5, 60, @@ -349,6 +357,7 @@ mod test { async fn delete_limit(rate_limiter: &mut TestsLimiter) { let limit = Limit::new( + None, "test_namespace", 10, 60, @@ -366,6 +375,7 @@ mod test { async fn delete_limit_also_deletes_associated_counters(rate_limiter: &mut TestsLimiter) { let namespace = "test_namespace"; let limit = Limit::new( + None, namespace, 10, 60, @@ -401,6 +411,7 @@ mod test { let limits = [ Limit::new( + None, namespace, 10, 60, @@ -408,6 +419,7 @@ mod test { vec!["app_id"], ), Limit::new( + None, namespace, 5, 60, @@ -433,6 +445,7 @@ mod test { rate_limiter .add_limit(&Limit::new( + None, namespace1, 10, 60, @@ -441,7 +454,14 @@ mod test { )) .await; rate_limiter - .add_limit(&Limit::new(namespace2, 5, 60, vec!["x == '10'"], vec!["z"])) + .add_limit(&Limit::new( + None, + namespace2, + 5, + 60, + vec!["x == '10'"], + vec!["z"], + )) .await; rate_limiter.delete_limits(namespace1).await.unwrap(); @@ -453,6 +473,7 @@ mod test { async fn delete_limits_of_a_namespace_also_deletes_counters(rate_limiter: &mut TestsLimiter) { let namespace = "test_namespace"; let limit = Limit::new( + None, namespace, 5, 60, @@ -487,6 +508,7 @@ mod test { let namespace = "test_namespace"; let max_hits = 3; let limit = Limit::new( + None, namespace, max_hits, 60, @@ -524,6 +546,7 @@ mod test { let max_hits = 3; let limits = vec![ Limit::new( + None, namespace, max_hits, 60, @@ -531,6 +554,7 @@ mod test { vec!["app_id"], ), Limit::new( + None, namespace, max_hits + 1, 60, @@ -592,6 +616,7 @@ mod test { async fn rate_limited_with_delta_higher_than_one(rate_limiter: &mut TestsLimiter) { let namespace = "test_namespace"; let limit = Limit::new( + None, namespace, 10, 60, @@ -627,6 +652,7 @@ mod test { let max = 10; let namespace = "test_namespace"; let limit = Limit::new( + None, namespace, max, 60, @@ -650,6 +676,7 @@ mod test { let namespace = "test_namespace"; let max_hits = 3; let limit = Limit::new( + None, namespace, max_hits, 60, @@ -704,6 +731,7 @@ mod test { let namespace = "test_namespace"; let limit = Limit::new( + None, namespace, 0, // So reporting 1 more would not be allowed 60, @@ -728,6 +756,7 @@ mod test { let namespace = "test_namespace"; let limit = Limit::new( + None, namespace, 0, // So reporting 1 more would not be allowed 60, @@ -751,6 +780,7 @@ mod test { let max_hits = 3; let limit = Limit::new( + None, namespace, max_hits, 60, @@ -788,6 +818,7 @@ mod test { let max_hits = 3; let limit = Limit::new( + None, namespace, max_hits, 60, @@ -838,6 +869,7 @@ mod test { let namespace = "test_namespace"; let limit = Limit::new( + None, namespace, 10, 60, @@ -867,6 +899,7 @@ mod test { let namespace = "test_namespace"; let limit = Limit::new( + None, namespace, 0, // So reporting 1 more would not be allowed 60, @@ -895,6 +928,7 @@ mod test { let hits_app_2 = 5; let limit = Limit::new( + None, namespace, max_hits, 60, @@ -952,6 +986,7 @@ mod test { // There's a limit, but no counters. The result should be empty. let limit = Limit::new( + None, "test_namespace", 10, 60, @@ -973,6 +1008,7 @@ mod test { let limit_time = 1; let limit = Limit::new( + None, namespace, 10, limit_time, @@ -998,6 +1034,7 @@ mod test { async fn configure_with_creates_the_given_limits(rate_limiter: &mut TestsLimiter) { let first_limit = Limit::new( + None, "first_namespace", 10, 60, @@ -1006,6 +1043,7 @@ mod test { ); let second_limit = Limit::new( + None, "second_namespace", 20, 60, @@ -1037,6 +1075,7 @@ mod test { let hits_to_report = 1; let limit = Limit::new( + None, namespace, max_value, 60, @@ -1076,6 +1115,7 @@ mod test { let namespace = "test_namespace"; let limit_to_be_kept = Limit::new( + None, namespace, 10, 1, @@ -1084,6 +1124,7 @@ mod test { ); let limit_to_be_deleted = Limit::new( + None, namespace, 20, 60, @@ -1110,6 +1151,7 @@ mod test { let namespace = "test_namespace"; let limit_orig = Limit::new( + None, namespace, 10, 60, @@ -1118,6 +1160,7 @@ mod test { ); let limit_update = Limit::new( + None, namespace, 20, 60, @@ -1142,6 +1185,7 @@ mod test { let namespace = "test_namespace"; let limit_1 = Limit::new( + None, namespace, 10, 60, @@ -1150,6 +1194,7 @@ mod test { ); let limit_2 = Limit::new( + None, namespace, 20, 60, @@ -1158,6 +1203,7 @@ mod test { ); let mut limit_3 = Limit::new( + None, namespace, 20, 60, @@ -1187,6 +1233,7 @@ mod test { let namespace = "test_namespace"; let max_hits = 3; let limit = Limit::new( + None, namespace, max_hits, 60,