From 75434060a51861559e082d9b1359ae974b2cb2ab Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 9 Oct 2024 18:24:10 +0100 Subject: [PATCH] local_proxy: integrate with pg_session_jwt extension (#9086) --- Cargo.lock | 18 + proxy/Cargo.toml | 4 +- proxy/src/auth/backend/jwt.rs | 12 +- proxy/src/auth/backend/mod.rs | 2 + proxy/src/control_plane/provider/mod.rs | 2 +- proxy/src/serverless/backend.rs | 97 ++++- proxy/src/serverless/local_conn_pool.rs | 544 ++++++++++++++++++++++++ proxy/src/serverless/mod.rs | 3 + proxy/src/serverless/sql_over_http.rs | 82 +++- workspace_hack/Cargo.toml | 4 +- 10 files changed, 741 insertions(+), 27 deletions(-) create mode 100644 proxy/src/serverless/local_conn_pool.rs diff --git a/Cargo.lock b/Cargo.lock index 865fb3388960..5edf5cf7b4d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1820,6 +1820,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" dependencies = [ "base16ct 0.2.0", + "base64ct", "crypto-bigint 0.5.5", "digest", "ff 0.13.0", @@ -1829,6 +1830,8 @@ dependencies = [ "pkcs8 0.10.2", "rand_core 0.6.4", "sec1 0.7.3", + "serde_json", + "serdect", "subtle", "zeroize", ] @@ -4037,6 +4040,8 @@ dependencies = [ "bytes", "fallible-iterator", "postgres-protocol", + "serde", + "serde_json", ] [[package]] @@ -5256,6 +5261,7 @@ dependencies = [ "der 0.7.8", "generic-array", "pkcs8 0.10.2", + "serdect", "subtle", "zeroize", ] @@ -5510,6 +5516,16 @@ dependencies = [ "syn 2.0.52", ] +[[package]] +name = "serdect" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a84f14a19e9a014bb9f4512488d9829a68e04ecabffb0f9904cd1ace94598177" +dependencies = [ + "base16ct 0.2.0", + "serde", +] + [[package]] name = "sha1" version = "0.10.5" @@ -7302,6 +7318,7 @@ dependencies = [ "num-traits", "once_cell", "parquet", + "postgres-types", "prettyplease", "proc-macro2", "prost", @@ -7326,6 +7343,7 @@ dependencies = [ "time", "time-macros", "tokio", + "tokio-postgres", "tokio-stream", "tokio-util", "toml_edit", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index c995ac36607e..963fb94a7de9 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -77,7 +77,7 @@ subtle.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] } -tokio-postgres.workspace = true +tokio-postgres = { workspace = true, features = ["with-serde_json-1"] } tokio-postgres-rustls.workspace = true tokio-rustls.workspace = true tokio-util.workspace = true @@ -101,7 +101,7 @@ jose-jwa = "0.1.2" jose-jwk = { version = "0.1.2", features = ["p256", "p384", "rsa"] } signature = "2" ecdsa = "0.16" -p256 = "0.13" +p256 = { version = "0.13", features = ["jwk"] } rsa = "0.9" workspace_hack.workspace = true diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 0c66fe5381f1..17ab7eda2245 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -17,6 +17,8 @@ use crate::{ RoleName, }; +use super::ComputeCredentialKeys; + // TODO(conrad): make these configurable. const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30); const MIN_RENEW: Duration = Duration::from_secs(30); @@ -241,7 +243,7 @@ impl JwkCacheEntryLock { endpoint: EndpointId, role_name: &RoleName, fetch: &F, - ) -> Result<(), anyhow::Error> { + ) -> Result { // JWT compact form is defined to be // || . || || . || // where Signature = alg( || . || ); @@ -300,9 +302,9 @@ impl JwkCacheEntryLock { key => bail!("unsupported key type {key:?}"), }; - let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD) + let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD) .context("Provided authentication token is not a valid JWT encoding")?; - let payload = serde_json::from_slice::>(&payload) + let payload = serde_json::from_slice::>(&payloadb) .context("Provided authentication token is not a valid JWT encoding")?; tracing::debug!(?payload, "JWT signature valid with claims"); @@ -327,7 +329,7 @@ impl JwkCacheEntryLock { ); } - Ok(()) + Ok(ComputeCredentialKeys::JwtPayload(payloadb)) } } @@ -339,7 +341,7 @@ impl JwkCache { role_name: &RoleName, fetch: &F, jwt: &str, - ) -> Result<(), anyhow::Error> { + ) -> Result { // try with just a read lock first let key = (endpoint.clone(), role_name.clone()); let entry = self.map.get(&key).as_deref().map(Arc::clone); diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 78766193ad02..c9aa5b7e618d 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -175,10 +175,12 @@ impl ComputeUserInfo { } } +#[cfg_attr(test, derive(Debug))] pub(crate) enum ComputeCredentialKeys { #[cfg(any(test, feature = "testing"))] Password(Vec), AuthKeys(AuthKeys), + JwtPayload(Vec), None, } diff --git a/proxy/src/control_plane/provider/mod.rs b/proxy/src/control_plane/provider/mod.rs index 566841535ede..01d93dee43f4 100644 --- a/proxy/src/control_plane/provider/mod.rs +++ b/proxy/src/control_plane/provider/mod.rs @@ -309,7 +309,7 @@ impl NodeInfo { #[cfg(any(test, feature = "testing"))] ComputeCredentialKeys::Password(password) => self.config.password(password), ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), - ComputeCredentialKeys::None => &mut self.config, + ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config, }; } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 764b97fb7be5..8a8f38d18169 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -3,10 +3,12 @@ use std::{io, sync::Arc, time::Duration}; use async_trait::async_trait; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use tokio::net::{lookup_host, TcpStream}; -use tracing::{field::display, info}; +use tokio_postgres::types::ToSql; +use tracing::{debug, field::display, info}; use crate::{ auth::{ + self, backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo}, check_peer_addr_is_in_list, AuthError, }, @@ -32,10 +34,12 @@ use crate::{ use super::{ conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}, http_conn_pool::{self, poll_http2_client}, + local_conn_pool::{self, LocalClient, LocalConnPool}, }; pub(crate) struct PoolingBackend { pub(crate) http_conn_pool: Arc, + pub(crate) local_pool: Arc>, pub(crate) pool: Arc>, pub(crate) config: &'static ProxyConfig, pub(crate) endpoint_rate_limiter: Arc, @@ -112,7 +116,7 @@ impl PoolingBackend { config: &AuthenticationConfig, user_info: &ComputeUserInfo, jwt: String, - ) -> Result<(), AuthError> { + ) -> Result { match &self.config.auth_backend { crate::auth::Backend::ControlPlane(console, ()) => { config @@ -127,13 +131,16 @@ impl PoolingBackend { .await .map_err(|e| AuthError::auth_failed(e.to_string()))?; - Ok(()) + Ok(ComputeCredentials { + info: user_info.clone(), + keys: crate::auth::backend::ComputeCredentialKeys::None, + }) } crate::auth::Backend::ConsoleRedirect(_, ()) => Err(AuthError::auth_failed( "JWT login over web auth proxy is not supported", )), crate::auth::Backend::Local(_) => { - config + let keys = config .jwks_cache .check_jwt( ctx, @@ -145,8 +152,10 @@ impl PoolingBackend { .await .map_err(|e| AuthError::auth_failed(e.to_string()))?; - // todo: rewrite JWT signature with key shared somehow between local proxy and postgres - Ok(()) + Ok(ComputeCredentials { + info: user_info.clone(), + keys, + }) } } } @@ -231,6 +240,77 @@ impl PoolingBackend { ) .await } + + /// Connect to postgres over localhost. + /// + /// We expect postgres to be started here, so we won't do any retries. + /// + /// # Panics + /// + /// Panics if called with a non-local_proxy backend. + #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] + pub(crate) async fn connect_to_local_postgres( + &self, + ctx: &RequestMonitoring, + conn_info: ConnInfo, + ) -> Result, HttpConnError> { + if let Some(client) = self.local_pool.get(ctx, &conn_info)? { + return Ok(client); + } + + let conn_id = uuid::Uuid::new_v4(); + tracing::Span::current().record("conn_id", display(conn_id)); + info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); + + let mut node_info = match &self.config.auth_backend { + auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => { + unreachable!("only local_proxy can connect to local postgres") + } + auth::Backend::Local(local) => local.node_info.clone(), + }; + + let config = node_info + .config + .user(&conn_info.user_info.user) + .dbname(&conn_info.dbname); + + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + let (client, connection) = config.connect(tokio_postgres::NoTls).await?; + drop(pause); + + tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); + + let handle = local_conn_pool::poll_client( + self.local_pool.clone(), + ctx, + conn_info, + client, + connection, + conn_id, + node_info.aux.clone(), + ); + + let kid = handle.get_client().get_process_id() as i64; + let jwk = p256::PublicKey::from(handle.key().verifying_key()).to_jwk(); + + debug!(kid, ?jwk, "setting up backend session state"); + + // initiates the auth session + handle + .get_client() + .query( + "select auth.init($1, $2);", + &[ + &kid as &(dyn ToSql + Sync), + &tokio_postgres::types::Json(jwk), + ], + ) + .await?; + + info!(?kid, "backend session state init"); + + Ok(handle) + } } #[derive(Debug, thiserror::Error)] @@ -241,6 +321,8 @@ pub(crate) enum HttpConnError { PostgresConnectionError(#[from] tokio_postgres::Error), #[error("could not connection to local-proxy in compute")] LocalProxyConnectionError(#[from] LocalProxyConnError), + #[error("could not parse JWT payload")] + JwtPayloadError(serde_json::Error), #[error("could not get auth info")] GetAuthInfo(#[from] GetAuthInfoError), @@ -266,6 +348,7 @@ impl ReportableError for HttpConnError { HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, HttpConnError::PostgresConnectionError(p) => p.get_error_kind(), HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute, + HttpConnError::JwtPayloadError(_) => ErrorKind::User, HttpConnError::GetAuthInfo(a) => a.get_error_kind(), HttpConnError::AuthError(a) => a.get_error_kind(), HttpConnError::WakeCompute(w) => w.get_error_kind(), @@ -280,6 +363,7 @@ impl UserFacingError for HttpConnError { HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), HttpConnError::PostgresConnectionError(p) => p.to_string(), HttpConnError::LocalProxyConnectionError(p) => p.to_string(), + HttpConnError::JwtPayloadError(p) => p.to_string(), HttpConnError::GetAuthInfo(c) => c.to_string_client(), HttpConnError::AuthError(c) => c.to_string_client(), HttpConnError::WakeCompute(c) => c.to_string_client(), @@ -296,6 +380,7 @@ impl CouldRetry for HttpConnError { HttpConnError::PostgresConnectionError(e) => e.could_retry(), HttpConnError::LocalProxyConnectionError(e) => e.could_retry(), HttpConnError::ConnectionClosedAbruptly(_) => false, + HttpConnError::JwtPayloadError(_) => false, HttpConnError::GetAuthInfo(_) => false, HttpConnError::AuthError(_) => false, HttpConnError::WakeCompute(_) => false, diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs new file mode 100644 index 000000000000..1dde5952e103 --- /dev/null +++ b/proxy/src/serverless/local_conn_pool.rs @@ -0,0 +1,544 @@ +use futures::{future::poll_fn, Future}; +use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding}; +use p256::ecdsa::{Signature, SigningKey}; +use parking_lot::RwLock; +use rand::rngs::OsRng; +use serde_json::Value; +use signature::Signer; +use std::task::{ready, Poll}; +use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; +use tokio::time::Instant; +use tokio_postgres::tls::NoTlsStream; +use tokio_postgres::types::ToSql; +use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; +use tokio_util::sync::CancellationToken; +use typed_json::json; + +use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; +use crate::metrics::Metrics; +use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; +use crate::{context::RequestMonitoring, DbName, RoleName}; + +use tracing::{debug, error, warn, Span}; +use tracing::{info, info_span, Instrument}; + +use super::backend::HttpConnError; +use super::conn_pool::{ClientInnerExt, ConnInfo}; + +struct ConnPoolEntry { + conn: ClientInner, + _last_access: std::time::Instant, +} + +// /// key id for the pg_session_jwt state +// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1); + +// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool +// Number of open connections is limited by the `max_conns_per_endpoint`. +pub(crate) struct EndpointConnPool { + pools: HashMap<(DbName, RoleName), DbUserConnPool>, + total_conns: usize, + max_conns: usize, + global_pool_size_max_conns: usize, +} + +impl EndpointConnPool { + fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { + let Self { + pools, total_conns, .. + } = self; + pools + .get_mut(&db_user) + .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns)) + } + + fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool { + let Self { + pools, total_conns, .. + } = self; + if let Some(pool) = pools.get_mut(&db_user) { + let old_len = pool.conns.len(); + pool.conns.retain(|conn| conn.conn.conn_id != conn_id); + let new_len = pool.conns.len(); + let removed = old_len - new_len; + if removed > 0 { + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(removed as i64); + } + *total_conns -= removed; + removed > 0 + } else { + false + } + } + + fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) { + let conn_id = client.conn_id; + + if client.is_closed() { + info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed"); + return; + } + let global_max_conn = pool.read().global_pool_size_max_conns; + if pool.read().total_conns >= global_max_conn { + info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full"); + return; + } + + // return connection to the pool + let mut returned = false; + let mut per_db_size = 0; + let total_conns = { + let mut pool = pool.write(); + + if pool.total_conns < pool.max_conns { + let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); + pool_entries.conns.push(ConnPoolEntry { + conn: client, + _last_access: std::time::Instant::now(), + }); + + returned = true; + per_db_size = pool_entries.conns.len(); + + pool.total_conns += 1; + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .inc(); + } + + pool.total_conns + }; + + // do logging outside of the mutex + if returned { + info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); + } else { + info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); + } + } +} + +impl Drop for EndpointConnPool { + fn drop(&mut self) { + if self.total_conns > 0 { + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(self.total_conns as i64); + } + } +} + +pub(crate) struct DbUserConnPool { + conns: Vec>, +} + +impl Default for DbUserConnPool { + fn default() -> Self { + Self { conns: Vec::new() } + } +} + +impl DbUserConnPool { + fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { + let old_len = self.conns.len(); + + self.conns.retain(|conn| !conn.conn.is_closed()); + + let new_len = self.conns.len(); + let removed = old_len - new_len; + *conns -= removed; + removed + } + + fn get_conn_entry(&mut self, conns: &mut usize) -> Option> { + let mut removed = self.clear_closed_clients(conns); + let conn = self.conns.pop(); + if conn.is_some() { + *conns -= 1; + removed += 1; + } + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(removed as i64); + conn + } +} + +pub(crate) struct LocalConnPool { + global_pool: RwLock>, + + config: &'static crate::config::HttpConfig, +} + +impl LocalConnPool { + pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { + Arc::new(Self { + global_pool: RwLock::new(EndpointConnPool { + pools: HashMap::new(), + total_conns: 0, + max_conns: config.pool_options.max_conns_per_endpoint, + global_pool_size_max_conns: config.pool_options.max_total_conns, + }), + config, + }) + } + + pub(crate) fn get_idle_timeout(&self) -> Duration { + self.config.pool_options.idle_timeout + } + + // pub(crate) fn shutdown(&self) { + // let mut pool = self.global_pool.write(); + // pool.pools.clear(); + // pool.total_conns = 0; + // } + + pub(crate) fn get( + self: &Arc, + ctx: &RequestMonitoring, + conn_info: &ConnInfo, + ) -> Result>, HttpConnError> { + let mut client: Option> = None; + if let Some(entry) = self + .global_pool + .write() + .get_conn_entry(conn_info.db_and_user()) + { + client = Some(entry.conn); + } + + // ok return cached connection if found and establish a new one otherwise + if let Some(client) = client { + if client.is_closed() { + info!("local_pool: cached connection '{conn_info}' is closed, opening a new one"); + return Ok(None); + } + tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); + tracing::Span::current().record( + "pid", + tracing::field::display(client.inner.get_process_id()), + ); + info!( + cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), + "local_pool: reusing connection '{conn_info}'" + ); + client.session.send(ctx.session_id())?; + ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); + ctx.success(); + return Ok(Some(LocalClient::new( + client, + conn_info.clone(), + Arc::downgrade(self), + ))); + } + Ok(None) + } +} + +pub(crate) fn poll_client( + global_pool: Arc>, + ctx: &RequestMonitoring, + conn_info: ConnInfo, + client: tokio_postgres::Client, + mut connection: tokio_postgres::Connection, + conn_id: uuid::Uuid, + aux: MetricsAuxInfo, +) -> LocalClient { + let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); + let mut session_id = ctx.session_id(); + let (tx, mut rx) = tokio::sync::watch::channel(session_id); + + let span = info_span!(parent: None, "connection", %conn_id); + let cold_start_info = ctx.cold_start_info(); + span.in_scope(|| { + info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection"); + }); + let pool = Arc::downgrade(&global_pool); + let pool_clone = pool.clone(); + + let db_user = conn_info.db_and_user(); + let idle = global_pool.get_idle_timeout(); + let cancel = CancellationToken::new(); + let cancelled = cancel.clone().cancelled_owned(); + + tokio::spawn( + async move { + let _conn_gauge = conn_gauge; + let mut idle_timeout = pin!(tokio::time::sleep(idle)); + let mut cancelled = pin!(cancelled); + + poll_fn(move |cx| { + if cancelled.as_mut().poll(cx).is_ready() { + info!("connection dropped"); + return Poll::Ready(()) + } + + match rx.has_changed() { + Ok(true) => { + session_id = *rx.borrow_and_update(); + info!(%session_id, "changed session"); + idle_timeout.as_mut().reset(Instant::now() + idle); + } + Err(_) => { + info!("connection dropped"); + return Poll::Ready(()) + } + _ => {} + } + + // 5 minute idle connection timeout + if idle_timeout.as_mut().poll(cx).is_ready() { + idle_timeout.as_mut().reset(Instant::now() + idle); + info!("connection idle"); + if let Some(pool) = pool.clone().upgrade() { + // remove client from pool - should close the connection if it's idle. + // does nothing if the client is currently checked-out and in-use + if pool.global_pool.write().remove_client(db_user.clone(), conn_id) { + info!("idle connection removed"); + } + } + } + + loop { + let message = ready!(connection.poll_message(cx)); + + match message { + Some(Ok(AsyncMessage::Notice(notice))) => { + info!(%session_id, "notice: {}", notice); + } + Some(Ok(AsyncMessage::Notification(notif))) => { + warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); + } + Some(Ok(_)) => { + warn!(%session_id, "unknown message"); + } + Some(Err(e)) => { + error!(%session_id, "connection error: {}", e); + break + } + None => { + info!("connection closed"); + break + } + } + } + + // remove from connection pool + if let Some(pool) = pool.clone().upgrade() { + if pool.global_pool.write().remove_client(db_user.clone(), conn_id) { + info!("closed connection removed"); + } + } + + Poll::Ready(()) + }).await; + + } + .instrument(span)); + + let key = SigningKey::random(&mut OsRng); + + let inner = ClientInner { + inner: client, + session: tx, + cancel, + aux, + conn_id, + key, + jti: 0, + }; + LocalClient::new(inner, conn_info, pool_clone) +} + +struct ClientInner { + inner: C, + session: tokio::sync::watch::Sender, + cancel: CancellationToken, + aux: MetricsAuxInfo, + conn_id: uuid::Uuid, + + // needed for pg_session_jwt state + key: SigningKey, + jti: u64, +} + +impl Drop for ClientInner { + fn drop(&mut self) { + // on client drop, tell the conn to shut down + self.cancel.cancel(); + } +} + +impl ClientInner { + pub(crate) fn is_closed(&self) -> bool { + self.inner.is_closed() + } +} + +impl LocalClient { + pub(crate) fn metrics(&self) -> Arc { + let aux = &self.inner.as_ref().unwrap().aux; + USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id, + branch_id: aux.branch_id, + }) + } +} + +pub(crate) struct LocalClient { + span: Span, + inner: Option>, + conn_info: ConnInfo, + pool: Weak>, +} + +pub(crate) struct Discard<'a, C: ClientInnerExt> { + conn_info: &'a ConnInfo, + pool: &'a mut Weak>, +} + +impl LocalClient { + pub(self) fn new( + inner: ClientInner, + conn_info: ConnInfo, + pool: Weak>, + ) -> Self { + Self { + inner: Some(inner), + span: Span::current(), + conn_info, + pool, + } + } + pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) { + let Self { + inner, + pool, + conn_info, + span: _, + } = self; + let inner = inner.as_mut().expect("client inner should not be removed"); + (&mut inner.inner, Discard { conn_info, pool }) + } + pub(crate) fn key(&self) -> &SigningKey { + let inner = &self + .inner + .as_ref() + .expect("client inner should not be removed"); + &inner.key + } +} + +impl LocalClient { + pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> { + let inner = self + .inner + .as_mut() + .expect("client inner should not be removed"); + inner.jti += 1; + + let kid = inner.inner.get_process_id(); + let header = json!({"kid":kid}).to_string(); + + let mut payload = serde_json::from_slice::>(payload) + .map_err(HttpConnError::JwtPayloadError)?; + payload.insert("jti".to_string(), Value::Number(inner.jti.into())); + let payload = Value::Object(payload).to_string(); + + debug!( + kid, + jti = inner.jti, + ?header, + ?payload, + "signing new ephemeral JWT" + ); + + let token = sign_jwt(&inner.key, header, payload); + + // initiates the auth session + inner.inner.simple_query("discard all").await?; + inner + .inner + .query( + "select auth.jwt_session_init($1)", + &[&token as &(dyn ToSql + Sync)], + ) + .await?; + + info!(kid, jti = inner.jti, "user session state init"); + + Ok(()) + } +} + +fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String { + let header = Base64UrlUnpadded::encode_string(header.as_bytes()); + let payload = Base64UrlUnpadded::encode_string(payload.as_bytes()); + + let message = format!("{header}.{payload}"); + let sig: Signature = sk.sign(message.as_bytes()); + let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes()); + format!("{message}.{base64_sig}") +} + +impl Discard<'_, C> { + pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { + let conn_info = &self.conn_info; + if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { + info!( + "local_pool: throwing away connection '{conn_info}' because connection is not idle" + ); + } + } + pub(crate) fn discard(&mut self) { + let conn_info = &self.conn_info; + if std::mem::take(self.pool).strong_count() > 0 { + info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); + } + } +} + +impl LocalClient { + pub fn get_client(&self) -> &C { + &self + .inner + .as_ref() + .expect("client inner should not be removed") + .inner + } + + fn do_drop(&mut self) -> Option { + let conn_info = self.conn_info.clone(); + let client = self + .inner + .take() + .expect("client inner should not be removed"); + if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { + let current_span = self.span.clone(); + // return connection to the pool + return Some(move || { + let _span = current_span.enter(); + EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client); + }); + } + None + } +} + +impl Drop for LocalClient { + fn drop(&mut self) { + if let Some(drop) = self.do_drop() { + tokio::task::spawn_blocking(drop); + } + } +} diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 5987776c2824..9be6b592bd65 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -8,6 +8,7 @@ mod conn_pool; mod http_conn_pool; mod http_util; mod json; +mod local_conn_pool; mod sql_over_http; mod websocket; @@ -63,6 +64,7 @@ pub async fn task_main( info!("websocket server has shut down"); } + let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config); let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config); { let conn_pool = Arc::clone(&conn_pool); @@ -105,6 +107,7 @@ pub async fn task_main( let backend = Arc::new(PoolingBackend { http_conn_pool: Arc::clone(&http_conn_pool), + local_pool, pool: Arc::clone(&conn_pool), config, endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 34c19157e638..f7c3b26917d7 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -40,7 +40,7 @@ use url::Url; use urlencoding; use utils::http::error::ApiError; -use crate::auth::backend::ComputeCredentials; +use crate::auth::backend::ComputeCredentialKeys; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; use crate::auth::ComputeUserInfoParseError; @@ -56,20 +56,22 @@ use crate::metrics::Metrics; use crate::proxy::run_until_cancelled; use crate::proxy::NeonOptions; use crate::serverless::backend::HttpConnError; +use crate::usage_metrics::MetricCounter; use crate::usage_metrics::MetricCounterRecorder; use crate::DbName; use crate::RoleName; use super::backend::LocalProxyConnError; use super::backend::PoolingBackend; +use super::conn_pool; use super::conn_pool::AuthData; -use super::conn_pool::Client; use super::conn_pool::ConnInfo; use super::conn_pool::ConnInfoWithAuth; use super::http_util::json_response; use super::json::json_to_pg_text; use super::json::pg_text_row_to_json; use super::json::JsonConversionError; +use super::local_conn_pool; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] @@ -620,6 +622,9 @@ async fn handle_db_inner( let authenticate_and_connect = Box::pin( async { + let is_local_proxy = + matches!(backend.config.auth_backend, crate::auth::Backend::Local(_)); + let keys = match auth { AuthData::Password(pw) => { backend @@ -639,18 +644,24 @@ async fn handle_db_inner( &conn_info.user_info, jwt, ) - .await?; + .await? + } + }; - ComputeCredentials { - info: conn_info.user_info.clone(), - keys: crate::auth::backend::ComputeCredentialKeys::None, - } + let client = match keys.keys { + ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => { + let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?; + client.set_jwt_session(&payload).await?; + Client::Local(client) + } + _ => { + let client = backend + .connect_to_compute(ctx, conn_info, keys, !allow_pool) + .await?; + Client::Remote(client) } }; - let client = backend - .connect_to_compute(ctx, conn_info, keys, !allow_pool) - .await?; // not strictly necessary to mark success here, // but it's just insurance for if we forget it somewhere else ctx.success(); @@ -791,7 +802,7 @@ impl QueryData { self, config: &'static ProxyConfig, cancel: CancellationToken, - client: &mut Client, + client: &mut Client, parsed_headers: HttpHeaders, ) -> Result { let (inner, mut discard) = client.inner(); @@ -865,7 +876,7 @@ impl BatchQueryData { self, config: &'static ProxyConfig, cancel: CancellationToken, - client: &mut Client, + client: &mut Client, parsed_headers: HttpHeaders, ) -> Result { info!("starting transaction"); @@ -1058,3 +1069,50 @@ async fn query_to_json( Ok((ready, results)) } + +enum Client { + Remote(conn_pool::Client), + Local(local_conn_pool::LocalClient), +} + +enum Discard<'a> { + Remote(conn_pool::Discard<'a, tokio_postgres::Client>), + Local(local_conn_pool::Discard<'a, tokio_postgres::Client>), +} + +impl Client { + fn metrics(&self) -> Arc { + match self { + Client::Remote(client) => client.metrics(), + Client::Local(local_client) => local_client.metrics(), + } + } + + fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) { + match self { + Client::Remote(client) => { + let (c, d) = client.inner(); + (c, Discard::Remote(d)) + } + Client::Local(local_client) => { + let (c, d) = local_client.inner(); + (c, Discard::Local(d)) + } + } + } +} + +impl Discard<'_> { + fn check_idle(&mut self, status: ReadyForQueryStatus) { + match self { + Discard::Remote(discard) => discard.check_idle(status), + Discard::Local(discard) => discard.check_idle(status), + } + } + fn discard(&mut self) { + match self { + Discard::Remote(discard) => discard.discard(), + Discard::Local(discard) => discard.discard(), + } + } +} diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 5bed02df4e11..0a90b6b6f763 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -58,6 +58,7 @@ num-integer = { version = "0.1", features = ["i128"] } num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } parquet = { version = "53", default-features = false, features = ["zstd"] } +postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2", default-features = false, features = ["with-serde_json-1"] } prost = { version = "0.13", features = ["prost-derive"] } rand = { version = "0.8", features = ["small_rng"] } regex = { version = "1" } @@ -66,7 +67,7 @@ regex-syntax = { version = "0.8" } reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] } scopeguard = { version = "1" } serde = { version = "1", features = ["alloc", "derive"] } -serde_json = { version = "1", features = ["raw_value"] } +serde_json = { version = "1", features = ["alloc", "raw_value"] } sha2 = { version = "0.10", features = ["asm", "oid"] } signature = { version = "2", default-features = false, features = ["digest", "rand_core", "std"] } smallvec = { version = "1", default-features = false, features = ["const_new", "write"] } @@ -76,6 +77,7 @@ sync_wrapper = { version = "0.1", default-features = false, features = ["futures tikv-jemalloc-sys = { version = "0.5" } time = { version = "0.3", features = ["macros", "serde-well-known"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] } +tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2", features = ["with-serde_json-1"] } tokio-stream = { version = "0.1", features = ["net"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] } toml_edit = { version = "0.22", features = ["serde"] }