diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 8ad9a2cf..46b8292d 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -14,37 +14,56 @@ use crate::counter::Counter; use crate::limit::Limit; +use serde::{Deserialize, Serialize}; use std::sync::Arc; -pub fn key_for_counter(counter: &Counter) -> String { - if counter.remaining().is_some() || counter.expires_in().is_some() { - format!( - "{},counter:{}", - prefix_for_namespace(counter.namespace().as_ref()), - serde_json::to_string(&counter.key()).unwrap() - ) +pub fn key_for_counter(counter: &Counter) -> Vec { + if counter.id().is_none() { + // continue to use the legacy text encoding... + let namespace = counter.namespace().as_ref(); + let key = if counter.remaining().is_some() || counter.expires_in().is_some() { + format!( + "namespace:{{{namespace}}},counter:{}", + serde_json::to_string(&counter.key()).unwrap() + ) + } else { + format!( + "namespace:{{{namespace}}},counter:{}", + serde_json::to_string(counter).unwrap() + ) + }; + key.into_bytes() } else { - format!( - "{},counter:{}", - prefix_for_namespace(counter.namespace().as_ref()), - serde_json::to_string(counter).unwrap() - ) + // if the id is set, use the new binary encoding... + bin::key_for_counter_v2(counter) } } -pub fn key_for_counters_of_limit(limit: &Limit) -> String { - format!( - "namespace:{{{}}},counters_of_limit:{}", - limit.namespace().as_ref(), - serde_json::to_string(limit).unwrap() - ) -} +pub fn key_for_counters_of_limit(limit: &Limit) -> Vec { + if limit.id().is_none() { + let namespace = limit.namespace().as_ref(); + format!( + "namespace:{{{namespace}}},counters_of_limit:{}", + serde_json::to_string(limit).unwrap() + ) + .into_bytes() + } else { + #[derive(PartialEq, Debug, Serialize, Deserialize)] + struct IdLimitKey<'a> { + id: &'a str, + } + + let id = limit.id().as_ref().unwrap(); + let key = IdLimitKey { id: id.as_ref() }; -pub fn prefix_for_namespace(namespace: &str) -> String { - format!("namespace:{{{namespace}}},") + let mut encoded_key = Vec::new(); + encoded_key = postcard::to_extend(&2u8, encoded_key).unwrap(); + encoded_key = postcard::to_extend(&key, encoded_key).unwrap(); + encoded_key + } } -pub fn counter_from_counter_key(key: &str, limit: Arc) -> Counter { +pub fn counter_from_counter_key(key: &Vec, limit: Arc) -> Counter { let mut counter = partial_counter_from_counter_key(key); if !counter.update_to_limit(Arc::clone(&limit)) { // this means some kind of data corruption _or_ most probably @@ -58,33 +77,38 @@ pub fn counter_from_counter_key(key: &str, limit: Arc) -> Counter { counter } -pub fn partial_counter_from_counter_key(key: &str) -> Counter { - let namespace_prefix = "namespace:"; - let counter_prefix = ",counter:"; - - // Find the start position of the counter portion - let start_pos_namespace = key - .find(namespace_prefix) - .expect("Namespace not found in the key"); - let start_pos_counter = key[start_pos_namespace..] - .find(counter_prefix) - .expect("Counter not found in the key") - + start_pos_namespace - + counter_prefix.len(); - - // Extract counter JSON substring and deserialize it - let counter_str = &key[start_pos_counter..]; - let counter: Counter = - serde_json::from_str(counter_str).expect("Failed to deserialize counter JSON"); - counter +pub fn partial_counter_from_counter_key(key: &Vec) -> Counter { + if key.starts_with(b"namespace:") { + let key = String::from_utf8_lossy(key.as_ref()); + + // It's using to the legacy text encoding... + let namespace_prefix = "namespace:"; + let counter_prefix = ",counter:"; + + // Find the start position of the counter portion + let start_pos_namespace = key + .find(namespace_prefix) + .expect("Namespace not found in the key"); + let start_pos_counter = key[start_pos_namespace..] + .find(counter_prefix) + .expect("Counter not found in the key") + + start_pos_namespace + + counter_prefix.len(); + + // Extract counter JSON substring and deserialize it + let counter_str = &key[start_pos_counter..]; + let counter: Counter = + serde_json::from_str(counter_str).expect("Failed to deserialize counter JSON"); + counter + } else { + // It's using to the new binary encoding... + bin::partial_counter_from_counter_key_v2(key) + } } #[cfg(test)] mod tests { - use super::{ - key_for_counter, key_for_counters_of_limit, partial_counter_from_counter_key, - prefix_for_namespace, - }; + use super::{key_for_counter, key_for_counters_of_limit, partial_counter_from_counter_key}; use crate::counter::Counter; use crate::Limit; use std::collections::HashMap; @@ -100,7 +124,7 @@ mod tests { vec!["app_id"], ); assert_eq!( - "namespace:{example.com},counters_of_limit:{\"namespace\":\"example.com\",\"seconds\":60,\"conditions\":[\"req.method == \\\"GET\\\"\"],\"variables\":[\"app_id\"]}", + "namespace:{example.com},counters_of_limit:{\"namespace\":\"example.com\",\"seconds\":60,\"conditions\":[\"req.method == \\\"GET\\\"\"],\"variables\":[\"app_id\"]}".as_bytes(), key_for_counters_of_limit(&limit)) } @@ -111,8 +135,6 @@ mod tests { let counter = Counter::new(limit.clone(), HashMap::default()); let raw = key_for_counter(&counter); assert_eq!(counter, partial_counter_from_counter_key(&raw)); - let prefix = prefix_for_namespace(namespace); - assert_eq!(&raw[0..prefix.len()], &prefix); } #[test] @@ -204,18 +226,52 @@ pub mod bin { pub fn key_for_counter_v2(counter: &Counter) -> Vec { let mut encoded_key = Vec::new(); - if counter.id().is_some() { - let key: IdCounterKey = counter.into(); + if counter.id().is_none() { + let key: CounterKey = counter.into(); encoded_key = postcard::to_extend(&1u8, encoded_key).unwrap(); - encoded_key = postcard::to_extend(&key, encoded_key).unwrap(); + encoded_key = postcard::to_extend(&key, encoded_key).unwrap() } else { - let key: CounterKey = counter.into(); + let key: IdCounterKey = counter.into(); encoded_key = postcard::to_extend(&2u8, encoded_key).unwrap(); - encoded_key = postcard::to_extend(&key, encoded_key).unwrap() + encoded_key = postcard::to_extend(&key, encoded_key).unwrap(); } encoded_key } + pub fn partial_counter_from_counter_key_v2(key: &[u8]) -> Counter { + let (version, key) = postcard::take_from_bytes::(key).unwrap(); + match version { + 1u8 => { + let CounterKey { + ns, + seconds, + conditions, + variables, + } = postcard::from_bytes(key).unwrap(); + + let map: HashMap = variables + .into_iter() + .map(|(var, value)| (var.to_string(), value.to_string())) + .collect(); + let limit = Limit::new(ns, u64::default(), seconds, conditions, map.keys()); + Counter::new(limit, map) + } + 2u8 => { + let IdCounterKey { id, variables } = postcard::from_bytes(key).unwrap(); + let map: HashMap = variables + .into_iter() + .map(|(var, value)| (var.to_string(), value.to_string())) + .collect(); + + // we are not able to rebuild the full limit since we only have the id and variables. + let mut limit = Limit::new::<&str, &str>("", u64::default(), 0, vec![], map.keys()); + limit.set_id(id.to_string()); + Counter::new(limit, map) + } + _ => panic!("Unknown version: {}", version), + } + } + pub fn key_for_counter(counter: &Counter) -> Vec { let key: CounterKey = counter.into(); postcard::to_stdvec(&key).unwrap() diff --git a/limitador/src/storage/redis/redis_async.rs b/limitador/src/storage/redis/redis_async.rs index d29e7b3a..4e5aca48 100644 --- a/limitador/src/storage/redis/redis_async.rs +++ b/limitador/src/storage/redis/redis_async.rs @@ -38,7 +38,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { let mut con = self.conn_manager.clone(); match con - .get::>(key_for_counter(counter)) + .get::, Option>(key_for_counter(counter)) .instrument(debug_span!("datastore")) .await? { @@ -71,7 +71,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { load_counters: bool, ) -> Result { let mut con = self.conn_manager.clone(); - let counter_keys: Vec = counters.iter().map(key_for_counter).collect(); + let counter_keys: Vec> = counters.iter().map(key_for_counter).collect(); if load_counters { let script = redis::Script::new(VALUES_AND_TTLS); @@ -139,7 +139,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { for limit in limits { let counter_keys = { - con.smembers::>(key_for_counters_of_limit(limit)) + con.smembers::, HashSet>>(key_for_counters_of_limit(limit)) .instrument(debug_span!("datastore")) .await? }; @@ -156,7 +156,7 @@ impl AsyncCounterStorage for AsyncRedisStorage { // This does not cause any bugs, but consumes memory // unnecessarily. let option = { - con.get::>(counter_key.clone()) + con.get::, Option>(counter_key.clone()) .instrument(debug_span!("datastore")) .await? }; @@ -218,7 +218,7 @@ impl AsyncRedisStorage { let mut con = self.conn_manager.clone(); let counter_keys = { - con.smembers::>(key_for_counters_of_limit(limit)) + con.smembers::, HashSet>>(key_for_counters_of_limit(limit)) .instrument(debug_span!("datastore")) .await? }; diff --git a/limitador/src/storage/redis/redis_sync.rs b/limitador/src/storage/redis/redis_sync.rs index 81eb3f11..66ad9abc 100644 --- a/limitador/src/storage/redis/redis_sync.rs +++ b/limitador/src/storage/redis/redis_sync.rs @@ -29,7 +29,7 @@ impl CounterStorage for RedisStorage { fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { let mut con = self.conn_pool.get()?; - match con.get::>(key_for_counter(counter))? { + match con.get::, Option>(key_for_counter(counter))? { Some(val) => Ok(u64::try_from(val).unwrap_or(0) + delta <= counter.max_value()), None => Ok(counter.max_value().checked_sub(delta).is_some()), } @@ -62,7 +62,7 @@ impl CounterStorage for RedisStorage { load_counters: bool, ) -> Result { let mut con = self.conn_pool.get()?; - let counter_keys: Vec = counters.iter().map(key_for_counter).collect(); + let counter_keys: Vec> = counters.iter().map(key_for_counter).collect(); if load_counters { let script = redis::Script::new(VALUES_AND_TTLS); @@ -115,7 +115,7 @@ impl CounterStorage for RedisStorage { for limit in limits { let counter_keys = - con.smembers::>(key_for_counters_of_limit(limit))?; + con.smembers::, HashSet>>(key_for_counters_of_limit(limit))?; for counter_key in counter_keys { let mut counter: Counter = @@ -128,7 +128,7 @@ impl CounterStorage for RedisStorage { // do the "get" + "delete if none" atomically. // This does not cause any bugs, but consumes memory // unnecessarily. - if let Some(val) = con.get::>(counter_key.clone())? { + if let Some(val) = con.get::, Option>(counter_key.clone())? { counter.set_remaining( limit .max_value() @@ -150,8 +150,8 @@ impl CounterStorage for RedisStorage { let mut con = self.conn_pool.get()?; for limit in limits { - let counter_keys = - con.smembers::>(key_for_counters_of_limit(limit.deref()))?; + let counter_keys = con + .smembers::, HashSet>>(key_for_counters_of_limit(limit.deref()))?; for counter_key in counter_keys { con.del(counter_key)?;