From 0c3364f7f0355fad0239235d8e94e0cfd059a71d Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 19 Mar 2024 14:53:43 +0000 Subject: [PATCH] make password hashing async aware --- postgres-protocol/Cargo.toml | 4 ++++ postgres-protocol/src/authentication/sasl.rs | 17 +++++++++++------ postgres-protocol/src/password/mod.rs | 11 +++++++---- postgres-protocol/src/password/test.rs | 6 +++--- tokio-postgres/src/connect_raw.rs | 1 + 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index 38ce2048f..06c3582d3 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -20,3 +20,7 @@ memchr = "2.0" rand = "0.8" sha2 = "0.10" stringprep = "0.1" +tokio = { version = "1.0", features = ["rt"] } + +[dev-dependencies] +tokio = { version = "1.0", features = ["full"] } diff --git a/postgres-protocol/src/authentication/sasl.rs b/postgres-protocol/src/authentication/sasl.rs index 41d0e41b0..40f3c0317 100644 --- a/postgres-protocol/src/authentication/sasl.rs +++ b/postgres-protocol/src/authentication/sasl.rs @@ -9,6 +9,7 @@ use std::io; use std::iter; use std::mem; use std::str; +use tokio::task::yield_now; const NONCE_LENGTH: usize = 24; @@ -32,7 +33,7 @@ fn normalize(pass: &[u8]) -> Vec { } } -pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] { +pub(crate) async fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] { let mut hmac = Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); hmac.update(salt); @@ -49,6 +50,10 @@ pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] { for (hi, prev) in hi.iter_mut().zip(prev) { *hi ^= prev; } + // yield every ~1ms + if i % 4096 == 0 { + yield_now().await + } } hi.into() @@ -200,7 +205,7 @@ impl ScramSha256 { /// Updates the state machine with the response from the backend. /// /// This should be called when an `AuthenticationSASLContinue` message is received. - pub fn update(&mut self, message: &[u8]) -> io::Result<()> { + pub async fn update(&mut self, message: &[u8]) -> io::Result<()> { let (client_nonce, password, channel_binding) = match mem::replace(&mut self.state, State::Done) { State::Update { @@ -227,7 +232,7 @@ impl ScramSha256 { Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), }; - let salted_password = hi(&password, &salt, parsed.iteration_count); + let salted_password = hi(&password, &salt, parsed.iteration_count).await; let make_key = |name| { let mut hmac = Hmac::::new_from_slice(&salted_password) @@ -481,8 +486,8 @@ mod test { } // recorded auth exchange from psql - #[test] - fn exchange() { + #[tokio::test] + async fn exchange() { let password = "foobar"; let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB"; @@ -502,7 +507,7 @@ mod test { ); assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first); - scram.update(server_first.as_bytes()).unwrap(); + scram.update(server_first.as_bytes()).await.unwrap(); assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final); scram.finish(server_final.as_bytes()).unwrap(); diff --git a/postgres-protocol/src/password/mod.rs b/postgres-protocol/src/password/mod.rs index a60687bbe..e669e80f3 100644 --- a/postgres-protocol/src/password/mod.rs +++ b/postgres-protocol/src/password/mod.rs @@ -24,16 +24,19 @@ const SCRAM_DEFAULT_SALT_LEN: usize = 16; /// /// The client may assume the returned string doesn't contain any /// special characters that would require escaping in an SQL command. -pub fn scram_sha_256(password: &[u8]) -> String { +pub async fn scram_sha_256(password: &[u8]) -> String { let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN]; let mut rng = rand::thread_rng(); rng.fill_bytes(&mut salt); - scram_sha_256_salt(password, salt) + scram_sha_256_salt(password, salt).await } // Internal implementation of scram_sha_256 with a caller-provided // salt. This is useful for testing. -pub(crate) fn scram_sha_256_salt(password: &[u8], salt: [u8; SCRAM_DEFAULT_SALT_LEN]) -> String { +pub(crate) async fn scram_sha_256_salt( + password: &[u8], + salt: [u8; SCRAM_DEFAULT_SALT_LEN], +) -> String { // Prepare the password, per [RFC // 4013](https://tools.ietf.org/html/rfc4013), if possible. // @@ -58,7 +61,7 @@ pub(crate) fn scram_sha_256_salt(password: &[u8], salt: [u8; SCRAM_DEFAULT_SALT_ }; // salt password - let salted_password = sasl::hi(&prepared, &salt, SCRAM_DEFAULT_ITERATIONS); + let salted_password = sasl::hi(&prepared, &salt, SCRAM_DEFAULT_ITERATIONS).await; // client key let mut hmac = Hmac::::new_from_slice(&salted_password) diff --git a/postgres-protocol/src/password/test.rs b/postgres-protocol/src/password/test.rs index 1432cb204..c9d340f09 100644 --- a/postgres-protocol/src/password/test.rs +++ b/postgres-protocol/src/password/test.rs @@ -1,11 +1,11 @@ use crate::password; -#[test] -fn test_encrypt_scram_sha_256() { +#[tokio::test] +async fn test_encrypt_scram_sha_256() { // Specify the salt to make the test deterministic. Any bytes will do. let salt: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; assert_eq!( - password::scram_sha_256_salt(b"secret", salt), + password::scram_sha_256_salt(b"secret", salt).await, "SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA=" ); } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 0beead11f..8e788984a 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -300,6 +300,7 @@ where scram .update(body.data()) + .await .map_err(|e| Error::authentication(e.into()))?; let mut buf = BytesMut::new();