diff --git a/Cargo.lock b/Cargo.lock index fb7b7109..4af88c31 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -416,9 +416,9 @@ checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "base64" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bindgen" @@ -1555,15 +1555,19 @@ name = "limitador" version = "0.6.0-dev" dependencies = [ "async-trait", + "base64 0.22.1", "cfg-if", "criterion", "dashmap", "futures", "getrandom", + "h2 0.3.26", "metrics", "moka", "paste", "postcard", + "prost", + "prost-types", "r2d2", "rand", "redis", @@ -1575,8 +1579,14 @@ dependencies = [ "serial_test", "tempfile", "thiserror", + "time", "tokio", + "tokio-stream", + "tonic", + "tonic-build", + "tonic-reflection", "tracing", + "uuid", ] [[package]] @@ -1702,7 +1712,7 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d58e362dc7206e9456ddbcdbd53c71ba441020e62104703075a69151e38d85f" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "http-body-util", "hyper 1.3.1", "hyper-tls", @@ -3385,6 +3395,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" dependencies = [ "getrandom", + "rand", ] [[package]] diff --git a/limitador-server/src/config.rs b/limitador-server/src/config.rs index dc4ef59c..b7f2d60d 100644 --- a/limitador-server/src/config.rs +++ b/limitador-server/src/config.rs @@ -154,8 +154,8 @@ pub struct InMemoryStorageConfiguration { pub struct DistributedStorageConfiguration { pub name: String, pub cache_size: Option, - pub local: String, - pub broadcast: String, + pub listen_address: String, + pub peer_urls: Vec, } #[derive(PartialEq, Eq, Debug)] diff --git a/limitador-server/src/main.rs b/limitador-server/src/main.rs index 13c965b1..53534a59 100644 --- a/limitador-server/src/main.rs +++ b/limitador-server/src/main.rs @@ -43,6 +43,10 @@ use std::path::Path; use std::sync::Arc; use std::time::Duration; use std::{env, process}; + +#[cfg(feature = "distributed_storage")] +use clap::parser::ValuesRef; + use sysinfo::{MemoryRefreshKind, RefreshKind, System}; use thiserror::Error; use tokio::runtime::Handle; @@ -165,8 +169,8 @@ impl Limiter { let storage = DistributedInMemoryStorage::new( cfg.name, cfg.cache_size.or_else(guess_cache_size).unwrap(), - cfg.local, - Some(cfg.broadcast), + cfg.listen_address, + cfg.peer_urls, ); let rate_limiter_builder = RateLimiterBuilder::with_storage(Storage::with_counter_storage(Box::new(storage))); @@ -604,18 +608,18 @@ fn create_config() -> (Configuration, &'static str) { .help("Unique name to identify this Limitador instance"), ) .arg( - Arg::new("LOCAL") + Arg::new("LISTEN_ADDRESS") .action(ArgAction::Set) .required(true) .display_order(2) - .help("Local IP:PORT to send datagrams from"), + .help("Local IP:PORT to listen on for replication"), ) .arg( - Arg::new("BROADCAST") - .action(ArgAction::Set) - .required(true) + Arg::new("PEER_URLS") + .action(ArgAction::Append) + .required(false) .display_order(3) - .help("Broadcast IP:PORT to send datagrams to"), + .help("A replication peer url that this instance will connect to"), ) .arg( Arg::new("CACHE_SIZE") @@ -697,8 +701,12 @@ fn create_config() -> (Configuration, &'static str) { Some(("distributed", sub)) => { StorageConfiguration::Distributed(DistributedStorageConfiguration { name: sub.get_one::("NAME").unwrap().to_owned(), - local: sub.get_one::("LOCAL").unwrap().to_owned(), - broadcast: sub.get_one::("BROADCAST").unwrap().to_owned(), + listen_address: sub.get_one::("LISTEN_ADDRESS").unwrap().to_owned(), + peer_urls: sub + .get_many::("PEER_URLS") + .unwrap_or(ValuesRef::default()) + .map(|x| x.to_owned()) + .collect(), cache_size: sub.get_one::("CACHE_SIZE").copied(), }) } diff --git a/limitador/Cargo.toml b/limitador/Cargo.toml index 366db867..bdf94940 100644 --- a/limitador/Cargo.toml +++ b/limitador/Cargo.toml @@ -15,7 +15,7 @@ edition = "2021" [features] default = ["disk_storage", "redis_storage"] disk_storage = ["rocksdb"] -distributed_storage = [] +distributed_storage = ["tokio", "tokio-stream", "h2", "base64", "uuid", "tonic", "tonic-reflection", "prost", "prost-types"] redis_storage = ["redis", "r2d2", "tokio"] lenient_conditions = [] @@ -49,6 +49,16 @@ tokio = { version = "1", optional = true, features = [ "time", ] } +base64 = { version = "0.22", optional = true } +tokio-stream = { version = "0.1", optional = true } +h2 = { version = "0.3", optional = true } +uuid = { version = "1.8.0", features = ["v4", "fast-rng"], optional = true } +tonic = { version = "0.11", optional = true } +tonic-reflection = { version = "0.11", optional = true } +prost = { version = "0.12", optional = true } +prost-types = { version = "0.12", optional = true } +time = "0.3.36" + [dev-dependencies] serial_test = "3.0" criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } @@ -62,6 +72,9 @@ tokio = { version = "1", features = [ "time", ] } +[build-dependencies] +tonic-build = "0.11" + [[bench]] name = "bench" path = "benches/bench.rs" diff --git a/limitador/benches/bench.rs b/limitador/benches/bench.rs index 111bad1f..fcd3db24 100644 --- a/limitador/benches/bench.rs +++ b/limitador/benches/bench.rs @@ -147,7 +147,7 @@ fn bench_distributed(c: &mut Criterion) { "test_node".to_owned(), 10_000, "127.0.0.1:0".to_owned(), - None, + vec![], )); bench_is_rate_limited(b, test_scenario, storage); }) @@ -162,7 +162,7 @@ fn bench_distributed(c: &mut Criterion) { "test_node".to_owned(), 10_000, "127.0.0.1:0".to_owned(), - None, + vec![], )); bench_update_counters(b, test_scenario, storage); }) @@ -177,7 +177,7 @@ fn bench_distributed(c: &mut Criterion) { "test_node".to_owned(), 10_000, "127.0.0.1:0".to_owned(), - None, + vec![], )); bench_check_rate_limited_and_update(b, test_scenario, storage); }) diff --git a/limitador/build.rs b/limitador/build.rs new file mode 100644 index 00000000..b046a647 --- /dev/null +++ b/limitador/build.rs @@ -0,0 +1,22 @@ +use std::error::Error; +use std::path::Path; + +fn main() -> Result<(), Box> { + generate_protobuf() +} + +fn generate_protobuf() -> Result<(), Box> { + if cfg!(feature = "distributed_storage") { + let proto_path: &Path = "proto/distributed.proto".as_ref(); + + let proto_dir = proto_path + .parent() + .expect("proto file should reside in a directory"); + + tonic_build::configure() + .protoc_arg("--experimental_allow_proto3_optional") + .compile(&[proto_path], &[proto_dir])?; + } + + Ok(()) +} diff --git a/limitador/proto/distributed.proto b/limitador/proto/distributed.proto new file mode 100644 index 00000000..ff470931 --- /dev/null +++ b/limitador/proto/distributed.proto @@ -0,0 +1,61 @@ +syntax = "proto3"; + +package limitador.service.distributed.v1; + +// A packet defines all the types of messages that can be sent between replication peers. +message Packet { + oneof message { + // the Hello message is used to introduce a peer to another peer. It is the first message sent by a peer. + Hello hello = 1; + // the MembershipUpdate message is used to gossip about the other peers in the cluster: + // 1) sent after the first Hello message + // 2) sent when the membership state changes + MembershipUpdate membership_update = 2; + // the Ping message is used to request a pong from the other peer. + Ping ping = 3; + // the Pong message is used to respond to a ping. + Pong pong = 4; + // the CounterUpdate message is used to send counter updates. + CounterUpdate counter_update = 5; + } +} + +// this is the first packet sent by a peer to another peer. +message Hello { + // the peer id of the sending peer + string sender_peer_id = 1; + // urls that the sending peer thinks it can be reached at. + repeated string sender_urls = 2; + // url the session initiator used to connect to the receiver peer. + optional string receiver_url = 3; +} + +// A request to a peer to respond with a Pong message. +message Ping {} + +// Pong is the response to a Ping and Hello message. +message Pong { + // the current time at of the peer in milliseconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. + uint64 current_time = 3; +} + +message MembershipUpdate { + repeated Peer peers = 1; +} + +message Peer { + string peer_id = 1; + uint32 latency = 2; // the round trip latency to the peer in milliseconds. + repeated string urls = 3; // url that can be used to connect to the peer. +} + +message CounterUpdate { + bytes key = 1; + map values = 2; + uint64 expires_at = 3; +} + +// Replication is the limitador replication service. +service Replication { + rpc Stream(stream Packet) returns (stream Packet) {} +} \ No newline at end of file diff --git a/limitador/src/storage/distributed/grpc/mod.rs b/limitador/src/storage/distributed/grpc/mod.rs new file mode 100644 index 00000000..2339e1cf --- /dev/null +++ b/limitador/src/storage/distributed/grpc/mod.rs @@ -0,0 +1,689 @@ +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::net::SocketAddr; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::{error::Error, io::ErrorKind, pin::Pin}; + +use moka::sync::Cache; +use tokio::sync::mpsc::Sender; +use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio::time::sleep; + +use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt}; +use tonic::{Code, Request, Response, Status, Streaming}; +use tracing::debug; + +use crate::counter::Counter; +use crate::storage::distributed::cr_counter_value::CrCounterValue; +use crate::storage::distributed::grpc::v1::packet::Message; +use crate::storage::distributed::grpc::v1::replication_client::ReplicationClient; +use crate::storage::distributed::grpc::v1::replication_server::{Replication, ReplicationServer}; +use crate::storage::distributed::grpc::v1::{ + CounterUpdate, Hello, MembershipUpdate, Packet, Peer, Pong, +}; +use crate::storage::distributed::CounterKey; + +// clippy will barf on protobuff generated code for enum variants in +// v3::socket_option::SocketState, so allow this lint +#[allow(clippy::enum_variant_names, clippy::derive_partial_eq_without_eq)] +pub mod v1 { + tonic::include_proto!("limitador.service.distributed.v1"); +} + +#[derive(Copy, Clone, Debug)] +enum ClockSkew { + None(), + Slow(Duration), + Fast(Duration), +} + +impl ClockSkew { + fn new(local: SystemTime, remote: SystemTime) -> ClockSkew { + if local == remote { + ClockSkew::None() + } else if local.gt(&remote) { + ClockSkew::Slow(local.duration_since(remote).unwrap()) + } else { + ClockSkew::Fast(remote.duration_since(local).unwrap()) + } + } + + #[allow(dead_code)] + fn remote(&self, time: SystemTime) -> SystemTime { + match self { + ClockSkew::None() => time, + ClockSkew::Slow(duration) => time - *duration, + ClockSkew::Fast(duration) => time + *duration, + } + } + + #[allow(dead_code)] + fn remote_now(&self) -> SystemTime { + self.remote(SystemTime::now()) + } +} + +impl std::fmt::Display for ClockSkew { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ClockSkew::None() => write!(f, "remote time is the same"), + ClockSkew::Slow(duration) => { + write!(f, "remote time is slow by {}ms", duration.as_millis()) + } + ClockSkew::Fast(duration) => { + write!(f, "remote time is fast by {}ms", duration.as_millis()) + } + } + } +} + +#[derive(Clone)] +struct Session { + broker_state: BrokerState, + replication_state: Arc>, + out_stream: MessageSender, + peer_id: String, +} + +impl Session { + async fn close(&mut self) { + let mut state = self.replication_state.write().await; + if let Some(peer) = state.peer_trackers.get_mut(&self.peer_id) { + peer.session = None; + } + } + + async fn send(&self, message: Message) -> Result<(), Status> { + self.out_stream.clone().send(Ok(message)).await + } + + async fn process(&mut self, in_stream: &mut Streaming) -> Result<(), Status> { + // Send a MembershipUpdate to inform the peer about all the members + // We should resend it again if we learn of new members. + self.send(Message::MembershipUpdate(MembershipUpdate { + peers: { + let state = self.replication_state.read().await; + state.peers().clone() + }, + })) + .await?; + + let mut udpates_to_send = self.broker_state.publisher.subscribe(); + + loop { + tokio::select! { + update = udpates_to_send.recv() => { + let update = update.map_err(|_| Status::unknown("broadcast error"))?; + self.send(Message::CounterUpdate(update)).await?; + } + result = in_stream.next() => { + match result { + None=> { + // signals the end of stream... + return Ok(()) + }, + Some(Ok(packet)) => { + self.process_packet(packet).await?; + }, + Some(Err(err)) => { + if is_disconnect(&err) { + debug!("peer: '{}': disconnected: {:?}", self.peer_id, err); + return Ok(()); + } else { + return Err(err); + } + }, + } + } + } + } + } + + async fn process_packet(&self, packet: Packet) -> Result<(), Status> { + match packet.message { + Some(Message::Ping(_)) => { + debug!("peer: '{}': Ping", self.peer_id); + self.out_stream + .clone() + .send(Ok(Message::Pong(Pong { + current_time: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as u64, + }))) + .await?; + } + Some(Message::MembershipUpdate(update)) => { + debug!("peer: '{}': MembershipUpdate", self.peer_id); + // add any new peers to peer_trackers + let mut state = self.replication_state.write().await; + for peer in update.peers { + let peer_id = peer.peer_id.clone(); + match state.peer_trackers.get(&peer_id) { + None => { + // we are discovering a peer from neighbor, adding a tracker will + // trigger a connection attempt to it. + state.peer_trackers.insert( + peer_id.clone(), + PeerTracker { + peer_id, + url: None, + discovered_urls: peer.urls.iter().cloned().collect(), + latency: 0, // todo maybe set this to peer.latency + session.latency + clock_skew: ClockSkew::None(), + session: None, + }, + ); + } + Some(_peer_tracker) => { + // // TODO: add discovered urls to the existing tracker. + // peer.urls.clone().iter().for_each(|url| { + // peer_tracker.discovered_urls.insert(url.clone()); + // }); + } + } + } + } + Some(Message::CounterUpdate(update)) => { + debug!("peer: '{}': CounterUpdate", self.peer_id); + + let counter_key = postcard::from_bytes::(update.key.as_slice()) + .map_err(|err| { + Status::internal(format!("failed to decode counter key: {:?}", err)) + })?; + + let values = BTreeMap::from_iter( + update + .values + .iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())), + ); + + let counter = >::into(counter_key); + if counter.is_qualified() { + if let Some(counter) = self.broker_state.qualified_counters.get(&counter) { + counter.merge( + (UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into(), + ); + } + } else { + let counters = self.broker_state.limits_for_namespace.read().unwrap(); + let limits = counters.get(counter.namespace()).unwrap(); + let value = limits.get(counter.limit()).unwrap(); + value.merge( + (UNIX_EPOCH + Duration::from_secs(update.expires_at), values).into(), + ); + }; + } + _ => { + debug!("peer: '{}': unsupported packet: {:?}", self.peer_id, packet); + return Err(Status::invalid_argument(format!( + "unsupported packet {:?}", + packet + ))); + } + } + Ok(()) + } +} + +#[derive(Clone)] +struct PeerTracker { + peer_id: String, + url: Option, + discovered_urls: HashSet, + latency: u32, + // Keep track of the clock skew between us and the peer + clock_skew: ClockSkew, + // The communication session we have with the peer, may be None if not connected + session: Option, +} + +// Track the replication session with all peers. +struct ReplicationState { + // URLs our peers have used to connect to us. + discovered_urls: HashSet, + peer_trackers: HashMap, +} + +impl ReplicationState { + fn peers(&self) -> Vec { + let mut peers = Vec::new(); + self.peer_trackers.iter().for_each(|(_, peer_tracker)| { + peers.push(Peer { + peer_id: peer_tracker.peer_id.clone(), + latency: peer_tracker.latency, + urls: peer_tracker + .discovered_urls + .iter() + .map(String::to_owned) + .collect(), // peer_tracker.urls.clone().into_iter().collect() + }); + }); + peers.sort_by(|a, b| a.peer_id.cmp(&b.peer_id)); + peers + } +} + +fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> { + let mut err: &(dyn Error + 'static) = err_status; + + loop { + if let Some(io_err) = err.downcast_ref::() { + return Some(io_err); + } + + // h2::Error do not expose std::io::Error with `source()` + // https://github.com/hyperium/h2/pull/462 + if let Some(h2_err) = err.downcast_ref::() { + if let Some(io_err) = h2_err.get_io() { + return Some(io_err); + } + } + + err = match err.source() { + Some(err) => err, + None => return None, + }; + } +} + +async fn read_hello(in_stream: &mut Streaming) -> Result { + match in_stream.next().await { + Some(Ok(packet)) => match packet.message { + Some(Message::Hello(value)) => Ok(value), + _ => Err(Status::invalid_argument("expected Hello")), + }, + _ => Err(Status::invalid_argument("expected Hello")), + } +} + +async fn read_pong(in_stream: &mut Streaming) -> Result { + match in_stream.next().await { + Some(Ok(packet)) => match packet.message { + Some(Message::Pong(value)) => Ok(value), + _ => Err(Status::invalid_argument("expected Pong")), + }, + _ => Err(Status::invalid_argument("expected Pong")), + } +} + +fn is_disconnect(err: &Status) -> bool { + if let Some(io_err) = match_for_io_error(err) { + if io_err.kind() == ErrorKind::BrokenPipe { + return true; + } + } + false +} + +// MessageSender is used to abstract the difference between the server and client sender streams... +#[derive(Clone)] +enum MessageSender { + Server(Sender>), + Client(Sender), +} + +impl MessageSender { + async fn send(self, message: Result) -> Result<(), Status> { + match self { + MessageSender::Server(sender) => { + let value = message.map(|x| Packet { message: Some(x) }); + let result = sender.send(value).await; + result.map_err(|_| Status::unknown("send error")) + } + MessageSender::Client(sender) => match message { + Ok(message) => { + let result = sender + .send(Packet { + message: Some(message), + }) + .await; + result.map_err(|_| Status::unknown("send error")) + } + Err(err) => Err(err), + }, + } + } +} + +#[derive(Clone)] +struct BrokerState { + id: String, + limits_for_namespace: Arc>, + qualified_counters: Arc>>>, + publisher: broadcast::Sender, +} + +#[derive(Clone)] +pub struct Broker { + listen_address: SocketAddr, + peer_urls: Vec, + broker_state: BrokerState, + replication_state: Arc>, +} + +impl Broker { + pub fn new( + id: String, + listen_address: SocketAddr, + peer_urls: Vec, + limits_for_namespace: Arc>, + qualified_counters: Arc>>>, + ) -> Broker { + let (tx, _) = broadcast::channel(16); + let publisher: broadcast::Sender = tx; + + Broker { + listen_address, + peer_urls, + broker_state: BrokerState { + id, + publisher, + limits_for_namespace, + qualified_counters, + }, + replication_state: Arc::new(RwLock::new(ReplicationState { + discovered_urls: HashSet::new(), + peer_trackers: HashMap::new(), + })), + } + } + + pub fn publish(&self, counter_update: CounterUpdate) { + // ignore the send error, it just means there are no active subscribers + _ = self.broker_state.publisher.send(counter_update); + } + + pub async fn start(&self) { + self.clone().peer_urls.into_iter().for_each(|peer_url| { + let broker = self.clone(); + let peer_url = peer_url.clone(); + _ = tokio::spawn(async move { + // Keep trying until we get once successful connection handshake. Once that + // happens, we will know the peer_id and can recover by reconnecting to the peer + loop { + match broker.connect_to_peer(peer_url.clone()).await { + Ok(_) => return, + Err(err) => { + debug!("failed to connect with peer '{}': {:?}", peer_url, err); + sleep(Duration::from_secs(1)).await + } + } + } + }); + }); + + // Periodically reconnect to failed peers + { + let broker = self.clone(); + tokio::spawn(async move { + loop { + sleep(Duration::from_secs(1)).await; + broker.reconnect_to_failed_peers().await; + } + }); + } + + debug!( + "peer '{}' listening on: id={}", + self.broker_state.id, self.listen_address + ); + + tonic::transport::Server::builder() + .add_service(ReplicationServer::new(self.clone())) + .serve(self.listen_address) + .await + .unwrap(); + } + + // Connect to a peer and start a replication session. This returns once the session handshake + // completes. + async fn connect_to_peer(&self, peer_url: String) -> Result<(), Status> { + let mut client = match ReplicationClient::connect(peer_url.clone()).await { + Ok(client) => client, + Err(err) => { + return Err(Status::new(Code::Unknown, err.to_string())); + } + }; + + let (tx, rx) = mpsc::channel(1); + + let mut in_stream = client.stream(ReceiverStream::new(rx)).await?.into_inner(); + let mut sender = MessageSender::Client(tx); + let session = self + .handshake(&mut in_stream, &mut sender, Some(peer_url)) + .await?; + + // this means we already have a session with this peer... + let mut session = match session { + None => return Ok(()), // this just means we already have a session with this peer + Some(session) => session, + }; + + // Session is now established, process the session async... + tokio::spawn(async move { + match session.process(&mut in_stream).await { + Ok(_) => { + debug!("client initiated stream ended"); + } + Err(err) => { + debug!("client initiated stream processing failed {:?}", err); + } + } + session.close().await; + }); + + Ok(()) + } + + // Reconnect failed peers periodically + async fn reconnect_to_failed_peers(&self) { + let failed_peers: Vec<_> = { + let state = self.replication_state.read().await; + state + .peer_trackers + .iter() + .filter_map(|(_, peer_tracker)| { + if peer_tracker.session.is_none() { + // first try to connect to the configured URL + let mut urls: Vec<_> = peer_tracker.url.iter().cloned().collect(); + // Then try to connect to discovered urls. + let mut discovered_urls = + peer_tracker.discovered_urls.iter().cloned().collect(); + urls.append(&mut discovered_urls); + Some((peer_tracker.peer_id.clone(), urls)) + } else { + None + } + }) + .collect() + }; + + for (peer_id, urls) in failed_peers { + for url in urls { + debug!( + "attempting to reconnect to failed peer '{}' over {:?}", + peer_id, url + ); + match self.connect_to_peer(url.clone()).await { + Ok(_) => break, + Err(err) => { + debug!("failed to connect with peer '{}': {:?}", url, err); + } + } + } + } + } + + // handshake is called when a new stream is created, it will handle the initial handshake + // and updating the session state in the state.peer_trackers map. Result is None if an + // existing session is already established with the peer. + async fn handshake( + &self, + in_stream: &mut Streaming, + out_stream: &mut MessageSender, + peer_url: Option, + ) -> Result, Status> { + // Let the peer know who we are... + let start = SystemTime::now(); // .duration_since(UNIX_EPOCH).unwrap().as_millis() as u64 + { + let state = self.replication_state.read().await; + out_stream + .clone() + .send(Ok(Message::Hello(Hello { + sender_peer_id: self.broker_state.id.clone(), + sender_urls: state.discovered_urls.clone().into_iter().collect(), + receiver_url: peer_url.clone(), + }))) + .await?; + } + + // Wait for the peer to tell us who he is... + let peer_hello = read_hello(in_stream).await?; + + // respond with a Pong so the peer can calculate the round trip latency + out_stream + .clone() + .send(Ok(Message::Pong(Pong { + current_time: start.duration_since(UNIX_EPOCH).unwrap().as_millis() as u64, + }))) + .await?; + + // Get the pong back from the peer... + let peer_pong = read_pong(in_stream).await?; + let end = SystemTime::now(); + + let peer_id = peer_hello.sender_peer_id.clone(); + + // When a peer initiates a connection, we discover a URL that can be used + // to connect to us. + if let Some(url) = peer_hello.receiver_url { + let mut state = self.replication_state.write().await; + state.discovered_urls.insert(url); + } + + let session = Session { + peer_id: peer_id.clone(), + replication_state: self.replication_state.clone(), + broker_state: self.broker_state.clone(), + out_stream: out_stream.clone(), + }; + + // We now know who the peer is and our latency to him. + let mut state = self.replication_state.write().await; + let (tracker, option) = match state.peer_trackers.get_mut(&peer_id) { + Some(tracker) => { + match tracker.clone().session { + Some(prev_session) => { + // we already have a session with this peer, this is common since + // both peers are racing to connect to each other at the same time + // But we only need to keep one session. Use the order of the + // peer ids to agree on which session keep. + + if peer_id < self.broker_state.id { + // close the previous session, use the new one... + _ = prev_session + .out_stream + .send(Err(Status::already_exists("session"))) + .await; + tracker.session = Some(session.clone()); + + (tracker, Some(session)) + } else { + // use the previous session, close the new one... + _ = session + .out_stream + .send(Err(Status::already_exists("session"))) + .await; + (tracker, None) + } + } + None => { + tracker.session = Some(session.clone()); + (tracker, Some(session)) + } + } + } + None => { + let latency = end.duration_since(start).unwrap(); + let peer_time = UNIX_EPOCH.add(Duration::from_millis(peer_pong.current_time)); + let peer_time_adj = peer_time.add(latency.div_f32(2.0)); // adjust for round trip latency + let discovered_urls = peer_hello + .sender_urls + .iter() + .map(String::to_owned) + .collect(); + let tracker = PeerTracker { + peer_id: peer_id.clone(), + url: None, + discovered_urls, + latency: latency.as_millis() as u32, + clock_skew: ClockSkew::new(end, peer_time_adj), + session: Some(session.clone()), + }; + + debug!( + "peer {} clock skew: {}", + peer_id.clone(), + &tracker.clock_skew + ); + state.peer_trackers.insert(peer_id.clone(), tracker); + let tracker = state.peer_trackers.get_mut(&peer_id).unwrap(); + (tracker, Some(session)) + } + }; + + // keep track of the URL we used to connect to the peer. + if peer_url.is_some() { + tracker.url.clone_from(&peer_url) + } + + Ok(option) + } +} + +#[tonic::async_trait] +impl Replication for Broker { + type StreamStream = Pin> + Send>>; + + // Accepts a connection from a peer and starts a replication session + async fn stream( + &self, + req: Request>, + ) -> Result, Status> { + debug!("ReplicationServer::stream"); + + let mut in_stream = req.into_inner(); + let (tx, rx) = mpsc::channel(1); + + let broker = self.clone(); + tokio::spawn(async move { + let mut sender = MessageSender::Server(tx); + match broker.handshake(&mut in_stream, &mut sender, None).await { + Ok(Some(mut session)) => { + match session.process(&mut in_stream).await { + Ok(_) => { + debug!("server accepted stream ended"); + } + Err(err) => { + debug!("server accepted stream processing failed {:?}", err); + } + } + session.close().await; + } + Ok(None) => { + // dup session.. + } + Err(err) => { + debug!("stream handshake failed {:?}", err); + } + } + }); + + Ok(Response::new( + Box::pin(ReceiverStream::new(rx)) as Self::StreamStream + )) + } +} diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index ff9aa923..2c319576 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -1,5 +1,5 @@ use std::collections::hash_map::Entry; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{HashMap, HashSet}; use std::net::ToSocketAddrs; use std::ops::Deref; use std::sync::{Arc, RwLock}; @@ -7,24 +7,24 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use moka::sync::Cache; use serde::{Deserialize, Serialize}; -use tokio::net::UdpSocket; -use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; use crate::counter::Counter; use crate::limit::{Limit, Namespace}; use crate::storage::distributed::cr_counter_value::CrCounterValue; +use crate::storage::distributed::grpc::v1::CounterUpdate; +use crate::storage::distributed::grpc::Broker; use crate::storage::{Authorization, CounterStorage, StorageErr}; mod cr_counter_value; +mod grpc; type NamespacedLimitCounters = HashMap>; pub struct CrInMemoryStorage { identifier: String, - sender: Sender, limits_for_namespace: Arc>>>, qualified_counters: Arc>>>, + broker: Broker, } impl CounterStorage for CrInMemoryStorage { @@ -248,100 +248,44 @@ impl CounterStorage for CrInMemoryStorage { } } +pub type LimitsMap = HashMap>>; + impl CrInMemoryStorage { pub fn new( identifier: String, cache_size: u64, - local: String, - broadcast: Option, + listen_address: String, + peer_urls: Vec, ) -> Self { - let (sender, mut rx) = mpsc::channel(1000); + // let (sender, mut rx) = mpsc::channel(1000); - let local = local.to_socket_addrs().unwrap().next().unwrap(); - if let Some(remote) = broadcast.clone() { - tokio::spawn(async move { - let sock = UdpSocket::bind(local).await.unwrap(); - sock.set_broadcast(true).unwrap(); - sock.connect(remote).await.unwrap(); - loop { - let message: CounterValueMessage = rx.recv().await.unwrap(); - let buf = postcard::to_stdvec(&message).unwrap(); - match sock.send(&buf).await { - Ok(len) => { - if len != buf.len() { - println!("Couldn't send complete message!"); - } - } - Err(err) => println!("Couldn't send update: {:?}", err), - }; - } - }); - } + let listen_address = listen_address.to_socket_addrs().unwrap().next().unwrap(); + let peer_urls = peer_urls.clone(); - let limits_for_namespace = Arc::new(RwLock::new(HashMap::< - Namespace, - HashMap>, - >::new())); + let limits_for_namespace = Arc::new(RwLock::new(LimitsMap::new())); let qualified_counters: Arc>>> = Arc::new(Cache::new(cache_size)); + let broker = grpc::Broker::new( + identifier.clone(), + listen_address, + peer_urls, + limits_for_namespace.clone(), + qualified_counters.clone(), + ); + { - let limits_for_namespace = limits_for_namespace.clone(); - let qualified_counters = qualified_counters.clone(); - - if let Some(broadcast) = broadcast.clone() { - tokio::spawn(async move { - let sock = UdpSocket::bind(broadcast).await.unwrap(); - sock.set_broadcast(true).unwrap(); - let mut buf = [0; 1024]; - loop { - let (len, addr) = sock.recv_from(&mut buf).await.unwrap(); - if addr != local { - match postcard::from_bytes::(&buf[..len]) { - Ok(message) => { - let CounterValueMessage { - counter_key, - expiry, - values, - } = message; - let counter = >::into(counter_key); - if counter.is_qualified() { - if let Some(counter) = qualified_counters.get(&counter) { - counter.merge( - (UNIX_EPOCH + Duration::from_secs(expiry), values) - .into(), - ); - } - } else { - let counters = limits_for_namespace.read().unwrap(); - let limits = counters.get(counter.namespace()).unwrap(); - let value = limits.get(counter.limit()).unwrap(); - value.merge( - (UNIX_EPOCH + Duration::from_secs(expiry), values) - .into(), - ); - }; - } - Err(err) => { - println!( - "Error from {} bytes: {:?} \n{:?}", - len, - err, - &buf[..len] - ) - } - } - } - } - }); - } + let broker = broker.clone(); + tokio::spawn(async move { + broker.start().await; + }); } Self { identifier, - sender, limits_for_namespace, qualified_counters, + broker, } } @@ -395,27 +339,20 @@ impl CrInMemoryStorage { when: SystemTime, ) { counter.inc_at(delta, Duration::from_secs(key.seconds()), when); - let sender = self.sender.clone(); + let counter = counter.clone(); - tokio::spawn(async move { - let (expiry, values) = counter.into_inner(); - let message = CounterValueMessage { - counter_key: key.into(), - expiry: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), - values, - }; - sender.send(message).await - }); + let (expiry, values) = counter.into_inner(); + let key: CounterKey = key.into(); + let key = postcard::to_stdvec(&key).unwrap(); + + self.broker.publish(CounterUpdate { + key, + values: values.into_iter().collect(), + expires_at: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), + }) } } -#[derive(Debug, Serialize, Deserialize)] -struct CounterValueMessage { - counter_key: CounterKey, - expiry: u64, - values: BTreeMap, -} - #[derive(Debug, Serialize, Deserialize)] struct CounterKey { namespace: Namespace, diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index 06c308c2..9a24663a 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -17,7 +17,7 @@ macro_rules! test_with_all_storage_impls { #[tokio::test] async fn [<$function _distributed_storage>]() { let rate_limiter = - RateLimiter::new_with_storage(Box::new(CrInMemoryStorage::new("test_node".to_owned(), 10_000, "127.0.0.1:19876".to_owned(), Some("127.0.0.255:19876".to_owned())))); + RateLimiter::new_with_storage(Box::new(CrInMemoryStorage::new("test_node".to_owned(), 10_000, "127.0.0.1:19876".to_owned(), vec![]))); $function(&mut TestsLimiter::new_from_blocking_impl(rate_limiter)).await; } @@ -72,12 +72,75 @@ macro_rules! test_with_all_storage_impls { }; } +#[cfg(feature = "distributed_storage")] +async fn distributed_storage_factory( + count: usize, +) -> Vec { + use crate::helpers::tests_limiter::TestsLimiter; + use limitador::storage::distributed::CrInMemoryStorage; + use limitador::RateLimiter; + + let addresses = (0..count) + .map(|i| format!("127.0.0.1:{}", 5200 + i)) + .collect::>(); + return (0..count) + .map(|i| { + let node = format!("n{}", i); + let listen_address = addresses.get(i).unwrap().to_owned(); + let peer_urls = addresses + .iter() + .map(|x| format!("http://{}", x)) + .collect::>(); + + TestsLimiter::new_from_blocking_impl(RateLimiter::new_with_storage(Box::new( + CrInMemoryStorage::new(node, 10_000, listen_address, peer_urls), + ))) + }) + .collect::>(); +} + +macro_rules! test_with_distributed_storage_impls { + // This macro uses the "paste" crate to define the names of the functions. + // Also, the Redis tests cannot be run in parallel. The "serial" tag from + // the "serial-test" crate takes care of that. + ($function:ident) => { + paste::item! { + #[cfg(feature = "distributed_storage")] + #[tokio::test] + async fn [<$function _distributed_storage>]() { + $function(crate::distributed_storage_factory).await; + } + } + }; +} + mod helpers; #[cfg(test)] mod test { extern crate limitador; + #[allow(dead_code)] + async fn eventually( + timeout: Duration, + tick: Duration, + condition: impl Fn() -> F, + ) -> Result + where + F: Future, + { + tokio::time::timeout(timeout, async move { + let mut i = tokio::time::interval(tick); + loop { + if condition().await { + return true; + } + i.tick().await; + } + }) + .await + } + // To be able to pass the tests without Redis cfg_if::cfg_if! { if #[cfg(feature = "redis_storage")] { @@ -101,9 +164,11 @@ mod test { use limitador::storage::distributed::CrInMemoryStorage; use limitador::storage::in_memory::InMemoryStorage; use std::collections::{HashMap, HashSet}; + use std::future::Future; use std::thread::sleep; use std::time::Duration; use tempfile::TempDir; + use tokio::time::error::Elapsed; test_with_all_storage_impls!(get_namespaces); test_with_all_storage_impls!(get_namespaces_returns_empty_when_there_arent_any); @@ -140,6 +205,8 @@ mod test { test_with_all_storage_impls!(configure_with_updates_the_limits); test_with_all_storage_impls!(add_limit_only_adds_if_not_present); + test_with_distributed_storage_impls!(distributed_rate_limited); + // All these functions need to use async/await. That's needed to support // both the sync and the async implementations of the rate limiter. @@ -1107,4 +1174,62 @@ mod test { assert_eq!(known_limit.max_value(), 10); assert_eq!(known_limit.name(), None); } + + #[allow(dead_code)] + async fn distributed_rate_limited(create_distributed_limiters: fn(count: usize) -> Fut) + where + Fut: Future>, + { + let rate_limiters = create_distributed_limiters(2).await; + tokio::time::sleep(Duration::from_secs(1)).await; + + let namespace = "test_namespace"; + let max_hits = 3; + let limit = Limit::new( + namespace, + max_hits, + 60, + vec!["req.method == 'GET'"], + vec!["app_id"], + ); + + for rate_limiter in rate_limiters.iter() { + rate_limiter.add_limit(&limit).await; + } + + let mut values: HashMap = HashMap::new(); + values.insert("req.method".to_string(), "GET".to_string()); + values.insert("app_id".to_string(), "test_app_id".to_string()); + + for i in 0..max_hits { + // Alternate between the two rate limiters + let rate_limiter = rate_limiters.get((i % 2) as usize).unwrap(); + assert!( + !rate_limiter + .is_rate_limited(namespace, &values, 1) + .await + .unwrap(), + "Must not be limited after {i}" + ); + rate_limiter + .update_counters(namespace, &values, 1) + .await + .unwrap(); + } + + // eventually it should get rate limited... + assert!(eventually( + Duration::from_secs(5), + Duration::from_millis(100), + || async { + let rate_limiter = rate_limiters.first().unwrap(); + rate_limiter + .is_rate_limited(namespace, &values, 1) + .await + .unwrap() + } + ) + .await + .unwrap()); + } }