Skip to content

Commit

Permalink
[proxy/authorize]: improve JWKS reliability (#9676)
Browse files Browse the repository at this point in the history
While setting up some tests, I noticed that we didn't support keycloak.
They make use of encryption JWKs as well as signature ones. Our current
jwks crate does not support parsing encryption keys which caused the
entire jwk set to fail to parse. Switching to lazy parsing fixes this.

Also while setting up tests, I couldn't use localhost jwks server as we
require HTTPS and we were using webpki so it was impossible to add a
custom CA. Enabling native roots addresses this possibility.

I saw some of our current e2e tests against our custom JWKS in s3 were
taking a while to fetch. I've added a timeout + retries to address this.
  • Loading branch information
conradludgate authored Nov 7, 2024
1 parent 75aa19a commit 82e3f0e
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 25 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ prometheus.workspace = true
rand.workspace = true
regex.workspace = true
remote_storage = { version = "0.1", path = "../libs/remote_storage/" }
reqwest.workspace = true
reqwest = { workspace = true, features = ["rustls-tls-native-roots"] }
reqwest-middleware = { workspace = true, features = ["json"] }
reqwest-retry.workspace = true
reqwest-tracing.workspace = true
Expand Down
162 changes: 149 additions & 13 deletions proxy/src/auth/backend/jwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@ use arc_swap::ArcSwapOption;
use dashmap::DashMap;
use jose_jwk::crypto::KeyInfo;
use reqwest::{redirect, Client};
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use serde::de::Visitor;
use serde::{Deserialize, Deserializer};
use serde_json::value::RawValue;
use signature::Verifier;
use thiserror::Error;
use tokio::time::Instant;

use crate::auth::backend::ComputeCredentialKeys;
use crate::context::RequestMonitoring;
use crate::control_plane::errors::GetEndpointJwksError;
use crate::http::parse_json_body_with_limit;
use crate::http::read_body_with_limit;
use crate::intern::RoleNameInt;
use crate::types::{EndpointId, RoleName};

Expand All @@ -28,6 +31,10 @@ const MAX_RENEW: Duration = Duration::from_secs(3600);
const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
const JWKS_USER_AGENT: &str = "neon-proxy";

const JWKS_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
const JWKS_FETCH_TIMEOUT: Duration = Duration::from_secs(5);
const JWKS_FETCH_RETRIES: u32 = 3;

/// How to get the JWT auth rules
pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
fn fetch_auth_rules(
Expand Down Expand Up @@ -55,7 +62,7 @@ pub(crate) struct AuthRule {
}

pub struct JwkCache {
client: reqwest::Client,
client: reqwest_middleware::ClientWithMiddleware,

map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
}
Expand Down Expand Up @@ -117,6 +124,14 @@ impl Default for JwkCacheEntryLock {
}
}

#[derive(Deserialize)]
struct JwkSet<'a> {
/// we parse into raw-value because not all keys in a JWKS are ones
/// we can parse directly, so we parse them lazily.
#[serde(borrow)]
keys: Vec<&'a RawValue>,
}

impl JwkCacheEntryLock {
async fn acquire_permit<'a>(self: &'a Arc<Self>) -> JwkRenewalPermit<'a> {
JwkRenewalPermit::acquire_permit(self).await
Expand All @@ -130,7 +145,7 @@ impl JwkCacheEntryLock {
&self,
_permit: JwkRenewalPermit<'_>,
ctx: &RequestMonitoring,
client: &reqwest::Client,
client: &reqwest_middleware::ClientWithMiddleware,
endpoint: EndpointId,
auth_rules: &F,
) -> Result<Arc<JwkCacheEntry>, JwtError> {
Expand All @@ -154,22 +169,73 @@ impl JwkCacheEntryLock {
let req = client.get(rule.jwks_url.clone());
// TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`.
// TODO(conrad): We need to filter out URLs that point to local resources. Public internet only.
match req.send().await.and_then(|r| r.error_for_status()) {
match req.send().await.and_then(|r| {
r.error_for_status()
.map_err(reqwest_middleware::Error::Reqwest)
}) {
// todo: should we re-insert JWKs if we want to keep this JWKs URL?
// I expect these failures would be quite sparse.
Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
Ok(r) => {
let resp: http::Response<reqwest::Body> = r.into();
match parse_json_body_with_limit::<jose_jwk::JwkSet>(
resp.into_body(),
MAX_JWK_BODY_SIZE,
)
.await

let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE)
.await
{
Ok(bytes) => bytes,
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
continue;
}
};

match serde_json::from_slice::<JwkSet>(&bytes) {
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
}
Ok(jwks) => {
// size_of::<&RawValue>() == 16
// size_of::<jose_jwk::Jwk>() == 288
// better to not pre-allocate this as it might be pretty large - especially if it has many
// keys we don't want or need.
// trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`
// this would consume 8MiB just like that!
let mut keys = vec![];
let mut failed = 0;
for key in jwks.keys {
match serde_json::from_str::<jose_jwk::Jwk>(key.get()) {
Ok(key) => {
// if `use` (called `cls` in rust) is specified to be something other than signing,
// we can skip storing it.
if key
.prm
.cls
.as_ref()
.is_some_and(|c| *c != jose_jwk::Class::Signing)
{
continue;
}

keys.push(key);
}
Err(e) => {
tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK");
failed += 1;
}
}
}
keys.shrink_to_fit();

if failed > 0 {
tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs");
}

if keys.is_empty() {
tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body");
continue;
}

let jwks = jose_jwk::JwkSet { keys };
key_sets.insert(
rule.id,
KeySet {
Expand All @@ -179,7 +245,7 @@ impl JwkCacheEntryLock {
},
);
}
}
};
}
}
}
Expand All @@ -196,7 +262,7 @@ impl JwkCacheEntryLock {
async fn get_or_update_jwk_cache<F: FetchAuthRules>(
self: &Arc<Self>,
ctx: &RequestMonitoring,
client: &reqwest::Client,
client: &reqwest_middleware::ClientWithMiddleware,
endpoint: EndpointId,
fetch: &F,
) -> Result<Arc<JwkCacheEntry>, JwtError> {
Expand Down Expand Up @@ -250,7 +316,7 @@ impl JwkCacheEntryLock {
self: &Arc<Self>,
ctx: &RequestMonitoring,
jwt: &str,
client: &reqwest::Client,
client: &reqwest_middleware::ClientWithMiddleware,
endpoint: EndpointId,
role_name: &RoleName,
fetch: &F,
Expand Down Expand Up @@ -369,8 +435,19 @@ impl Default for JwkCache {
let client = Client::builder()
.user_agent(JWKS_USER_AGENT)
.redirect(redirect::Policy::none())
.tls_built_in_native_certs(true)
.connect_timeout(JWKS_CONNECT_TIMEOUT)
.timeout(JWKS_FETCH_TIMEOUT)
.build()
.expect("using &str and standard redirect::Policy");
.expect("client config should be valid");

// Retry up to 3 times with increasing intervals between attempts.
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(JWKS_FETCH_RETRIES);

let client = reqwest_middleware::ClientBuilder::new(client)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();

JwkCache {
client,
map: DashMap::default(),
Expand Down Expand Up @@ -1209,4 +1286,63 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
}
}
}

#[tokio::test]
async fn check_jwk_keycloak_regression() {
let (rs, valid_jwk) = new_rsa_jwk(RS1, "rs1".into());
let valid_jwk = serde_json::to_value(valid_jwk).unwrap();

// This is valid, but we cannot parse it as we have no support for encryption JWKs, only signature based ones.
// This is taken directly from keycloak.
let invalid_jwk = serde_json::json! {
{
"kid": "U-Jc9xRli84eNqRpYQoIPF-GNuRWV3ZvAIhziRW2sbQ",
"kty": "RSA",
"alg": "RSA-OAEP",
"use": "enc",
"n": "yypYWsEKmM_wWdcPnSGLSm5ytw1WG7P7EVkKSulcDRlrM6HWj3PR68YS8LySYM2D9Z-79oAdZGKhIfzutqL8rK1vS14zDuPpAM-RWY3JuQfm1O_-1DZM8-07PmVRegP5KPxsKblLf_My8ByH6sUOIa1p2rbe2q_b0dSTXYu1t0dW-cGL5VShc400YymvTwpc-5uYNsaVxZajnB7JP1OunOiuCJ48AuVp3PqsLzgoXqlXEB1ZZdch3xT3bxaTtNruGvG4xmLZY68O_T3yrwTCNH2h_jFdGPyXdyZToCMSMK2qSbytlfwfN55pT9Vv42Lz1YmoB7XRjI9aExKPc5AxFw",
"e": "AQAB",
"x5c": [
"MIICmzCCAYMCBgGS41E6azANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDDAZtYXN0ZXIwHhcNMjQxMDMxMTYwMTQ0WhcNMzQxMDMxMTYwMzI0WjARMQ8wDQYDVQQDDAZtYXN0ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDLKlhawQqYz/BZ1w+dIYtKbnK3DVYbs/sRWQpK6VwNGWszodaPc9HrxhLwvJJgzYP1n7v2gB1kYqEh/O62ovysrW9LXjMO4+kAz5FZjcm5B+bU7/7UNkzz7Ts+ZVF6A/ko/GwpuUt/8zLwHIfqxQ4hrWnatt7ar9vR1JNdi7W3R1b5wYvlVKFzjTRjKa9PClz7m5g2xpXFlqOcHsk/U66c6K4InjwC5Wnc+qwvOCheqVcQHVll1yHfFPdvFpO02u4a8bjGYtljrw79PfKvBMI0faH+MV0Y/Jd3JlOgIxIwrapJvK2V/B83nmlP1W/jYvPViagHtdGMj1oTEo9zkDEXAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAECYX59+Q9v6c9sb6Q0/C6IgLWG2nVCgVE1YWwIzz+68WrhlmNCRuPjY94roB+tc2tdHbj+Nh3LMzJk7L1KCQoW1+LPK6A6E8W9ad0YPcuw8csV2pUA3+H56exQMH0fUAPQAU7tXWvnQ7otcpV1XA8afn/NTMTsnxi9mSkor8MLMYQ3aeRyh1+LAchHBthWiltqsSUqXrbJF59u5p0ghquuKcWR3TXsA7klGYBgGU5KAJifr9XT87rN0bOkGvbeWAgKvnQnjZwxdnLqTfp/pRY/PiJJHhgIBYPIA7STGnMPjmJ995i34zhnbnd8WHXJA3LxrIMqLW/l8eIdvtM1w8KI="
],
"x5t": "QhfzMMnuAfkReTgZ1HtrfyOeeZs",
"x5t#S256": "cmHDUdKgLiRCEN28D5FBy9IJLFmR7QWfm77SLhGTCTU"
}
};

let jwks = serde_json::json! {{ "keys": [invalid_jwk, valid_jwk ] }};
let jwks_addr = jwks_server(move |path| match path {
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
_ => None,
})
.await;

let role_name = RoleName::from("anonymous");
let role = RoleNameInt::from(&role_name);

let rules = vec![AuthRule {
id: "foo".to_owned(),
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
audience: None,
role_names: vec![role],
}];

let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();

let endpoint = EndpointId::from("ep");

let token = new_rsa_jwt("rs1".into(), rs);

jwk_cache
.check_jwt(
&RequestMonitoring::test(),
endpoint.clone(),
&role_name,
&fetch,
&token,
)
.await
.unwrap();
}
}
22 changes: 15 additions & 7 deletions proxy/src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ pub mod health_server;

use std::time::Duration;

use anyhow::bail;
use bytes::Bytes;
use http::Method;
use http_body_util::BodyExt;
Expand All @@ -16,7 +15,7 @@ use reqwest_middleware::RequestBuilder;
pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
pub(crate) use reqwest_retry::policies::ExponentialBackoff;
pub(crate) use reqwest_retry::RetryTransientMiddleware;
use serde::de::DeserializeOwned;
use thiserror::Error;

use crate::metrics::{ConsoleRequest, Metrics};
use crate::url::ApiUrl;
Expand Down Expand Up @@ -122,31 +121,40 @@ impl Endpoint {
}
}

pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
#[derive(Error, Debug)]
pub(crate) enum ReadBodyError {
#[error("Content length exceeds limit of {limit} bytes")]
BodyTooLarge { limit: usize },

#[error(transparent)]
Read(#[from] reqwest::Error),
}

pub(crate) async fn read_body_with_limit(
mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
limit: usize,
) -> anyhow::Result<D> {
) -> Result<Vec<u8>, ReadBodyError> {
// We could use `b.limited().collect().await.to_bytes()` here
// but this ends up being slightly more efficient as far as I can tell.

// check the lower bound of the size hint.
// in reqwest, this value is influenced by the Content-Length header.
let lower_bound = match usize::try_from(b.size_hint().lower()) {
Ok(bound) if bound <= limit => bound,
_ => bail!("Content length exceeds limit of {limit} bytes"),
_ => return Err(ReadBodyError::BodyTooLarge { limit }),
};
let mut bytes = Vec::with_capacity(lower_bound);

while let Some(frame) = b.frame().await.transpose()? {
if let Ok(data) = frame.into_data() {
if bytes.len() + data.len() > limit {
bail!("Content length exceeds limit of {limit} bytes")
return Err(ReadBodyError::BodyTooLarge { limit });
}
bytes.extend_from_slice(&data);
}
}

Ok(serde_json::from_slice::<D>(&bytes)?)
Ok(bytes)
}

#[cfg(test)]
Expand Down
3 changes: 1 addition & 2 deletions proxy/src/serverless/conn_pool_lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ use super::http_conn_pool::ClientDataHttp;
use super::local_conn_pool::ClientDataLocal;
use crate::auth::backend::ComputeUserInfo;
use crate::context::RequestMonitoring;
use crate::control_plane::messages::ColdStartInfo;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::types::{DbName, EndpointCacheKey, RoleName};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
Expand Down
1 change: 0 additions & 1 deletion proxy/src/serverless/http_conn_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use std::result::Result::Ok;
use tokio::net::TcpStream;
use tracing::{debug, error, info, info_span, Instrument};

Expand Down
2 changes: 1 addition & 1 deletion workspace_hack/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ rand = { version = "0.8", features = ["small_rng"] }
regex = { version = "1" }
regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] }
regex-syntax = { version = "0.8" }
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] }
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "rustls-tls-native-roots", "stream"] }
rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std", "tls12"] }
scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] }
Expand Down

1 comment on commit 82e3f0e

@github-actions
Copy link

@github-actions github-actions bot commented on 82e3f0e Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5480 tests run: 5248 passed, 0 failed, 232 skipped (full report)


Flaky tests (7)

Postgres 17

Postgres 14

Code coverage* (full report)

  • functions: 31.7% (7863 of 24802 functions)
  • lines: 49.4% (62211 of 125944 lines)

* collected from Rust tests only


The comment gets automatically updated with the latest test results
82e3f0e at 2024-11-07T20:39:41.929Z :recycle:

Please sign in to comment.