From 0da5564a67feee0c9f81052597f417601ffb69ac Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 3 Dec 2024 16:06:18 +0000 Subject: [PATCH] chore(proxy): enforce single host+port --- libs/proxy/tokio-postgres2/src/config.rs | 41 ++++----------- libs/proxy/tokio-postgres2/src/connect.rs | 38 ++++---------- proxy/src/auth/backend/console_redirect.rs | 8 +-- proxy/src/auth/backend/local.rs | 7 +-- proxy/src/compute.rs | 61 ++++++---------------- proxy/src/control_plane/client/mock.rs | 10 ++-- proxy/src/control_plane/client/neon.rs | 4 +- proxy/src/proxy/connect_compute.rs | 2 +- proxy/src/proxy/tests/mitm.rs | 4 +- proxy/src/proxy/tests/mod.rs | 14 ++--- proxy/src/serverless/backend.rs | 10 ++-- 11 files changed, 58 insertions(+), 141 deletions(-) diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 5dad835c3bdd..fd10ef6f207d 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -146,6 +146,9 @@ pub enum AuthKeys { /// ``` #[derive(Clone, PartialEq, Eq)] pub struct Config { + pub(crate) host: Host, + pub(crate) port: u16, + pub(crate) user: Option, pub(crate) password: Option>, pub(crate) auth_keys: Option>, @@ -153,8 +156,6 @@ pub struct Config { pub(crate) options: Option, pub(crate) application_name: Option, pub(crate) ssl_mode: SslMode, - pub(crate) host: Vec, - pub(crate) port: Vec, pub(crate) connect_timeout: Option, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, @@ -162,16 +163,12 @@ pub struct Config { pub(crate) max_backend_message_size: Option, } -impl Default for Config { - fn default() -> Config { - Config::new() - } -} - impl Config { /// Creates a new configuration. - pub fn new() -> Config { + pub fn new(host: String, port: u16) -> Config { Config { + host: Host::Tcp(host), + port, user: None, password: None, auth_keys: None, @@ -179,8 +176,6 @@ impl Config { options: None, application_name: None, ssl_mode: SslMode::Prefer, - host: vec![], - port: vec![], connect_timeout: None, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, @@ -283,32 +278,14 @@ impl Config { self.ssl_mode } - /// Adds a host to the configuration. - /// - /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. - pub fn host(&mut self, host: &str) -> &mut Config { - self.host.push(Host::Tcp(host.to_string())); - self - } - /// Gets the hosts that have been added to the configuration with `host`. - pub fn get_hosts(&self) -> &[Host] { + pub fn get_host(&self) -> &Host { &self.host } - /// Adds a port to the configuration. - /// - /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which - /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports - /// as hosts. - pub fn port(&mut self, port: u16) -> &mut Config { - self.port.push(port); - self - } - /// Gets the ports that have been added to the configuration with `port`. - pub fn get_ports(&self) -> &[u16] { - &self.port + pub fn get_port(&self) -> u16 { + self.port } /// Sets the timeout applied to socket-level connection attempts. diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 98067d91f942..75a58e6eacc9 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -19,38 +19,18 @@ pub async fn connect( where T: MakeTlsConnect, { - if config.host.is_empty() { - return Err(Error::config("host missing".into())); - } - - if config.port.len() > 1 && config.port.len() != config.host.len() { - return Err(Error::config("invalid number of ports".into())); - } - - let mut error = None; - for (i, host) in config.host.iter().enumerate() { - let port = config - .port - .get(i) - .or_else(|| config.port.first()) - .copied() - .unwrap_or(5432); - - let hostname = match host { - Host::Tcp(host) => host.as_str(), - }; + let hostname = match &config.host { + Host::Tcp(host) => host.as_str(), + }; - let tls = tls - .make_tls_connect(hostname) - .map_err(|e| Error::tls(e.into()))?; + let tls = tls + .make_tls_connect(hostname) + .map_err(|e| Error::tls(e.into()))?; - match connect_once(host, port, tls, config).await { - Ok((client, connection)) => return Ok((client, connection)), - Err(e) => error = Some(e), - } + match connect_once(&config.host, config.port, tls, config).await { + Ok((client, connection)) => Ok((client, connection)), + Err(e) => Err(e), } - - Err(error.unwrap()) } async fn connect_once( diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index 494564de05f0..6c2340c9ea74 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -151,12 +151,8 @@ async fn authenticate( // This config should be self-contained, because we won't // take username or dbname from client's startup message. - let mut config = compute::ConnCfg::new(); - config - .host(&db_info.host) - .port(db_info.port) - .dbname(&db_info.dbname) - .user(&db_info.user); + let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port); + config.dbname(&db_info.dbname).user(&db_info.user); ctx.set_dbname(db_info.dbname.into()); ctx.set_user(db_info.user.into()); diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index 32e0f536153d..d4273fb52167 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -29,12 +29,7 @@ impl LocalBackend { api: http::Endpoint::new(compute_ctl, http::new_client()), }, node_info: NodeInfo { - config: { - let mut cfg = ConnCfg::new(); - cfg.host(&postgres_addr.ip().to_string()); - cfg.port(postgres_addr.port()); - cfg - }, + config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()), // TODO(conrad): make this better reflect compute info rather than endpoint info. aux: MetricsAuxInfo { endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"), diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 06bc71c55988..ab0ff4b7950a 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -104,13 +104,13 @@ pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>; /// A config for establishing a connection to compute node. /// Eventually, `postgres_client` will be replaced with something better. /// Newtype allows us to implement methods on top of it. -#[derive(Clone, Default)] +#[derive(Clone)] pub(crate) struct ConnCfg(Box); /// Creation and initialization routines. impl ConnCfg { - pub(crate) fn new() -> Self { - Self::default() + pub(crate) fn new(host: String, port: u16) -> Self { + Self(Box::new(postgres_client::Config::new(host, port))) } /// Reuse password or auth keys from the other config. @@ -124,13 +124,9 @@ impl ConnCfg { } } - pub(crate) fn get_host(&self) -> Result { - match self.0.get_hosts() { - [postgres_client::config::Host::Tcp(s)] => Ok(s.into()), - // we should not have multiple address or unix addresses. - _ => Err(WakeComputeError::BadComputeAddress( - "invalid compute address".into(), - )), + pub(crate) fn get_host(&self) -> Host { + match self.0.get_host() { + postgres_client::config::Host::Tcp(s) => s.into(), } } @@ -227,43 +223,20 @@ impl ConnCfg { // We can't reuse connection establishing logic from `postgres_client` here, // because it has no means for extracting the underlying socket which we // require for our business. - let mut connection_error = None; - let ports = self.0.get_ports(); - let hosts = self.0.get_hosts(); - // the ports array is supposed to have 0 entries, 1 entry, or as many entries as in the hosts array - if ports.len() > 1 && ports.len() != hosts.len() { - return Err(io::Error::new( - io::ErrorKind::Other, - format!( - "bad compute config, \ - ports and hosts entries' count does not match: {:?}", - self.0 - ), - )); - } + let port = self.0.get_port(); + let host = self.0.get_host(); - for (i, host) in hosts.iter().enumerate() { - let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432); - let host = match host { - Host::Tcp(host) => host.as_str(), - }; - - match connect_once(host, *port).await { - Ok((sockaddr, stream)) => return Ok((sockaddr, stream, host)), - Err(err) => { - // We can't throw an error here, as there might be more hosts to try. - warn!("couldn't connect to compute node at {host}:{port}: {err}"); - connection_error = Some(err); - } + let host = match host { + Host::Tcp(host) => host.as_str(), + }; + + match connect_once(host, port).await { + Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)), + Err(err) => { + warn!("couldn't connect to compute node at {host}:{port}: {err}"); + Err(err) } } - - Err(connection_error.unwrap_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - format!("bad compute config: {:?}", self.0), - ) - })) } } diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 4d55f96ca198..eaf692ab279b 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -160,11 +160,11 @@ impl MockControlPlane { } async fn do_wake_compute(&self) -> Result { - let mut config = compute::ConnCfg::new(); - config - .host(self.endpoint.host_str().unwrap_or("localhost")) - .port(self.endpoint.port().unwrap_or(5432)) - .ssl_mode(postgres_client::config::SslMode::Disable); + let mut config = compute::ConnCfg::new( + self.endpoint.host_str().unwrap_or("localhost").to_owned(), + self.endpoint.port().unwrap_or(5432), + ); + config.ssl_mode(postgres_client::config::SslMode::Disable); let node = NodeInfo { config, diff --git a/proxy/src/control_plane/client/neon.rs b/proxy/src/control_plane/client/neon.rs index 5a78ec9d32aa..5c204ae1d700 100644 --- a/proxy/src/control_plane/client/neon.rs +++ b/proxy/src/control_plane/client/neon.rs @@ -241,8 +241,8 @@ impl NeonControlPlaneClient { // Don't set anything but host and port! This config will be cached. // We'll set username and such later using the startup message. // TODO: add more type safety (in progress). - let mut config = compute::ConnCfg::new(); - config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes. + let mut config = compute::ConnCfg::new(host.to_owned(), port); + config.ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes. let node = NodeInfo { config, diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 2e759b0894a2..585dce7baeb8 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -86,7 +86,7 @@ impl ConnectMechanism for TcpMechanism<'_> { node_info: &control_plane::CachedNodeInfo, timeout: time::Duration, ) -> Result { - let host = node_info.config.get_host()?; + let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; permit.release_result(node_info.connect(ctx, timeout).await) } diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index ef351f3b54b2..d72331c7bf78 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -158,7 +158,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { Scram::new("password").await?, )); - let _client_err = postgres_client::Config::new() + let _client_err = postgres_client::Config::new("test".to_owned(), 5432) .channel_binding(postgres_client::config::ChannelBinding::Disable) .user("user") .dbname("db") @@ -241,7 +241,7 @@ async fn connect_failure( Scram::new("password").await?, )); - let _client_err = postgres_client::Config::new() + let _client_err = postgres_client::Config::new("test".to_owned(), 5432) .channel_binding(channel_binding) .user("user") .dbname("db") diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index c8b742b3ff23..53345431e3cf 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -204,7 +204,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); - let client_err = postgres_client::Config::new() + let client_err = postgres_client::Config::new("test".to_owned(), 5432) .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Disable) @@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> { generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); - let _conn = postgres_client::Config::new() + let _conn = postgres_client::Config::new("test".to_owned(), 5432) .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Require) @@ -249,7 +249,7 @@ async fn handshake_raw() -> anyhow::Result<()> { let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); - let _conn = postgres_client::Config::new() + let _conn = postgres_client::Config::new("test".to_owned(), 5432) .user("john_doe") .dbname("earth") .options("project=generic-project-name") @@ -296,7 +296,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { Scram::new(password).await?, )); - let _conn = postgres_client::Config::new() + let _conn = postgres_client::Config::new("test".to_owned(), 5432) .channel_binding(postgres_client::config::ChannelBinding::Require) .user("user") .dbname("db") @@ -320,7 +320,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { Scram::new("password").await?, )); - let _conn = postgres_client::Config::new() + let _conn = postgres_client::Config::new("test".to_owned(), 5432) .channel_binding(postgres_client::config::ChannelBinding::Disable) .user("user") .dbname("db") @@ -348,7 +348,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> { .map(char::from) .collect(); - let _client_err = postgres_client::Config::new() + let _client_err = postgres_client::Config::new("test".to_owned(), 5432) .user("user") .dbname("db") .password(&password) // no password will match the mocked secret @@ -546,7 +546,7 @@ impl TestControlPlaneClient for TestConnectMechanism { fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = NodeInfo { - config: compute::ConnCfg::new(), + config: compute::ConnCfg::new("test".to_owned(), 5432), aux: MetricsAuxInfo { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 8c7931907da5..55d2e47fd3f2 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -499,7 +499,7 @@ impl ConnectMechanism for TokioMechanism { node_info: &CachedNodeInfo, timeout: Duration, ) -> Result { - let host = node_info.config.get_host()?; + let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; let mut config = (*node_info.config).clone(); @@ -549,16 +549,12 @@ impl ConnectMechanism for HyperMechanism { node_info: &CachedNodeInfo, timeout: Duration, ) -> Result { - let host = node_info.config.get_host()?; + let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let port = *node_info.config.get_ports().first().ok_or_else(|| { - HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress( - "local-proxy port missing on compute address".into(), - )) - })?; + let port = node_info.config.get_port(); let res = connect_http2(&host, port, timeout).await; drop(pause); let (client, connection) = permit.release_result(res)?;