Skip to content

Commit

Permalink
feat(server): proper rate limiting behind reverse-proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
Nuhvi committed May 21, 2024
1 parent 4a13e97 commit 795f7e6
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 63 deletions.
1 change: 1 addition & 0 deletions server/src/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ resolvers = []
minimum_ttl = 300
maximum_ttl = 86400
[rate_limiter]
behind_proxy = false
per_second = 2
burst_size = 10
20 changes: 7 additions & 13 deletions server/src/dht_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use pkarr::{

use tracing::debug;

use crate::{cache::HeedPkarrCache, rate_limiting::RateLimiterLayer};
use crate::{cache::HeedPkarrCache, rate_limiting::IpRateLimiter};

/// DhtServer with Rate limiting
pub struct DhtServer {
Expand All @@ -30,7 +30,7 @@ pub struct DhtServer {
cache: Box<crate::cache::HeedPkarrCache>,
minimum_ttl: u32,
maximum_ttl: u32,
rate_limiter_layer: RateLimiterLayer,
rate_limiter: IpRateLimiter,
}

impl Debug for DhtServer {
Expand All @@ -45,7 +45,7 @@ impl DhtServer {
resolvers: Option<Vec<String>>,
minimum_ttl: u32,
maximum_ttl: u32,
rate_limiter_layer: RateLimiterLayer,
rate_limiter: IpRateLimiter,
) -> Self {
Self {
// Default DhtServer used to stay a good citizen servicing the Dht.
Expand All @@ -60,7 +60,7 @@ impl DhtServer {
}),
minimum_ttl,
maximum_ttl,
rate_limiter_layer,
rate_limiter,
}
}
}
Expand Down Expand Up @@ -131,13 +131,9 @@ impl Server for DhtServer {
if should_query {
// Rate limit nodes that are making too many request forcing us to making too
// many queries, either by querying the same non-existent key, or many unique keys.
if self
.rate_limiter_layer
.config
.limiter()
.check_key(&from.ip())
.is_ok()
{
if self.rate_limiter.is_limited(&from.ip()) {
debug!(?from, "Resolver rate limiting");
} else {
rpc.get(
*target,
RequestTypeSpecific::GetValue(GetValueRequestArguments {
Expand All @@ -148,8 +144,6 @@ impl Server for DhtServer {
None,
self.resolvers.to_owned(),
);
} else {
debug!(?from, "Resolver rate limiting");
};
}
};
Expand Down
17 changes: 10 additions & 7 deletions server/src/http_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use tracing::{info, warn};

use pkarr::PkarrClientAsync;

use crate::rate_limiting::RateLimiterLayer;
use crate::rate_limiting::IpRateLimiter;

pub struct HttpServer {
tasks: JoinSet<std::io::Result<()>>,
Expand All @@ -20,9 +20,9 @@ impl HttpServer {
pub async fn spawn(
client: PkarrClientAsync,
port: u16,
rate_limiter_layer: RateLimiterLayer,
rate_limiter: IpRateLimiter,
) -> Result<HttpServer> {
let app = create_app(AppState { client }, rate_limiter_layer);
let app = create_app(AppState { client }, rate_limiter);

let mut tasks = JoinSet::new();

Expand Down Expand Up @@ -75,22 +75,25 @@ impl HttpServer {
}
}

pub(crate) fn create_app(state: AppState, rate_limiter_layer: RateLimiterLayer) -> Router {
pub fn create_app(state: AppState, rate_limiter: IpRateLimiter) -> Router {
let cors = CorsLayer::new()
.allow_methods([Method::GET, Method::PUT])
.allow_origin(cors::Any);

Router::new()
let router = Router::new()
.route("/:key", get(crate::handlers::get).put(crate::handlers::put))
.route(
"/",
get(|| async { "This is a Pkarr relay: pkarr.org/relays.\n" }),
)
.with_state(state)
.layer(rate_limiter_layer)
.layer(DefaultBodyLimit::max(1104))
.layer(cors)
.layer(TraceLayer::new_for_http())
.layer(TraceLayer::new_for_http());

rate_limiter.layer(&router);

router
}

#[derive(Debug, Clone)]
Expand Down
6 changes: 3 additions & 3 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async fn main() -> Result<()> {
fs::create_dir_all(env_path)?;
let cache = Box::new(HeedPkarrCache::new(env_path, config.cache_size()).unwrap());

let rate_limiter_layer = rate_limiting::create(config.rate_limiter());
let rate_limiter = rate_limiting::IpRateLimiter::new(config.rate_limiter());

let client = PkarrClient::builder()
.dht_settings(DhtSettings {
Expand All @@ -60,7 +60,7 @@ async fn main() -> Result<()> {
config.resolvers(),
config.minimum_ttl(),
config.maximum_ttl(),
rate_limiter_layer.clone(),
rate_limiter.clone(),
))),
..DhtSettings::default()
})
Expand All @@ -75,7 +75,7 @@ async fn main() -> Result<()> {

info!("Running as a resolver on UDP socket {udp_address}");

let http_server = HttpServer::spawn(client, config.relay_port(), rate_limiter_layer).await?;
let http_server = HttpServer::spawn(client, config.relay_port(), rate_limiter).await?;

tokio::signal::ctrl_c().await?;

Expand Down
128 changes: 88 additions & 40 deletions server/src/rate_limiting.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,111 @@
use std::{sync::Arc, time::Duration};
use std::{net::IpAddr, sync::Arc, time::Duration};

use axum::Router;
use governor::middleware::StateInformationMiddleware;
use serde::{Deserialize, Serialize};

use governor::middleware::StateInformationMiddleware;
use tower_governor::{
governor::GovernorConfigBuilder, key_extractor::PeerIpKeyExtractor, GovernorLayer,
governor::{GovernorConfig, GovernorConfigBuilder},
key_extractor::{PeerIpKeyExtractor, SmartIpKeyExtractor},
};

pub use tower_governor::GovernorLayer;

#[derive(Serialize, Deserialize, Debug)]
pub struct RateLimiterConfig {
pub(crate) behind_proxy: bool,
pub(crate) per_second: u64,
pub(crate) burst_size: u32,
}

impl Default for RateLimiterConfig {
fn default() -> Self {
Self {
behind_proxy: false,
per_second: 2,
burst_size: 10,
}
}
}

pub type RateLimiterLayer = GovernorLayer<PeerIpKeyExtractor, StateInformationMiddleware>;

/// Create the default rate-limiting layer.
///
/// This will be used by the [crate::http_server::HttpServer] to guard all endpoints (GET and PUT)
/// and in [crate::dht_server::DhtServer] before calling [pkarr::mainline::rpc::Rpc::get]
/// after a cache miss or if its cached packet is expired.
///
/// The purpose is to limit DHT queries as much as possible, while serving honest clients still.
///
/// This spawns a background thread to clean up the rate limiting cache.
///
/// # Limits
///
/// * allow a burst of `10 requests` per IP address
/// * replenish `1 request` every `2 seconds`
pub fn create(config: &RateLimiterConfig) -> RateLimiterLayer {
let governor_config = GovernorConfigBuilder::default()
.use_headers()
.per_second(config.per_second)
.burst_size(config.burst_size)
.finish()
.expect("failed to build rate-limiting governor");

let governor_config = Arc::new(governor_config);

// The governor needs a background task for garbage collection (to clear expired records)
let gc_interval = Duration::from_secs(60);
let governor_limiter = governor_config.limiter().clone();
std::thread::spawn(move || loop {
std::thread::sleep(gc_interval);
tracing::debug!("rate limiting storage size: {}", governor_limiter.len());
governor_limiter.retain_recent();
});

GovernorLayer {
config: governor_config,
#[derive(Debug, Clone)]
/// A rate limiter that works for direct connections (Peer) or behind reverse-proxy (Proxy)
pub enum IpRateLimiter {
Peer(Arc<GovernorConfig<PeerIpKeyExtractor, StateInformationMiddleware>>),
Proxy(Arc<GovernorConfig<SmartIpKeyExtractor, StateInformationMiddleware>>),
}

impl IpRateLimiter {
/// Create an [IpRateLimiter]
///
/// This spawns a background thread to clean up the rate limiting cache.
pub fn new(config: &RateLimiterConfig) -> Self {
match config.behind_proxy {
true => {
let config = Arc::new(
GovernorConfigBuilder::default()
.use_headers()
.per_second(config.per_second)
.burst_size(config.burst_size)
.key_extractor(SmartIpKeyExtractor)
.finish()
.expect("failed to build rate-limiting governor"),
);

// The governor needs a background task for garbage collection (to clear expired records)
let gc_interval = Duration::from_secs(60);

let governor_limiter = config.limiter().clone();
std::thread::spawn(move || loop {
std::thread::sleep(gc_interval);
tracing::debug!("rate limiting storage size: {}", governor_limiter.len());
governor_limiter.retain_recent();
});

Self::Proxy(config)
}
false => {
let config = Arc::new(
GovernorConfigBuilder::default()
.use_headers()
.per_second(config.per_second)
.burst_size(config.burst_size)
.finish()
.expect("failed to build rate-limiting governor"),
);

// The governor needs a background task for garbage collection (to clear expired records)
let gc_interval = Duration::from_secs(60);

let governor_limiter = config.limiter().clone();
std::thread::spawn(move || loop {
std::thread::sleep(gc_interval);
tracing::debug!("rate limiting storage size: {}", governor_limiter.len());
governor_limiter.retain_recent();
});

Self::Peer(config)
}
}
}

/// Check if the Ip is allowed to make more requests
pub fn is_limited(&self, ip: &IpAddr) -> bool {
match self {
IpRateLimiter::Peer(config) => config.limiter().check_key(ip).is_err(),
IpRateLimiter::Proxy(config) => config.limiter().check_key(ip).is_err(),
}
}

/// Add a [GovernorLayer] on the provided [Router]
pub fn layer(&self, router: &Router) {
let _ = match self {
IpRateLimiter::Peer(config) => router.clone().layer(GovernorLayer {
config: config.clone(),
}),
IpRateLimiter::Proxy(config) => router.clone().layer(GovernorLayer {
config: config.clone(),
}),
};
}
}

0 comments on commit 795f7e6

Please sign in to comment.