diff --git a/Cargo.lock b/Cargo.lock index 5d12b966..12c5c1ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5361,6 +5361,7 @@ dependencies = [ "bincode 1.3.3", "hyper", "pretty_assertions", + "rstest", "serde", "thiserror", "tokio", diff --git a/crates/mempool_infra/Cargo.toml b/crates/mempool_infra/Cargo.toml index 3f4e30d7..c65396a3 100644 --- a/crates/mempool_infra/Cargo.toml +++ b/crates/mempool_infra/Cargo.toml @@ -15,6 +15,7 @@ workspace = true async-trait.workspace = true bincode.workspace = true hyper.workspace = true +rstest.workspace = true serde.workspace = true thiserror.workspace = true tokio.workspace = true diff --git a/crates/mempool_infra/tests/component_server_client_http_test.rs b/crates/mempool_infra/tests/component_server_client_http_test.rs index 665566e2..6d24144c 100644 --- a/crates/mempool_infra/tests/component_server_client_http_test.rs +++ b/crates/mempool_infra/tests/component_server_client_http_test.rs @@ -10,6 +10,7 @@ use common::{ }; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Request, Response, Server, StatusCode}; +use rstest::rstest; use serde::Serialize; use starknet_mempool_infra::component_client::ComponentClientHttp; use starknet_mempool_infra::component_definitions::{ComponentRequestHandler, ServerError}; @@ -27,6 +28,7 @@ const B_PORT: u16 = 10001; const UNCONNECTED_SERVER_PORT: u16 = 10002; const FAULTY_SERVER_REQ_DESER_PORT: u16 = 10003; const FAULTY_SERVER_RES_DESER_PORT: u16 = 10004; +const MOCK_SERVER_ERROR: &str = "mock server error"; #[async_trait] impl ComponentAClientTrait for ComponentClientHttp { @@ -83,7 +85,7 @@ fn assert_error_contains_keywords(error: String, expected_error_contained_keywor } } -async fn spawn_faulty_server(ip: IpAddr, port: u16, body: T) +async fn create_client_and_faulty_server(port: u16, body: T) -> ComponentAClient where T: Serialize + Send + Sync + 'static + Clone, { @@ -98,7 +100,7 @@ where .unwrap()) } - let socket = SocketAddr::new(ip, port); + let socket = SocketAddr::new(LOCAL_IP, port); let make_svc = make_service_fn(|_conn| { let body = body.clone(); async move { Ok::<_, hyper::Error>(service_fn(move |req| handler(req, body.clone()))) } @@ -108,6 +110,8 @@ where // Ensure the server starts running. task::yield_now().await; + + ComponentAClient::new(LOCAL_IP, port) } #[tokio::test] @@ -151,28 +155,30 @@ async fn test_unconnected_server() { let client = ComponentAClient::new(LOCAL_IP, UNCONNECTED_SERVER_PORT); let expected_error_contained_keywords = vec!["Connection refused"]; - verify_error(client.clone(), expected_error_contained_keywords).await; -} - -#[tokio::test] -async fn test_faulty_server_request_deseralization_failure() { - let mock_server_error = "Mock server error"; - let returned_server_error = - ServerError::RequestDeserializationFailure(mock_server_error.to_string()); - - spawn_faulty_server(LOCAL_IP, FAULTY_SERVER_REQ_DESER_PORT, returned_server_error).await; - let client = ComponentAClient::new(LOCAL_IP, FAULTY_SERVER_REQ_DESER_PORT); - - let expected_error_contained_keywords = - vec![StatusCode::BAD_REQUEST.as_str(), mock_server_error]; verify_error(client, expected_error_contained_keywords).await; } +#[rstest] +#[case::request_deserialization_failure( + create_client_and_faulty_server( + FAULTY_SERVER_REQ_DESER_PORT, + ServerError::RequestDeserializationFailure(MOCK_SERVER_ERROR.to_string()) + ).await, + vec![ + StatusCode::BAD_REQUEST.as_str(), + "Could not deserialize client request", + MOCK_SERVER_ERROR + ], +)] +#[case::response_deserialization_failure( + create_client_and_faulty_server(FAULTY_SERVER_RES_DESER_PORT, "arbitrary data").await, + vec!["Could not deserialize server response"], + +)] #[tokio::test] -async fn test_faulty_server_response_deseralization_failure() { - spawn_faulty_server(LOCAL_IP, FAULTY_SERVER_RES_DESER_PORT, "arbitrary data").await; - let client = ComponentAClient::new(LOCAL_IP, FAULTY_SERVER_RES_DESER_PORT); - - let expected_error_contained_keywords = vec!["Could not deserialize server response"]; +async fn test_faulty_server( + #[case] client: ComponentAClient, + #[case] expected_error_contained_keywords: Vec<&str>, +) { verify_error(client, expected_error_contained_keywords).await; }