From 26074b1c378fb600eaeaa3c6c3e452f3a959c77f Mon Sep 17 00:00:00 2001 From: Yair <92672946+yair-starkware@users.noreply.github.com> Date: Sun, 14 Jul 2024 16:14:23 +0300 Subject: [PATCH] test(gateway): test RpcStateReader (#449) --- Cargo.lock | 26 +++ Cargo.toml | 1 + crates/gateway/Cargo.toml | 1 + crates/gateway/src/lib.rs | 2 + crates/gateway/src/rpc_objects.rs | 10 +- crates/gateway/src/rpc_state_reader_test.rs | 205 ++++++++++++++++++++ 6 files changed, 240 insertions(+), 5 deletions(-) create mode 100644 crates/gateway/src/rpc_state_reader_test.rs diff --git a/Cargo.lock b/Cargo.lock index 728b4205a..7b936828d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3478,6 +3478,25 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mockito" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2f6e023aa5bdf392aa06c78e4a4e6d498baab5138d0c993503350ebbc37bf1e" +dependencies = [ + "assert-json-diff", + "colored", + "futures-core", + "hyper", + "log", + "rand", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "native-tls" version = "0.2.12" @@ -5089,6 +5108,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "similar" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa42c91313f1d05da9b26f267f931cf178d4aba455b4c4622dd7355eb80c6640" + [[package]] name = "simple_asn1" version = "0.6.2" @@ -5363,6 +5388,7 @@ dependencies = [ "cairo-vm", "hyper", "mempool_test_utils", + "mockito", "num-bigint", "num-traits 0.2.19", "papyrus_config", diff --git a/Cargo.toml b/Cargo.toml index e28380da9..6c0a7f633 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ hyper = { version = "0.14", features = ["client", "server", "http1", "http2", "t indexmap = "2.1.0" itertools = "0.13.0" lazy_static = "1.4.0" +mockito = "1.4.0" num-traits = "0.2" num-bigint = { version = "0.4.5", default-features = false } # TODO(YaelD, 28/5/2024): The special Papyrus version is needed in order to be aligned with the diff --git a/crates/gateway/Cargo.toml b/crates/gateway/Cargo.toml index a202e8427..dd3392146 100644 --- a/crates/gateway/Cargo.toml +++ b/crates/gateway/Cargo.toml @@ -37,6 +37,7 @@ validator.workspace = true [dev-dependencies] assert_matches.workspace = true +mockito.workspace = true num-bigint.workspace = true pretty_assertions.workspace = true rstest.workspace = true diff --git a/crates/gateway/src/lib.rs b/crates/gateway/src/lib.rs index ec8052af8..bb9ff7c55 100644 --- a/crates/gateway/src/lib.rs +++ b/crates/gateway/src/lib.rs @@ -6,6 +6,8 @@ pub mod errors; pub mod gateway; mod rpc_objects; mod rpc_state_reader; +#[cfg(test)] +mod rpc_state_reader_test; mod state_reader; #[cfg(test)] mod state_reader_test_utils; diff --git a/crates/gateway/src/rpc_objects.rs b/crates/gateway/src/rpc_objects.rs index f6295ab21..15adf23c1 100644 --- a/crates/gateway/src/rpc_objects.rs +++ b/crates/gateway/src/rpc_objects.rs @@ -58,13 +58,13 @@ pub struct GetBlockWithTxHashesParams { pub block_id: BlockId, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Default, Deserialize, Serialize)] pub struct ResourcePrice { pub price_in_wei: GasPrice, pub price_in_fri: GasPrice, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Default, Deserialize, Serialize)] pub struct BlockHeader { pub block_hash: BlockHash, pub parent_hash: BlockHash, @@ -107,20 +107,20 @@ pub enum RpcResponse { Error(RpcErrorResponse), } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default)] pub struct RpcSuccessResponse { pub jsonrpc: Option, pub result: Value, pub id: u32, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default)] pub struct RpcErrorResponse { pub jsonrpc: Option, pub error: RpcSpecError, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default)] pub struct RpcSpecError { pub code: u16, pub message: String, diff --git a/crates/gateway/src/rpc_state_reader_test.rs b/crates/gateway/src/rpc_state_reader_test.rs new file mode 100644 index 000000000..3596d484d --- /dev/null +++ b/crates/gateway/src/rpc_state_reader_test.rs @@ -0,0 +1,205 @@ +use blockifier::execution::contract_class::ContractClass; +use blockifier::state::state_api::StateReader; +use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; +use papyrus_rpc::CompiledContractClass; +use serde::Serialize; +use serde_json::json; +use starknet_api::block::{BlockNumber, GasPrice}; +use starknet_api::core::{ClassHash, ContractAddress, Nonce, PatriciaKey}; +use starknet_api::{class_hash, contract_address, felt, patricia_key}; + +use crate::config::RpcStateReaderConfig; +use crate::rpc_objects::{ + BlockHeader, BlockId, GetBlockWithTxHashesParams, GetClassHashAtParams, + GetCompiledContractClassParams, GetNonceParams, GetStorageAtParams, ResourcePrice, RpcResponse, + RpcSuccessResponse, +}; +use crate::rpc_state_reader::RpcStateReader; +use crate::state_reader::MempoolStateReader; + +async fn run_rpc_server() -> mockito::ServerGuard { + mockito::Server::new_async().await +} + +fn mock_rpc_interaction( + server: &mut mockito::ServerGuard, + json_rpc_version: &str, + method: &str, + params: impl Serialize, + expected_response: &RpcResponse, +) -> mockito::Mock { + let request_body = json!({ + "jsonrpc": json_rpc_version, + "id": 0, + "method": method, + "params": json!(params), + }); + server + .mock("POST", "/") + .match_header("Content-Type", "application/json") + .match_body(mockito::Matcher::Json(request_body)) + .with_status(201) + .with_body(serde_json::to_string(expected_response).unwrap()) + .create() +} + +#[tokio::test] +async fn test_get_block_info() { + let mut server = run_rpc_server().await; + let config = RpcStateReaderConfig { url: server.url(), ..Default::default() }; + + let expected_result = BlockNumber(100); + + let mock = mock_rpc_interaction( + &mut server, + &config.json_rpc_version, + "starknet_getBlockWithTxHashes", + GetBlockWithTxHashesParams { block_id: BlockId::Latest }, + &RpcResponse::Success(RpcSuccessResponse { + result: serde_json::to_value(BlockHeader { + block_number: expected_result, + // GasPrice must be non-zero. + l1_gas_price: ResourcePrice { + price_in_wei: GasPrice(1), + price_in_fri: GasPrice(1), + }, + l1_data_gas_price: ResourcePrice { + price_in_wei: GasPrice(1), + price_in_fri: GasPrice(1), + }, + ..Default::default() + }) + .unwrap(), + ..Default::default() + }), + ); + + let client = RpcStateReader::from_latest(&config); + let result = + tokio::task::spawn_blocking(move || client.get_block_info()).await.unwrap().unwrap(); + // TODO(yair): Add partial_eq for BlockInfo and assert_eq the whole BlockInfo. + assert_eq!(result.block_number, expected_result); + mock.assert_async().await; +} + +#[tokio::test] +async fn test_get_storage_at() { + let mut server = run_rpc_server().await; + let config = RpcStateReaderConfig { url: server.url(), ..Default::default() }; + + let expected_result = felt!("0x999"); + + let mock = mock_rpc_interaction( + &mut server, + &config.json_rpc_version, + "starknet_getStorageAt", + GetStorageAtParams { + block_id: BlockId::Latest, + contract_address: contract_address!("0x1"), + key: starknet_api::state::StorageKey::from(0u32), + }, + &RpcResponse::Success(RpcSuccessResponse { + result: serde_json::to_value(expected_result).unwrap(), + ..Default::default() + }), + ); + + let client = RpcStateReader::from_latest(&config); + let result = tokio::task::spawn_blocking(move || { + client.get_storage_at(contract_address!("0x1"), starknet_api::state::StorageKey::from(0u32)) + }) + .await + .unwrap() + .unwrap(); + assert_eq!(result, expected_result); + mock.assert_async().await; +} + +#[tokio::test] +async fn test_get_nonce_at() { + let mut server = run_rpc_server().await; + let config = RpcStateReaderConfig { url: server.url(), ..Default::default() }; + + let expected_result = Nonce(felt!("0x999")); + + let mock = mock_rpc_interaction( + &mut server, + &config.json_rpc_version, + "starknet_getNonce", + GetNonceParams { block_id: BlockId::Latest, contract_address: contract_address!("0x1") }, + &RpcResponse::Success(RpcSuccessResponse { + result: serde_json::to_value(expected_result).unwrap(), + ..Default::default() + }), + ); + + let client = RpcStateReader::from_latest(&config); + let result = tokio::task::spawn_blocking(move || client.get_nonce_at(contract_address!("0x1"))) + .await + .unwrap() + .unwrap(); + assert_eq!(result, expected_result); + mock.assert_async().await; +} + +#[tokio::test] +async fn test_get_compiled_contract_class() { + let mut server = run_rpc_server().await; + let config = RpcStateReaderConfig { url: server.url(), ..Default::default() }; + + let expected_result = CasmContractClass::default(); + + let mock = mock_rpc_interaction( + &mut server, + &config.json_rpc_version, + "starknet_getCompiledContractClass", + GetCompiledContractClassParams { + block_id: BlockId::Latest, + class_hash: class_hash!("0x1"), + }, + &RpcResponse::Success(RpcSuccessResponse { + result: serde_json::to_value(CompiledContractClass::V1(expected_result)).unwrap(), + ..Default::default() + }), + ); + + let client = RpcStateReader::from_latest(&config); + let result = + tokio::task::spawn_blocking(move || client.get_compiled_contract_class(class_hash!("0x1"))) + .await + .unwrap() + .unwrap(); + assert_eq!(result, ContractClass::V1(CasmContractClass::default().try_into().unwrap())); + mock.assert_async().await; +} + +#[tokio::test] +async fn test_get_class_hash_at() { + let mut server = run_rpc_server().await; + let config = RpcStateReaderConfig { url: server.url(), ..Default::default() }; + + let expected_result = class_hash!("0x999"); + + let mock = mock_rpc_interaction( + &mut server, + &config.json_rpc_version, + "starknet_getClassHashAt", + GetClassHashAtParams { + block_id: BlockId::Latest, + contract_address: contract_address!("0x1"), + }, + &RpcResponse::Success(RpcSuccessResponse { + result: serde_json::to_value(expected_result).unwrap(), + ..Default::default() + }), + ); + + let client = RpcStateReader::from_latest(&config); + let result = + tokio::task::spawn_blocking(move || client.get_class_hash_at(contract_address!("0x1"))) + .await + .unwrap() + .unwrap(); + assert_eq!(result, expected_result); + mock.assert_async().await; +}