diff --git a/Cargo.lock b/Cargo.lock index b0c7aec6ae1c..9bff5e1eff4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4372,6 +4372,7 @@ dependencies = [ "hyper 1.2.0", "hyper-tungstenite", "hyper-util", + "indexmap 2.0.1", "ipnet", "itertools", "lasso", diff --git a/Cargo.toml b/Cargo.toml index a6d406dc2f6e..1ddadd2f3ce8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,6 +99,7 @@ humantime = "2.1" humantime-serde = "1.1.1" hyper = "0.14" hyper-tungstenite = "0.13.0" +indexmap = "2" inotify = "0.10.2" ipnet = "2.9.0" itertools = "0.10" diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 8e0dbe6ce4f5..141d8a6d0198 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -480,6 +480,15 @@ impl CounterPairVec { let id = self.vec.with_labels(labels); self.vec.remove_metric(id) } + + pub fn sample(&self, labels: ::Group<'_>) -> u64 { + let id = self.vec.with_labels(labels); + let metric = self.vec.get_metric(id); + + let inc = metric.inc.count.load(std::sync::atomic::Ordering::Relaxed); + let dec = metric.dec.count.load(std::sync::atomic::Ordering::Relaxed); + inc.saturating_sub(dec) + } } impl ::measured::metric::group::MetricGroup for CounterPairVec diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 0e8d03906b2a..3002006aedac 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -40,6 +40,7 @@ hyper.workspace = true hyper1 = { package = "hyper", version = "1.2", features = ["server"] } hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] } http-body-util = { version = "0.1" } +indexmap.workspace = true ipnet.workspace = true itertools.workspace = true lasso = { workspace = true, features = ["multi-threaded"] } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 0956aae6c0a5..5399f13eddb9 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -27,6 +27,7 @@ use proxy::redis::cancellation_publisher::RedisPublisherClient; use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use proxy::redis::elasticache; use proxy::redis::notifications; +use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -243,6 +244,12 @@ struct SqlOverHttpArgs { /// increase memory used by the pool #[clap(long, default_value_t = 128)] sql_over_http_pool_shards: usize, + + #[clap(long, default_value_t = 10000)] + sql_over_http_client_conn_threshold: u64, + + #[clap(long, default_value_t = 64)] + sql_over_http_cancel_set_shards: usize, } #[tokio::main] @@ -599,6 +606,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { opt_in: args.sql_over_http.sql_over_http_pool_opt_in, max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, }, + cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards), + client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, }; let authentication_config = AuthenticationConfig { scram_protocol_timeout: args.scram_protocol_timeout, diff --git a/proxy/src/config.rs b/proxy/src/config.rs index e09040775641..b7ab2c00f90d 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -2,7 +2,7 @@ use crate::{ auth::{self, backend::AuthRateLimiter}, console::locks::ApiLocks, rate_limiter::RateBucketInfo, - serverless::GlobalConnPoolOptions, + serverless::{cancel_set::CancelSet, GlobalConnPoolOptions}, Host, }; use anyhow::{bail, ensure, Context, Ok}; @@ -56,6 +56,8 @@ pub struct TlsConfig { pub struct HttpConfig { pub request_timeout: tokio::time::Duration, pub pool_options: GlobalConnPoolOptions, + pub cancel_set: CancelSet, + pub client_conn_threshold: u64, } pub struct AuthenticationConfig { diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 1a0d1f7b0e60..cbff51f2074d 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -3,6 +3,7 @@ //! Handles both SQL over HTTP and SQL over Websockets. mod backend; +pub mod cancel_set; mod conn_pool; mod http_util; mod json; @@ -109,20 +110,37 @@ pub async fn task_main( let conn_id = uuid::Uuid::new_v4(); let http_conn_span = tracing::info_span!("http_conn", ?conn_id); - connections.spawn( - connection_handler( - config, - backend.clone(), - connections.clone(), - cancellation_handler.clone(), - cancellation_token.clone(), - server.clone(), - tls_acceptor.clone(), - conn, - peer_addr, - ) - .instrument(http_conn_span), - ); + let n_connections = Metrics::get() + .proxy + .client_connections + .sample(crate::metrics::Protocol::Http); + tracing::trace!(?n_connections, threshold = ?config.http_config.client_conn_threshold, "check"); + if n_connections > config.http_config.client_conn_threshold { + tracing::trace!("attempting to cancel a random connection"); + if let Some(token) = config.http_config.cancel_set.take() { + tracing::debug!("cancelling a random connection"); + token.cancel() + } + } + + let conn_token = cancellation_token.child_token(); + let conn = connection_handler( + config, + backend.clone(), + connections.clone(), + cancellation_handler.clone(), + conn_token.clone(), + server.clone(), + tls_acceptor.clone(), + conn, + peer_addr, + ) + .instrument(http_conn_span); + + connections.spawn(async move { + let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token); + conn.await + }); } connections.wait().await; @@ -243,6 +261,7 @@ async fn connection_handler( // On cancellation, trigger the HTTP connection handler to shut down. let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await { Either::Left((_cancelled, mut conn)) => { + tracing::debug!(%peer_addr, "cancelling connection"); conn.as_mut().graceful_shutdown(); conn.await } diff --git a/proxy/src/serverless/cancel_set.rs b/proxy/src/serverless/cancel_set.rs new file mode 100644 index 000000000000..390df7f4f7af --- /dev/null +++ b/proxy/src/serverless/cancel_set.rs @@ -0,0 +1,102 @@ +//! A set for cancelling random http connections + +use std::{ + hash::{BuildHasher, BuildHasherDefault}, + num::NonZeroUsize, + time::Duration, +}; + +use indexmap::IndexMap; +use parking_lot::Mutex; +use rand::{thread_rng, Rng}; +use rustc_hash::FxHasher; +use tokio::time::Instant; +use tokio_util::sync::CancellationToken; +use uuid::Uuid; + +type Hasher = BuildHasherDefault; + +pub struct CancelSet { + shards: Box<[Mutex]>, + // keyed by random uuid, fxhasher is fine + hasher: Hasher, +} + +pub struct CancelShard { + tokens: IndexMap, +} + +impl CancelSet { + pub fn new(shards: usize) -> Self { + CancelSet { + shards: (0..shards) + .map(|_| { + Mutex::new(CancelShard { + tokens: IndexMap::with_hasher(Hasher::default()), + }) + }) + .collect(), + hasher: Hasher::default(), + } + } + + pub fn take(&self) -> Option { + for _ in 0..4 { + if let Some(token) = self.take_raw(thread_rng().gen()) { + return Some(token); + } + tracing::trace!("failed to get cancel token"); + } + None + } + + pub fn take_raw(&self, rng: usize) -> Option { + NonZeroUsize::new(self.shards.len()) + .and_then(|len| self.shards[rng % len].lock().take(rng / len)) + } + + pub fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> { + let shard = NonZeroUsize::new(self.shards.len()).map(|len| { + let hash = self.hasher.hash_one(id) as usize; + let shard = &self.shards[hash % len]; + shard.lock().insert(id, token); + shard + }); + CancelGuard { shard, id } + } +} + +impl CancelShard { + fn take(&mut self, rng: usize) -> Option { + NonZeroUsize::new(self.tokens.len()).and_then(|len| { + // 10 second grace period so we don't cancel new connections + if self.tokens.get_index(rng % len)?.1 .0.elapsed() < Duration::from_secs(10) { + return None; + } + + let (_key, (_insert, token)) = self.tokens.swap_remove_index(rng % len)?; + Some(token) + }) + } + + fn remove(&mut self, id: uuid::Uuid) { + self.tokens.swap_remove(&id); + } + + fn insert(&mut self, id: uuid::Uuid, token: CancellationToken) { + self.tokens.insert(id, (Instant::now(), token)); + } +} + +pub struct CancelGuard<'a> { + shard: Option<&'a Mutex>, + id: Uuid, +} + +impl Drop for CancelGuard<'_> { + fn drop(&mut self) { + if let Some(shard) = self.shard { + shard.lock().remove(self.id); + } + } +} diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 798e48850906..5fa253acf86d 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -716,7 +716,7 @@ impl Drop for Client { mod tests { use std::{mem, sync::atomic::AtomicBool}; - use crate::{BranchId, EndpointId, ProjectId}; + use crate::{serverless::cancel_set::CancelSet, BranchId, EndpointId, ProjectId}; use super::*; @@ -767,6 +767,8 @@ mod tests { max_total_conns: 3, }, request_timeout: Duration::from_secs(1), + cancel_set: CancelSet::new(0), + client_conn_threshold: u64::MAX, })); let pool = GlobalConnPool::new(config); let conn_info = ConnInfo { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index e856053a7e8c..5376bddfd314 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -424,8 +424,8 @@ pub enum SqlOverHttpCancel { impl ReportableError for SqlOverHttpCancel { fn get_error_kind(&self) -> ErrorKind { match self { - SqlOverHttpCancel::Postgres => ErrorKind::RateLimit, - SqlOverHttpCancel::Connect => ErrorKind::ServiceRateLimit, + SqlOverHttpCancel::Postgres => ErrorKind::ClientDisconnect, + SqlOverHttpCancel::Connect => ErrorKind::ClientDisconnect, } } }