Skip to content

Commit

Permalink
Revert on transient redis error
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsnaps committed May 9, 2024
1 parent 87f65dc commit 27af568
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 25 deletions.
80 changes: 64 additions & 16 deletions limitador/src/storage/redis/counters_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,22 @@ impl CachedCounterValue {
value - start == 0
}

fn revert_writes(&self, writes: i64) -> Result<(), ()> {
let newer = self.initial_value.load(Ordering::SeqCst);
if newer > writes {
return match self.initial_value.compare_exchange(
newer,
newer - writes,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Err(expiry) if expiry != 0 => Err(()),
_ => Ok(()),
};
}
Ok(())
}

pub fn hits(&self, _: &Counter) -> i64 {
self.value.value_at(SystemTime::now())
}
Expand Down Expand Up @@ -160,10 +176,10 @@ impl Batcher {
self.notifier.notify_one();
}

pub async fn consume<F, Fut, O>(&self, max: usize, consumer: F) -> O
pub async fn consume<F, Fut, T, E>(&self, max: usize, consumer: F) -> Result<T, E>
where
F: FnOnce(HashMap<Counter, Arc<CachedCounterValue>>) -> Fut,
Fut: Future<Output = O>,
Fut: Future<Output = Result<T, E>>,
{
let mut ready = self.batch_ready(max);
loop {
Expand All @@ -185,10 +201,12 @@ impl Batcher {
result.insert(counter.clone(), value);
}
let result = consumer(result).await;
batch.iter().for_each(|counter| {
self.updates
.remove_if(counter, |_, v| v.no_pending_writes());
});
if result.is_ok() {
batch.iter().for_each(|counter| {
self.updates
.remove_if(counter, |_, v| v.no_pending_writes());
});
}
return result;
} else {
ready = select! {
Expand Down Expand Up @@ -240,6 +258,31 @@ impl CountersCache {
&self.batcher
}

pub fn return_pending_writes(
&self,
counter: &Counter,
value: i64,
writes: i64,
) -> Result<(), ()> {
if writes != 0 {
let mut miss = false;
let value = self.cache.get_with_by_ref(counter, || {
if let Some(entry) = self.batcher.updates.get(counter) {
entry.value().clone()
} else {
miss = true;
let value = Arc::new(CachedCounterValue::from_authority(counter, value));
value.delta(counter, writes);
value
}
});
if miss.not() {
return value.revert_writes(writes);
}
}
Ok(())
}

pub fn apply_remote_delta(
&self,
counter: Counter,
Expand Down Expand Up @@ -415,9 +458,10 @@ mod tests {
.consume(2, |items| {
assert!(items.is_empty());
assert!(SystemTime::now().duration_since(start).unwrap() >= duration);
async {}
async { Ok::<(), ()>(()) }
})
.await;
.await
.expect("Always Ok!");
}

#[tokio::test]
Expand All @@ -441,9 +485,10 @@ mod tests {
SystemTime::now().duration_since(start).unwrap()
>= Duration::from_millis(100)
);
async {}
async { Ok::<(), ()>(()) }
})
.await;
.await
.expect("Always Ok!");
}

#[tokio::test]
Expand All @@ -466,9 +511,10 @@ mod tests {
let wait_period = SystemTime::now().duration_since(start).unwrap();
assert!(wait_period >= Duration::from_millis(40));
assert!(wait_period < Duration::from_millis(50));
async {}
async { Ok::<(), ()>(()) }
})
.await;
.await
.expect("Always Ok!");
}

#[tokio::test]
Expand All @@ -487,9 +533,10 @@ mod tests {
assert!(
SystemTime::now().duration_since(start).unwrap() < Duration::from_millis(5)
);
async {}
async { Ok::<(), ()>(()) }
})
.await;
.await
.expect("Always Ok!");
}

#[tokio::test]
Expand All @@ -512,9 +559,10 @@ mod tests {
let wait_period = SystemTime::now().duration_since(start).unwrap();
assert!(wait_period >= Duration::from_millis(40));
assert!(wait_period < Duration::from_millis(50));
async {}
async { Ok::<(), ()>(()) }
})
.await;
.await
.expect("Always Ok!");
}
}

Expand Down
90 changes: 81 additions & 9 deletions limitador/src/storage/redis/redis_cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ impl CachedRedisStorageBuilder {
async fn update_counters<C: ConnectionLike>(
redis_conn: &mut C,
counters_and_deltas: HashMap<Counter, Arc<CachedCounterValue>>,
) -> Result<Vec<(Counter, i64, i64, i64)>, StorageErr> {
) -> Result<Vec<(Counter, i64, i64, i64)>, (Vec<(Counter, i64, i64, i64)>, StorageErr)> {
let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS);
let mut script_invocation = redis_script.prepare_invoke();

Expand All @@ -299,25 +299,31 @@ async fn update_counters<C: ConnectionLike>(
script_invocation.arg(counter.seconds());
script_invocation.arg(delta);
// We need to store the counter in the actual order we are sending it to the script
res.push((counter, 0, last_value_from_redis, 0));
res.push((counter, last_value_from_redis, delta, 0));
}
}

let span = debug_span!("datastore");
// The redis crate is not working with tables, thus the response will be a Vec of counter values
let script_res: Vec<i64> = script_invocation
let script_res: Vec<i64> = match script_invocation
.invoke_async(redis_conn)
.instrument(span)
.await?;
.await
{
Ok(res) => res,
Err(err) => {
return Err((res, err.into()));
}
};

// We need to update the values and ttls returned by redis
let counters_range = 0..res.len();
let script_res_range = (0..script_res.len()).step_by(2);

for (i, j) in counters_range.zip(script_res_range) {
let (_, val, delta, expires_at) = &mut res[i];
*val = script_res[j];
*delta = script_res[j] - *delta;
*delta = script_res[j] - *val; // new value - previous one = remote writes
*val = script_res[j]; // update to value to newest
*expires_at = script_res[j + 1];
}
res
Expand All @@ -344,9 +350,22 @@ async fn flush_batcher_and_update_counters<C: ConnectionLike>(
update_counters(&mut redis_conn, counters)
})
.await
.or_else(|err| {
.or_else(|(data, err)| {
if err.is_transient() {
flip_partitioned(&partitioned, true);
let counters = data.len();
let mut reverted = 0;
for (counter, old_value, pending_writes, _) in data {
if cached_counters
.return_pending_writes(&counter, old_value, pending_writes)
.is_err()
{
tracing::log::error!("Couldn't revert writes back to {:?}", &counter);
} else {
reverted += 1;
}
}
tracing::log::warn!("Reverted {} of {} counter increments", reverted, counters);
Ok(Vec::new())
} else {
Err(err)
Expand All @@ -370,9 +389,10 @@ mod tests {
};
use crate::storage::redis::redis_cached::{flush_batcher_and_update_counters, update_counters};
use crate::storage::redis::CachedRedisStorage;
use redis::{ErrorKind, Value};
use redis::{Cmd, ErrorKind, RedisError, Value};
use redis_test::{MockCmd, MockRedisConnection};
use std::collections::HashMap;
use std::io;
use std::ops::Add;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
Expand Down Expand Up @@ -510,8 +530,60 @@ mod tests {
)
.await;

let c = cached_counters.get(&counter).unwrap();
assert_eq!(c.hits(&counter), 8);
assert_eq!(c.pending_writes(), Ok(0));
}

#[tokio::test]
async fn flush_batcher_reverts_on_err() {
let counter = Counter::new(
Limit::new(
"test_namespace",
10,
60,
vec!["req.method == 'POST'"],
vec!["app_id"],
),
Default::default(),
);

let error: RedisError = io::Error::new(io::ErrorKind::TimedOut, "That was long!").into();
assert!(error.is_timeout());
let mock_client = MockRedisConnection::new(vec![MockCmd::new::<&mut Cmd, Value>(
redis::cmd("EVALSHA")
.arg("95a717e821d8fbdd667b5e4c6fede4c9cad16006")
.arg("2")
.arg(key_for_counter(&counter))
.arg(key_for_counters_of_limit(counter.limit()))
.arg(60)
.arg(3),
Err(error),
)]);

let cache = CountersCacheBuilder::new().build(Duration::from_millis(10));
let value = Arc::new(CachedCounterValue::from_authority(&counter, 2));
value.delta(&counter, 3);
cache.batcher().add(counter.clone(), value);

let cached_counters: Arc<CountersCache> = Arc::new(cache);
let partitioned = Arc::new(AtomicBool::new(false));

if let Some(c) = cached_counters.get(&counter) {
assert_eq!(c.hits(&counter), 8);
assert_eq!(c.hits(&counter), 5);
}

flush_batcher_and_update_counters(
mock_client,
true,
cached_counters.clone(),
partitioned,
100,
)
.await;

let c = cached_counters.get(&counter).unwrap();
assert_eq!(c.hits(&counter), 5);
assert_eq!(c.pending_writes(), Ok(3));
}
}

0 comments on commit 27af568

Please sign in to comment.