diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index 7fe15fd2..dc77f159 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -1,5 +1,5 @@ use crate::limit::conditions::{ErrorType, Literal, SyntaxError, Token, TokenType}; -use serde::{Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::collections::{BTreeSet, HashMap, HashSet}; use std::error::Error; @@ -28,7 +28,7 @@ mod deprecated { #[cfg(feature = "lenient_conditions")] pub use deprecated::check_deprecated_syntax_usages_and_reset; -#[derive(Debug, Hash, Eq, PartialEq, Clone, Serialize, Deserialize)] +#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord, Serialize, Deserialize)] pub struct Namespace(String); impl From<&str> for Namespace { @@ -49,7 +49,7 @@ impl From for Namespace { } } -#[derive(Eq, Debug, Clone, Serialize, Deserialize)] +#[derive(Eq, Debug, Clone, PartialOrd, Ord, Serialize, Deserialize)] pub struct Limit { #[serde(skip_serializing, default)] id: Option, @@ -62,13 +62,11 @@ pub struct Limit { // Need to sort to generate the same object when using the JSON as a key or // value in Redis. - #[serde(serialize_with = "ordered_condition_set")] - conditions: HashSet, - #[serde(serialize_with = "ordered_set")] - variables: HashSet, + conditions: BTreeSet, + variables: BTreeSet, } -#[derive(Deserialize, Serialize, PartialEq, Eq, Debug, Clone, Hash)] +#[derive(Deserialize, Serialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)] #[serde(try_from = "String", into = "String")] pub struct Condition { var_name: String, @@ -267,7 +265,7 @@ impl From for String { } } -#[derive(PartialEq, Eq, Debug, Clone, Hash)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Hash)] pub enum Predicate { Equal, NotEqual, @@ -291,22 +289,6 @@ impl From for String { } } -fn ordered_condition_set(value: &HashSet, serializer: S) -> Result -where - S: Serializer, -{ - let ordered: BTreeSet = value.iter().map(|c| c.clone().into()).collect(); - ordered.serialize(serializer) -} - -fn ordered_set(value: &HashSet, serializer: S) -> Result -where - S: Serializer, -{ - let ordered: BTreeSet<_> = value.iter().collect(); - ordered.serialize(serializer) -} - impl Limit { pub fn new, T: TryInto>( namespace: N, diff --git a/limitador/src/storage/in_memory.rs b/limitador/src/storage/in_memory.rs index 2fedb7e8..b01eeb88 100644 --- a/limitador/src/storage/in_memory.rs +++ b/limitador/src/storage/in_memory.rs @@ -3,35 +3,32 @@ use crate::limit::{Limit, Namespace}; use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::{Authorization, CounterStorage, StorageErr}; use moka::sync::Cache; -use std::collections::hash_map::Entry; -use std::collections::{HashMap, HashSet}; +use std::collections::btree_map::Entry; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::ops::Deref; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; -type NamespacedLimitCounters = HashMap>; - pub struct InMemoryStorage { - limits_for_namespace: RwLock>, + simple_limits: RwLock>, qualified_counters: Cache>, } impl CounterStorage for InMemoryStorage { #[tracing::instrument(skip_all)] fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { - let limits_by_namespace = self.limits_for_namespace.read().unwrap(); - - let mut value = 0; - - if counter.is_qualified() { - if let Some(counter) = self.qualified_counters.get(counter) { - value = counter.value(); - } - } else if let Some(limits) = limits_by_namespace.get(counter.limit().namespace()) { - if let Some(counter) = limits.get(counter.limit()) { - value = counter.value(); - } - } + let value = if counter.is_qualified() { + self.qualified_counters + .get(counter) + .map(|c| c.value()) + .unwrap_or_default() + } else { + let limits_by_namespace = self.simple_limits.read().unwrap(); + limits_by_namespace + .get(counter.limit()) + .map(|c| c.value()) + .unwrap_or_default() + }; Ok(counter.max_value() >= value + delta) } @@ -39,19 +36,15 @@ impl CounterStorage for InMemoryStorage { #[tracing::instrument(skip_all)] fn add_counter(&self, limit: &Limit) -> Result<(), StorageErr> { if limit.variables().is_empty() { - let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); - limits_by_namespace - .entry(limit.namespace().clone()) - .or_default() - .entry(limit.clone()) - .or_default(); + let mut limits_by_namespace = self.simple_limits.write().unwrap(); + limits_by_namespace.entry(limit.clone()).or_default(); } Ok(()) } #[tracing::instrument(skip_all)] fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { - let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); + let mut counters = self.simple_limits.write().unwrap(); let now = SystemTime::now(); if counter.is_qualified() { let value = match self.qualified_counters.get(counter) { @@ -62,23 +55,13 @@ impl CounterStorage for InMemoryStorage { }; value.update(delta, counter.window(), now); } else { - match limits_by_namespace.entry(counter.limit().namespace().clone()) { + match counters.entry(counter.limit().clone()) { Entry::Vacant(v) => { - let mut limits = HashMap::new(); - limits.insert( - counter.limit().clone(), - AtomicExpiringValue::new(delta, now + counter.window()), - ); - v.insert(limits); + v.insert(AtomicExpiringValue::new(delta, now + counter.window())); + } + Entry::Occupied(o) => { + o.get().update(delta, counter.window(), now); } - Entry::Occupied(mut o) => match o.get_mut().entry(counter.limit().clone()) { - Entry::Vacant(v) => { - v.insert(AtomicExpiringValue::new(delta, now + counter.window())); - } - Entry::Occupied(o) => { - o.get().update(delta, counter.window(), now); - } - }, } } Ok(()) @@ -91,7 +74,7 @@ impl CounterStorage for InMemoryStorage { delta: u64, load_counters: bool, ) -> Result { - let limits_by_namespace = self.limits_for_namespace.read().unwrap(); + let limits_by_namespace = self.simple_limits.read().unwrap(); let mut first_limited = None; let mut counter_values_to_update: Vec<(&AtomicExpiringValue, Duration)> = Vec::new(); let mut qualified_counter_values_to_updated: Vec<(Arc, Duration)> = @@ -119,10 +102,8 @@ impl CounterStorage for InMemoryStorage { // Process simple counters for counter in counters.iter_mut().filter(|c| !c.is_qualified()) { - let atomic_expiring_value: &AtomicExpiringValue = limits_by_namespace - .get(counter.limit().namespace()) - .and_then(|limits| limits.get(counter.limit())) - .unwrap(); + let atomic_expiring_value: &AtomicExpiringValue = + limits_by_namespace.get(counter.limit()).unwrap(); if let Some(limited) = process_counter(counter, atomic_expiring_value.value(), delta) { if !load_counters { @@ -135,7 +116,7 @@ impl CounterStorage for InMemoryStorage { // Process qualified counters for counter in counters.iter_mut().filter(|c| c.is_qualified()) { let value = match self.qualified_counters.get(counter) { - None => self.qualified_counters.get_with(counter.clone(), || { + None => self.qualified_counters.get_with_by_ref(counter, || { Arc::new(AtomicExpiringValue::new(0, now + counter.window())) }), Some(counter) => counter, @@ -171,24 +152,14 @@ impl CounterStorage for InMemoryStorage { fn get_counters(&self, limits: &HashSet>) -> Result, StorageErr> { let mut res = HashSet::new(); - let namespaces: HashSet<&Namespace> = limits.iter().map(|l| l.namespace()).collect(); - let limits_by_namespace = self.limits_for_namespace.read().unwrap(); - - for namespace in namespaces { - if let Some(limits) = limits_by_namespace.get(namespace) { - for limit in limits.keys() { - if limits.contains_key(limit) { - for (counter, expiring_value) in self.counters_in_namespace(namespace) { - let mut counter_with_val = counter.clone(); - counter_with_val.set_remaining( - counter_with_val.max_value() - expiring_value.value(), - ); - counter_with_val.set_expires_in(expiring_value.ttl()); - if counter_with_val.expires_in().unwrap() > Duration::ZERO { - res.insert(counter_with_val); - } - } - } + for limit in limits { + for (counter, expiring_value) in self.counters_in_namespace(limit.namespace()) { + let mut counter_with_val = counter.clone(); + counter_with_val + .set_remaining(counter_with_val.max_value() - expiring_value.value()); + counter_with_val.set_expires_in(expiring_value.ttl()); + if counter_with_val.expires_in().unwrap() > Duration::ZERO { + res.insert(counter_with_val); } } } @@ -218,7 +189,7 @@ impl CounterStorage for InMemoryStorage { #[tracing::instrument(skip_all)] fn clear(&self) -> Result<(), StorageErr> { - self.limits_for_namespace.write().unwrap().clear(); + self.simple_limits.write().unwrap().clear(); Ok(()) } } @@ -226,7 +197,7 @@ impl CounterStorage for InMemoryStorage { impl InMemoryStorage { pub fn new(cache_size: u64) -> Self { Self { - limits_for_namespace: RwLock::new(HashMap::new()), + simple_limits: RwLock::new(BTreeMap::new()), qualified_counters: Cache::new(cache_size), } } @@ -237,11 +208,11 @@ impl InMemoryStorage { ) -> HashMap { let mut res: HashMap = HashMap::new(); - if let Some(counters_by_limit) = self.limits_for_namespace.read().unwrap().get(namespace) { - for (limit, value) in counters_by_limit { + for (limit, counter) in self.simple_limits.read().unwrap().iter() { + if limit.namespace() == namespace { res.insert( Counter::new(limit.clone(), HashMap::default()), - value.clone(), + counter.clone(), ); } } @@ -256,14 +227,7 @@ impl InMemoryStorage { } fn delete_counters_of_limit(&self, limit: &Limit) { - if let Some(counters_by_limit) = self - .limits_for_namespace - .write() - .unwrap() - .get_mut(limit.namespace()) - { - counters_by_limit.remove(limit); - } + self.simple_limits.write().unwrap().remove(limit); } fn counter_is_within_limits(counter: &Counter, current_val: Option<&u64>, delta: u64) -> bool {