Skip to content

Commit

Permalink
perf(Backend): ⚡ alternative rate limit policy
Browse files Browse the repository at this point in the history
  • Loading branch information
Eason0729 committed Feb 7, 2024
1 parent 6acf303 commit 042c0fe
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 87 deletions.
13 changes: 11 additions & 2 deletions backend/src/controller/judger/score.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::cmp;

use super::{submit, user_contest};
use crate::{
entity::{contest, problem, user},
Expand Down Expand Up @@ -153,8 +155,15 @@ impl ScoreUpload {
.one(&txn)
.await?;

score = score.saturating_sub(submit.map(|x| x.score).unwrap_or_default());
score = score.saturating_add(self.submit.score);
let original_score = submit.map(|x| x.score).unwrap_or_default();

if original_score >= self.submit.score {
tracing::trace!(reason = "unchange score", "score_contest");
return Ok(());
}

score = score.saturating_add(cmp::max(self.submit.score, original_score));
score = score.saturating_sub(original_score);

linker.score = ActiveValue::Set(score);
linker.update(&txn).await.map_err(Into::<Error>::into)?;
Expand Down
250 changes: 186 additions & 64 deletions backend/src/controller/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -1,101 +1,223 @@
use core::time;
use std::{net::IpAddr, sync::Arc};
use std::{hash::Hash, marker::PhantomData, net::IpAddr, str::FromStr, sync::Arc, time::Duration};

use futures::Future;
use ip_network::IpNetwork;
use leaky_bucket::RateLimiter;
use quick_cache::sync::Cache;

use tracing::instrument;
use tracing::{instrument, Instrument};

use crate::util::error::Error;

const BUCKET_WIDTH: usize = 256;
const BUCKET_WIDTH: usize = 512;

pub struct RateLimitController {
limiter: Cache<IpAddr, Arc<RateLimiter>>,
trusts: Vec<IpNetwork>,
/// Policy(number) for rate limit
trait LimitPolicy {
const BUCKET_WIDTH: usize = BUCKET_WIDTH;
const INIT_CAP: usize;
const MAX_CAP: usize;
/// How many fill per 30 second
const FILL_RATE: usize;

fn into_limiter() -> RateLimiter {
RateLimiter::builder()
.interval(Duration::from_secs(30))
.initial(Self::INIT_CAP)
.max(Self::MAX_CAP)
.refill(Self::FILL_RATE)
.build()
}
}

/// policy for [`TrafficType::Login`]
struct LoginPolicy;

impl LimitPolicy for LoginPolicy {
const INIT_CAP: usize = 32;
const MAX_CAP: usize = 64;
const FILL_RATE: usize = 6;
}

/// policy for [`TrafficType::Guest`]
struct GuestPolicy;

impl LimitPolicy for GuestPolicy {
const INIT_CAP: usize = 32;
const MAX_CAP: usize = 64;
const FILL_RATE: usize = 6;
}

/// policy for [`TrafficType::Blacklist`]
///
/// note that this is a global rate limit,
/// users in blacklist use same [`leaky_bucket::RateLimiter`],
/// so number is significantly higher
struct BlacklistPolicy;

impl LimitPolicy for BlacklistPolicy {
const INIT_CAP: usize = 128;
const MAX_CAP: usize = 256;
const FILL_RATE: usize = 32;
}

macro_rules! check_rate_limit {
($s:expr) => {{
use futures::Future;
struct LimitMap<K, P: LimitPolicy>
where
K: Send + Eq + Hash + Clone,
{
cache: Cache<K, Arc<RateLimiter>>,
_policy: PhantomData<P>,
}

/// interface that it's able to calculate rate limit ans store state (by key)
trait Limit<K: Send> {
/// return true if limited
fn check(&self, key: &K) -> bool;
/// return `Err(Error::RateLimit)` when limitation reached,
/// `Ok(())` otherwise.
fn check_error(&self, key: &K) -> Result<(), Error> {
match self.check(key) {
true => Err(Error::RateLimit),
false => Ok(()),
}
}
}

impl Limit<()> for Arc<RateLimiter> {
fn check(&self, _: &()) -> bool {
struct Waker;
impl std::task::Wake for Waker {
fn wake(self: Arc<Self>) {
log::error!("waker wake");
unreachable!("waker wake");
}
}

let waker = Arc::new(Waker).into();
let mut cx = std::task::Context::from_waker(&waker);

let ac = $s;
let ac = self.clone().acquire_owned(1);
tokio::pin!(ac);
if ac.as_mut().poll(&mut cx).is_pending() {
return Err(Error::RateLimit);

ac.as_mut().poll(&mut cx).is_pending()
}
}

impl<K, P: LimitPolicy> Limit<K> for LimitMap<K, P>
where
K: Send + Eq + Hash + Clone,
{
fn check(&self, key: &K) -> bool {
self.cache
.get_or_insert_with(key, || Result::<_, ()>::Ok(Arc::new(P::into_limiter())))
.unwrap()
.check(&())
}
}

impl<K, P: LimitPolicy> Default for LimitMap<K, P>
where
K: Send + Eq + Hash + Clone,
{
fn default() -> Self {
Self {
cache: Cache::new(P::BUCKET_WIDTH),
_policy: Default::default(),
}
}};
}
}

pub struct RateLimitController {
ip_blacklist: Cache<IpAddr, ()>,
user_limiter: LimitMap<i32, LoginPolicy>,
ip_limiter: LimitMap<IpAddr, GuestPolicy>,
blacklist_limiter: Arc<RateLimiter>,
trusts: Vec<IpNetwork>,
}

/// Type of traffic
pub enum TrafficType {
/// Login user(with vaild token)
Login(i32),
/// Guest(without token)
Guest,
/// traffic with token from blacklisted ip
///
/// see [`RateLimitController::check`]
Blacklist(crate::controller::token::Error),
}

impl RateLimitController {
pub fn new(trusts: &[IpNetwork]) -> Self {
Self {
limiter: Cache::new(BUCKET_WIDTH),
ip_blacklist: Cache::new(BUCKET_WIDTH),
user_limiter: Default::default(),
ip_limiter: Default::default(),
blacklist_limiter: Arc::new(BlacklistPolicy::into_limiter()),
trusts: trusts.to_vec(),
}
}
#[instrument(skip_all, level = "debug")]
pub fn check_ip<T>(&self, req: &tonic::Request<T>, permits: usize) -> Result<(), Error> {
if self.trusts.is_empty() {
return Ok(());
}
if req.remote_addr().is_none() {
tracing::warn!(msg = "cannot not retrieve remote address", "config");
return Ok(());
}
let remote = req.remote_addr().unwrap().ip();
/// retrieve ip address from request
///
/// if used on unix socket return 0.0.0.0
///
/// if upstream is trusted but sent no `X-Forwarded-For`, use remote address
#[instrument(skip_all, level = "trace")]
fn ip<T>(&self, req: &tonic::Request<T>) -> Result<IpAddr, Error> {
let mut remote = req
.remote_addr()
.map(|x| x.ip())
.unwrap_or_else(|| IpAddr::from_str("0.0.0.0").unwrap());

tracing::trace!(remote = remote.to_string());

for trust in &self.trusts {
if !trust.contains(remote) {
continue;
}
if let Some(ip) = req.metadata().get("X-Forwarded-For") {
let ip = ip
.to_str()
.map_err(|_| Error::Unreachable("header must not contain non-ascii char"))?
.parse()
.map_err(|_| Error::Unreachable("MalFormatted header"))?;
return self.acquire(ip, permits);
} else {
tracing::warn!(msg = "No \"X-Forwarded-For\" found", "config");
if let Some(addr) = req.metadata().get("X-Forwarded-For") {
remote = addr
.to_str()
.map_err(|_| Error::Unreachable("header must not contain non-ascii char"))?
.parse()
.map_err(|_| Error::Unreachable("MalFormatted header"))?;
}
}
}
Err(Error::RateLimit)
}

Ok(remote)
}
/// check rate limit
///
/// f should be a FnOnce that emit a future yield TokenState
///
/// There are three type of traffic
///
/// - [`TrafficType::Login`]: faster rate and apply rate limit base on user id
/// - [`TrafficType::Guest`]: slower rate and apply rate limit base on ip address
/// - [`TrafficType::Blacklist`]: dedicated rate limit (because verify token take time)
///
/// We identify [`TrafficType::Blacklist`] by ip blacklist,
/// whose entries is added when user fail to login or sent invaild token
#[instrument(skip_all, level = "debug")]
fn acquire(&self, ip: IpAddr, permits: usize) -> Result<(), Error> {
let limiter = self
.limiter
.get_or_insert_with::<_, ()>(&ip, || {
Ok(Arc::new(
RateLimiter::builder()
.max(40)
.initial(10)
.interval(time::Duration::from_secs(3))
.build(),
))
})
.map_err(|_| Error::Unreachable("creation function for limiter shouldn't panic"))?;
let owned = limiter.acquire_owned(permits);

check_rate_limit!(owned);
Ok(())
pub async fn check<'a, T, F, Fut>(&self, req: &'a tonic::Request<T>, f: F) -> Result<(), Error>
where
F: FnOnce(&'a tonic::Request<T>) -> Fut,
Fut: Future<Output = TrafficType>,
{
let addr = self.ip(req)?;

if self.ip_blacklist.get(&addr).is_some() {
return self.ip_limiter.check_error(&addr);
}

match f(req)
.instrument(tracing::debug_span!("token_verify"))
.await
{
TrafficType::Login(x) => self.user_limiter.check_error(&x),
TrafficType::Guest => self.ip_limiter.check_error(&addr),
TrafficType::Blacklist(err) => {
tracing::warn!(msg = err.to_string(), "ip_blacklist");
self.ip_blacklist.insert(addr, ());
self.blacklist_limiter.check_error(&())
}
}
}
}

// impl Default for RateLimitController {
// fn default() -> Self {
// Self {
// limiter: Cache::new(256),
// }
// }
// }
4 changes: 1 addition & 3 deletions backend/src/controller/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ impl TokenController {
base64::Engine::decode(&base64::engine::general_purpose::STANDARD_NO_PAD, token)?;
let rand: Rand = rand.try_into().map_err(|_| Error::InvalidTokenLength)?;

let token: CachedToken;

let cache_result = {
match self.cache.get(&rand) {
Some(cc) => {
Expand All @@ -154,7 +152,7 @@ impl TokenController {
token
}
None => {
token = (token::Entity::find()
let token: CachedToken = (token::Entity::find()
.filter(token::Column::Rand.eq(rand.to_vec()))
.one(self.db.deref())
.in_current_span()
Expand Down
4 changes: 3 additions & 1 deletion backend/src/endpoint/submit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ impl SubmitSet for Arc<Server> {
}

#[instrument(skip_all, level = "debug")]
async fn list_langs(&self, _: Request<()>) -> Result<Response<Languages>, Status> {
async fn list_langs(&self, req: Request<()>) -> Result<Response<Languages>, Status> {
self.parse_auth(&req).await?;

let list: Vec<_> = self
.judger
.list_lang()
Expand Down
5 changes: 3 additions & 2 deletions backend/src/endpoint/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,12 @@ impl TokenSet for Arc<Server> {
}
#[instrument(skip_all, level = "debug")]
async fn logout(&self, req: Request<()>) -> Result<Response<()>, Status> {
let meta = req.metadata();
self.parse_auth(&req).await?.ok_or_default()?;

if let Some(x) = meta.get("token") {
if let Some(x) = req.metadata().get("token") {
let token = x.to_str().unwrap();


self.token.remove(token.to_string()).await?;
tracing::event!(Level::TRACE, token = token);

Expand Down
2 changes: 1 addition & 1 deletion backend/src/endpoint/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ impl UserSet for Arc<Server> {

#[instrument(skip_all, level = "debug")]
async fn my_info(&self, req: Request<()>) -> Result<Response<UserInfo>, Status> {
let (auth, _req) = self.parse_request(req).await?;
let auth = self.parse_auth(&req).await?;
let (user_id, _) = auth.ok_or_default()?;

let model = Entity::find_by_id(user_id)
Expand Down
Loading

0 comments on commit 042c0fe

Please sign in to comment.