diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs index b2b9954a0003..f79de0704394 100644 --- a/proxy/src/scram.rs +++ b/proxy/src/scram.rs @@ -114,28 +114,21 @@ mod tests { } fn run_round_trip_test(client_password: &str) { - let secret = ServerSecret::build("pencil").unwrap(); - let mut exchange = Exchange::new(&secret, rand::random, None); - - let mut client = + let scram_secret = ServerSecret::build("pencil").unwrap(); + let sasl_client = ScramSha256::new(client_password.as_bytes(), ChannelBinding::unsupported()); - let client_first = std::str::from_utf8(client.message()).unwrap(); - exchange = match exchange.exchange(client_first).unwrap() { - Step::Continue(exchange, message) => { - client.update(message.as_bytes()).unwrap(); - exchange - } - Step::Success(_, _) => panic!("expected continue, got success"), - Step::Failure(f) => panic!("{f}"), - }; + let outcome = super::exchange( + &scram_secret, + sasl_client, + crate::config::TlsServerEndPoint::Undefined, + ) + .unwrap(); - let client_final = std::str::from_utf8(client.message()).unwrap(); - match exchange.exchange(client_final).unwrap() { - Step::Success(_, message) => client.finish(message.as_bytes()).unwrap(), - Step::Continue(_, _) => panic!("expected success, got continue"), - Step::Failure(f) => panic!("{f}"), - }; + match outcome { + crate::sasl::Outcome::Success(_) => {} + crate::sasl::Outcome::Failure(r) => panic!("{r}"), + } } #[test]