Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(proxy): add option to forward startup params #9979

Merged
merged 12 commits into from
Dec 4, 2024
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ bindgen = "0.70"
bit_field = "0.10.2"
bstr = "1.0"
byteorder = "1.4"
bytes = "1.0"
bytes = "1.9"
camino = "1.1.6"
cfg-if = "1.0.0"
chrono = { version = "0.4", default-features = false, features = ["clock"] }
Expand Down
2 changes: 1 addition & 1 deletion libs/pq_proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl StartupMessageParamsBuilder {

#[derive(Debug, Clone, Default)]
pub struct StartupMessageParams {
params: Bytes,
pub params: Bytes,
}

impl StartupMessageParams {
Expand Down
4 changes: 2 additions & 2 deletions libs/proxy/postgres-protocol2/src/authentication/sasl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ enum Credentials<const N: usize> {
/// A regular password as a vector of bytes.
Password(Vec<u8>),
/// A precomputed pair of keys.
Keys(Box<ScramKeys<N>>),
Keys(ScramKeys<N>),
}

enum State {
Expand Down Expand Up @@ -176,7 +176,7 @@ impl ScramSha256 {

/// Constructs a new instance which will use the provided key pair for authentication.
pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
let password = Credentials::Keys(keys.into());
let password = Credentials::Keys(keys);
ScramSha256::new_inner(password, channel_binding, nonce())
}

Expand Down
28 changes: 20 additions & 8 deletions libs/proxy/postgres-protocol2/src/message/frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,22 +255,34 @@ pub fn ssl_request(buf: &mut BytesMut) {
}

#[inline]
pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()>
where
I: IntoIterator<Item = (&'a str, &'a str)>,
{
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
write_body(buf, |buf| {
// postgres protocol version 3.0(196608) in bigger-endian
buf.put_i32(0x00_03_00_00);
for (key, value) in parameters {
write_cstr(key.as_bytes(), buf)?;
write_cstr(value.as_bytes(), buf)?;
}
buf.put_slice(&parameters.params);
buf.put_u8(0);
Ok(())
})
}

#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct StartupMessageParams {
pub params: BytesMut,
}

impl StartupMessageParams {
/// Set parameter's value by its name.
pub fn insert(&mut self, name: &str, value: &str) {
if name.contains('\0') || value.contains('\0') {
panic!("startup parameter name or value contained a null")
}
self.params.put_slice(name.as_bytes());
self.params.put_u8(0);
self.params.put_slice(value.as_bytes());
self.params.put_u8(0);
}
}

#[inline]
pub fn sync(buf: &mut BytesMut) {
buf.put_u8(b'S');
Expand Down
13 changes: 1 addition & 12 deletions libs/proxy/tokio-postgres2/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ impl FallibleIterator for BackendMessages {
}
}

pub struct PostgresCodec {
pub max_message_size: Option<usize>,
}
pub struct PostgresCodec;

impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
Expand Down Expand Up @@ -66,15 +64,6 @@ impl Decoder for PostgresCodec {
break;
}

if let Some(max) = self.max_message_size {
if len > max {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"message too large",
));
}
}

match header.tag() {
backend::NOTICE_RESPONSE_TAG
| backend::NOTIFICATION_RESPONSE_TAG
Expand Down
Loading
Loading