diff --git a/libs/proxy/postgres-protocol2/src/message/frontend.rs b/libs/proxy/postgres-protocol2/src/message/frontend.rs index 9481e8186dd48..d49b4addf2f8e 100644 --- a/libs/proxy/postgres-protocol2/src/message/frontend.rs +++ b/libs/proxy/postgres-protocol2/src/message/frontend.rs @@ -255,10 +255,7 @@ pub fn ssl_request(buf: &mut BytesMut) { } #[inline] -pub fn startup_message( - parameters: &StartupMessageParams, - buf: &mut BytesMut, -) -> io::Result<()> { +pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> { write_body(buf, |buf| { // postgres protocol version 3.0(196608) in bigger-endian buf.put_i32(0x00_03_00_00); @@ -275,18 +272,14 @@ pub struct StartupMessageParams { impl StartupMessageParams { /// Set parameter's value by its name. - pub fn insert(&mut self, name: &str, value: &str) -> Result<(), io::Error> { + pub fn insert(&mut self, name: &str, value: &str) { if name.contains('\0') | value.contains('\0') { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "string contains embedded null", - )); + panic!("startup parameter name or value contained a null") } self.params.put(name.as_bytes()); self.params.put_u8(0); self.params.put(value.as_bytes()); self.params.put_u8(0); - Ok(()) } } diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index c820b3a3138fa..9003f24df87c4 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -69,8 +69,8 @@ pub struct Config { pub(crate) password: Option>, pub(crate) auth_keys: Option>, pub(crate) ssl_mode: SslMode, - pub(crate) host: Option, - pub(crate) port: Option, + pub(crate) host: Host, + pub(crate) port: u16, pub(crate) connect_timeout: Option, pub(crate) channel_binding: ChannelBinding, pub(crate) server_params: StartupMessageParams, @@ -79,21 +79,15 @@ pub struct Config { username: bool, } -impl Default for Config { - fn default() -> Config { - Config::new() - } -} - impl Config { /// Creates a new configuration. - pub fn new() -> Config { + pub fn new(host: String, port: u16) -> Config { Config { password: None, auth_keys: None, ssl_mode: SslMode::Prefer, - host: None, - port: None, + host: Host::Tcp(host), + port, connect_timeout: None, channel_binding: ChannelBinding::Prefer, server_params: StartupMessageParams::default(), @@ -165,9 +159,7 @@ impl Config { self.username = true; } - self.server_params - .insert(name, value) - .expect("name or value must not have null bytes"); + self.server_params.insert(name, value); self } @@ -184,32 +176,14 @@ impl Config { self.ssl_mode } - /// Adds a host to the configuration. - /// - /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. - pub fn host(&mut self, host: &str) -> &mut Config { - self.host = Some(Host::Tcp(host.to_string())); - self - } - /// Gets the hosts that have been added to the configuration with `host`. - pub fn get_hosts(&self) -> &Option { + pub fn get_host(&self) -> &Host { &self.host } - /// Adds a port to the configuration. - /// - /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which - /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports - /// as hosts. - pub fn port(&mut self, port: u16) -> &mut Config { - self.port = Some(port); - self - } - /// Gets the ports that have been added to the configuration with `port`. - pub fn get_ports(&self) -> &Option { - &self.port + pub fn get_port(&self) -> u16 { + self.port } /// Sets the timeout applied to socket-level connection attempts. diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index f18ba365a2b93..53216bfc4f9f0 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -16,11 +16,8 @@ pub async fn connect( where T: MakeTlsConnect, { - let Some(host) = &config.host else { - return Err(Error::config("host missing".into())); - }; - - let port = config.port.unwrap_or(5432); + let host = &config.host; + let port = config.port; let hostname = match host { Host::Tcp(host) => host.as_str(), diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index 7649775a88bca..9410d8138b698 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -120,9 +120,7 @@ where T: AsyncRead + AsyncWrite + Unpin, { let mut params = config.server_params.clone(); - params - .insert("client_encoding", "UTF8") - .expect("value does not contain null"); + params.insert("client_encoding", "UTF8"); let mut buf = BytesMut::new(); frontend::startup_message(¶ms, &mut buf).map_err(Error::encode)?; diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index 494564de05f09..6c2340c9ea748 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -151,12 +151,8 @@ async fn authenticate( // This config should be self-contained, because we won't // take username or dbname from client's startup message. - let mut config = compute::ConnCfg::new(); - config - .host(&db_info.host) - .port(db_info.port) - .dbname(&db_info.dbname) - .user(&db_info.user); + let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port); + config.dbname(&db_info.dbname).user(&db_info.user); ctx.set_dbname(db_info.dbname.into()); ctx.set_user(db_info.user.into()); diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index 32e0f536153df..d4273fb521670 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -29,12 +29,7 @@ impl LocalBackend { api: http::Endpoint::new(compute_ctl, http::new_client()), }, node_info: NodeInfo { - config: { - let mut cfg = ConnCfg::new(); - cfg.host(&postgres_addr.ip().to_string()); - cfg.port(postgres_addr.port()); - cfg - }, + config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()), // TODO(conrad): make this better reflect compute info rather than endpoint info. aux: MetricsAuxInfo { endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"), diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index bcb0ef40bd74d..7bc5587a25358 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -70,11 +70,12 @@ impl ReportableError for CancelError { impl CancellationHandler

{ /// Run async action within an ephemeral session identified by [`CancelKeyData`]. pub(crate) fn get_session(self: Arc) -> Session

{ - // HACK: We'd rather get the real backend_pid but postgres_client doesn't - // expose it and we don't want to do another roundtrip to query - // for it. The client will be able to notice that this is not the - // actual backend_pid, but backend_pid is not used for anything - // so it doesn't matter. + // we intentionally generate a random "backend pid" and "secret key" here. + // we use the corresponding u64 as an identifier for the + // actual endpoint+pid+secret for postgres/pgbouncer. + // + // if we forwarded the backend_pid from postgres to the client, there would be a lot + // of overlap between our computes as most pids are small (~100). let key = loop { let key = rand::random(); diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index ffdce637ab99c..fbeeeda3afb22 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -104,13 +104,13 @@ pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>; /// A config for establishing a connection to compute node. /// Eventually, `postgres_client` will be replaced with something better. /// Newtype allows us to implement methods on top of it. -#[derive(Clone, Default)] +#[derive(Clone)] pub(crate) struct ConnCfg(Box); /// Creation and initialization routines. impl ConnCfg { - pub(crate) fn new() -> Self { - Self::default() + pub(crate) fn new(host: String, port: u16) -> Self { + Self(Box::new(postgres_client::Config::new(host, port))) } /// Reuse password or auth keys from the other config. @@ -124,12 +124,9 @@ impl ConnCfg { } } - pub(crate) fn get_host(&self) -> Result { - match self.0.get_hosts() { - Some(postgres_client::config::Host::Tcp(s)) => Ok(s.into()), - None => Err(WakeComputeError::BadComputeAddress( - "invalid compute address".into(), - )), + pub(crate) fn get_host(&self) -> Host { + match self.0.get_host() { + postgres_client::config::Host::Tcp(s) => s.into(), } } @@ -209,18 +206,8 @@ impl ConnCfg { // We can't reuse connection establishing logic from `postgres_client` here, // because it has no means for extracting the underlying socket which we // require for our business. - let ports = self.0.get_ports(); - let hosts = self.0.get_hosts(); - - let Some(host) = hosts else { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("bad compute config: {:?}", self.0), - )); - }; - - let port = ports.unwrap_or(5432); - let host = match host { + let port = self.0.get_port(); + let host = match self.0.get_host() { Host::Tcp(host) => host.as_str(), }; diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 4d55f96ca198b..eaf692ab279bc 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -160,11 +160,11 @@ impl MockControlPlane { } async fn do_wake_compute(&self) -> Result { - let mut config = compute::ConnCfg::new(); - config - .host(self.endpoint.host_str().unwrap_or("localhost")) - .port(self.endpoint.port().unwrap_or(5432)) - .ssl_mode(postgres_client::config::SslMode::Disable); + let mut config = compute::ConnCfg::new( + self.endpoint.host_str().unwrap_or("localhost").to_owned(), + self.endpoint.port().unwrap_or(5432), + ); + config.ssl_mode(postgres_client::config::SslMode::Disable); let node = NodeInfo { config, diff --git a/proxy/src/control_plane/client/neon.rs b/proxy/src/control_plane/client/neon.rs index 5a78ec9d32aa8..5c204ae1d7003 100644 --- a/proxy/src/control_plane/client/neon.rs +++ b/proxy/src/control_plane/client/neon.rs @@ -241,8 +241,8 @@ impl NeonControlPlaneClient { // Don't set anything but host and port! This config will be cached. // We'll set username and such later using the startup message. // TODO: add more type safety (in progress). - let mut config = compute::ConnCfg::new(); - config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes. + let mut config = compute::ConnCfg::new(host.to_owned(), port); + config.ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes. let node = NodeInfo { config, diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 2e759b0894a2c..585dce7baeb8c 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -86,7 +86,7 @@ impl ConnectMechanism for TcpMechanism<'_> { node_info: &control_plane::CachedNodeInfo, timeout: time::Duration, ) -> Result { - let host = node_info.config.get_host()?; + let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; permit.release_result(node_info.connect(ctx, timeout).await) } diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index 4d38eb3270807..59c9ac27b838f 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -164,7 +164,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { Scram::new("password").await?, )); - let _client_err = postgres_client::Config::new() + let _client_err = postgres_client::Config::new("test".to_owned(), 5432) .channel_binding(postgres_client::config::ChannelBinding::Disable) .user("user") .dbname("db") @@ -247,7 +247,7 @@ async fn connect_failure( Scram::new("password").await?, )); - let _client_err = postgres_client::Config::new() + let _client_err = postgres_client::Config::new("test".to_owned(), 5432) .channel_binding(channel_binding) .user("user") .dbname("db") diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 4d24eaef68b3f..9f6c8431b4621 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -204,7 +204,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); - let client_err = postgres_client::Config::new() + let client_err = postgres_client::Config::new("test".to_owned(), 5432) .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Disable) @@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> { generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); - let _conn = postgres_client::Config::new() + let _conn = postgres_client::Config::new("test".to_owned(), 5432) .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Require) @@ -249,7 +249,7 @@ async fn handshake_raw() -> anyhow::Result<()> { let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); - let _conn = postgres_client::Config::new() + let _conn = postgres_client::Config::new("test".to_owned(), 5432) .user("john_doe") .dbname("earth") .set_param("options", "project=generic-project-name") @@ -296,7 +296,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { Scram::new(password).await?, )); - let _conn = postgres_client::Config::new() + let _conn = postgres_client::Config::new("test".to_owned(), 5432) .channel_binding(postgres_client::config::ChannelBinding::Require) .user("user") .dbname("db") @@ -320,7 +320,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { Scram::new("password").await?, )); - let _conn = postgres_client::Config::new() + let _conn = postgres_client::Config::new("test".to_owned(), 5432) .channel_binding(postgres_client::config::ChannelBinding::Disable) .user("user") .dbname("db") @@ -348,7 +348,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> { .map(char::from) .collect(); - let _client_err = postgres_client::Config::new() + let _client_err = postgres_client::Config::new("test".to_owned(), 5432) .user("user") .dbname("db") .password(&password) // no password will match the mocked secret @@ -546,7 +546,7 @@ impl TestControlPlaneClient for TestConnectMechanism { fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = NodeInfo { - config: compute::ConnCfg::new(), + config: compute::ConnCfg::new("localhost".to_owned(), 5432), aux: MetricsAuxInfo { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 91c131e912537..980383973a256 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -498,7 +498,7 @@ impl ConnectMechanism for TokioMechanism { node_info: &CachedNodeInfo, timeout: Duration, ) -> Result { - let host = node_info.config.get_host()?; + let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; let mut config = (*node_info.config).clone(); @@ -548,16 +548,12 @@ impl ConnectMechanism for HyperMechanism { node_info: &CachedNodeInfo, timeout: Duration, ) -> Result { - let host = node_info.config.get_host()?; + let host = node_info.config.get_host(); let permit = self.locks.get_permit(&host).await?; let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let port = node_info.config.get_ports().ok_or_else(|| { - HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress( - "local-proxy port missing on compute address".into(), - )) - })?; + let port = node_info.config.get_port(); let res = connect_http2(&host, port, timeout).await; drop(pause); let (client, connection) = permit.release_result(res)?;