From ad091c6f75eeb0babeee677960d802373849727e Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 18 Dec 2024 09:29:42 +0000 Subject: [PATCH] chore(proxy): pre-load native tls certificates and propagate compute client config --- proxy/src/bin/local_proxy.rs | 27 +++++++-- proxy/src/bin/proxy.rs | 42 ++++++++++++-- proxy/src/cancellation.rs | 85 ++++++++++++++++++---------- proxy/src/compute.rs | 39 ++----------- proxy/src/config.rs | 8 ++- proxy/src/console_redirect_proxy.rs | 4 +- proxy/src/control_plane/mod.rs | 6 +- proxy/src/postgres_rustls/mod.rs | 8 +-- proxy/src/proxy/connect_compute.rs | 28 ++++----- proxy/src/proxy/mod.rs | 4 +- proxy/src/proxy/passthrough.rs | 13 ++++- proxy/src/proxy/tests/mod.rs | 88 ++++++++++++++--------------- proxy/src/redis/notifications.rs | 3 + proxy/src/serverless/backend.rs | 14 ++--- proxy/src/serverless/websocket.rs | 2 +- 15 files changed, 211 insertions(+), 160 deletions(-) diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index 56bbd9485011..38b42e15a3e9 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -13,7 +13,9 @@ use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::local::{LocalBackend, JWKS_ROLE_MAP}; use proxy::auth::{self}; use proxy::cancellation::CancellationHandlerMain; -use proxy::config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig}; +use proxy::config::{ + self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, +}; use proxy::control_plane::locks::ApiLocks; use proxy::control_plane::messages::{EndpointJwksResponse, JwksSettings}; use proxy::http::health_server::AppMetrics; @@ -32,6 +34,8 @@ project_git_version!(GIT_VERSION); project_build_tag!(BUILD_TAG); use clap::Parser; +use rustls::crypto::ring; +use rustls::RootCertStore; use thiserror::Error; use tokio::net::TcpListener; use tokio::sync::Notify; @@ -209,6 +213,7 @@ async fn main() -> anyhow::Result<()> { http_listener, shutdown.clone(), Arc::new(CancellationHandlerMain::new( + &config.connect_to_compute, Arc::new(DashMap::new()), None, proxy::metrics::CancellationSource::Local, @@ -268,6 +273,22 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, }; + // local_proxy won't use TLS to talk to postgres. + let root_store = RootCertStore::empty(); + + let client_config = + rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) + .with_safe_default_protocol_versions() + .expect("ring should support the default protocol versions") + .with_root_certificates(root_store) + .with_no_client_auth(); + + let compute_config = ComputeConfig { + retry: RetryConfig::parse(RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)?, + tls: Arc::new(client_config), + timeout: Duration::from_secs(2), + }; + Ok(Box::leak(Box::new(ProxyConfig { tls_config: None, metric_collection: None, @@ -289,9 +310,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig region: "local".into(), wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, connect_compute_locks, - connect_to_compute_retry_config: RetryConfig::parse( - RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES, - )?, + connect_to_compute: compute_config, }))) } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 3dcf9ca060c0..1dace2ec8f97 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,14 +1,15 @@ use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; +use std::time::Duration; -use anyhow::bail; +use anyhow::{bail, Context}; use futures::future::Either; use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned}; use proxy::cancellation::{CancelMap, CancellationHandler}; use proxy::config::{ - self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, HttpConfig, + self, remote_storage_from_toml, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2, }; use proxy::context::parquet::ParquetUploadArgs; @@ -25,6 +26,7 @@ use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::GlobalConnPoolOptions; use proxy::{auth, control_plane, http, serverless, usage_metrics}; use remote_storage::RemoteStorageConfig; +use rustls::crypto::ring; use tokio::net::TcpListener; use tokio::sync::Mutex; use tokio::task::JoinSet; @@ -397,6 +399,7 @@ async fn main() -> anyhow::Result<()> { let cancellation_handler = Arc::new(CancellationHandler::< Option>>, >::new( + &config.connect_to_compute, cancel_map.clone(), redis_publisher, proxy::metrics::CancellationSource::FromClient, @@ -492,6 +495,7 @@ async fn main() -> anyhow::Result<()> { let cache = api.caches.project_info.clone(); if let Some(client) = client1 { maintenance_tasks.spawn(notifications::task_main( + config, client, cache.clone(), cancel_map.clone(), @@ -500,6 +504,7 @@ async fn main() -> anyhow::Result<()> { } if let Some(client) = client2 { maintenance_tasks.spawn(notifications::task_main( + config, client, cache.clone(), cancel_map.clone(), @@ -632,6 +637,23 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { console_redirect_confirmation_timeout: args.webauth_confirmation_timeout, }; + let root_store = load_certs() + .context("loading native tls certificates")? + .clone(); + + let client_config = + rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) + .with_safe_default_protocol_versions() + .expect("ring should support the default protocol versions") + .with_root_certificates(root_store) + .with_no_client_auth(); + + let compute_config = ComputeConfig { + retry: config::RetryConfig::parse(&args.connect_to_compute_retry)?, + tls: Arc::new(client_config), + timeout: Duration::from_secs(2), + }; + let config = ProxyConfig { tls_config, metric_collection, @@ -642,9 +664,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { region: args.region.clone(), wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, connect_compute_locks, - connect_to_compute_retry_config: config::RetryConfig::parse( - &args.connect_to_compute_retry, - )?, + connect_to_compute: compute_config, }; let config = Box::leak(Box::new(config)); @@ -654,6 +674,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { Ok(config) } +pub(crate) fn load_certs() -> anyhow::Result> { + let der_certs = rustls_native_certs::load_native_certs(); + + if !der_certs.errors.is_empty() { + bail!("could not parse certificates: {:?}", der_certs.errors); + } + + let mut store = rustls::RootCertStore::empty(); + store.add_parsable_certificates(der_certs.certs); + Ok(Arc::new(store)) +} + /// auth::Backend is created at proxy startup, and lives forever. fn build_auth_backend( args: &ProxyCliArgs, diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index ebaea173ae48..e989f4bbd1e5 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -3,11 +3,9 @@ use std::sync::Arc; use dashmap::DashMap; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; -use once_cell::sync::OnceCell; use postgres_client::tls::MakeTlsConnect; use postgres_client::CancelToken; use pq_proto::CancelKeyData; -use rustls::crypto::ring; use thiserror::Error; use tokio::net::TcpStream; use tokio::sync::Mutex; @@ -15,7 +13,7 @@ use tracing::{debug, info}; use uuid::Uuid; use crate::auth::{check_peer_addr_is_in_list, IpPattern}; -use crate::compute::load_certs; +use crate::config::ComputeConfig; use crate::error::ReportableError; use crate::ext::LockExt; use crate::metrics::{CancellationRequest, CancellationSource, Metrics}; @@ -35,6 +33,7 @@ type IpSubnetKey = IpNet; /// /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances. pub struct CancellationHandler

{ + compute_config: &'static ComputeConfig, map: CancelMap, client: P, /// This field used for the monitoring purposes. @@ -183,7 +182,7 @@ impl CancellationHandler

{ "cancelling query per user's request using key {key}, hostname {}, address: {}", cancel_closure.hostname, cancel_closure.socket_addr ); - cancel_closure.try_cancel_query().await + cancel_closure.try_cancel_query(self.compute_config).await } #[cfg(test)] @@ -198,8 +197,13 @@ impl CancellationHandler

{ } impl CancellationHandler<()> { - pub fn new(map: CancelMap, from: CancellationSource) -> Self { + pub fn new( + compute_config: &'static ComputeConfig, + map: CancelMap, + from: CancellationSource, + ) -> Self { Self { + compute_config, map, client: (), from, @@ -214,8 +218,14 @@ impl CancellationHandler<()> { } impl CancellationHandler>>> { - pub fn new(map: CancelMap, client: Option>>, from: CancellationSource) -> Self { + pub fn new( + compute_config: &'static ComputeConfig, + map: CancelMap, + client: Option>>, + from: CancellationSource, + ) -> Self { Self { + compute_config, map, client, from, @@ -229,8 +239,6 @@ impl CancellationHandler>>> { } } -static TLS_ROOTS: OnceCell> = OnceCell::new(); - /// This should've been a [`std::future::Future`], but /// it's impossible to name a type of an unboxed future /// (we'd need something like `#![feature(type_alias_impl_trait)]`). @@ -257,27 +265,13 @@ impl CancelClosure { } } /// Cancels the query running on user's compute node. - pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> { + pub(crate) async fn try_cancel_query( + self, + compute_config: &ComputeConfig, + ) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; - let root_store = TLS_ROOTS - .get_or_try_init(load_certs) - .map_err(|_e| { - CancelError::IO(std::io::Error::new( - std::io::ErrorKind::Other, - "TLS root store initialization failed".to_string(), - )) - })? - .clone(); - - let client_config = - rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_safe_default_protocol_versions() - .expect("ring should support the default protocol versions") - .with_root_certificates(root_store) - .with_no_client_auth(); - - let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config); + let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); let tls = >::make_tls_connect( &mut mk_tls, &self.hostname, @@ -329,11 +323,41 @@ impl

Drop for Session

{ #[cfg(test)] #[expect(clippy::unwrap_used)] mod tests { + use std::time::Duration; + + use rustls::crypto::ring; + use rustls::RootCertStore; + use super::*; + use crate::config::RetryConfig; + + fn config() -> ComputeConfig { + let retry = RetryConfig { + base_delay: Duration::from_secs(1), + max_retries: 5, + backoff_factor: 2.0, + }; + + let root_store = RootCertStore::empty(); + + let client_config = + rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) + .with_safe_default_protocol_versions() + .expect("ring should support the default protocol versions") + .with_root_certificates(root_store) + .with_no_client_auth(); + + ComputeConfig { + retry, + tls: Arc::new(client_config), + timeout: Duration::from_secs(2), + } + } #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { let cancellation_handler = Arc::new(CancellationHandler::<()>::new( + Box::leak(Box::new(config())), CancelMap::default(), CancellationSource::FromRedis, )); @@ -349,8 +373,11 @@ mod tests { #[tokio::test] async fn cancel_session_noop_regression() { - let handler = - CancellationHandler::<()>::new(CancelMap::default(), CancellationSource::Local); + let handler = CancellationHandler::<()>::new( + Box::leak(Box::new(config())), + CancelMap::default(), + CancellationSource::Local, + ); handler .cancel_session( CancelKeyData { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 8dc9b59e81c5..17588b9c3469 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,16 +1,13 @@ use std::io; use std::net::SocketAddr; -use std::sync::Arc; use std::time::Duration; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; -use once_cell::sync::OnceCell; use postgres_client::tls::MakeTlsConnect; use postgres_client::{CancelToken, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; use pq_proto::StartupMessageParams; -use rustls::crypto::ring; use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::TcpStream; @@ -18,6 +15,7 @@ use tracing::{debug, error, info, warn}; use crate::auth::parse_endpoint_param; use crate::cancellation::CancelClosure; +use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::WakeComputeError; @@ -40,9 +38,6 @@ pub(crate) enum ConnectionError { #[error("{COULD_NOT_CONNECT}: {0}")] CouldNotConnect(#[from] io::Error), - #[error("Couldn't load native TLS certificates: {0:?}")] - TlsCertificateError(Vec), - #[error("{COULD_NOT_CONNECT}: {0}")] TlsError(#[from] InvalidDnsNameError), @@ -89,7 +84,6 @@ impl ReportableError for ConnectionError { } ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute, - ConnectionError::TlsCertificateError(_) => crate::error::ErrorKind::Service, ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), @@ -251,25 +245,13 @@ impl ConnCfg { &self, ctx: &RequestContext, aux: MetricsAuxInfo, - timeout: Duration, + config: &ComputeConfig, ) -> Result { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (socket_addr, stream, host) = self.connect_raw(timeout).await?; + let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?; drop(pause); - let root_store = TLS_ROOTS - .get_or_try_init(load_certs) - .map_err(ConnectionError::TlsCertificateError)? - .clone(); - - let client_config = - rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) - .with_safe_default_protocol_versions() - .expect("ring should support the default protocol versions") - .with_root_certificates(root_store) - .with_no_client_auth(); - - let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config); + let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(config.tls.clone()); let tls = >::make_tls_connect( &mut mk_tls, host, @@ -341,19 +323,6 @@ fn filtered_options(options: &str) -> Option { Some(options) } -pub(crate) fn load_certs() -> Result, Vec> { - let der_certs = rustls_native_certs::load_native_certs(); - - if !der_certs.errors.is_empty() { - return Err(der_certs.errors); - } - - let mut store = rustls::RootCertStore::empty(); - store.add_parsable_certificates(der_certs.certs); - Ok(Arc::new(store)) -} -static TLS_ROOTS: OnceCell> = OnceCell::new(); - #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 33d1d2e9e4a0..351ff6d12da1 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -32,7 +32,13 @@ pub struct ProxyConfig { pub handshake_timeout: Duration, pub wake_compute_retry_config: RetryConfig, pub connect_compute_locks: ApiLocks, - pub connect_to_compute_retry_config: RetryConfig, + pub connect_to_compute: ComputeConfig, +} + +pub struct ComputeConfig { + pub retry: RetryConfig, + pub tls: Arc, + pub timeout: Duration, } #[derive(Copy, Clone, Debug, ValueEnum, PartialEq)] diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index c477822e853c..25a549039ccc 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -115,7 +115,7 @@ pub async fn task_main( Ok(Some(p)) => { ctx.set_success(); let _disconnect = ctx.log_connect(); - match p.proxy_pass().await { + match p.proxy_pass(&config.connect_to_compute).await { Ok(()) => {} Err(ErrorSource::Client(e)) => { error!(?session_id, "per-client task finished with an IO error from the client: {e:#}"); @@ -216,7 +216,7 @@ pub(crate) async fn handle_client( }, &user_info, config.wake_compute_retry_config, - config.connect_to_compute_retry_config, + &config.connect_to_compute, ) .or_else(|e| stream.throw_error(e)) .await?; diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 0ca1a6aae0eb..c65041df0e37 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -10,13 +10,13 @@ pub mod client; pub(crate) mod errors; use std::sync::Arc; -use std::time::Duration; use crate::auth::backend::jwt::AuthRule; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::IpPattern; use crate::cache::project_info::ProjectInfoCacheImpl; use crate::cache::{Cached, TimedLru}; +use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; use crate::intern::ProjectIdInt; @@ -73,9 +73,9 @@ impl NodeInfo { pub(crate) async fn connect( &self, ctx: &RequestContext, - timeout: Duration, + config: &ComputeConfig, ) -> Result { - self.config.connect(ctx, self.aux.clone(), timeout).await + self.config.connect(ctx, self.aux.clone(), config).await } pub(crate) fn reuse_settings(&mut self, other: Self) { diff --git a/proxy/src/postgres_rustls/mod.rs b/proxy/src/postgres_rustls/mod.rs index 5ef20991c309..abf48d6f8261 100644 --- a/proxy/src/postgres_rustls/mod.rs +++ b/proxy/src/postgres_rustls/mod.rs @@ -126,16 +126,14 @@ mod private { /// That way you can connect to PostgreSQL using `rustls` as the TLS stack. #[derive(Clone)] pub struct MakeRustlsConnect { - config: Arc, + pub config: Arc, } impl MakeRustlsConnect { /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. #[must_use] - pub fn new(config: ClientConfig) -> Self { - Self { - config: Arc::new(config), - } + pub fn new(config: Arc) -> Self { + Self { config } } } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 4a30d2398558..8a804948606b 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -6,7 +6,7 @@ use tracing::{debug, info, warn}; use super::retry::ShouldRetryWakeCompute; use crate::auth::backend::ComputeCredentialKeys; use crate::compute::{self, PostgresConnection, COULD_NOT_CONNECT}; -use crate::config::RetryConfig; +use crate::config::{ComputeConfig, RetryConfig}; use crate::context::RequestContext; use crate::control_plane::errors::WakeComputeError; use crate::control_plane::locks::ApiLocks; @@ -19,8 +19,6 @@ use crate::proxy::retry::{retry_after, should_retry, CouldRetry}; use crate::proxy::wake_compute::wake_compute; use crate::types::Host; -const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); - /// If we couldn't connect, a cached connection info might be to blame /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. @@ -49,7 +47,7 @@ pub(crate) trait ConnectMechanism { &self, ctx: &RequestContext, node_info: &control_plane::CachedNodeInfo, - timeout: time::Duration, + config: &ComputeConfig, ) -> Result; fn update_connect_config(&self, conf: &mut compute::ConnCfg); @@ -86,11 +84,11 @@ impl ConnectMechanism for TcpMechanism<'_> { &self, ctx: &RequestContext, node_info: &control_plane::CachedNodeInfo, - timeout: time::Duration, + config: &ComputeConfig, ) -> Result { let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; - permit.release_result(node_info.connect(ctx, timeout).await) + permit.release_result(node_info.connect(ctx, config).await) } fn update_connect_config(&self, config: &mut compute::ConnCfg) { @@ -105,7 +103,7 @@ pub(crate) async fn connect_to_compute Result where M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug, @@ -119,10 +117,7 @@ where mechanism.update_connect_config(&mut node_info.config); // try once - let err = match mechanism - .connect_once(ctx, &node_info, CONNECT_TIMEOUT) - .await - { + let err = match mechanism.connect_once(ctx, &node_info, compute).await { Ok(res) => { ctx.success(); Metrics::get().proxy.retries_metric.observe( @@ -142,7 +137,7 @@ where let node_info = if !node_info.cached() || !err.should_retry_wake_compute() { // If we just recieved this from cplane and didn't get it from cache, we shouldn't retry. // Do not need to retrieve a new node_info, just return the old one. - if should_retry(&err, num_retries, connect_to_compute_retry_config) { + if should_retry(&err, num_retries, compute.retry) { Metrics::get().proxy.retries_metric.observe( RetriesMetricGroup { outcome: ConnectOutcome::Failed, @@ -172,10 +167,7 @@ where debug!("wake_compute success. attempting to connect"); num_retries = 1; loop { - match mechanism - .connect_once(ctx, &node_info, CONNECT_TIMEOUT) - .await - { + match mechanism.connect_once(ctx, &node_info, compute).await { Ok(res) => { ctx.success(); Metrics::get().proxy.retries_metric.observe( @@ -190,7 +182,7 @@ where return Ok(res); } Err(e) => { - if !should_retry(&e, num_retries, connect_to_compute_retry_config) { + if !should_retry(&e, num_retries, compute.retry) { // Don't log an error here, caller will print the error Metrics::get().proxy.retries_metric.observe( RetriesMetricGroup { @@ -206,7 +198,7 @@ where } }; - let wait_duration = retry_after(num_retries, connect_to_compute_retry_config); + let wait_duration = retry_after(num_retries, compute.retry); num_retries += 1; let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index dbe174cab7d5..3926c56fecc5 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -152,7 +152,7 @@ pub async fn task_main( Ok(Some(p)) => { ctx.set_success(); let _disconnect = ctx.log_connect(); - match p.proxy_pass().await { + match p.proxy_pass(&config.connect_to_compute).await { Ok(()) => {} Err(ErrorSource::Client(e)) => { warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}"); @@ -351,7 +351,7 @@ pub(crate) async fn handle_client( }, &user_info, config.wake_compute_retry_config, - config.connect_to_compute_retry_config, + &config.connect_to_compute, ) .or_else(|e| stream.throw_error(e)) .await?; diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index dcaa81e5cdda..a42f9aad39a7 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -5,6 +5,7 @@ use utils::measured_stream::MeasuredStream; use super::copy_bidirectional::ErrorSource; use crate::cancellation; use crate::compute::PostgresConnection; +use crate::config::ComputeConfig; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; use crate::stream::Stream; @@ -67,9 +68,17 @@ pub(crate) struct ProxyPassthrough { } impl ProxyPassthrough { - pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { + pub(crate) async fn proxy_pass( + self, + compute_config: &ComputeConfig, + ) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; - if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { + if let Err(err) = self + .compute + .cancel_closure + .try_cancel_query(compute_config) + .await + { tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); } res diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 95c518fed9c2..5dc13982bd06 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -13,7 +13,7 @@ use postgres_client::tls::{MakeTlsConnect, NoTls}; use retry::{retry_after, ShouldRetryWakeCompute}; use rstest::rstest; use rustls::crypto::ring; -use rustls::pki_types; +use rustls::{pki_types, RootCertStore}; use tokio::io::DuplexStream; use super::connect_compute::ConnectMechanism; @@ -22,7 +22,7 @@ use super::*; use crate::auth::backend::{ ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, }; -use crate::config::{CertResolver, RetryConfig}; +use crate::config::{CertResolver, ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{ @@ -67,7 +67,7 @@ fn generate_certs( } struct ClientConfig<'a> { - config: rustls::ClientConfig, + config: Arc, hostname: &'a str, } @@ -120,6 +120,7 @@ fn generate_tls_config<'a>( store }) .with_no_client_auth(); + let config = Arc::new(config); ClientConfig { config, hostname } }; @@ -468,7 +469,7 @@ impl ConnectMechanism for TestConnectMechanism { &self, _ctx: &RequestContext, _node_info: &control_plane::CachedNodeInfo, - _timeout: std::time::Duration, + _config: &ComputeConfig, ) -> Result { let mut counter = self.counter.lock().unwrap(); let action = self.sequence[*counter]; @@ -576,6 +577,29 @@ fn helper_create_connect_info( user_info } +fn config() -> ComputeConfig { + let retry = RetryConfig { + base_delay: Duration::from_secs(1), + max_retries: 5, + backoff_factor: 2.0, + }; + + let root_store = RootCertStore::empty(); + + let client_config = + rustls::ClientConfig::builder_with_provider(Arc::new(ring::default_provider())) + .with_safe_default_protocol_versions() + .expect("ring should support the default protocol versions") + .with_root_certificates(root_store) + .with_no_client_auth(); + + ComputeConfig { + retry, + tls: Arc::new(client_config), + timeout: Duration::from_secs(2), + } +} + #[tokio::test] async fn connect_to_compute_success() { let _ = env_logger::try_init(); @@ -583,12 +607,8 @@ async fn connect_to_compute_success() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -601,12 +621,8 @@ async fn connect_to_compute_retry() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -620,12 +636,8 @@ async fn connect_to_compute_non_retry_1() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); @@ -639,12 +651,8 @@ async fn connect_to_compute_non_retry_2() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -665,17 +673,13 @@ async fn connect_to_compute_non_retry_3() { max_retries: 1, backoff_factor: 2.0, }; - let connect_to_compute_retry_config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; + let config = config(); connect_to_compute( &ctx, &mechanism, &user_info, wake_compute_retry_config, - connect_to_compute_retry_config, + &config, ) .await .unwrap_err(); @@ -690,12 +694,8 @@ async fn wake_retry() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -709,12 +709,8 @@ async fn wake_non_retry() { let ctx = RequestContext::test(); let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); let user_info = helper_create_connect_info(&mechanism); - let config = RetryConfig { - base_delay: Duration::from_secs(1), - max_retries: 5, - backoff_factor: 2.0, - }; - connect_to_compute(&ctx, &mechanism, &user_info, config, config) + let config = config(); + connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index d18dfd246556..80b93b6c4fdb 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -12,6 +12,7 @@ use uuid::Uuid; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; use crate::cancellation::{CancelMap, CancellationHandler}; +use crate::config::ProxyConfig; use crate::intern::{ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; @@ -249,6 +250,7 @@ async fn handle_messages( /// Handle console's invalidation messages. #[tracing::instrument(name = "redis_notifications", skip_all)] pub async fn task_main( + config: &'static ProxyConfig, redis: ConnectionWithCredentialsProvider, cache: Arc, cancel_map: CancelMap, @@ -258,6 +260,7 @@ where C: ProjectInfoCache + Send + Sync + 'static, { let cancellation_handler = Arc::new(CancellationHandler::<()>::new( + &config.connect_to_compute, cancel_map, crate::metrics::CancellationSource::FromRedis, )); diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 449d50b6e78b..b398c3ddd07b 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -22,7 +22,7 @@ use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; -use crate::config::ProxyConfig; +use crate::config::{ComputeConfig, ProxyConfig}; use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; @@ -196,7 +196,7 @@ impl PoolingBackend { }, &backend, self.config.wake_compute_retry_config, - self.config.connect_to_compute_retry_config, + &self.config.connect_to_compute, ) .await } @@ -237,7 +237,7 @@ impl PoolingBackend { }, &backend, self.config.wake_compute_retry_config, - self.config.connect_to_compute_retry_config, + &self.config.connect_to_compute, ) .await } @@ -502,7 +502,7 @@ impl ConnectMechanism for TokioMechanism { &self, ctx: &RequestContext, node_info: &CachedNodeInfo, - timeout: Duration, + compute_config: &ComputeConfig, ) -> Result { let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; @@ -511,7 +511,7 @@ impl ConnectMechanism for TokioMechanism { let config = config .user(&self.conn_info.user_info.user) .dbname(&self.conn_info.dbname) - .connect_timeout(timeout); + .connect_timeout(compute_config.timeout); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let res = config.connect(postgres_client::NoTls).await; @@ -552,7 +552,7 @@ impl ConnectMechanism for HyperMechanism { &self, ctx: &RequestContext, node_info: &CachedNodeInfo, - timeout: Duration, + config: &ComputeConfig, ) -> Result { let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; @@ -560,7 +560,7 @@ impl ConnectMechanism for HyperMechanism { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); let port = node_info.config.get_port(); - let res = connect_http2(&host, port, timeout).await; + let res = connect_http2(&host, port, config.timeout).await; drop(pause); let (client, connection) = permit.release_result(res)?; diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 812fedaf0422..47326c11815d 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -168,7 +168,7 @@ pub(crate) async fn serve_websocket( Ok(Some(p)) => { ctx.set_success(); ctx.log_connect(); - match p.proxy_pass().await { + match p.proxy_pass(&config.connect_to_compute).await { Ok(()) => Ok(()), Err(ErrorSource::Client(err)) => Err(err).context("client"), Err(ErrorSource::Compute(err)) => Err(err).context("compute"),