From 82e3f0ecba8542cd5d1a95bf6d938aacbc073905 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 7 Nov 2024 16:24:38 +0000 Subject: [PATCH] [proxy/authorize]: improve JWKS reliability (#9676) While setting up some tests, I noticed that we didn't support keycloak. They make use of encryption JWKs as well as signature ones. Our current jwks crate does not support parsing encryption keys which caused the entire jwk set to fail to parse. Switching to lazy parsing fixes this. Also while setting up tests, I couldn't use localhost jwks server as we require HTTPS and we were using webpki so it was impossible to add a custom CA. Enabling native roots addresses this possibility. I saw some of our current e2e tests against our custom JWKS in s3 were taking a while to fetch. I've added a timeout + retries to address this. --- Cargo.lock | 1 + proxy/Cargo.toml | 2 +- proxy/src/auth/backend/jwt.rs | 162 +++++++++++++++++++++++-- proxy/src/http/mod.rs | 22 ++-- proxy/src/serverless/conn_pool_lib.rs | 3 +- proxy/src/serverless/http_conn_pool.rs | 1 - workspace_hack/Cargo.toml | 2 +- 7 files changed, 168 insertions(+), 25 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7d18f44aec6e..00d58be2d5d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4743,6 +4743,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "rustls 0.22.4", + "rustls-native-certs 0.7.0", "rustls-pemfile 2.1.1", "rustls-pki-types", "serde", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index efd336dbea2e..1665d6361a1d 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -60,7 +60,7 @@ prometheus.workspace = true rand.workspace = true regex.workspace = true remote_storage = { version = "0.1", path = "../libs/remote_storage/" } -reqwest.workspace = true +reqwest = { workspace = true, features = ["rustls-tls-native-roots"] } reqwest-middleware = { workspace = true, features = ["json"] } reqwest-retry.workspace = true reqwest-tracing.workspace = true diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 83c3617612fa..bfc674139bf2 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -7,8 +7,11 @@ use arc_swap::ArcSwapOption; use dashmap::DashMap; use jose_jwk::crypto::KeyInfo; use reqwest::{redirect, Client}; +use reqwest_retry::policies::ExponentialBackoff; +use reqwest_retry::RetryTransientMiddleware; use serde::de::Visitor; use serde::{Deserialize, Deserializer}; +use serde_json::value::RawValue; use signature::Verifier; use thiserror::Error; use tokio::time::Instant; @@ -16,7 +19,7 @@ use tokio::time::Instant; use crate::auth::backend::ComputeCredentialKeys; use crate::context::RequestMonitoring; use crate::control_plane::errors::GetEndpointJwksError; -use crate::http::parse_json_body_with_limit; +use crate::http::read_body_with_limit; use crate::intern::RoleNameInt; use crate::types::{EndpointId, RoleName}; @@ -28,6 +31,10 @@ const MAX_RENEW: Duration = Duration::from_secs(3600); const MAX_JWK_BODY_SIZE: usize = 64 * 1024; const JWKS_USER_AGENT: &str = "neon-proxy"; +const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2); +const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5); +const JWKS_FETCH_RETRIES: u32 = 3; + /// How to get the JWT auth rules pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static { fn fetch_auth_rules( @@ -55,7 +62,7 @@ pub(crate) struct AuthRule { } pub struct JwkCache { - client: reqwest::Client, + client: reqwest_middleware::ClientWithMiddleware, map: DashMap<(EndpointId, RoleName), Arc>, } @@ -117,6 +124,14 @@ impl Default for JwkCacheEntryLock { } } +#[derive(Deserialize)] +struct JwkSet<'a> { + /// we parse into raw-value because not all keys in a JWKS are ones + /// we can parse directly, so we parse them lazily. + #[serde(borrow)] + keys: Vec<&'a RawValue>, +} + impl JwkCacheEntryLock { async fn acquire_permit<'a>(self: &'a Arc) -> JwkRenewalPermit<'a> { JwkRenewalPermit::acquire_permit(self).await @@ -130,7 +145,7 @@ impl JwkCacheEntryLock { &self, _permit: JwkRenewalPermit<'_>, ctx: &RequestMonitoring, - client: &reqwest::Client, + client: &reqwest_middleware::ClientWithMiddleware, endpoint: EndpointId, auth_rules: &F, ) -> Result, JwtError> { @@ -154,22 +169,73 @@ impl JwkCacheEntryLock { let req = client.get(rule.jwks_url.clone()); // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`. // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only. - match req.send().await.and_then(|r| r.error_for_status()) { + match req.send().await.and_then(|r| { + r.error_for_status() + .map_err(reqwest_middleware::Error::Reqwest) + }) { // todo: should we re-insert JWKs if we want to keep this JWKs URL? // I expect these failures would be quite sparse. Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"), Ok(r) => { let resp: http::Response = r.into(); - match parse_json_body_with_limit::( - resp.into_body(), - MAX_JWK_BODY_SIZE, - ) - .await + + let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE) + .await { + Ok(bytes) => bytes, + Err(e) => { + tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); + continue; + } + }; + + match serde_json::from_slice::(&bytes) { Err(e) => { tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); } Ok(jwks) => { + // size_of::<&RawValue>() == 16 + // size_of::() == 288 + // better to not pre-allocate this as it might be pretty large - especially if it has many + // keys we don't want or need. + // trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}` + // this would consume 8MiB just like that! + let mut keys = vec![]; + let mut failed = 0; + for key in jwks.keys { + match serde_json::from_str::(key.get()) { + Ok(key) => { + // if `use` (called `cls` in rust) is specified to be something other than signing, + // we can skip storing it. + if key + .prm + .cls + .as_ref() + .is_some_and(|c| *c != jose_jwk::Class::Signing) + { + continue; + } + + keys.push(key); + } + Err(e) => { + tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK"); + failed += 1; + } + } + } + keys.shrink_to_fit(); + + if failed > 0 { + tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs"); + } + + if keys.is_empty() { + tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body"); + continue; + } + + let jwks = jose_jwk::JwkSet { keys }; key_sets.insert( rule.id, KeySet { @@ -179,7 +245,7 @@ impl JwkCacheEntryLock { }, ); } - } + }; } } } @@ -196,7 +262,7 @@ impl JwkCacheEntryLock { async fn get_or_update_jwk_cache( self: &Arc, ctx: &RequestMonitoring, - client: &reqwest::Client, + client: &reqwest_middleware::ClientWithMiddleware, endpoint: EndpointId, fetch: &F, ) -> Result, JwtError> { @@ -250,7 +316,7 @@ impl JwkCacheEntryLock { self: &Arc, ctx: &RequestMonitoring, jwt: &str, - client: &reqwest::Client, + client: &reqwest_middleware::ClientWithMiddleware, endpoint: EndpointId, role_name: &RoleName, fetch: &F, @@ -369,8 +435,19 @@ impl Default for JwkCache { let client = Client::builder() .user_agent(JWKS_USER_AGENT) .redirect(redirect::Policy::none()) + .tls_built_in_native_certs(true) + .connect_timeout(JWKS_CONNECT_TIMEOUT) + .timeout(JWKS_FETCH_TIMEOUT) .build() - .expect("using &str and standard redirect::Policy"); + .expect("client config should be valid"); + + // Retry up to 3 times with increasing intervals between attempts. + let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES); + + let client = reqwest_middleware::ClientBuilder::new(client) + .with(RetryTransientMiddleware::new_with_policy(retry_policy)) + .build(); + JwkCache { client, map: DashMap::default(), @@ -1209,4 +1286,63 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL } } } + + #[tokio::test] + async fn check_jwk_keycloak_regression() { + let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into()); + let valid_jwk = serde_json::to_value(valid_jwk).unwrap(); + + // This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones. + // This is taken directly from keycloak. + let invalid_jwk = serde_json::json! { + { + "kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ", + "kty": "RSA", + "alg": "RSA-OAEP", + "use": "enc", + "n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw", + "e": "AQAB", + "x5c": [ + "MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI=" + ], + "x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs", + "x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU" + } + }; + + let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }}; + let jwks_addr = jwks_server(move |path| match path { + "/" => Some(serde_json::to_vec(&jwks).unwrap()), + _ => None, + }) + .await; + + let role_name = RoleName::from("anonymous"); + let role = RoleNameInt::from(&role_name); + + let rules = vec![AuthRule { + id: "foo".to_owned(), + jwks_url: format!("http://{jwks_addr}/").parse().unwrap(), + audience: None, + role_names: vec![role], + }]; + + let fetch = Fetch(rules); + let jwk_cache = JwkCache::default(); + + let endpoint = EndpointId::from("ep"); + + let token = new_rsa_jwt("rs1".into(), rs); + + jwk_cache + .check_jwt( + &RequestMonitoring::test(), + endpoint.clone(), + &role_name, + &fetch, + &token, + ) + .await + .unwrap(); + } } diff --git a/proxy/src/http/mod.rs b/proxy/src/http/mod.rs index f1b632e70406..b1642cedb301 100644 --- a/proxy/src/http/mod.rs +++ b/proxy/src/http/mod.rs @@ -6,7 +6,6 @@ pub mod health_server; use std::time::Duration; -use anyhow::bail; use bytes::Bytes; use http::Method; use http_body_util::BodyExt; @@ -16,7 +15,7 @@ use reqwest_middleware::RequestBuilder; pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error}; pub(crate) use reqwest_retry::policies::ExponentialBackoff; pub(crate) use reqwest_retry::RetryTransientMiddleware; -use serde::de::DeserializeOwned; +use thiserror::Error; use crate::metrics::{ConsoleRequest, Metrics}; use crate::url::ApiUrl; @@ -122,10 +121,19 @@ impl Endpoint { } } -pub(crate) async fn parse_json_body_with_limit( +#[derive(Error, Debug)] +pub(crate) enum ReadBodyError { + #[error("Content length exceeds limit of {limit} bytes")] + BodyTooLarge { limit: usize }, + + #[error(transparent)] + Read(#[from] reqwest::Error), +} + +pub(crate) async fn read_body_with_limit( mut b: impl Body + Unpin, limit: usize, -) -> anyhow::Result { +) -> Result, ReadBodyError> { // We could use `b.limited().collect().await.to_bytes()` here // but this ends up being slightly more efficient as far as I can tell. @@ -133,20 +141,20 @@ pub(crate) async fn parse_json_body_with_limit( // in reqwest, this value is influenced by the Content-Length header. let lower_bound = match usize::try_from(b.size_hint().lower()) { Ok(bound) if bound <= limit => bound, - _ => bail!("Content length exceeds limit of {limit} bytes"), + _ => return Err(ReadBodyError::BodyTooLarge { limit }), }; let mut bytes = Vec::with_capacity(lower_bound); while let Some(frame) = b.frame().await.transpose()? { if let Ok(data) = frame.into_data() { if bytes.len() + data.len() > limit { - bail!("Content length exceeds limit of {limit} bytes") + return Err(ReadBodyError::BodyTooLarge { limit }); } bytes.extend_from_slice(&data); } } - Ok(serde_json::from_slice::(&bytes)?) + Ok(bytes) } #[cfg(test)] diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index 00a8ac47681d..61c39c32c942 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -16,8 +16,7 @@ use super::http_conn_pool::ClientDataHttp; use super::local_conn_pool::ClientDataLocal; use crate::auth::backend::ComputeUserInfo; use crate::context::RequestMonitoring; -use crate::control_plane::messages::ColdStartInfo; -use crate::control_plane::messages::MetricsAuxInfo; +use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::types::{DbName, EndpointCacheKey, RoleName}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 56be70abeca3..a1d4473b0146 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -7,7 +7,6 @@ use hyper::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use parking_lot::RwLock; use rand::Rng; -use std::result::Result::Ok; use tokio::net::TcpStream; use tracing::{debug, error, info, info_span, Instrument}; diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 02deecd385d8..ae4018a8849c 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -64,7 +64,7 @@ rand = { version = "0.8", features = ["small_rng"] } regex = { version = "1" } regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] } regex-syntax = { version = "0.8" } -reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] } +reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "rustls-tls-native-roots", "stream"] } rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std", "tls12"] } scopeguard = { version = "1" } serde = { version = "1", features = ["alloc", "derive"] }