Skip to content

Commit

Permalink
require exactly 1 host and 1 port
Browse files Browse the repository at this point in the history
  • Loading branch information
conradludgate committed Dec 3, 2024
1 parent 6924248 commit df6a6b1
Show file tree
Hide file tree
Showing 14 changed files with 52 additions and 115 deletions.
13 changes: 3 additions & 10 deletions libs/proxy/postgres-protocol2/src/message/frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,7 @@ pub fn ssl_request(buf: &mut BytesMut) {
}

#[inline]
pub fn startup_message(
parameters: &StartupMessageParams,
buf: &mut BytesMut,
) -> io::Result<()> {
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
write_body(buf, |buf| {
// postgres protocol version 3.0(196608) in bigger-endian
buf.put_i32(0x00_03_00_00);
Expand All @@ -275,18 +272,14 @@ pub struct StartupMessageParams {

impl StartupMessageParams {
/// Set parameter's value by its name.
pub fn insert(&mut self, name: &str, value: &str) -> Result<(), io::Error> {
pub fn insert(&mut self, name: &str, value: &str) {
if name.contains('\0') | value.contains('\0') {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
panic!("startup parameter name or value contained a null")
}
self.params.put(name.as_bytes());
self.params.put_u8(0);
self.params.put(value.as_bytes());
self.params.put_u8(0);
Ok(())
}
}

Expand Down
44 changes: 9 additions & 35 deletions libs/proxy/tokio-postgres2/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ pub struct Config {
pub(crate) password: Option<Vec<u8>>,
pub(crate) auth_keys: Option<Box<AuthKeys>>,
pub(crate) ssl_mode: SslMode,
pub(crate) host: Option<Host>,
pub(crate) port: Option<u16>,
pub(crate) host: Host,
pub(crate) port: u16,
pub(crate) connect_timeout: Option<Duration>,
pub(crate) channel_binding: ChannelBinding,
pub(crate) server_params: StartupMessageParams,
Expand All @@ -79,21 +79,15 @@ pub struct Config {
username: bool,
}

impl Default for Config {
fn default() -> Config {
Config::new()
}
}

impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
pub fn new(host: String, port: u16) -> Config {
Config {
password: None,
auth_keys: None,
ssl_mode: SslMode::Prefer,
host: None,
port: None,
host: Host::Tcp(host),
port,
connect_timeout: None,
channel_binding: ChannelBinding::Prefer,
server_params: StartupMessageParams::default(),
Expand Down Expand Up @@ -165,9 +159,7 @@ impl Config {
self.username = true;
}

self.server_params
.insert(name, value)
.expect("name or value must not have null bytes");
self.server_params.insert(name, value);
self
}

Expand All @@ -184,32 +176,14 @@ impl Config {
self.ssl_mode
}

/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order.
pub fn host(&mut self, host: &str) -> &mut Config {
self.host = Some(Host::Tcp(host.to_string()));
self
}

/// Gets the hosts that have been added to the configuration with `host`.
pub fn get_hosts(&self) -> &Option<Host> {
pub fn get_host(&self) -> &Host {
&self.host
}

/// Adds a port to the configuration.
///
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
/// as hosts.
pub fn port(&mut self, port: u16) -> &mut Config {
self.port = Some(port);
self
}

/// Gets the ports that have been added to the configuration with `port`.
pub fn get_ports(&self) -> &Option<u16> {
&self.port
pub fn get_port(&self) -> u16 {
self.port
}

/// Sets the timeout applied to socket-level connection attempts.
Expand Down
7 changes: 2 additions & 5 deletions libs/proxy/tokio-postgres2/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@ pub async fn connect<T>(
where
T: MakeTlsConnect<TcpStream>,
{
let Some(host) = &config.host else {
return Err(Error::config("host missing".into()));
};

let port = config.port.unwrap_or(5432);
let host = &config.host;
let port = config.port;

let hostname = match host {
Host::Tcp(host) => host.as_str(),
Expand Down
4 changes: 1 addition & 3 deletions libs/proxy/tokio-postgres2/src/connect_raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut params = config.server_params.clone();
params
.insert("client_encoding", "UTF8")
.expect("value does not contain null");
params.insert("client_encoding", "UTF8");

let mut buf = BytesMut::new();
frontend::startup_message(&params, &mut buf).map_err(Error::encode)?;
Expand Down
8 changes: 2 additions & 6 deletions proxy/src/auth/backend/console_redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,8 @@ async fn authenticate(

// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port);
config.dbname(&db_info.dbname).user(&db_info.user);

ctx.set_dbname(db_info.dbname.into());
ctx.set_user(db_info.user.into());
Expand Down
7 changes: 1 addition & 6 deletions proxy/src/auth/backend/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@ impl LocalBackend {
api: http::Endpoint::new(compute_ctl, http::new_client()),
},
node_info: NodeInfo {
config: {
let mut cfg = ConnCfg::new();
cfg.host(&postgres_addr.ip().to_string());
cfg.port(postgres_addr.port());
cfg
},
config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()),
// TODO(conrad): make this better reflect compute info rather than endpoint info.
aux: MetricsAuxInfo {
endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),
Expand Down
11 changes: 6 additions & 5 deletions proxy/src/cancellation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ impl ReportableError for CancelError {
impl<P: CancellationPublisher> CancellationHandler<P> {
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
// HACK: We'd rather get the real backend_pid but postgres_client doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
// we intentionally generate a random "backend pid" and "secret key" here.
// we use the corresponding u64 as an identifier for the
// actual endpoint+pid+secret for postgres/pgbouncer.
//
// if we forwarded the backend_pid from postgres to the client, there would be a lot
// of overlap between our computes as most pids are small (~100).
let key = loop {
let key = rand::random();

Expand Down
29 changes: 8 additions & 21 deletions proxy/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `postgres_client` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone, Default)]
#[derive(Clone)]
pub(crate) struct ConnCfg(Box<postgres_client::Config>);

/// Creation and initialization routines.
impl ConnCfg {
pub(crate) fn new() -> Self {
Self::default()
pub(crate) fn new(host: String, port: u16) -> Self {
Self(Box::new(postgres_client::Config::new(host, port)))
}

/// Reuse password or auth keys from the other config.
Expand All @@ -124,12 +124,9 @@ impl ConnCfg {
}
}

pub(crate) fn get_host(&self) -> Result<Host, WakeComputeError> {
match self.0.get_hosts() {
Some(postgres_client::config::Host::Tcp(s)) => Ok(s.into()),
None => Err(WakeComputeError::BadComputeAddress(
"invalid compute address".into(),
)),
pub(crate) fn get_host(&self) -> Host {
match self.0.get_host() {
postgres_client::config::Host::Tcp(s) => s.into(),
}
}

Expand Down Expand Up @@ -209,18 +206,8 @@ impl ConnCfg {
// We can't reuse connection establishing logic from `postgres_client` here,
// because it has no means for extracting the underlying socket which we
// require for our business.
let ports = self.0.get_ports();
let hosts = self.0.get_hosts();

let Some(host) = hosts else {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("bad compute config: {:?}", self.0),
));
};

let port = ports.unwrap_or(5432);
let host = match host {
let port = self.0.get_port();
let host = match self.0.get_host() {
Host::Tcp(host) => host.as_str(),
};

Expand Down
10 changes: 5 additions & 5 deletions proxy/src/control_plane/client/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ impl MockControlPlane {
}

async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.ssl_mode(postgres_client::config::SslMode::Disable);
let mut config = compute::ConnCfg::new(
self.endpoint.host_str().unwrap_or("localhost").to_owned(),
self.endpoint.port().unwrap_or(5432),
);
config.ssl_mode(postgres_client::config::SslMode::Disable);

let node = NodeInfo {
config,
Expand Down
4 changes: 2 additions & 2 deletions proxy/src/control_plane/client/neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ impl NeonControlPlaneClient {
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new();
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
let mut config = compute::ConnCfg::new(host.to_owned(), port);
config.ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.

let node = NodeInfo {
config,
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/proxy/connect_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
node_info: &control_plane::CachedNodeInfo,
timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> {
let host = node_info.config.get_host()?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;
permit.release_result(node_info.connect(ctx, timeout).await)
}
Expand Down
4 changes: 2 additions & 2 deletions proxy/src/proxy/tests/mitm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
Scram::new("password").await?,
));

let _client_err = postgres_client::Config::new()
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(postgres_client::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
Expand Down Expand Up @@ -247,7 +247,7 @@ async fn connect_failure(
Scram::new("password").await?,
));

let _client_err = postgres_client::Config::new()
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(channel_binding)
.user("user")
.dbname("db")
Expand Down
14 changes: 7 additions & 7 deletions proxy/src/proxy/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (_, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));

let client_err = postgres_client::Config::new()
let client_err = postgres_client::Config::new("test".to_owned(), 5432)
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Disable)
Expand Down Expand Up @@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));

let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Require)
Expand All @@ -249,7 +249,7 @@ async fn handshake_raw() -> anyhow::Result<()> {

let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));

let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.user("john_doe")
.dbname("earth")
.set_param("options", "project=generic-project-name")
Expand Down Expand Up @@ -296,7 +296,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
Scram::new(password).await?,
));

let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(postgres_client::config::ChannelBinding::Require)
.user("user")
.dbname("db")
Expand All @@ -320,7 +320,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
Scram::new("password").await?,
));

let _conn = postgres_client::Config::new()
let _conn = postgres_client::Config::new("test".to_owned(), 5432)
.channel_binding(postgres_client::config::ChannelBinding::Disable)
.user("user")
.dbname("db")
Expand Down Expand Up @@ -348,7 +348,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
.map(char::from)
.collect();

let _client_err = postgres_client::Config::new()
let _client_err = postgres_client::Config::new("test".to_owned(), 5432)
.user("user")
.dbname("db")
.password(&password) // no password will match the mocked secret
Expand Down Expand Up @@ -546,7 +546,7 @@ impl TestControlPlaneClient for TestConnectMechanism {

fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo {
config: compute::ConnCfg::new(),
config: compute::ConnCfg::new("localhost".to_owned(), 5432),
aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
Expand Down
10 changes: 3 additions & 7 deletions proxy/src/serverless/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ impl ConnectMechanism for TokioMechanism {
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;

let mut config = (*node_info.config).clone();
Expand Down Expand Up @@ -548,16 +548,12 @@ impl ConnectMechanism for HyperMechanism {
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let host = node_info.config.get_host();
let permit = self.locks.get_permit(&host).await?;

let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);

let port = node_info.config.get_ports().ok_or_else(|| {
HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress(
"local-proxy port missing on compute address".into(),
))
})?;
let port = node_info.config.get_port();
let res = connect_http2(&host, port, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
Expand Down

0 comments on commit df6a6b1

Please sign in to comment.