Skip to content

Commit

Permalink
Expose Context all the way up
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Snaps <[email protected]>
  • Loading branch information
alexsnaps committed Dec 3, 2024
1 parent 345be88 commit e0ae4f6
Show file tree
Hide file tree
Showing 14 changed files with 192 additions and 196 deletions.
6 changes: 4 additions & 2 deletions limitador-server/src/envoy_rls/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,20 @@ impl RateLimitService for MyRateLimiter {
req.hits_addend
};

let ctx = (&values).into();

let rate_limited_resp = match &*self.limiter {
Limiter::Blocking(limiter) => limiter.check_rate_limited_and_update(
&namespace,
&values,
&ctx,
u64::from(hits_addend),
self.rate_limit_headers != RateLimitHeaders::None,
),
Limiter::Async(limiter) => {
limiter
.check_rate_limited_and_update(
&namespace,
&values,
&ctx,
u64::from(hits_addend),
self.rate_limit_headers != RateLimitHeaders::None,
)
Expand Down
20 changes: 9 additions & 11 deletions limitador-server/src/http_api/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ async fn check(
response_headers: _,
} = request.into_inner();
let namespace = namespace.into();
let ctx = (&values).into();
let is_rate_limited_result = match state.get_ref().limiter() {
Limiter::Blocking(limiter) => limiter.is_rate_limited(&namespace, &values, delta),
Limiter::Async(limiter) => limiter.is_rate_limited(&namespace, &values, delta).await,
Limiter::Blocking(limiter) => limiter.is_rate_limited(&namespace, &ctx, delta),
Limiter::Async(limiter) => limiter.is_rate_limited(&namespace, &ctx, delta).await,
};

match is_rate_limited_result {
Expand Down Expand Up @@ -152,9 +153,10 @@ async fn report(
response_headers: _,
} = request.into_inner();
let namespace = namespace.into();
let ctx = (&values).into();
let update_counters_result = match data.get_ref().limiter() {
Limiter::Blocking(limiter) => limiter.update_counters(&namespace, &values, delta),
Limiter::Async(limiter) => limiter.update_counters(&namespace, &values, delta).await,
Limiter::Blocking(limiter) => limiter.update_counters(&namespace, &ctx, delta),
Limiter::Async(limiter) => limiter.update_counters(&namespace, &ctx, delta).await,
};

match update_counters_result {
Expand All @@ -176,22 +178,18 @@ async fn check_and_report(
response_headers,
} = request.into_inner();
let namespace = namespace.into();
let ctx = (&values).into();
let rate_limit_data = data.get_ref();
let rate_limited_and_update_result = match rate_limit_data.limiter() {
Limiter::Blocking(limiter) => limiter.check_rate_limited_and_update(
&namespace,
&values,
&ctx,
delta,
response_headers.is_some(),
),
Limiter::Async(limiter) => {
limiter
.check_rate_limited_and_update(
&namespace,
&values,
delta,
response_headers.is_some(),
)
.check_rate_limited_and_update(&namespace, &ctx, delta, response_headers.is_some())
.await
}
};
Expand Down
12 changes: 6 additions & 6 deletions limitador/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ fn bench_is_rate_limited(
rate_limiter
.is_rate_limited(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
)
.unwrap(),
Expand Down Expand Up @@ -357,7 +357,7 @@ fn async_bench_is_rate_limited<F>(
rate_limiter
.is_rate_limited(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
)
.await
Expand All @@ -383,7 +383,7 @@ fn bench_update_counters(
rate_limiter
.update_counters(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
)
.unwrap();
Expand All @@ -410,7 +410,7 @@ fn async_bench_update_counters<F>(
rate_limiter
.update_counters(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
)
.await
Expand All @@ -437,7 +437,7 @@ fn bench_check_rate_limited_and_update(
rate_limiter
.check_rate_limited_and_update(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
false,
)
Expand Down Expand Up @@ -467,7 +467,7 @@ fn async_bench_check_rate_limited_and_update<F>(
rate_limiter
.check_rate_limited_and_update(
&params.namespace.to_owned().into(),
&params.values,
&(&params.values).into(),
params.delta,
false,
)
Expand Down
20 changes: 6 additions & 14 deletions limitador/src/counter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::limit::{Limit, Namespace};
use crate::limit::{Context, Limit, Namespace};
use crate::LimitadorResult;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
Expand All @@ -17,15 +17,9 @@ pub struct Counter {
}

impl Counter {
pub fn new<L: Into<Arc<Limit>>>(
limit: L,
set_variables: HashMap<String, String>,
) -> LimitadorResult<Option<Self>> {
pub fn new<L: Into<Arc<Limit>>>(limit: L, ctx: &Context) -> LimitadorResult<Option<Self>> {
let limit = limit.into();
let mut vars = set_variables;
vars.retain(|var, _| limit.has_variable(var));

let variables = limit.resolve_variables(vars)?;
let variables = limit.resolve_variables(ctx)?;
match variables {
None => Ok(None),
Some(variables) => Ok(Some(Self {
Expand Down Expand Up @@ -159,11 +153,9 @@ mod tests {
Vec::default(),
[var.try_into().expect("failed parsing!")],
);
let counter = Counter::new(
limit,
HashMap::from([("ts".to_string(), "2019-10-12T13:20:50.52Z".to_string())]),
)
.expect("failed creating counter");
let map = HashMap::from([("ts".to_string(), "2019-10-12T13:20:50.52Z".to_string())]);
let ctx = (&map).into();
let counter = Counter::new(limit, &ctx).expect("failed creating counter");
assert_eq!(
counter.unwrap().set_variables.get(var),
Some("13".to_string()).as_ref()
Expand Down
60 changes: 29 additions & 31 deletions limitador/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
//!
//! ```
//! use limitador::RateLimiter;
//! use limitador::limit::Limit;
//! use limitador::limit::{Limit, Context};
//! use std::collections::HashMap;
//!
//! let mut rate_limiter = RateLimiter::new(1000);
Expand All @@ -116,22 +116,23 @@
//!
//! // Check if we can report
//! let namespace = "my_namespace".into();
//! assert!(!rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap());
//! let ctx = &values_to_report.into();
//! assert!(!rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap());
//!
//! // Report
//! rate_limiter.update_counters(&namespace, &values_to_report, 1).unwrap();
//! rate_limiter.update_counters(&namespace, &ctx, 1).unwrap();
//!
//! // Check and report again
//! assert!(!rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap());
//! rate_limiter.update_counters(&namespace, &values_to_report, 1).unwrap();
//! assert!(!rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap());
//! rate_limiter.update_counters(&namespace, &ctx, 1).unwrap();
//!
//! // We've already reported 2, so reporting another one should not be allowed
//! assert!(rate_limiter.is_rate_limited(&namespace, &values_to_report, 1).unwrap());
//! assert!(rate_limiter.is_rate_limited(&namespace, &ctx, 1).unwrap());
//!
//! // You can also check and report if not limited in a single call. It's useful
//! // for example, when calling Limitador from a proxy. Instead of doing 2
//! // separate calls, we can issue just one:
//! rate_limiter.check_rate_limited_and_update(&namespace, &values_to_report, 1, false).unwrap();
//! rate_limiter.check_rate_limited_and_update(&namespace, &ctx, 1, false).unwrap();
//! ```
//!
//! # Async
Expand Down Expand Up @@ -194,7 +195,7 @@

use crate::counter::Counter;
use crate::errors::LimitadorError;
use crate::limit::{Limit, Namespace};
use crate::limit::{Context, Limit, Namespace};
use crate::storage::in_memory::InMemoryStorage;
use crate::storage::{AsyncCounterStorage, AsyncStorage, Authorization, CounterStorage, Storage};
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -358,7 +359,7 @@ impl RateLimiter {
pub fn is_rate_limited(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
values: &Context,
delta: u64,
) -> LimitadorResult<bool> {
let counters = self.counters_that_apply(namespace, values)?;
Expand All @@ -380,10 +381,10 @@ impl RateLimiter {
pub fn update_counters(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context,
delta: u64,
) -> LimitadorResult<()> {
let counters = self.counters_that_apply(namespace, values)?;
let counters = self.counters_that_apply(namespace, ctx)?;

counters
.iter()
Expand All @@ -394,11 +395,11 @@ impl RateLimiter {
pub fn check_rate_limited_and_update(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context,
delta: u64,
load_counters: bool,
) -> LimitadorResult<CheckResult> {
let mut counters = self.counters_that_apply(namespace, values)?;
let mut counters = self.counters_that_apply(namespace, ctx)?;

if counters.is_empty() {
return Ok(CheckResult {
Expand Down Expand Up @@ -476,14 +477,13 @@ impl RateLimiter {
fn counters_that_apply(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context,
) -> LimitadorResult<Vec<Counter>> {
let limits = self.storage.get_limits(namespace);
let ctx = values.into();
limits
.iter()
.filter(|lim| lim.applies(&ctx))
.filter_map(|lim| match Counter::new(Arc::clone(lim), values.clone()) {
.filter(|lim| lim.applies(ctx))
.filter_map(|lim| match Counter::new(Arc::clone(lim), ctx) {
Ok(None) => None,
Ok(Some(c)) => Some(Ok(c)),
Err(e) => Some(Err(e)),
Expand Down Expand Up @@ -533,10 +533,10 @@ impl AsyncRateLimiter {
pub async fn is_rate_limited(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context<'_>,
delta: u64,
) -> LimitadorResult<bool> {
let counters = self.counters_that_apply(namespace, values).await?;
let counters = self.counters_that_apply(namespace, ctx).await?;

for counter in counters {
match self.storage.is_within_limits(&counter, delta).await {
Expand All @@ -554,10 +554,10 @@ impl AsyncRateLimiter {
pub async fn update_counters(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context<'_>,
delta: u64,
) -> LimitadorResult<()> {
let counters = self.counters_that_apply(namespace, values).await?;
let counters = self.counters_that_apply(namespace, ctx).await?;

for counter in counters {
self.storage.update_counter(&counter, delta).await?
Expand All @@ -569,12 +569,12 @@ impl AsyncRateLimiter {
pub async fn check_rate_limited_and_update(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context<'_>,
delta: u64,
load_counters: bool,
) -> LimitadorResult<CheckResult> {
// the above where-clause is needed in order to call unwrap().
let mut counters = self.counters_that_apply(namespace, values).await?;
let mut counters = self.counters_that_apply(namespace, ctx).await?;

if counters.is_empty() {
return Ok(CheckResult {
Expand Down Expand Up @@ -657,14 +657,13 @@ impl AsyncRateLimiter {
async fn counters_that_apply(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
ctx: &Context<'_>,
) -> LimitadorResult<Vec<Counter>> {
let limits = self.storage.get_limits(namespace);
let ctx = values.into();
limits
.iter()
.filter(|lim| lim.applies(&ctx))
.filter_map(|lim| match Counter::new(Arc::clone(lim), values.clone()) {
.filter(|lim| lim.applies(ctx))
.filter_map(|lim| match Counter::new(Arc::clone(lim), ctx) {
Ok(None) => None,
Ok(Some(c)) => Some(Ok(c)),
Err(e) => Some(Err(e)),
Expand Down Expand Up @@ -696,9 +695,8 @@ fn classify_limits_by_namespace(

#[cfg(test)]
mod test {
use crate::limit::{Expression, Limit};
use crate::limit::{Context, Expression, Limit};
use crate::RateLimiter;
use std::collections::HashMap;

#[test]
fn properly_updates_existing_limits() {
Expand All @@ -713,7 +711,7 @@ mod test {
assert_eq!(limits.iter().next().unwrap().max_value(), 42);

let r = rl
.check_rate_limited_and_update(&namespace.into(), &HashMap::default(), 1, true)
.check_rate_limited_and_update(&namespace.into(), &Context::default(), 1, true)
.unwrap();
assert_eq!(r.counters.first().unwrap().max_value(), 42);

Expand All @@ -727,7 +725,7 @@ mod test {
assert_eq!(limits.iter().next().unwrap().max_value(), 50);

let r = rl
.check_rate_limited_and_update(&namespace.into(), &HashMap::default(), 1, true)
.check_rate_limited_and_update(&namespace.into(), &Context::default(), 1, true)
.unwrap();
assert_eq!(r.counters.first().unwrap().max_value(), 50);
}
Expand Down
Loading

0 comments on commit e0ae4f6

Please sign in to comment.