Skip to content

Commit

Permalink
Allow Redis to use the shorter binary key encodings
Browse files Browse the repository at this point in the history
Use the shorter binary encoding for limits/counters when the id is set, continue to use the older text encoding for backward compatibility when the id is not set.


Signed-off-by: Hiram Chirino <[email protected]>
  • Loading branch information
chirino committed Jul 10, 2024
1 parent 70ebd0e commit f5c9545
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 64 deletions.
162 changes: 109 additions & 53 deletions limitador/src/storage/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,56 @@

use crate::counter::Counter;
use crate::limit::Limit;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

pub fn key_for_counter(counter: &Counter) -> String {
if counter.remaining().is_some() || counter.expires_in().is_some() {
format!(
"{},counter:{}",
prefix_for_namespace(counter.namespace().as_ref()),
serde_json::to_string(&counter.key()).unwrap()
)
pub fn key_for_counter(counter: &Counter) -> Vec<u8> {
if counter.id().is_none() {
// continue to use the legacy text encoding...
let namespace = counter.namespace().as_ref();
let key = if counter.remaining().is_some() || counter.expires_in().is_some() {
format!(
"namespace:{{{namespace}}},counter:{}",
serde_json::to_string(&counter.key()).unwrap()
)
} else {
format!(
"namespace:{{{namespace}}},counter:{}",
serde_json::to_string(counter).unwrap()
)
};
key.into_bytes()
} else {
format!(
"{},counter:{}",
prefix_for_namespace(counter.namespace().as_ref()),
serde_json::to_string(counter).unwrap()
)
// if the id is set, use the new binary encoding...
bin::key_for_counter_v2(counter)
}
}

pub fn key_for_counters_of_limit(limit: &Limit) -> String {
format!(
"namespace:{{{}}},counters_of_limit:{}",
limit.namespace().as_ref(),
serde_json::to_string(limit).unwrap()
)
}
pub fn key_for_counters_of_limit(limit: &Limit) -> Vec<u8> {
if limit.id().is_none() {
let namespace = limit.namespace().as_ref();
format!(
"namespace:{{{namespace}}},counters_of_limit:{}",
serde_json::to_string(limit).unwrap()
)
.into_bytes()
} else {
#[derive(PartialEq, Debug, Serialize, Deserialize)]
struct IdLimitKey<'a> {
id: &'a str,
}

let id = limit.id().as_ref().unwrap();
let key = IdLimitKey { id: id.as_ref() };

pub fn prefix_for_namespace(namespace: &str) -> String {
format!("namespace:{{{namespace}}},")
let mut encoded_key = Vec::new();
encoded_key = postcard::to_extend(&2u8, encoded_key).unwrap();
encoded_key = postcard::to_extend(&key, encoded_key).unwrap();
encoded_key
}
}

pub fn counter_from_counter_key(key: &str, limit: Arc<Limit>) -> Counter {
pub fn counter_from_counter_key(key: &Vec<u8>, limit: Arc<Limit>) -> Counter {
let mut counter = partial_counter_from_counter_key(key);
if !counter.update_to_limit(Arc::clone(&limit)) {
// this means some kind of data corruption _or_ most probably
Expand All @@ -58,33 +77,38 @@ pub fn counter_from_counter_key(key: &str, limit: Arc<Limit>) -> Counter {
counter
}

pub fn partial_counter_from_counter_key(key: &str) -> Counter {
let namespace_prefix = "namespace:";
let counter_prefix = ",counter:";

// Find the start position of the counter portion
let start_pos_namespace = key
.find(namespace_prefix)
.expect("Namespace not found in the key");
let start_pos_counter = key[start_pos_namespace..]
.find(counter_prefix)
.expect("Counter not found in the key")
+ start_pos_namespace
+ counter_prefix.len();

// Extract counter JSON substring and deserialize it
let counter_str = &key[start_pos_counter..];
let counter: Counter =
serde_json::from_str(counter_str).expect("Failed to deserialize counter JSON");
counter
pub fn partial_counter_from_counter_key(key: &Vec<u8>) -> Counter {
if key.starts_with(b"namespace:") {
let key = String::from_utf8_lossy(key.as_ref());

// It's using to the legacy text encoding...
let namespace_prefix = "namespace:";
let counter_prefix = ",counter:";

// Find the start position of the counter portion
let start_pos_namespace = key
.find(namespace_prefix)
.expect("Namespace not found in the key");
let start_pos_counter = key[start_pos_namespace..]
.find(counter_prefix)
.expect("Counter not found in the key")
+ start_pos_namespace
+ counter_prefix.len();

// Extract counter JSON substring and deserialize it
let counter_str = &key[start_pos_counter..];
let counter: Counter =
serde_json::from_str(counter_str).expect("Failed to deserialize counter JSON");
counter
} else {
// It's using to the new binary encoding...
bin::partial_counter_from_counter_key_v2(key)
}
}

#[cfg(test)]
mod tests {
use super::{
key_for_counter, key_for_counters_of_limit, partial_counter_from_counter_key,
prefix_for_namespace,
};
use super::{key_for_counter, key_for_counters_of_limit, partial_counter_from_counter_key};
use crate::counter::Counter;
use crate::Limit;
use std::collections::HashMap;
Expand All @@ -100,7 +124,7 @@ mod tests {
vec!["app_id"],
);
assert_eq!(
"namespace:{example.com},counters_of_limit:{\"namespace\":\"example.com\",\"seconds\":60,\"conditions\":[\"req.method == \\\"GET\\\"\"],\"variables\":[\"app_id\"]}",
"namespace:{example.com},counters_of_limit:{\"namespace\":\"example.com\",\"seconds\":60,\"conditions\":[\"req.method == \\\"GET\\\"\"],\"variables\":[\"app_id\"]}".as_bytes(),
key_for_counters_of_limit(&limit))
}

Expand All @@ -111,8 +135,6 @@ mod tests {
let counter = Counter::new(limit.clone(), HashMap::default());
let raw = key_for_counter(&counter);
assert_eq!(counter, partial_counter_from_counter_key(&raw));
let prefix = prefix_for_namespace(namespace);
assert_eq!(&raw[0..prefix.len()], &prefix);
}

#[test]
Expand Down Expand Up @@ -204,18 +226,52 @@ pub mod bin {

pub fn key_for_counter_v2(counter: &Counter) -> Vec<u8> {
let mut encoded_key = Vec::new();
if counter.id().is_some() {
let key: IdCounterKey = counter.into();
if counter.id().is_none() {
let key: CounterKey = counter.into();
encoded_key = postcard::to_extend(&1u8, encoded_key).unwrap();
encoded_key = postcard::to_extend(&key, encoded_key).unwrap();
encoded_key = postcard::to_extend(&key, encoded_key).unwrap()
} else {
let key: CounterKey = counter.into();
let key: IdCounterKey = counter.into();
encoded_key = postcard::to_extend(&2u8, encoded_key).unwrap();
encoded_key = postcard::to_extend(&key, encoded_key).unwrap()
encoded_key = postcard::to_extend(&key, encoded_key).unwrap();
}
encoded_key
}

pub fn partial_counter_from_counter_key_v2(key: &[u8]) -> Counter {
let (version, key) = postcard::take_from_bytes::<u8>(key).unwrap();
match version {
1u8 => {
let CounterKey {
ns,
seconds,
conditions,
variables,
} = postcard::from_bytes(key).unwrap();

let map: HashMap<String, String> = variables
.into_iter()
.map(|(var, value)| (var.to_string(), value.to_string()))
.collect();
let limit = Limit::new(ns, u64::default(), seconds, conditions, map.keys());
Counter::new(limit, map)
}
2u8 => {
let IdCounterKey { id, variables } = postcard::from_bytes(key).unwrap();
let map: HashMap<String, String> = variables
.into_iter()
.map(|(var, value)| (var.to_string(), value.to_string()))
.collect();

// we are not able to rebuild the full limit since we only have the id and variables.
let mut limit = Limit::new::<&str, &str>("", u64::default(), 0, vec![], map.keys());
limit.set_id(id.to_string());
Counter::new(limit, map)
}
_ => panic!("Unknown version: {}", version),
}
}

pub fn key_for_counter(counter: &Counter) -> Vec<u8> {
let key: CounterKey = counter.into();
postcard::to_stdvec(&key).unwrap()
Expand Down
10 changes: 5 additions & 5 deletions limitador/src/storage/redis/redis_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl AsyncCounterStorage for AsyncRedisStorage {
let mut con = self.conn_manager.clone();

match con
.get::<String, Option<i64>>(key_for_counter(counter))
.get::<Vec<u8>, Option<i64>>(key_for_counter(counter))
.instrument(debug_span!("datastore"))
.await?
{
Expand Down Expand Up @@ -71,7 +71,7 @@ impl AsyncCounterStorage for AsyncRedisStorage {
load_counters: bool,
) -> Result<Authorization, StorageErr> {
let mut con = self.conn_manager.clone();
let counter_keys: Vec<String> = counters.iter().map(key_for_counter).collect();
let counter_keys: Vec<Vec<u8>> = counters.iter().map(key_for_counter).collect();

if load_counters {
let script = redis::Script::new(VALUES_AND_TTLS);
Expand Down Expand Up @@ -139,7 +139,7 @@ impl AsyncCounterStorage for AsyncRedisStorage {

for limit in limits {
let counter_keys = {
con.smembers::<String, HashSet<String>>(key_for_counters_of_limit(limit))
con.smembers::<Vec<u8>, HashSet<Vec<u8>>>(key_for_counters_of_limit(limit))
.instrument(debug_span!("datastore"))
.await?
};
Expand All @@ -156,7 +156,7 @@ impl AsyncCounterStorage for AsyncRedisStorage {
// This does not cause any bugs, but consumes memory
// unnecessarily.
let option = {
con.get::<String, Option<i64>>(counter_key.clone())
con.get::<Vec<u8>, Option<i64>>(counter_key.clone())
.instrument(debug_span!("datastore"))
.await?
};
Expand Down Expand Up @@ -218,7 +218,7 @@ impl AsyncRedisStorage {
let mut con = self.conn_manager.clone();

let counter_keys = {
con.smembers::<String, HashSet<String>>(key_for_counters_of_limit(limit))
con.smembers::<Vec<u8>, HashSet<Vec<u8>>>(key_for_counters_of_limit(limit))
.instrument(debug_span!("datastore"))
.await?
};
Expand Down
12 changes: 6 additions & 6 deletions limitador/src/storage/redis/redis_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl CounterStorage for RedisStorage {
fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result<bool, StorageErr> {
let mut con = self.conn_pool.get()?;

match con.get::<String, Option<i64>>(key_for_counter(counter))? {
match con.get::<Vec<u8>, Option<i64>>(key_for_counter(counter))? {
Some(val) => Ok(u64::try_from(val).unwrap_or(0) + delta <= counter.max_value()),
None => Ok(counter.max_value().checked_sub(delta).is_some()),
}
Expand Down Expand Up @@ -62,7 +62,7 @@ impl CounterStorage for RedisStorage {
load_counters: bool,
) -> Result<Authorization, StorageErr> {
let mut con = self.conn_pool.get()?;
let counter_keys: Vec<String> = counters.iter().map(key_for_counter).collect();
let counter_keys: Vec<Vec<u8>> = counters.iter().map(key_for_counter).collect();

if load_counters {
let script = redis::Script::new(VALUES_AND_TTLS);
Expand Down Expand Up @@ -115,7 +115,7 @@ impl CounterStorage for RedisStorage {

for limit in limits {
let counter_keys =
con.smembers::<String, HashSet<String>>(key_for_counters_of_limit(limit))?;
con.smembers::<Vec<u8>, HashSet<Vec<u8>>>(key_for_counters_of_limit(limit))?;

for counter_key in counter_keys {
let mut counter: Counter =
Expand All @@ -128,7 +128,7 @@ impl CounterStorage for RedisStorage {
// do the "get" + "delete if none" atomically.
// This does not cause any bugs, but consumes memory
// unnecessarily.
if let Some(val) = con.get::<String, Option<i64>>(counter_key.clone())? {
if let Some(val) = con.get::<Vec<u8>, Option<i64>>(counter_key.clone())? {
counter.set_remaining(
limit
.max_value()
Expand All @@ -150,8 +150,8 @@ impl CounterStorage for RedisStorage {
let mut con = self.conn_pool.get()?;

for limit in limits {
let counter_keys =
con.smembers::<String, HashSet<String>>(key_for_counters_of_limit(limit.deref()))?;
let counter_keys = con
.smembers::<Vec<u8>, HashSet<Vec<u8>>>(key_for_counters_of_limit(limit.deref()))?;

for counter_key in counter_keys {
con.del(counter_key)?;
Expand Down

0 comments on commit f5c9545

Please sign in to comment.