diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8d17f4d6b..e82c709ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,9 +3,11 @@ name: CI on: pull_request: branches: + - neon - master push: branches: + - neon - master env: @@ -17,7 +19,7 @@ jobs: name: rustfmt runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: sfackler/actions/rustup@master - uses: sfackler/actions/rustfmt@master @@ -25,7 +27,7 @@ jobs: name: clippy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: sfackler/actions/rustup@master - run: echo "::set-output name=version::$(rustc --version)" id: rust-version @@ -45,17 +47,17 @@ jobs: with: path: target key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - - run: cargo clippy --all --all-targets + - run: cargo clippy --workspace --all-targets test: name: test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.62.0 + version: 1.67.0 - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - uses: actions/cache@v1 diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 0315ac805..051a12000 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -64,6 +64,7 @@ port = 5433 ssl = on ssl_cert_file = 'server.crt' ssl_key_file = 'server.key' +wal_level = logical EOCONF cat > "$PGDATA/pg_hba.conf" <<-EOCONF @@ -82,6 +83,7 @@ host all ssl_user ::0/0 reject # IPv4 local connections: host all postgres 0.0.0.0/0 trust +host replication postgres 0.0.0.0/0 trust # IPv6 local connections: host all postgres ::0/0 trust # Unix socket connections: diff --git a/postgres-derive-test/src/lib.rs b/postgres-derive-test/src/lib.rs index d1478ac4c..f0534f32c 100644 --- a/postgres-derive-test/src/lib.rs +++ b/postgres-derive-test/src/lib.rs @@ -14,7 +14,7 @@ where T: PartialEq + FromSqlOwned + ToSql + Sync, S: fmt::Display, { - for &(ref val, ref repr) in checks.iter() { + for (val, repr) in checks.iter() { let stmt = conn .prepare(&format!("SELECT {}::{}", *repr, sql_type)) .unwrap(); @@ -38,7 +38,7 @@ pub fn test_type_asymmetric( S: fmt::Display, C: Fn(&T, &F) -> bool, { - for &(ref val, ref repr) in checks.iter() { + for (val, repr) in checks.iter() { let stmt = conn .prepare(&format!("SELECT {}::{}", *repr, sql_type)) .unwrap(); diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index 2a72cc60c..23cdc1b13 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -9,11 +9,12 @@ repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" [dependencies] -base64 = "0.20" +base64 = "0.22" byteorder = "1.0" bytes = "1.0" fallible-iterator = "0.2" hmac = "0.12" +lazy_static = "1.4" md-5 = "0.10" memchr = "2.0" rand = "0.8" diff --git a/postgres-protocol/src/authentication/sasl.rs b/postgres-protocol/src/authentication/sasl.rs index ea2f55cad..41d0e41b0 100644 --- a/postgres-protocol/src/authentication/sasl.rs +++ b/postgres-protocol/src/authentication/sasl.rs @@ -96,14 +96,32 @@ impl ChannelBinding { } } +/// A pair of keys for the SCRAM-SHA-256 mechanism. +/// See for details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ScramKeys { + /// Used by server to authenticate client. + pub client_key: [u8; N], + /// Used by client to verify server's signature. + pub server_key: [u8; N], +} + +/// Password or keys which were derived from it. +enum Credentials { + /// A regular password as a vector of bytes. + Password(Vec), + /// A precomputed pair of keys. + Keys(Box>), +} + enum State { Update { nonce: String, - password: Vec, + password: Credentials<32>, channel_binding: ChannelBinding, }, Finish { - salted_password: [u8; 32], + server_key: [u8; 32], auth_message: String, }, Done, @@ -129,30 +147,43 @@ pub struct ScramSha256 { state: State, } +fn nonce() -> String { + // rand 0.5's ThreadRng is cryptographically secure + let mut rng = rand::thread_rng(); + (0..NONCE_LENGTH) + .map(|_| { + let mut v = rng.gen_range(0x21u8..0x7e); + if v == 0x2c { + v = 0x7e + } + v as char + }) + .collect() +} + impl ScramSha256 { /// Constructs a new instance which will use the provided password for authentication. pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 { - // rand 0.5's ThreadRng is cryptographically secure - let mut rng = rand::thread_rng(); - let nonce = (0..NONCE_LENGTH) - .map(|_| { - let mut v = rng.gen_range(0x21u8..0x7e); - if v == 0x2c { - v = 0x7e - } - v as char - }) - .collect::(); + let password = Credentials::Password(normalize(password)); + ScramSha256::new_inner(password, channel_binding, nonce()) + } - ScramSha256::new_inner(password, channel_binding, nonce) + /// Constructs a new instance which will use the provided key pair for authentication. + pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 { + let password = Credentials::Keys(keys.into()); + ScramSha256::new_inner(password, channel_binding, nonce()) } - fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 { + fn new_inner( + password: Credentials<32>, + channel_binding: ChannelBinding, + nonce: String, + ) -> ScramSha256 { ScramSha256 { message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce), state: State::Update { nonce, - password: normalize(password), + password, channel_binding, }, } @@ -189,20 +220,32 @@ impl ScramSha256 { return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce")); } - let salt = match base64::decode(parsed.salt) { - Ok(salt) => salt, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), - }; + let (client_key, server_key) = match password { + Credentials::Password(password) => { + let salt = match base64::decode(parsed.salt) { + Ok(salt) => salt, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; - let salted_password = hi(&password, &salt, parsed.iteration_count); + let salted_password = hi(&password, &salt, parsed.iteration_count); - let mut hmac = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes"); - hmac.update(b"Client Key"); - let client_key = hmac.finalize().into_bytes(); + let make_key = |name| { + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(name); + + let mut key = [0u8; 32]; + key.copy_from_slice(hmac.finalize().into_bytes().as_slice()); + key + }; + + (make_key(b"Client Key"), make_key(b"Server Key")) + } + Credentials::Keys(keys) => (keys.client_key, keys.server_key), + }; let mut hash = Sha256::default(); - hash.update(client_key.as_slice()); + hash.update(client_key); let stored_key = hash.finalize_fixed(); let mut cbind_input = vec![]; @@ -225,10 +268,10 @@ impl ScramSha256 { *proof ^= signature; } - write!(&mut self.message, ",p={}", base64::encode(&*client_proof)).unwrap(); + write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap(); self.state = State::Finish { - salted_password, + server_key, auth_message, }; Ok(()) @@ -239,11 +282,11 @@ impl ScramSha256 { /// This should be called when the backend sends an `AuthenticationSASLFinal` message. /// Authentication has only succeeded if this method returns `Ok(())`. pub fn finish(&mut self, message: &[u8]) -> io::Result<()> { - let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) { + let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) { State::Finish { - salted_password, + server_key, auth_message, - } => (salted_password, auth_message), + } => (server_key, auth_message), _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), }; @@ -267,11 +310,6 @@ impl ScramSha256 { Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), }; - let mut hmac = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes"); - hmac.update(b"Server Key"); - let server_key = hmac.finalize().into_bytes(); - let mut hmac = Hmac::::new_from_slice(&server_key) .expect("HMAC is able to accept all key sizes"); hmac.update(auth_message.as_bytes()); @@ -351,7 +389,7 @@ impl<'a> Parser<'a> { } fn posit_number(&mut self) -> io::Result { - let n = self.take_while(|c| matches!(c, '0'..='9'))?; + let n = self.take_while(|c| c.is_ascii_digit())?; n.parse() .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) } @@ -458,7 +496,7 @@ mod test { let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw="; let mut scram = ScramSha256::new_inner( - password.as_bytes(), + Credentials::Password(normalize(password.as_bytes())), ChannelBinding::unsupported(), nonce.to_string(), ); diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index 8b6ff508d..1f7aa7923 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -14,7 +14,9 @@ use byteorder::{BigEndian, ByteOrder}; use bytes::{BufMut, BytesMut}; +use lazy_static::lazy_static; use std::io; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; pub mod authentication; pub mod escape; @@ -28,6 +30,11 @@ pub type Oid = u32; /// A Postgres Log Sequence Number (LSN). pub type Lsn = u64; +lazy_static! { + /// Postgres epoch is 2000-01-01T00:00:00Z + pub static ref PG_EPOCH: SystemTime = UNIX_EPOCH + Duration::from_secs(946_684_800); +} + /// An enum indicating if a value is `NULL` or not. pub enum IsNull { /// The value is `NULL`. diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 3f5374d64..b6883cc3c 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -8,9 +8,11 @@ use std::cmp; use std::io::{self, Read}; use std::ops::Range; use std::str; +use std::time::{Duration, SystemTime}; -use crate::Oid; +use crate::{Lsn, Oid, PG_EPOCH}; +// top-level message tags pub const PARSE_COMPLETE_TAG: u8 = b'1'; pub const BIND_COMPLETE_TAG: u8 = b'2'; pub const CLOSE_COMPLETE_TAG: u8 = b'3'; @@ -22,6 +24,7 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -33,6 +36,33 @@ pub const PARAMETER_DESCRIPTION_TAG: u8 = b't'; pub const ROW_DESCRIPTION_TAG: u8 = b'T'; pub const READY_FOR_QUERY_TAG: u8 = b'Z'; +// replication message tags +pub const XLOG_DATA_TAG: u8 = b'w'; +pub const PRIMARY_KEEPALIVE_TAG: u8 = b'k'; + +// logical replication message tags +const BEGIN_TAG: u8 = b'B'; +const COMMIT_TAG: u8 = b'C'; +const ORIGIN_TAG: u8 = b'O'; +const RELATION_TAG: u8 = b'R'; +const TYPE_TAG: u8 = b'Y'; +const INSERT_TAG: u8 = b'I'; +const UPDATE_TAG: u8 = b'U'; +const DELETE_TAG: u8 = b'D'; +const TRUNCATE_TAG: u8 = b'T'; +const TUPLE_NEW_TAG: u8 = b'N'; +const TUPLE_KEY_TAG: u8 = b'K'; +const TUPLE_OLD_TAG: u8 = b'O'; +const TUPLE_DATA_NULL_TAG: u8 = b'n'; +const TUPLE_DATA_TOAST_TAG: u8 = b'u'; +const TUPLE_DATA_TEXT_TAG: u8 = b't'; + +// replica identity tags +const REPLICA_IDENTITY_DEFAULT_TAG: u8 = b'd'; +const REPLICA_IDENTITY_NOTHING_TAG: u8 = b'n'; +const REPLICA_IDENTITY_FULL_TAG: u8 = b'f'; +const REPLICA_IDENTITY_INDEX_TAG: u8 = b'i'; + #[derive(Debug, Copy, Clone)] pub struct Header { tag: u8, @@ -93,6 +123,7 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), + CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -190,6 +221,16 @@ impl Message { storage, }) } + COPY_BOTH_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyBothResponse(CopyBothResponseBody { + format, + len, + storage, + }) + } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; @@ -278,6 +319,69 @@ impl Message { } } +/// An enum representing Postgres backend replication messages. +#[non_exhaustive] +#[derive(Debug)] +pub enum ReplicationMessage { + XLogData(XLogDataBody), + PrimaryKeepAlive(PrimaryKeepAliveBody), +} + +impl ReplicationMessage { + #[inline] + pub fn parse(buf: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let replication_message = match tag { + XLOG_DATA_TAG => { + let wal_start = buf.read_u64::()?; + let wal_end = buf.read_u64::()?; + let ts = buf.read_i64::()?; + let timestamp = if ts > 0 { + *PG_EPOCH + Duration::from_micros(ts as u64) + } else { + *PG_EPOCH - Duration::from_micros(-ts as u64) + }; + let data = buf.read_all(); + ReplicationMessage::XLogData(XLogDataBody { + wal_start, + wal_end, + timestamp, + data, + }) + } + PRIMARY_KEEPALIVE_TAG => { + let wal_end = buf.read_u64::()?; + let ts = buf.read_i64::()?; + let timestamp = if ts > 0 { + *PG_EPOCH + Duration::from_micros(ts as u64) + } else { + *PG_EPOCH - Duration::from_micros(-ts as u64) + }; + let reply = buf.read_u8()?; + ReplicationMessage::PrimaryKeepAlive(PrimaryKeepAliveBody { + wal_end, + timestamp, + reply, + }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(replication_message) + } +} + struct Buffer { bytes: Bytes, idx: usize, @@ -524,6 +628,27 @@ impl CopyOutResponseBody { } } +pub struct CopyBothResponseBody { + storage: Bytes, + len: u16, + format: u8, +} + +impl CopyBothResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + #[derive(Debug)] pub struct DataRowBody { storage: Bytes, @@ -582,7 +707,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> { )); } let base = self.len - self.buf.len(); - self.buf = &self.buf[len as usize..]; + self.buf = &self.buf[len..]; Ok(Some(Some(base..base + len))) } } @@ -777,6 +902,655 @@ impl RowDescriptionBody { } } +#[derive(Debug)] +pub struct XLogDataBody { + wal_start: u64, + wal_end: u64, + timestamp: SystemTime, + data: D, +} + +impl XLogDataBody { + #[inline] + pub fn wal_start(&self) -> u64 { + self.wal_start + } + + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> SystemTime { + self.timestamp + } + + #[inline] + pub fn data(&self) -> &D { + &self.data + } + + #[inline] + pub fn into_data(self) -> D { + self.data + } + + pub fn map_data(self, f: F) -> Result, E> + where + F: Fn(D) -> Result, + { + let data = f(self.data)?; + Ok(XLogDataBody { + wal_start: self.wal_start, + wal_end: self.wal_end, + timestamp: self.timestamp, + data, + }) + } +} + +#[derive(Debug)] +pub struct PrimaryKeepAliveBody { + wal_end: u64, + timestamp: SystemTime, + reply: u8, +} + +impl PrimaryKeepAliveBody { + #[inline] + pub fn wal_end(&self) -> u64 { + self.wal_end + } + + #[inline] + pub fn timestamp(&self) -> SystemTime { + self.timestamp + } + + #[inline] + pub fn reply(&self) -> u8 { + self.reply + } +} + +#[non_exhaustive] +/// A message of the logical replication stream +#[derive(Debug)] +pub enum LogicalReplicationMessage { + /// A BEGIN statement + Begin(BeginBody), + /// A BEGIN statement + Commit(CommitBody), + /// An Origin replication message + /// Note that there can be multiple Origin messages inside a single transaction. + Origin(OriginBody), + /// A Relation replication message + Relation(RelationBody), + /// A Type replication message + Type(TypeBody), + /// An INSERT statement + Insert(InsertBody), + /// An UPDATE statement + Update(UpdateBody), + /// A DELETE statement + Delete(DeleteBody), + /// A TRUNCATE statement + Truncate(TruncateBody), +} + +impl LogicalReplicationMessage { + pub fn parse(buf: &Bytes) -> io::Result { + let mut buf = Buffer { + bytes: buf.clone(), + idx: 0, + }; + + let tag = buf.read_u8()?; + + let logical_replication_message = match tag { + BEGIN_TAG => Self::Begin(BeginBody { + final_lsn: buf.read_u64::()?, + timestamp: buf.read_i64::()?, + xid: buf.read_u32::()?, + }), + COMMIT_TAG => Self::Commit(CommitBody { + flags: buf.read_i8()?, + commit_lsn: buf.read_u64::()?, + end_lsn: buf.read_u64::()?, + timestamp: buf.read_i64::()?, + }), + ORIGIN_TAG => Self::Origin(OriginBody { + commit_lsn: buf.read_u64::()?, + name: buf.read_cstr()?, + }), + RELATION_TAG => { + let rel_id = buf.read_u32::()?; + let namespace = buf.read_cstr()?; + let name = buf.read_cstr()?; + let replica_identity = match buf.read_u8()? { + REPLICA_IDENTITY_DEFAULT_TAG => ReplicaIdentity::Default, + REPLICA_IDENTITY_NOTHING_TAG => ReplicaIdentity::Nothing, + REPLICA_IDENTITY_FULL_TAG => ReplicaIdentity::Full, + REPLICA_IDENTITY_INDEX_TAG => ReplicaIdentity::Index, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replica identity tag `{}`", tag), + )); + } + }; + let column_len = buf.read_i16::()?; + + let mut columns = Vec::with_capacity(column_len as usize); + for _ in 0..column_len { + columns.push(Column::parse(&mut buf)?); + } + + Self::Relation(RelationBody { + rel_id, + namespace, + name, + replica_identity, + columns, + }) + } + TYPE_TAG => Self::Type(TypeBody { + id: buf.read_u32::()?, + namespace: buf.read_cstr()?, + name: buf.read_cstr()?, + }), + INSERT_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + }; + + Self::Insert(InsertBody { rel_id, tuple }) + } + UPDATE_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + let new_tuple = match tag { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + TUPLE_OLD_TAG | TUPLE_KEY_TAG => { + if tag == TUPLE_OLD_TAG { + old_tuple = Some(Tuple::parse(&mut buf)?); + } else { + key_tuple = Some(Tuple::parse(&mut buf)?); + } + + match buf.read_u8()? { + TUPLE_NEW_TAG => Tuple::parse(&mut buf)?, + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected tuple tag `{}`", tag), + )); + } + } + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + }; + + Self::Update(UpdateBody { + rel_id, + key_tuple, + old_tuple, + new_tuple, + }) + } + DELETE_TAG => { + let rel_id = buf.read_u32::()?; + let tag = buf.read_u8()?; + + let mut key_tuple = None; + let mut old_tuple = None; + + match tag { + TUPLE_OLD_TAG => old_tuple = Some(Tuple::parse(&mut buf)?), + TUPLE_KEY_TAG => key_tuple = Some(Tuple::parse(&mut buf)?), + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown tuple tag `{}`", tag), + )); + } + } + + Self::Delete(DeleteBody { + rel_id, + key_tuple, + old_tuple, + }) + } + TRUNCATE_TAG => { + let relation_len = buf.read_i32::()?; + let options = buf.read_i8()?; + + let mut rel_ids = Vec::with_capacity(relation_len as usize); + for _ in 0..relation_len { + rel_ids.push(buf.read_u32::()?); + } + + Self::Truncate(TruncateBody { options, rel_ids }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(logical_replication_message) + } +} + +/// A row as it appears in the replication stream +#[derive(Debug)] +pub struct Tuple(Vec); + +impl Tuple { + #[inline] + /// The tuple data of this tuple + pub fn tuple_data(&self) -> &[TupleData] { + &self.0 + } +} + +impl Tuple { + fn parse(buf: &mut Buffer) -> io::Result { + let col_len = buf.read_i16::()?; + let mut tuple = Vec::with_capacity(col_len as usize); + for _ in 0..col_len { + tuple.push(TupleData::parse(buf)?); + } + + Ok(Tuple(tuple)) + } +} + +/// A column as it appears in the replication stream +#[derive(Debug)] +pub struct Column { + flags: i8, + name: Bytes, + type_id: i32, + type_modifier: i32, +} + +impl Column { + #[inline] + /// Flags for the column. Currently can be either 0 for no flags or 1 which marks the column as + /// part of the key. + pub fn flags(&self) -> i8 { + self.flags + } + + #[inline] + /// Name of the column. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// ID of the column's data type. + pub fn type_id(&self) -> i32 { + self.type_id + } + + #[inline] + /// Type modifier of the column (`atttypmod`). + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } +} + +impl Column { + fn parse(buf: &mut Buffer) -> io::Result { + Ok(Self { + flags: buf.read_i8()?, + name: buf.read_cstr()?, + type_id: buf.read_i32::()?, + type_modifier: buf.read_i32::()?, + }) + } +} + +/// The data of an individual column as it appears in the replication stream +#[derive(Debug)] +pub enum TupleData { + /// Represents a NULL value + Null, + /// Represents an unchanged TOASTed value (the actual value is not sent). + UnchangedToast, + /// Column data as text formatted value. + Text(Bytes), +} + +impl TupleData { + fn parse(buf: &mut Buffer) -> io::Result { + let type_tag = buf.read_u8()?; + + let tuple = match type_tag { + TUPLE_DATA_NULL_TAG => TupleData::Null, + TUPLE_DATA_TOAST_TAG => TupleData::UnchangedToast, + TUPLE_DATA_TEXT_TAG => { + let len = buf.read_i32::()?; + let mut data = vec![0; len as usize]; + buf.read_exact(&mut data)?; + TupleData::Text(data.into()) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown replication message tag `{}`", tag), + )); + } + }; + + Ok(tuple) + } +} + +/// A BEGIN statement +#[derive(Debug)] +pub struct BeginBody { + final_lsn: u64, + timestamp: i64, + xid: u32, +} + +impl BeginBody { + #[inline] + /// Gets the final lsn of the transaction + pub fn final_lsn(&self) -> Lsn { + self.final_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Xid of the transaction. + pub fn xid(&self) -> u32 { + self.xid + } +} + +/// A COMMIT statement +#[derive(Debug)] +pub struct CommitBody { + flags: i8, + commit_lsn: u64, + end_lsn: u64, + timestamp: i64, +} + +impl CommitBody { + #[inline] + /// The LSN of the commit. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// The end LSN of the transaction. + pub fn end_lsn(&self) -> Lsn { + self.end_lsn + } + + #[inline] + /// Commit timestamp of the transaction. The value is in number of microseconds since PostgreSQL epoch (2000-01-01). + pub fn timestamp(&self) -> i64 { + self.timestamp + } + + #[inline] + /// Flags; currently unused (will be 0). + pub fn flags(&self) -> i8 { + self.flags + } +} + +/// An Origin replication message +/// +/// Note that there can be multiple Origin messages inside a single transaction. +#[derive(Debug)] +pub struct OriginBody { + commit_lsn: u64, + name: Bytes, +} + +impl OriginBody { + #[inline] + /// The LSN of the commit on the origin server. + pub fn commit_lsn(&self) -> Lsn { + self.commit_lsn + } + + #[inline] + /// Name of the origin. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// Describes the REPLICA IDENTITY setting of a table +#[derive(Debug)] +pub enum ReplicaIdentity { + /// default selection for replica identity (primary key or nothing) + Default, + /// no replica identity is logged for this relation + Nothing, + /// all columns are logged as replica identity + Full, + /// An explicitly chosen candidate key's columns are used as replica identity. + /// Note this will still be set if the index has been dropped; in that case it + /// has the same meaning as 'd'. + Index, +} + +/// A Relation replication message +#[derive(Debug)] +pub struct RelationBody { + rel_id: u32, + namespace: Bytes, + name: Bytes, + replica_identity: ReplicaIdentity, + columns: Vec, +} + +impl RelationBody { + #[inline] + /// ID of the relation. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Relation name. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + /// Replica identity setting for the relation + pub fn replica_identity(&self) -> &ReplicaIdentity { + &self.replica_identity + } + + #[inline] + /// The column definitions of this relation + pub fn columns(&self) -> &[Column] { + &self.columns + } +} + +/// A Type replication message +#[derive(Debug)] +pub struct TypeBody { + id: u32, + namespace: Bytes, + name: Bytes, +} + +impl TypeBody { + #[inline] + /// ID of the data type. + pub fn id(&self) -> Oid { + self.id + } + + #[inline] + /// Namespace (empty string for pg_catalog). + pub fn namespace(&self) -> io::Result<&str> { + get_str(&self.namespace) + } + + #[inline] + /// Name of the data type. + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } +} + +/// An INSERT statement +#[derive(Debug)] +pub struct InsertBody { + rel_id: u32, + tuple: Tuple, +} + +impl InsertBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// The inserted tuple + pub fn tuple(&self) -> &Tuple { + &self.tuple + } +} + +/// An UPDATE statement +#[derive(Debug)] +pub struct UpdateBody { + rel_id: u32, + old_tuple: Option, + key_tuple: Option, + new_tuple: Tuple, +} + +impl UpdateBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is optional and is only present if the update changed data in any of the + /// column(s) that are part of the REPLICA IDENTITY index. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is optional and is only present if table in which the update happened has + /// REPLICA IDENTITY set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } + + #[inline] + /// The new tuple + pub fn new_tuple(&self) -> &Tuple { + &self.new_tuple + } +} + +/// A DELETE statement +#[derive(Debug)] +pub struct DeleteBody { + rel_id: u32, + old_tuple: Option, + key_tuple: Option, +} + +impl DeleteBody { + #[inline] + /// ID of the relation corresponding to the ID in the relation message. + pub fn rel_id(&self) -> u32 { + self.rel_id + } + + #[inline] + /// This field is present if the table in which the delete has happened uses an index as + /// REPLICA IDENTITY. + pub fn key_tuple(&self) -> Option<&Tuple> { + self.key_tuple.as_ref() + } + + #[inline] + /// This field is present if the table in which the delete has happened has REPLICA IDENTITY + /// set to FULL. + pub fn old_tuple(&self) -> Option<&Tuple> { + self.old_tuple.as_ref() + } +} + +/// A TRUNCATE statement +#[derive(Debug)] +pub struct TruncateBody { + options: i8, + rel_ids: Vec, +} + +impl TruncateBody { + #[inline] + /// The IDs of the relations corresponding to the ID in the relation messages + pub fn rel_ids(&self) -> &[u32] { + &self.rel_ids + } + + #[inline] + /// Option bits for TRUNCATE: 1 for CASCADE, 2 for RESTART IDENTITY + pub fn options(&self) -> i8 { + self.options + } +} + pub struct Fields<'a> { buf: &'a [u8], remaining: u16, diff --git a/postgres-protocol/src/types/test.rs b/postgres-protocol/src/types/test.rs index 6f1851fc2..3e33b08f0 100644 --- a/postgres-protocol/src/types/test.rs +++ b/postgres-protocol/src/types/test.rs @@ -174,7 +174,7 @@ fn ltree_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) + assert!(ltree_from_sql(query.as_slice()).is_ok()) } #[test] @@ -182,7 +182,7 @@ fn ltree_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) + assert!(ltree_from_sql(query.as_slice()).is_err()) } #[test] @@ -202,7 +202,7 @@ fn lquery_str() { let mut query = vec![1u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(lquery_from_sql(query.as_slice()), Ok(_))) + assert!(lquery_from_sql(query.as_slice()).is_ok()) } #[test] @@ -210,7 +210,7 @@ fn lquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("A.B.C".as_bytes()); - assert!(matches!(lquery_from_sql(query.as_slice()), Err(_))) + assert!(lquery_from_sql(query.as_slice()).is_err()) } #[test] @@ -230,7 +230,7 @@ fn ltxtquery_str() { let mut query = vec![1u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Ok(_))) + assert!(ltree_from_sql(query.as_slice()).is_ok()) } #[test] @@ -238,5 +238,5 @@ fn ltxtquery_wrong_version() { let mut query = vec![2u8]; query.extend_from_slice("a & b*".as_bytes()); - assert!(matches!(ltree_from_sql(query.as_slice()), Err(_))) + assert!(ltree_from_sql(query.as_slice()).is_err()) } diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml index 70f1ed54a..3f9245c73 100644 --- a/postgres-types/Cargo.toml +++ b/postgres-types/Cargo.toml @@ -16,7 +16,6 @@ array-impls = ["array-init"] with-bit-vec-0_6 = ["bit-vec-06"] with-cidr-0_2 = ["cidr-02"] with-chrono-0_4 = ["chrono-04"] -with-eui48-0_4 = ["eui48-04"] with-eui48-1 = ["eui48-1"] with-geo-types-0_6 = ["geo-types-06"] with-geo-types-0_7 = ["geo-types-0_7"] @@ -39,7 +38,6 @@ chrono-04 = { version = "0.4.16", package = "chrono", default-features = false, "clock", ], optional = true } cidr-02 = { version = "0.2", package = "cidr", optional = true } -eui48-04 = { version = "0.4", package = "eui48", optional = true } eui48-1 = { version = "1.0", package = "eui48", optional = true } geo-types-06 = { version = "0.6", package = "geo-types", optional = true } geo-types-0_7 = { version = "0.7", package = "geo-types", optional = true } diff --git a/postgres-types/src/chrono_04.rs b/postgres-types/src/chrono_04.rs index 0ec92437d..aef7549d0 100644 --- a/postgres-types/src/chrono_04.rs +++ b/postgres-types/src/chrono_04.rs @@ -40,7 +40,7 @@ impl ToSql for NaiveDateTime { impl<'a> FromSql<'a> for DateTime { fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { let naive = NaiveDateTime::from_sql(type_, raw)?; - Ok(DateTime::from_utc(naive, Utc)) + Ok(DateTime::from_naive_utc_and_offset(naive, Utc)) } accepts!(TIMESTAMPTZ); diff --git a/postgres-types/src/eui48_04.rs b/postgres-types/src/eui48_04.rs deleted file mode 100644 index 45df89a84..000000000 --- a/postgres-types/src/eui48_04.rs +++ /dev/null @@ -1,27 +0,0 @@ -use bytes::BytesMut; -use eui48_04::MacAddress; -use postgres_protocol::types; -use std::error::Error; - -use crate::{FromSql, IsNull, ToSql, Type}; - -impl<'a> FromSql<'a> for MacAddress { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { - let bytes = types::macaddr_from_sql(raw)?; - Ok(MacAddress::new(bytes)) - } - - accepts!(MACADDR); -} - -impl ToSql for MacAddress { - fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { - let mut bytes = [0; 6]; - bytes.copy_from_slice(self.as_bytes()); - types::macaddr_to_sql(bytes, w); - Ok(IsNull::No) - } - - accepts!(MACADDR); - to_sql_checked!(); -} diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..b10d298de 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -214,8 +214,6 @@ mod bit_vec_06; mod chrono_04; #[cfg(feature = "with-cidr-0_2")] mod cidr_02; -#[cfg(feature = "with-eui48-0_4")] -mod eui48_04; #[cfg(feature = "with-eui48-1")] mod eui48_1; #[cfg(feature = "with-geo-types-0_6")] @@ -395,6 +393,22 @@ impl WrongType { } } +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + /// A trait for types that can be created from a Postgres value. /// /// # Types @@ -846,7 +860,7 @@ pub trait ToSql: fmt::Debug { /// Supported Postgres message format types /// /// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum Format { /// Text format (UTF-8) Text, diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index bd7c297f3..0e0cce541 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -24,7 +24,6 @@ circle-ci = { repository = "sfackler/rust-postgres" } array-impls = ["tokio-postgres/array-impls"] with-bit-vec-0_6 = ["tokio-postgres/with-bit-vec-0_6"] with-chrono-0_4 = ["tokio-postgres/with-chrono-0_4"] -with-eui48-0_4 = ["tokio-postgres/with-eui48-0_4"] with-eui48-1 = ["tokio-postgres/with-eui48-1"] with-geo-types-0_6 = ["tokio-postgres/with-geo-types-0_6"] with-geo-types-0_7 = ["tokio-postgres/with-geo-types-0_7"] diff --git a/postgres/src/client.rs b/postgres/src/client.rs index 29cac840d..34b60537c 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -439,7 +439,9 @@ impl Client { /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { - self.connection.block_on(self.client.batch_execute(query)) + self.connection + .block_on(self.client.batch_execute(query)) + .map(|_| ()) } /// Begins a new database transaction. diff --git a/postgres/src/config.rs b/postgres/src/config.rs index b541ec846..44e4bec3a 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -12,7 +12,9 @@ use std::sync::Arc; use std::time::Duration; use tokio::runtime; #[doc(inline)] -pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs}; +pub use tokio_postgres::config::{ + AuthKeys, ChannelBinding, Host, ScramKeys, SslMode, TargetSessionAttrs, +}; use tokio_postgres::error::DbError; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::{Error, Socket}; @@ -149,6 +151,20 @@ impl Config { self.config.get_password() } + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.config.auth_keys(keys); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.config.get_auth_keys() + } + /// Sets the name of the database to connect to. /// /// Defaults to the user. diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index fbe85cbde..7bbe50465 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -55,7 +55,6 @@ //! | ------- | ----------- | ------------------ | ------- | //! | `with-bit-vec-0_6` | Enable support for the `bit-vec` crate. | [bit-vec](https://crates.io/crates/bit-vec) 0.6 | no | //! | `with-chrono-0_4` | Enable support for the `chrono` crate. | [chrono](https://crates.io/crates/chrono) 0.4 | no | -//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 0.4 | no | //! | `with-eui48-1` | Enable support for the 1.0 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 1.0 | no | //! | `with-geo-types-0_6` | Enable support for the 0.6 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.6.0) 0.6 | no | //! | `with-geo-types-0_7` | Enable support for the 0.7 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.7.0) 0.7 | no | diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index 17c49c406..3d147cd01 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -35,6 +35,7 @@ impl<'a> Transaction<'a> { pub fn commit(mut self) -> Result<(), Error> { self.connection .block_on(self.transaction.take().unwrap().commit()) + .map(|_| ()) } /// Rolls the transaction back, discarding all changes made within it. @@ -43,6 +44,7 @@ impl<'a> Transaction<'a> { pub fn rollback(mut self) -> Result<(), Error> { self.connection .block_on(self.transaction.take().unwrap().rollback()) + .map(|_| ()) } /// Like `Client::prepare`. @@ -193,6 +195,7 @@ impl<'a> Transaction<'a> { pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { self.connection .block_on(self.transaction.as_ref().unwrap().batch_execute(query)) + .map(|_| ()) } /// Like `Client::cancel_token`. diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 68737f738..93b12cd73 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -30,7 +30,6 @@ runtime = ["tokio/net", "tokio/time"] array-impls = ["postgres-types/array-impls"] with-bit-vec-0_6 = ["postgres-types/with-bit-vec-0_6"] with-chrono-0_4 = ["postgres-types/with-chrono-0_4"] -with-eui48-0_4 = ["postgres-types/with-eui48-0_4"] with-eui48-1 = ["postgres-types/with-eui48-1"] with-geo-types-0_6 = ["postgres-types/with-geo-types-0_6"] with-geo-types-0_7 = ["postgres-types/with-geo-types-0_7"] @@ -55,7 +54,7 @@ pin-project-lite = "0.2" phf = "0.11" postgres-protocol = { version = "0.6.4", path = "../postgres-protocol" } postgres-types = { version = "0.2.4", path = "../postgres-types" } -socket2 = "0.4" +socket2 = { version = "0.5", features = ["all"] } tokio = { version = "1.0", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } @@ -73,7 +72,6 @@ tokio = { version = "1.0", features = [ bit-vec-06 = { version = "0.6", package = "bit-vec" } chrono-04 = { version = "0.4", package = "chrono", default-features = false } -eui48-04 = { version = "0.4", package = "eui48" } eui48-1 = { version = "1.0", package = "eui48" } geo-types-06 = { version = "0.6", package = "geo-types" } geo-types-07 = { version = "0.7", package = "geo-types" } diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index ad5aa2866..7c89e6069 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -3,6 +3,7 @@ use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::Host; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; +use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; @@ -15,8 +16,9 @@ use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, - Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, + copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, + CopyInSink, Error, ReadyForQueryStatus, Row, SimpleQueryMessage, Statement, ToStatement, + Transaction, TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -193,6 +195,11 @@ impl Client { } } + /// Returns process_id. + pub fn get_process_id(&self) -> i32 { + self.process_id + } + pub(crate) fn inner(&self) -> &Arc { &self.inner } @@ -372,6 +379,17 @@ impl Client { query::query(&self.inner, statement, params).await } + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + query::query_txt(&self.inner, statement, params).await + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list @@ -439,6 +457,14 @@ impl Client { copy_in::copy_in(self.inner(), statement).await } + /// Executes a `COPY FROM STDIN` query, returning a sink used to write the copy data. + pub async fn copy_in_simple(&self, query: &str) -> Result, Error> + where + U: Buf + 'static + Send, + { + copy_in::copy_in_simple(self.inner(), query).await + } + /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data. /// /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. @@ -454,6 +480,20 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } + /// Executes a `COPY TO STDOUT` query, returning a stream of the resulting data. + pub async fn copy_out_simple(&self, query: &str) -> Result { + copy_out::copy_out_simple(self.inner(), query).await + } + + /// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy + /// data. + pub async fn copy_both_simple(&self, query: &str) -> Result, Error> + where + T: Buf + 'static + Send, + { + copy_both::copy_both_simple(self.inner(), query).await + } + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that @@ -485,7 +525,7 @@ impl Client { /// Prepared statements should be use for any query which contains user-specified data, as they provided the /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! - pub async fn batch_execute(&self, query: &str) -> Result<(), Error> { + pub async fn batch_execute(&self, query: &str) -> Result { simple_query::batch_execute(self.inner(), query).await } @@ -587,6 +627,11 @@ impl Client { self.inner().clear_type_cache(); } + /// Query for type information + pub async fn get_type(&self, oid: Oid) -> Result { + crate::prepare::get_type(&self.inner, oid).await + } + /// Determines if the connection to the server has already closed. /// /// In that case, all future queries will fail. diff --git a/tokio-postgres/src/codec.rs b/tokio-postgres/src/codec.rs index 9d078044b..23c371542 100644 --- a/tokio-postgres/src/codec.rs +++ b/tokio-postgres/src/codec.rs @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages { } } -pub struct PostgresCodec; +pub struct PostgresCodec { + pub max_message_size: Option, +} impl Encoder for PostgresCodec { type Error = io::Error; @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec { break; } + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 5b364ec06..fdb5e6359 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -23,6 +23,8 @@ use std::time::Duration; use std::{error, fmt, iter, mem}; use tokio::io::{AsyncRead, AsyncWrite}; +pub use postgres_protocol::authentication::sasl::ScramKeys; + /// Properties required of a session. #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[non_exhaustive] @@ -57,6 +59,16 @@ pub enum ChannelBinding { Require, } +/// Replication mode configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -69,6 +81,13 @@ pub enum Host { Unix(PathBuf), } +/// Precomputed keys which may override password during auth. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AuthKeys { + /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`. + ScramSha256(ScramKeys<32>), +} + /// Connection configuration. /// /// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: @@ -153,6 +172,7 @@ pub enum Host { pub struct Config { pub(crate) user: Option, pub(crate) password: Option>, + pub(crate) auth_keys: Option>, pub(crate) dbname: Option, pub(crate) options: Option, pub(crate) application_name: Option, @@ -164,6 +184,8 @@ pub struct Config { pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, + pub(crate) replication_mode: Option, + pub(crate) max_backend_message_size: Option, } impl Default for Config { @@ -183,6 +205,7 @@ impl Config { Config { user: None, password: None, + auth_keys: None, dbname: None, options: None, application_name: None, @@ -194,6 +217,8 @@ impl Config { keepalive_config, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, + replication_mode: None, + max_backend_message_size: None, } } @@ -226,6 +251,20 @@ impl Config { self.password.as_deref() } + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.auth_keys = Some(Box::new(keys)); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.auth_keys.as_deref().copied() + } + /// Sets the name of the database to connect to. /// /// Defaults to the user. @@ -424,6 +463,28 @@ impl Config { self.channel_binding } + /// Set replication mode. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -527,6 +588,25 @@ impl Config { }; self.channel_binding(channel_binding); } + "replication" => { + let mode = match value { + "off" => None, + "true" => Some(ReplicationMode::Physical), + "database" => Some(ReplicationMode::Logical), + _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), + }; + if let Some(mode) = mode { + self.replication_mode(mode); + } + } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -601,6 +681,7 @@ impl fmt::Debug for Config { .field("keepalives_retries", &self.keepalive_config.retries) .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) .finish() } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index d97636221..0beead11f 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Config}; +use crate::config::{self, AuthKeys, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -90,7 +90,12 @@ where let stream = connect_tls(stream, config.ssl_mode, tls).await?; let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), buf: BackendMessages::empty(), delayed: VecDeque::new(), }; @@ -124,6 +129,12 @@ where if let Some(application_name) = &config.application_name { params.push(("application_name", &**application_name)); } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } let mut buf = BytesMut::new(); frontend::startup_message(params, &mut buf).map_err(Error::encode)?; @@ -228,11 +239,6 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { - let password = config - .password - .as_ref() - .ok_or_else(|| Error::config("password missing".into()))?; - let mut has_scram = false; let mut has_scram_plus = false; let mut mechanisms = body.mechanisms(); @@ -270,7 +276,13 @@ where can_skip_channel_binding(config)?; } - let mut scram = ScramSha256::new(password, channel_binding); + let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() { + ScramSha256::new_with_keys(keys, channel_binding) + } else if let Some(password) = config.get_password() { + ScramSha256::new(password, channel_binding) + } else { + return Err(Error::config("password or auth keys missing".into())); + }; let mut buf = BytesMut::new(); frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 30be4e834..8b8c7b6fa 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -1,4 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::copy_both::CopyBothReceiver; use crate::copy_in::CopyInReceiver; use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; @@ -20,6 +21,7 @@ use tokio_util::codec::Framed; pub enum RequestMessages { Single(FrontendMessage), CopyIn(CopyInReceiver), + CopyBoth(CopyBothReceiver), } pub struct Request { @@ -47,8 +49,10 @@ enum State { /// occurred, or because its associated `Client` has dropped and all outstanding work has completed. #[must_use = "futures do nothing unless polled"] pub struct Connection { - stream: Framed, PostgresCodec>, - parameters: HashMap, + /// HACK: we need this in the Neon Proxy. + pub stream: Framed, PostgresCodec>, + /// HACK: we need this in the Neon Proxy to forward params. + pub parameters: HashMap, receiver: mpsc::UnboundedReceiver, pending_request: Option, pending_responses: VecDeque, @@ -258,6 +262,24 @@ where .map_err(Error::io)?; self.pending_request = Some(RequestMessages::CopyIn(receiver)); } + RequestMessages::CopyBoth(mut receiver) => { + let message = match receiver.poll_next_unpin(cx) { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => { + trace!("poll_write: finished copy_both request"); + continue; + } + Poll::Pending => { + trace!("poll_write: waiting on copy_both stream"); + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + return Ok(true); + } + }; + Pin::new(&mut self.stream) + .start_send(message) + .map_err(Error::io)?; + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + } } } } diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs new file mode 100644 index 000000000..79a7be34a --- /dev/null +++ b/tokio-postgres/src/copy_both.rs @@ -0,0 +1,248 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{simple_query, Error}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures_channel::mpsc; +use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::CopyData; +use std::marker::{PhantomData, PhantomPinned}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub(crate) enum CopyBothMessage { + Message(FrontendMessage), + Done, +} + +pub struct CopyBothReceiver { + receiver: mpsc::Receiver, + done: bool, +} + +impl CopyBothReceiver { + pub(crate) fn new(receiver: mpsc::Receiver) -> CopyBothReceiver { + CopyBothReceiver { + receiver, + done: false, + } + } +} + +impl Stream for CopyBothReceiver { + type Item = FrontendMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.done { + return Poll::Ready(None); + } + + match ready!(self.receiver.poll_next_unpin(cx)) { + Some(CopyBothMessage::Message(message)) => Poll::Ready(Some(message)), + Some(CopyBothMessage::Done) => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + frontend::sync(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + None => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_fail("", &mut buf).unwrap(); + frontend::sync(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + } + } +} + +enum SinkState { + Active, + Closing, + Reading, +} + +pin_project! { + /// A sink for `COPY ... FROM STDIN` query data. + /// + /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is + /// not, the copy will be aborted. + pub struct CopyBothDuplex { + #[pin] + sender: mpsc::Sender, + responses: Responses, + buf: BytesMut, + state: SinkState, + #[pin] + _p: PhantomPinned, + _p2: PhantomData, + } +} + +impl CopyBothDuplex +where + T: Buf + 'static + Send, +{ + pub(crate) fn new(sender: mpsc::Sender, responses: Responses) -> Self { + Self { + sender, + responses, + buf: BytesMut::new(), + state: SinkState::Active, + _p: PhantomPinned, + _p2: PhantomData, + } + } + + /// A poll-based version of `finish`. + pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.state { + SinkState::Active => { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + this.sender + .start_send(CopyBothMessage::Done) + .map_err(|_| Error::closed())?; + *this.state = SinkState::Closing; + } + SinkState::Closing => { + let this = self.as_mut().project(); + ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?; + *this.state = SinkState::Reading; + } + SinkState::Reading => { + let this = self.as_mut().project(); + match ready!(this.responses.poll_next(cx))? { + Message::CommandComplete(body) => { + let rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + return Poll::Ready(Ok(rows)); + } + _ => return Poll::Ready(Err(Error::unexpected_message())), + } + } + } + } + } + + /// Completes the copy, returning the number of rows inserted. + /// + /// The `Sink::close` method is equivalent to `finish`, except that it does not return the + /// number of rows. + pub async fn finish(mut self: Pin<&mut Self>) -> Result { + future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await + } +} + +impl Stream for CopyBothDuplex { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.responses.poll_next(cx)?) { + Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), + Message::CopyDone => Poll::Ready(None), + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } +} + +impl Sink for CopyBothDuplex +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + let this = self.project(); + + let data: Box = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .as_mut() + .start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed())?; + } + + this.sender.poll_flush(cx).map_err(|_| Error::closed()) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_finish(cx).map_ok(|_| ()) + } +} + +pub async fn copy_both_simple( + client: &InnerClient, + query: &str, +) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy both query {}", query); + + let buf = simple_query::encode(client, query)?; + + let (mut sender, receiver) = mpsc::channel(1); + let receiver = CopyBothReceiver::new(receiver); + let mut responses = client.send(RequestMessages::CopyBoth(receiver))?; + + sender + .send(CopyBothMessage::Message(FrontendMessage::Raw(buf))) + .await + .map_err(|_| Error::closed())?; + + match responses.next().await? { + Message::CopyBothResponse(_) => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(CopyBothDuplex::new(sender, responses)) +} diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index de1da933b..cd3c4f8db 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -1,8 +1,8 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, slice_iter, Error, Statement}; -use bytes::{Buf, BufMut, BytesMut}; +use crate::{query, simple_query, slice_iter, Error, Statement}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures_channel::mpsc; use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; use log::debug; @@ -194,14 +194,10 @@ where } } -pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> +async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result, Error> where T: Buf + 'static + Send, { - debug!("executing copy in statement {}", statement.name()); - - let buf = query::encode(client, &statement, slice_iter(&[]))?; - let (mut sender, receiver) = mpsc::channel(1); let receiver = CopyInReceiver::new(receiver); let mut responses = client.send(RequestMessages::CopyIn(receiver))?; @@ -211,9 +207,11 @@ where .await .map_err(|_| Error::closed())?; - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + if !simple { + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } } match responses.next().await? { @@ -230,3 +228,23 @@ where _p2: PhantomData, }) } + +pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in statement {}", statement.name()); + + let buf = query::encode(client, &statement, slice_iter(&[]))?; + start(client, buf, false).await +} + +pub async fn copy_in_simple(client: &InnerClient, query: &str) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in query {}", query); + + let buf = simple_query::encode(client, query)?; + start(client, buf, true).await +} diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 1e6949252..981f9365e 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, slice_iter, Error, Statement}; +use crate::{query, simple_query, slice_iter, Error, Statement}; use bytes::Bytes; use futures_util::{ready, Stream}; use log::debug; @@ -11,23 +11,36 @@ use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; +pub async fn copy_out_simple(client: &InnerClient, query: &str) -> Result { + debug!("executing copy out query {}", query); + + let buf = simple_query::encode(client, query)?; + let responses = start(client, buf, true).await?; + Ok(CopyOutStream { + responses, + _p: PhantomPinned, + }) +} + pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result { debug!("executing copy out statement {}", statement.name()); let buf = query::encode(client, &statement, slice_iter(&[]))?; - let responses = start(client, buf).await?; + let responses = start(client, buf, false).await?; Ok(CopyOutStream { responses, _p: PhantomPinned, }) } -async fn start(client: &InnerClient, buf: Bytes) -> Result { +async fn start(client: &InnerClient, buf: Bytes, simple: bool) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + if !simple { + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } } match responses.next().await? { diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index b2a907558..a1f0e73b7 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -2,6 +2,7 @@ use crate::query::RowStream; use crate::types::{BorrowToSql, ToSql, Type}; use crate::{Client, Error, Row, Statement, ToStatement, Transaction}; use async_trait::async_trait; +use postgres_protocol::Oid; mod private { pub trait Sealed {} @@ -56,6 +57,13 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; + /// Like `Client::query_raw_txt`. + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send; + /// Like `Client::prepare`. async fn prepare(&self, query: &str) -> Result; @@ -69,6 +77,9 @@ pub trait GenericClient: private::Sealed { /// Like `Client::transaction`. async fn transaction(&mut self) -> Result, Error>; + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result; + /// Returns a reference to the underlying `Client`. fn client(&self) -> &Client; } @@ -133,6 +144,15 @@ impl GenericClient for Client { self.query_raw(statement, params).await } + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -149,6 +169,11 @@ impl GenericClient for Client { self.transaction().await } + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result { + self.get_type(oid).await + } + fn client(&self) -> &Client { self } @@ -215,6 +240,15 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -232,6 +266,11 @@ impl GenericClient for Transaction<'_> { self.transaction().await } + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result { + self.client().get_type(oid).await + } + fn client(&self) -> &Client { self.client() } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a9ecba4f1..86a8f7ec3 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -107,7 +107,6 @@ //! | `array-impls` | Enables `ToSql` and `FromSql` trait impls for arrays | - | no | //! | `with-bit-vec-0_6` | Enable support for the `bit-vec` crate. | [bit-vec](https://crates.io/crates/bit-vec) 0.6 | no | //! | `with-chrono-0_4` | Enable support for the `chrono` crate. | [chrono](https://crates.io/crates/chrono) 0.4 | no | -//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 0.4 | no | //! | `with-eui48-1` | Enable support for the 1.0 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 1.0 | no | //! | `with-geo-types-0_6` | Enable support for the 0.6 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.6.0) 0.6 | no | //! | `with-geo-types-0_7` | Enable support for the 0.7 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.7.0) 0.7 | no | @@ -119,10 +118,13 @@ #![doc(html_root_url = "https://docs.rs/tokio-postgres/0.7")] #![warn(rust_2018_idioms, clippy::all, missing_docs)] +use postgres_protocol::message::backend::ReadyForQueryBody; + pub use crate::cancel_token::CancelToken; pub use crate::client::Client; pub use crate::config::Config; pub use crate::connection::Connection; +pub use crate::copy_both::CopyBothDuplex; pub use crate::copy_in::CopyInSink; pub use crate::copy_out::CopyOutStream; use crate::error::DbError; @@ -143,6 +145,31 @@ pub use crate::transaction::Transaction; pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; use crate::types::ToSql; +/// After executing a query, the connection will be in one of these states +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum ReadyForQueryStatus { + /// Connection state is unknown + Unknown, + /// Connection is idle (no transactions) + Idle = b'I', + /// Connection is in a transaction block + Transaction = b'T', + /// Connection is in a failed transaction block + FailedTransaction = b'E', +} + +impl From for ReadyForQueryStatus { + fn from(value: ReadyForQueryBody) -> Self { + match value.status() { + b'I' => Self::Idle, + b'T' => Self::Transaction, + b'E' => Self::FailedTransaction, + _ => Self::Unknown, + } + } +} + pub mod binary_copy; mod bind; #[cfg(feature = "runtime")] @@ -159,15 +186,17 @@ mod connect_raw; mod connect_socket; mod connect_tls; mod connection; +mod copy_both; mod copy_in; mod copy_out; pub mod error; mod generic_client; mod keepalive; -mod maybe_tls_stream; +pub mod maybe_tls_stream; mod portal; mod prepare; mod query; +pub mod replication; pub mod row; mod simple_query; #[cfg(feature = "runtime")] diff --git a/tokio-postgres/src/maybe_tls_stream.rs b/tokio-postgres/src/maybe_tls_stream.rs index 73b0c4721..9a7e24899 100644 --- a/tokio-postgres/src/maybe_tls_stream.rs +++ b/tokio-postgres/src/maybe_tls_stream.rs @@ -1,11 +1,17 @@ +//! MaybeTlsStream. +//! +//! Represents a stream that may or may not be encrypted with TLS. use crate::tls::{ChannelBinding, TlsStream}; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +/// A stream that may or may not be encrypted with TLS. pub enum MaybeTlsStream { + /// An unencrypted stream. Raw(S), + /// An encrypted stream. Tls(T), } diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..0abb8e453 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -95,7 +95,7 @@ pub async fn prepare( let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { let type_ = get_type(client, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_); + let column = Column::new(field.name().to_string(), type_, field); columns.push(column); } } @@ -126,7 +126,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -async fn get_type(client: &Arc, oid: Oid) -> Result { +pub async fn get_type(client: &Arc, oid: Oid) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 71db8769a..80c5f65d5 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -2,16 +2,19 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, Row, Statement}; -use bytes::{Bytes, BytesMut}; +use crate::{Column, Error, Portal, ReadyForQueryStatus, Row, Statement}; +use bytes::{BufMut, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; +use postgres_types::{Format, Type}; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; struct BorrowToSqlParamsDebug<'a, T>(&'a [T]); @@ -52,6 +55,110 @@ where Ok(RowStream { statement, responses, + command_tag: None, + status: ReadyForQueryStatus::Unknown, + output_format: Format::Binary, + _p: PhantomPinned, + }) +} + +pub async fn query_txt( + client: &Arc, + query: &str, + params: I, +) -> Result +where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, +{ + let params = params.into_iter(); + + let buf = client.with_buf(|buf| { + frontend::parse( + "", // unnamed prepared statement + query, // query to parse + std::iter::empty(), // give no type info + buf, + ) + .map_err(Error::encode)?; + frontend::describe(b'S', "", buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + } + None => Ok(postgres_protocol::IsNull::Yes), + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + // now read the responses + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let parameter_description = match responses.next().await? { + Message::ParameterDescription(body) => body, + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse)? { + let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN); + parameters.push(type_); + } + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN); + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + Ok(RowStream { + statement: Statement::new_anonymous(parameters, columns), + responses, + command_tag: None, + status: ReadyForQueryStatus::Unknown, + output_format: Format::Text, _p: PhantomPinned, }) } @@ -72,6 +179,9 @@ pub async fn query_portal( Ok(RowStream { statement: portal.statement().clone(), responses, + command_tag: None, + status: ReadyForQueryStatus::Unknown, + output_format: Format::Binary, _p: PhantomPinned, }) } @@ -202,6 +312,9 @@ pin_project! { pub struct RowStream { statement: Statement, responses: Responses, + command_tag: Option, + output_format: Format, + status: ReadyForQueryStatus, #[pin] _p: PhantomPinned, } @@ -215,14 +328,45 @@ impl Stream for RowStream { loop { match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { - return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) + return Poll::Ready(Some(Ok(Row::new( + this.statement.clone(), + body, + *this.output_format, + )?))) + } + Message::EmptyQueryResponse | Message::PortalSuspended => {} + Message::CommandComplete(body) => { + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } + } + Message::ReadyForQuery(status) => { + *this.status = status.into(); + return Poll::Ready(None); } - Message::EmptyQueryResponse - | Message::CommandComplete(_) - | Message::PortalSuspended => {} - Message::ReadyForQuery(_) => return Poll::Ready(None), _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } } } } + +impl RowStream { + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[Column] { + self.statement.columns() + } + + /// Returns the command tag of this query. + /// + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option { + self.command_tag.clone() + } + + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> ReadyForQueryStatus { + self.status + } +} diff --git a/tokio-postgres/src/replication.rs b/tokio-postgres/src/replication.rs new file mode 100644 index 000000000..e7c845958 --- /dev/null +++ b/tokio-postgres/src/replication.rs @@ -0,0 +1,201 @@ +//! Utilities for working with the PostgreSQL replication copy both format. + +use crate::copy_both::CopyBothDuplex; +use crate::Error; +use bytes::{BufMut, Bytes, BytesMut}; +use futures_util::{ready, SinkExt, Stream}; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::{LogicalReplicationMessage, ReplicationMessage}; +use postgres_protocol::PG_EPOCH; +use postgres_types::PgLsn; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::SystemTime; + +const STANDBY_STATUS_UPDATE_TAG: u8 = b'r'; +const HOT_STANDBY_FEEDBACK_TAG: u8 = b'h'; +const ZENITH_STATUS_UPDATE_TAG_BYTE: u8 = b'z'; + +pin_project! { + /// A type which deserializes the postgres replication protocol. This type can be used with + /// both physical and logical replication to get access to the byte content of each replication + /// message. + /// + /// The replication *must* be explicitly completed via the `finish` method. + pub struct ReplicationStream { + #[pin] + stream: CopyBothDuplex, + } +} + +impl ReplicationStream { + /// Creates a new ReplicationStream that will wrap the underlying CopyBoth stream + pub fn new(stream: CopyBothDuplex) -> Self { + Self { stream } + } + + /// Send zenith status update to server. + pub async fn zenith_status_update( + self: Pin<&mut Self>, + len: u64, + data: &[u8], + ) -> Result<(), Error> { + let mut this = self.project(); + + let mut buf = BytesMut::new(); + buf.put_u8(ZENITH_STATUS_UPDATE_TAG_BYTE); + buf.put_u64(len); + buf.put_slice(data); + + this.stream.send(buf.freeze()).await + } + + /// Send standby update to server. + pub async fn standby_status_update( + self: Pin<&mut Self>, + write_lsn: PgLsn, + flush_lsn: PgLsn, + apply_lsn: PgLsn, + timestamp: SystemTime, + reply: u8, + ) -> Result<(), Error> { + let mut this = self.project(); + + let timestamp = match timestamp.duration_since(*PG_EPOCH) { + Ok(d) => d.as_micros() as i64, + Err(e) => -(e.duration().as_micros() as i64), + }; + + let mut buf = BytesMut::new(); + buf.put_u8(STANDBY_STATUS_UPDATE_TAG); + buf.put_u64(write_lsn.into()); + buf.put_u64(flush_lsn.into()); + buf.put_u64(apply_lsn.into()); + buf.put_i64(timestamp); + buf.put_u8(reply); + + this.stream.send(buf.freeze()).await + } + + /// Send hot standby feedback message to server. + pub async fn hot_standby_feedback( + self: Pin<&mut Self>, + timestamp: SystemTime, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + ) -> Result<(), Error> { + let mut this = self.project(); + + let timestamp = match timestamp.duration_since(*PG_EPOCH) { + Ok(d) => d.as_micros() as i64, + Err(e) => -(e.duration().as_micros() as i64), + }; + + let mut buf = BytesMut::new(); + buf.put_u8(HOT_STANDBY_FEEDBACK_TAG); + buf.put_i64(timestamp); + buf.put_u32(global_xmin); + buf.put_u32(global_xmin_epoch); + buf.put_u32(catalog_xmin); + buf.put_u32(catalog_xmin_epoch); + + this.stream.send(buf.freeze()).await + } +} + +impl Stream for ReplicationStream { + type Item = Result, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.stream.poll_next(cx)) { + Some(Ok(buf)) => { + Poll::Ready(Some(ReplicationMessage::parse(&buf).map_err(Error::parse))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} + +pin_project! { + /// A type which deserializes the postgres logical replication protocol. This type gives access + /// to a high level representation of the changes in transaction commit order. + /// + /// The replication *must* be explicitly completed via the `finish` method. + pub struct LogicalReplicationStream { + #[pin] + stream: ReplicationStream, + } +} + +impl LogicalReplicationStream { + /// Creates a new LogicalReplicationStream that will wrap the underlying CopyBoth stream + pub fn new(stream: CopyBothDuplex) -> Self { + Self { + stream: ReplicationStream::new(stream), + } + } + + /// Send standby update to server. + pub async fn standby_status_update( + self: Pin<&mut Self>, + write_lsn: PgLsn, + flush_lsn: PgLsn, + apply_lsn: PgLsn, + timestamp: SystemTime, + reply: u8, + ) -> Result<(), Error> { + let this = self.project(); + this.stream + .standby_status_update(write_lsn, flush_lsn, apply_lsn, timestamp, reply) + .await + } + + /// Send hot standby feedback message to server. + pub async fn hot_standby_feedback( + self: Pin<&mut Self>, + timestamp: SystemTime, + global_xmin: u32, + global_xmin_epoch: u32, + catalog_xmin: u32, + catalog_xmin_epoch: u32, + ) -> Result<(), Error> { + let this = self.project(); + this.stream + .hot_standby_feedback( + timestamp, + global_xmin, + global_xmin_epoch, + catalog_xmin, + catalog_xmin_epoch, + ) + .await + } +} + +impl Stream for LogicalReplicationStream { + type Item = Result, Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.stream.poll_next(cx)) { + Some(Ok(ReplicationMessage::XLogData(body))) => { + let body = body + .map_data(|buf| LogicalReplicationMessage::parse(&buf)) + .map_err(Error::parse)?; + Poll::Ready(Some(Ok(ReplicationMessage::XLogData(body)))) + } + Some(Ok(ReplicationMessage::PrimaryKeepAlive(body))) => { + Poll::Ready(Some(Ok(ReplicationMessage::PrimaryKeepAlive(body)))) + } + Some(Ok(_)) => Poll::Ready(Some(Err(Error::unexpected_message()))), + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..bbedf0bda 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType}; use crate::{Error, Statement}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::DataRowBody; +use postgres_types::{Format, WrongFormat}; use std::fmt; use std::ops::Range; use std::str; @@ -97,6 +98,7 @@ where /// A row of data returned from the database by a query. pub struct Row { statement: Statement, + output_format: Format, body: DataRowBody, ranges: Vec>>, } @@ -110,12 +112,17 @@ impl fmt::Debug for Row { } impl Row { - pub(crate) fn new(statement: Statement, body: DataRowBody) -> Result { + pub(crate) fn new( + statement: Statement, + body: DataRowBody, + output_format: Format, + ) -> Result { let ranges = body.ranges().collect().map_err(Error::parse)?; Ok(Row { statement, body, ranges, + output_format, }) } @@ -184,9 +191,30 @@ impl Row { /// Get the raw bytes for the column at the given index. fn col_buffer(&self, idx: usize) -> Option<&[u8]> { - let range = self.ranges[idx].to_owned()?; + let range = self.ranges.get(idx)?.to_owned()?; Some(&self.body.buffer()[range]) } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result, Error> { + if self.output_format == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } } impl AsName for SimpleColumn { diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 7c266e409..28e7010eb 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{Error, SimpleQueryMessage, SimpleQueryRow}; +use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow}; use bytes::Bytes; use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; @@ -40,11 +40,15 @@ pub async fn simple_query(client: &InnerClient, query: &str) -> Result Result<(), Error> { +pub async fn batch_execute( + client: &InnerClient, + query: &str, +) -> Result { debug!("executing statement batch: {}", query); let buf = encode(client, query)?; @@ -52,7 +56,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro loop { match responses.next().await? { - Message::ReadyForQuery(_) => return Ok(()), + Message::ReadyForQuery(status) => return Ok(status.into()), Message::CommandComplete(_) | Message::EmptyQueryResponse | Message::RowDescription(_) @@ -62,7 +66,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro } } -fn encode(client: &InnerClient, query: &str) -> Result { +pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { client.with_buf(|buf| { frontend::query(query, buf).map_err(Error::encode)?; Ok(buf.split().freeze()) @@ -74,11 +78,21 @@ pin_project! { pub struct SimpleQueryStream { responses: Responses, columns: Option>, + status: ReadyForQueryStatus, #[pin] _p: PhantomPinned, } } +impl SimpleQueryStream { + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> ReadyForQueryStatus { + self.status + } +} + impl Stream for SimpleQueryStream { type Item = Result; @@ -117,7 +131,10 @@ impl Stream for SimpleQueryStream { }; return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); } - Message::ReadyForQuery(_) => return Poll::Ready(None), + Message::ReadyForQuery(s) => { + *this.status = s.into(); + return Poll::Ready(None); + } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } } diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..b2f4c87cb 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -2,7 +2,10 @@ use crate::client::InnerClient; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; -use postgres_protocol::message::frontend; +use postgres_protocol::{ + message::{backend::Field, frontend}, + Oid, +}; use std::{ fmt, sync::{Arc, Weak}, @@ -49,6 +52,15 @@ impl Statement { })) } + pub(crate) fn new_anonymous(params: Vec, columns: Vec) -> Statement { + Statement(Arc::new(StatementInner { + client: Weak::new(), + name: String::new(), + params, + columns, + })) + } + pub(crate) fn name(&self) -> &str { &self.0.name } @@ -68,11 +80,30 @@ impl Statement { pub struct Column { name: String, type_: Type, + + // raw fields from RowDescription + table_oid: Oid, + column_id: i16, + format: i16, + + // that better be stored in self.type_, but that is more radical refactoring + type_oid: Oid, + type_size: i16, + type_modifier: i32, } impl Column { - pub(crate) fn new(name: String, type_: Type) -> Column { - Column { name, type_ } + pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column { + Column { + name, + type_, + table_oid: raw_field.table_oid(), + column_id: raw_field.column_id(), + format: raw_field.format(), + type_oid: raw_field.type_oid(), + type_size: raw_field.type_size(), + type_modifier: raw_field.type_modifier(), + } } /// Returns the name of the column. @@ -84,6 +115,36 @@ impl Column { pub fn type_(&self) -> &Type { &self.type_ } + + /// Returns the table OID of the column. + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + /// Returns the column ID of the column. + pub fn column_id(&self) -> i16 { + self.column_id + } + + /// Returns the format of the column. + pub fn format(&self) -> i16 { + self.format + } + + /// Returns the type OID of the column. + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + /// Returns the type size of the column. + pub fn type_size(&self) -> i16 { + self.type_size + } + + /// Returns the type modifier of the column. + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } } impl fmt::Debug for Column { diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 96a324652..4e071172a 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -9,8 +9,8 @@ use crate::types::{BorrowToSql, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row, - SimpleQueryMessage, Statement, ToStatement, + bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, ReadyForQueryStatus, + Row, SimpleQueryMessage, Statement, ToStatement, }; use bytes::Buf; use futures_util::TryStreamExt; @@ -65,7 +65,7 @@ impl<'a> Transaction<'a> { } /// Consumes the transaction, committing all changes made within it. - pub async fn commit(mut self) -> Result<(), Error> { + pub async fn commit(mut self) -> Result { self.done = true; let query = if let Some(sp) = self.savepoint.as_ref() { format!("RELEASE {}", sp.name) @@ -78,7 +78,7 @@ impl<'a> Transaction<'a> { /// Rolls the transaction back, discarding all changes made within it. /// /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. - pub async fn rollback(mut self) -> Result<(), Error> { + pub async fn rollback(mut self) -> Result { self.done = true; let query = if let Some(sp) = self.savepoint.as_ref() { format!("ROLLBACK TO {}", sp.name) @@ -149,6 +149,16 @@ impl<'a> Transaction<'a> { self.client.query_raw(statement, params).await } + /// Like `Client::query_raw_txt`. + pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + self.client.query_raw_txt(statement, params).await + } + /// Like `Client::execute`. pub async fn execute( &self, @@ -250,7 +260,7 @@ impl<'a> Transaction<'a> { } /// Like `Client::batch_execute`. - pub async fn batch_execute(&self, query: &str) -> Result<(), Error> { + pub async fn batch_execute(&self, query: &str) -> Result { self.client.batch_execute(query).await } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..e21e2379c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -16,12 +16,15 @@ use tokio_postgres::error::SqlState; use tokio_postgres::tls::{NoTls, NoTlsStream}; use tokio_postgres::types::{Kind, Type}; use tokio_postgres::{ - AsyncMessage, Client, Config, Connection, Error, IsolationLevel, SimpleQueryMessage, + AsyncMessage, Client, Config, Connection, Error, IsolationLevel, ReadyForQueryStatus, + SimpleQueryMessage, }; mod binary_copy; mod parse; #[cfg(feature = "runtime")] +mod replication; +#[cfg(feature = "runtime")] mod runtime; mod types; @@ -249,6 +252,161 @@ async fn custom_array() { } } +#[tokio::test] +async fn query_raw_txt() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt("SELECT 55 * $1", [Some("42")]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + let res: i32 = rows[0].as_text(0).unwrap().unwrap().parse::().unwrap(); + assert_eq!(res, 55 * 42); + + let rows: Vec = client + .query_raw_txt("SELECT $1", [Some("42")]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "42"); + assert!(rows[0].body_len() > 0); +} + +#[tokio::test] +async fn query_raw_txt_nulls() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt( + "SELECT $1 as str, $2 as n, 'null' as str2, null as n2", + [Some("null"), None], + ) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + + let res = rows[0].as_text(0).unwrap(); + assert_eq!(res, Some("null")); + + let res = rows[0].as_text(1).unwrap(); + assert_eq!(res, None); + + let res = rows[0].as_text(2).unwrap(); + assert_eq!(res, Some("null")); + + let res = rows[0].as_text(3).unwrap(); + assert_eq!(res, None); +} + +#[tokio::test] +async fn limit_max_backend_message_size() { + let client = connect("user=postgres max_backend_message_size=10000").await; + let small: Vec = client + .query_raw_txt("SELECT REPEAT('a', 20)", [] as [Option<&str>; 0]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(small.len(), 1); + assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); + + let large: Result, Error> = client + .query_raw_txt("SELECT REPEAT('a', 2000000)", [] as [Option<&str>; 0]) + .await + .unwrap() + .try_collect() + .await; + + assert!(large.is_err()); +} + +#[tokio::test] +async fn command_tag() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("select unnest('{1,2,3}'::int[]);", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + + let mut rows: Vec = Vec::new(); + while let Some(row) = row_stream.next().await { + rows.push(row.unwrap()); + } + + assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); +} + +#[tokio::test] +async fn ready_for_query() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("START TRANSACTION", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), ReadyForQueryStatus::Transaction); + + let row_stream = client + .query_raw_txt("ROLLBACK", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), ReadyForQueryStatus::Idle); +} + +#[tokio::test] +async fn column_extras() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt( + "select relacl, relname from pg_class limit 1", + [] as [Option<&str>; 0], + ) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let column = rows[0].columns().get(1).unwrap(); + assert_eq!(column.name(), "relname"); + assert_eq!(column.type_(), &Type::NAME); + + assert!(column.table_oid() > 0); + assert_eq!(column.column_id(), 2); + assert_eq!(column.format(), 0); + + assert_eq!(column.type_oid(), 19); + assert_eq!(column.type_size(), 64); + assert_eq!(column.type_modifier(), -1); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await; diff --git a/tokio-postgres/tests/test/replication.rs b/tokio-postgres/tests/test/replication.rs new file mode 100644 index 000000000..c176a4104 --- /dev/null +++ b/tokio-postgres/tests/test/replication.rs @@ -0,0 +1,146 @@ +use futures_util::StreamExt; +use std::time::SystemTime; + +use postgres_protocol::message::backend::LogicalReplicationMessage::{Begin, Commit, Insert}; +use postgres_protocol::message::backend::ReplicationMessage::*; +use postgres_protocol::message::backend::TupleData; +use postgres_types::PgLsn; +use tokio_postgres::replication::LogicalReplicationStream; +use tokio_postgres::NoTls; +use tokio_postgres::SimpleQueryMessage::Row; + +#[tokio::test] +async fn test_replication() { + // form SQL connection + let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database"; + let (client, connection) = tokio_postgres::connect(conninfo, NoTls).await.unwrap(); + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + + client + .simple_query("DROP TABLE IF EXISTS test_logical_replication") + .await + .unwrap(); + client + .simple_query("CREATE TABLE test_logical_replication(i int)") + .await + .unwrap(); + let res = client + .simple_query("SELECT 'test_logical_replication'::regclass::oid") + .await + .unwrap(); + let rel_id: u32 = if let Row(row) = &res[0] { + row.get("oid").unwrap().parse().unwrap() + } else { + panic!("unexpeced query message"); + }; + + client + .simple_query("DROP PUBLICATION IF EXISTS test_pub") + .await + .unwrap(); + client + .simple_query("CREATE PUBLICATION test_pub FOR ALL TABLES") + .await + .unwrap(); + + let slot = "test_logical_slot"; + + let query = format!( + r#"CREATE_REPLICATION_SLOT {:?} TEMPORARY LOGICAL "pgoutput""#, + slot + ); + let slot_query = client.simple_query(&query).await.unwrap(); + let lsn = if let Row(row) = &slot_query[0] { + row.get("consistent_point").unwrap() + } else { + panic!("unexpeced query message"); + }; + + // issue a query that will appear in the slot's stream since it happened after its creation + client + .simple_query("INSERT INTO test_logical_replication VALUES (42)") + .await + .unwrap(); + + let options = r#"("proto_version" '1', "publication_names" 'test_pub')"#; + let query = format!( + r#"START_REPLICATION SLOT {:?} LOGICAL {} {}"#, + slot, lsn, options + ); + let copy_stream = client + .copy_both_simple::(&query) + .await + .unwrap(); + + let stream = LogicalReplicationStream::new(copy_stream); + tokio::pin!(stream); + + // verify that we can observe the transaction in the replication stream + let begin = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Begin(begin) = body.into_data() { + break begin; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + let insert = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Insert(insert) = body.into_data() { + break insert; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + let commit = loop { + match stream.next().await { + Some(Ok(XLogData(body))) => { + if let Commit(commit) = body.into_data() { + break commit; + } + } + Some(Ok(_)) => (), + Some(Err(_)) => panic!("unexpected replication stream error"), + None => panic!("unexpected replication stream end"), + } + }; + + assert_eq!(begin.final_lsn(), commit.commit_lsn()); + assert_eq!(insert.rel_id(), rel_id); + + let tuple_data = insert.tuple().tuple_data(); + assert_eq!(tuple_data.len(), 1); + assert!(matches!(tuple_data[0], TupleData::Text(_))); + if let TupleData::Text(data) = &tuple_data[0] { + assert_eq!(data, &b"42"[..]); + } + + // Send a standby status update and require a keep alive response + let lsn: PgLsn = lsn.parse().unwrap(); + stream + .as_mut() + .standby_status_update(lsn, lsn, lsn, SystemTime::now(), 1) + .await + .unwrap(); + loop { + match stream.next().await { + Some(Ok(PrimaryKeepAlive(_))) => break, + Some(Ok(_)) => (), + Some(Err(e)) => panic!("unexpected replication stream error: {}", e), + None => panic!("unexpected replication stream end"), + } + } +} diff --git a/tokio-postgres/tests/test/types/chrono_04.rs b/tokio-postgres/tests/test/types/chrono_04.rs index a8e9e5afa..2bfa475b9 100644 --- a/tokio-postgres/tests/test/types/chrono_04.rs +++ b/tokio-postgres/tests/test/types/chrono_04.rs @@ -1,4 +1,4 @@ -use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; +use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use std::fmt; use tokio_postgres::types::{Date, FromSqlOwned, Timestamp}; use tokio_postgres::Client; @@ -53,10 +53,10 @@ async fn test_with_special_naive_date_time_params() { async fn test_date_time_params() { fn make_check(time: &str) -> (Option>, &str) { ( - Some( - Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap(), - ), + Some(DateTime::from_naive_utc_and_offset( + NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), + Utc, + )), time, ) } @@ -76,10 +76,10 @@ async fn test_date_time_params() { async fn test_with_special_date_time_params() { fn make_check(time: &str) -> (Timestamp>, &str) { ( - Timestamp::Value( - Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap(), - ), + Timestamp::Value(DateTime::from_naive_utc_and_offset( + NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), + Utc, + )), time, ) } diff --git a/tokio-postgres/tests/test/types/eui48_04.rs b/tokio-postgres/tests/test/types/eui48_04.rs deleted file mode 100644 index 074faa37e..000000000 --- a/tokio-postgres/tests/test/types/eui48_04.rs +++ /dev/null @@ -1,18 +0,0 @@ -use eui48_04::MacAddress; - -use crate::types::test_type; - -#[tokio::test] -async fn test_eui48_params() { - test_type( - "MACADDR", - &[ - ( - Some(MacAddress::parse_str("12-34-56-AB-CD-EF").unwrap()), - "'12-34-56-ab-cd-ef'", - ), - (None, "NULL"), - ], - ) - .await -} diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 452d149fe..166a43c7d 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -17,8 +17,6 @@ use bytes::BytesMut; mod bit_vec_06; #[cfg(feature = "with-chrono-0_4")] mod chrono_04; -#[cfg(feature = "with-eui48-0_4")] -mod eui48_04; #[cfg(feature = "with-eui48-1")] mod eui48_1; #[cfg(feature = "with-geo-types-0_6")]