Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proxy: send cancel notifications to all instances #6719

Merged
merged 14 commits into from
Feb 13, 2024
7 changes: 4 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ futures-core = "0.3"
futures-util = "0.3"
git-version = "0.3"
hashbrown = "0.13"
hashlink = "0.8.1"
hashlink = "0.8.4"
hdrhistogram = "7.5.2"
hex = "0.4"
hex-literal = "0.4"
Expand Down
1 change: 1 addition & 0 deletions libs/pq_proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ rand.workspace = true
tokio.workspace = true
tracing.workspace = true
thiserror.workspace = true
serde.workspace = true

workspace_hack.workspace = true
3 changes: 2 additions & 1 deletion libs/pq_proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod framed;

use byteorder::{BigEndian, ReadBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap, fmt, io, str};

// re-export for use in utils pageserver_feedback.rs
Expand Down Expand Up @@ -123,7 +124,7 @@ impl StartupMessageParams {
}
}

#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub struct CancelKeyData {
pub backend_pid: i32,
pub cancel_key: i32,
Expand Down
32 changes: 31 additions & 1 deletion proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
use proxy::config::AuthenticationConfig;
use proxy::config::CacheOptions;
use proxy::config::HttpConfig;
Expand All @@ -12,6 +14,7 @@ use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::RateBucketInfo;
use proxy::rate_limiter::RateLimiterConfig;
use proxy::redis::notifications;
use proxy::redis::publisher::RedisPublisherClient;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;

Expand All @@ -22,6 +25,7 @@ use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::info;
Expand Down Expand Up @@ -129,6 +133,9 @@ struct ProxyCliArgs {
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
#[clap(long, default_value_t = 100)]
initial_limit: usize,
Expand Down Expand Up @@ -225,6 +232,19 @@ async fn main() -> anyhow::Result<()> {
let cancellation_token = CancellationToken::new();

let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
let cancel_map = CancelMap::default();
let redis_publisher = match &args.redis_notifications {
Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
url,
args.region.clone(),
&config.redis_rps_limit,
)?))),
None => None,
};
let cancellation_handler = Arc::new(CancellationHandler::new(
cancel_map.clone(),
redis_publisher,
));

// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
Expand All @@ -234,6 +254,7 @@ async fn main() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
endpoint_rate_limiter.clone(),
cancellation_handler.clone(),
));

// TODO: rename the argument to something like serverless.
Expand All @@ -248,6 +269,7 @@ async fn main() -> anyhow::Result<()> {
serverless_listener,
cancellation_token.clone(),
endpoint_rate_limiter.clone(),
cancellation_handler.clone(),
));
}

Expand All @@ -271,7 +293,12 @@ async fn main() -> anyhow::Result<()> {
let cache = api.caches.project_info.clone();
if let Some(url) = args.redis_notifications {
info!("Starting redis notifications listener ({url})");
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
maintenance_tasks.spawn(notifications::task_main(
url.to_owned(),
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
Expand Down Expand Up @@ -403,6 +430,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {

let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
let mut redis_rps_limit = args.redis_rps_limit.clone();
RateBucketInfo::validate(&mut redis_rps_limit)?;

let config = Box::leak(Box::new(ProxyConfig {
tls_config,
Expand All @@ -414,6 +443,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
require_client_ip: args.require_client_ip,
disable_ip_check_for_http: args.disable_ip_check_for_http,
endpoint_rps_limit,
redis_rps_limit,
handshake_timeout: args.handshake_timeout,
// TODO: add this argument
region: args.region.clone(),
Expand Down
109 changes: 91 additions & 18 deletions proxy/src/cancellation.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
use async_trait::async_trait;
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
use uuid::Uuid;

use crate::error::ReportableError;
use crate::{
error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
redis::publisher::RedisPublisherClient,
};

pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;

/// Enables serving `CancelRequest`s.
#[derive(Default)]
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
///
/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
pub struct CancellationHandler {
map: CancelMap,
redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
}

#[derive(Debug, Error)]
pub enum CancelError {
Expand All @@ -32,15 +44,43 @@ impl ReportableError for CancelError {
}
}

impl CancelMap {
impl CancellationHandler {
pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
Self { map, redis_client }
}
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
pub async fn cancel_session(
&self,
key: CancelKeyData,
session_id: Uuid,
) -> Result<(), CancelError> {
let from = "from_client";
// NB: we should immediately release the lock after cloning the token.
let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
tracing::warn!("query cancellation key not found: {key}");
if let Some(redis_client) = &self.redis_client {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
info!("publishing cancellation key to Redis");
match redis_client.lock().await.try_publish(key, session_id).await {
Ok(()) => {
info!("cancellation key successfuly published to Redis");
}
Err(e) => {
tracing::error!("failed to publish a message: {e}");
return Err(CancelError::IO(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)));
}
}
}
return Ok(());
};

NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
info!("cancelling query per user's request using key {key}");
cancel_closure.try_cancel_query().await
}
Expand All @@ -57,7 +97,7 @@ impl CancelMap {

// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.0.entry(key) {
match self.map.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => continue,
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
Expand All @@ -69,18 +109,46 @@ impl CancelMap {
info!("registered new query cancellation key {key}");
Session {
key,
cancel_map: self,
cancellation_handler: self,
}
}

#[cfg(test)]
fn contains(&self, session: &Session) -> bool {
self.0.contains_key(&session.key)
self.map.contains_key(&session.key)
}

#[cfg(test)]
fn is_empty(&self) -> bool {
self.0.is_empty()
self.map.is_empty()
}
}

#[async_trait]
pub trait NotificationsCancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
}

#[async_trait]
impl NotificationsCancellationHandler for CancellationHandler {
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
let from = "from_redis";
let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
match cancel_closure {
Some(cancel_closure) => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "found"])
.inc();
cancel_closure.try_cancel_query().await
}
None => {
NUM_CANCELLATION_REQUESTS
.with_label_values(&[from, "not_found"])
.inc();
tracing::warn!("query cancellation key not found: {key}");
Ok(())
}
}
}
}

Expand Down Expand Up @@ -115,23 +183,25 @@ pub struct Session {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancel_map: Arc<CancelMap>,
cancellation_handler: Arc<CancellationHandler>,
}

impl Session {
/// Store the cancel token for the given session.
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map.0.insert(self.key, Some(cancel_closure));
self.cancellation_handler
.map
.insert(self.key, Some(cancel_closure));

self.key
}
}

impl Drop for Session {
fn drop(&mut self) {
self.cancel_map.0.remove(&self.key);
self.cancellation_handler.map.remove(&self.key);
info!("dropped query cancellation key {}", &self.key);
}
}
Expand All @@ -142,13 +212,16 @@ mod tests {

#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancel_map: Arc<CancelMap> = Default::default();
let cancellation_handler = Arc::new(CancellationHandler {
map: CancelMap::default(),
redis_client: None,
});

let session = cancel_map.clone().get_session();
assert!(cancel_map.contains(&session));
let session = cancellation_handler.clone().get_session();
assert!(cancellation_handler.contains(&session));
drop(session);
// Check that the session has been dropped.
assert!(cancel_map.is_empty());
assert!(cancellation_handler.is_empty());

Ok(())
}
Expand Down
1 change: 1 addition & 0 deletions proxy/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub struct ProxyConfig {
pub require_client_ip: bool,
pub disable_ip_check_for_http: bool,
pub endpoint_rps_limit: Vec<RateBucketInfo>,
pub redis_rps_limit: Vec<RateBucketInfo>,
pub region: String,
pub handshake_timeout: Duration,
}
Expand Down
9 changes: 9 additions & 0 deletions proxy/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
.unwrap()
});

pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_cancellation_requests_total",
"Number of cancellation requests (per found/not_found).",
&["source", "kind"],
)
.unwrap()
});

#[derive(Clone)]
pub struct LatencyTimer {
// time since the stopwatch was started
Expand Down
Loading
Loading