Skip to content

Commit

Permalink
proxy: refactor auth backends (#9271)
Browse files Browse the repository at this point in the history
preliminary for #9270 

The auth::Backend didn't need to be in the mega ProxyConfig object, so I
split it off and passed it manually in the few places it was necessary.

I've also refined some of the uses of config I saw while doing this
small refactor.

I've also followed the trend and make the console redirect backend it's
own struct, same as LocalBackend and ControlPlaneBackend.
  • Loading branch information
conradludgate authored Oct 11, 2024
1 parent 5ef805e commit ab5bbb4
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 157 deletions.
25 changes: 24 additions & 1 deletion proxy/src/auth/backend/console_redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ pub(crate) enum WebAuthError {
Io(#[from] std::io::Error),
}

pub struct ConsoleRedirectBackend {
console_uri: reqwest::Url,
}

impl UserFacingError for WebAuthError {
fn to_string_client(&self) -> String {
"Internal error".to_string()
Expand Down Expand Up @@ -57,7 +61,26 @@ pub(crate) fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}

pub(super) async fn authenticate(
impl ConsoleRedirectBackend {
pub fn new(console_uri: reqwest::Url) -> Self {
Self { console_uri }
}

pub(super) fn url(&self) -> &reqwest::Url {
&self.console_uri
}

pub(crate) async fn authenticate(
&self,
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
authenticate(ctx, auth_config, &self.console_uri, client).await
}
}

async fn authenticate(
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,
Expand Down
19 changes: 9 additions & 10 deletions proxy/src/auth/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;

pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::WebAuthError;
use ipnet::{Ipv4Net, Ipv6Net};
use local::LocalBackend;
Expand Down Expand Up @@ -36,7 +37,7 @@ use crate::{
provider::{CachedAllowedIps, CachedNodeInfo},
Api,
},
stream, url,
stream,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};

Expand Down Expand Up @@ -69,7 +70,7 @@ pub enum Backend<'a, T, D> {
/// Cloud API (V2).
ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T),
/// Authentication via a web browser.
ConsoleRedirect(MaybeOwned<'a, url::ApiUrl>, D),
ConsoleRedirect(MaybeOwned<'a, ConsoleRedirectBackend>, D),
/// Local proxy uses configured auth credentials and does not wake compute
Local(MaybeOwned<'a, LocalBackend>),
}
Expand Down Expand Up @@ -106,9 +107,9 @@ impl std::fmt::Display for Backend<'_, (), ()> {
#[cfg(test)]
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
},
Self::ConsoleRedirect(url, ()) => fmt
Self::ConsoleRedirect(backend, ()) => fmt
.debug_tuple("ConsoleRedirect")
.field(&url.as_str())
.field(&backend.url().as_str())
.finish(),
Self::Local(_) => fmt.debug_tuple("Local").finish(),
}
Expand Down Expand Up @@ -241,7 +242,6 @@ impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
secret: AuthSecret,
endpoint: &EndpointId,
is_cleartext: bool,
Expand All @@ -265,7 +265,7 @@ impl AuthenticationConfig {
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
),
password_weight,
);
Expand Down Expand Up @@ -339,7 +339,6 @@ async fn auth_quirks(
let secret = if let Some(secret) = secret {
config.check_rate_limit(
ctx,
config,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
Expand Down Expand Up @@ -456,12 +455,12 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> {
Backend::ControlPlane(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
Self::ConsoleRedirect(url, ()) => {
Self::ConsoleRedirect(backend, ()) => {
info!("performing web authentication");

let info = console_redirect::authenticate(ctx, config, &url, client).await?;
let info = backend.authenticate(ctx, config, client).await?;

Backend::ConsoleRedirect(url, info)
Backend::ConsoleRedirect(backend, info)
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
Expand Down
25 changes: 19 additions & 6 deletions proxy/src/bin/local_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ use compute_api::spec::LocalProxySpec;
use dashmap::DashMap;
use futures::future::Either;
use proxy::{
auth::backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
auth::{
self,
backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
Expand Down Expand Up @@ -132,6 +135,7 @@ async fn main() -> anyhow::Result<()> {

let args = LocalProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;

// before we bind to any ports, write the process ID to a file
// so that compute-ctl can find our process later
Expand Down Expand Up @@ -193,6 +197,7 @@ async fn main() -> anyhow::Result<()> {

let task = serverless::task_main(
config,
auth_backend,
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
Expand Down Expand Up @@ -257,9 +262,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig

Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.compute),
)),
metric_collection: None,
allow_self_signed_compute: false,
http_config,
Expand All @@ -286,6 +288,17 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
})))
}

/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &LocalProxyCliArgs,
) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> {
let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.compute),
));

Ok(Box::leak(Box::new(auth_backend)))
}

async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
loop {
rx.notified().await;
Expand Down
150 changes: 80 additions & 70 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::ConsoleRedirectBackend;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
Expand Down Expand Up @@ -311,8 +312,9 @@ async fn main() -> anyhow::Result<()> {

let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;

info!("Authentication backend: {}", config.auth_backend);
info!("Authentication backend: {}", auth_backend);
info!("Using region: {}", args.aws_region);

let region_provider =
Expand Down Expand Up @@ -462,6 +464,7 @@ async fn main() -> anyhow::Result<()> {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
Expand All @@ -472,6 +475,7 @@ async fn main() -> anyhow::Result<()> {
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
Expand Down Expand Up @@ -506,7 +510,7 @@ async fn main() -> anyhow::Result<()> {
));
}

if let auth::Backend::ControlPlane(api, _) = &config.auth_backend {
if let auth::Backend::ControlPlane(api, _) = auth_backend {
if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
Expand Down Expand Up @@ -610,73 +614,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
bail!("dynamic rate limiter should be disabled");
}

let auth_backend = match &args.auth_backend {
AuthBackendType::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;

info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));

let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse()?;
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)?));
tokio::spawn(locks.garbage_collect_worker());

let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());

let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let api = control_plane::provider::neon::Api::new(
endpoint,
caches,
locks,
wake_compute_endpoint_rate_limiter,
);
let api = control_plane::provider::ControlPlaneBackend::Management(api);
auth::Backend::ControlPlane(MaybeOwned::Owned(api), ())
}

AuthBackendType::Web => {
let url = args.uri.parse()?;
auth::Backend::ConsoleRedirect(MaybeOwned::Owned(url), ())
}

#[cfg(feature = "testing")]
AuthBackendType::Postgres => {
let url = args.auth_endpoint.parse()?;
let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy);
let api = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
auth::Backend::ControlPlane(MaybeOwned::Owned(api), ())
}
};

let config::ConcurrencyLockOptions {
shards,
limiter,
Expand Down Expand Up @@ -728,7 +665,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {

let config = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend,
metric_collection,
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
Expand All @@ -748,6 +684,80 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
Ok(config)
}

/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &ProxyCliArgs,
) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> {
let auth_backend = match &args.auth_backend {
AuthBackendType::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;

info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));

let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse()?;
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)?));
tokio::spawn(locks.garbage_collect_worker());

let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());

let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let api = control_plane::provider::neon::Api::new(
endpoint,
caches,
locks,
wake_compute_endpoint_rate_limiter,
);
let api = control_plane::provider::ControlPlaneBackend::Management(api);
auth::Backend::ControlPlane(MaybeOwned::Owned(api), ())
}

AuthBackendType::Web => {
let url = args.uri.parse()?;
auth::Backend::ConsoleRedirect(MaybeOwned::Owned(ConsoleRedirectBackend::new(url)), ())
}

#[cfg(feature = "testing")]
AuthBackendType::Postgres => {
let url = args.auth_endpoint.parse()?;
let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy);
let api = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
auth::Backend::ControlPlane(MaybeOwned::Owned(api), ())
}
};

Ok(Box::leak(Box::new(auth_backend)))
}

#[cfg(test)]
mod tests {
use std::time::Duration;
Expand Down
Loading

1 comment on commit ab5bbb4

@github-actions
Copy link

@github-actions github-actions bot commented on ab5bbb4 Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5269 tests run: 5052 passed, 0 failed, 217 skipped (full report)


Flaky tests (4)

Postgres 16

Postgres 15

Postgres 14

Code coverage* (full report)

  • functions: 31.4% (7546 of 24022 functions)
  • lines: 49.3% (60358 of 122549 lines)

* collected from Rust tests only


The comment gets automatically updated with the latest test results
ab5bbb4 at 2024-10-14T12:10:54.566Z :recycle:

Please sign in to comment.