From dbe467e827bb7a02ff8e836cb08c9dd00b56f397 Mon Sep 17 00:00:00 2001 From: Jonathan LEI Date: Mon, 9 Sep 2024 22:11:58 -0700 Subject: [PATCH] feat: jsonrpc batch requests (#653) --- README.md | 16 +- examples/batch.rs | 36 ++ examples/parse_jsonrpc_request.rs | 4 +- starknet-providers/src/any.rs | 19 +- starknet-providers/src/jsonrpc/mod.rs | 372 +++++++++++++----- .../src/jsonrpc/transports/http.rs | 71 +++- .../src/jsonrpc/transports/mod.rs | 13 +- starknet-providers/src/lib.rs | 2 +- starknet-providers/src/provider.rs | 140 ++++++- starknet-providers/src/sequencer/provider.rs | 16 +- starknet-providers/tests/jsonrpc.rs | 45 ++- 11 files changed, 622 insertions(+), 112 deletions(-) create mode 100644 examples/batch.rs diff --git a/README.md b/README.md index ba8e58a1..3371af44 100644 --- a/README.md +++ b/README.md @@ -92,19 +92,21 @@ Examples can be found in the [examples folder](./examples): 6. [Query the latest block number with JSON-RPC](./examples/jsonrpc.rs) -7. [Call a contract view function](./examples/erc20_balance.rs) +7. [Batched JSON-RPC requests](./examples/batch.rs) -8. [Deploy an Argent X account to a pre-funded address](./examples/deploy_argent_account.rs) +8. [Call a contract view function](./examples/erc20_balance.rs) -9. [Inspect public key with Ledger](./examples/ledger_public_key.rs) +9. [Deploy an Argent X account to a pre-funded address](./examples/deploy_argent_account.rs) -10. [Deploy an OpenZeppelin account with Ledger](./examples/deploy_account_with_ledger.rs) +10. [Inspect public key with Ledger](./examples/ledger_public_key.rs) -11. [Transfer ERC20 tokens with Ledger](./examples/transfer_with_ledger.rs) +11. [Deploy an OpenZeppelin account with Ledger](./examples/deploy_account_with_ledger.rs) -12. [Parsing a JSON-RPC request on the server side](./examples/parse_jsonrpc_request.rs) +12. [Transfer ERC20 tokens with Ledger](./examples/transfer_with_ledger.rs) -13. [Inspecting a erased provider-specific error type](./examples/downcast_provider_error.rs) +13. [Parsing a JSON-RPC request on the server side](./examples/parse_jsonrpc_request.rs) + +14. [Inspecting a erased provider-specific error type](./examples/downcast_provider_error.rs) ## License diff --git a/examples/batch.rs b/examples/batch.rs new file mode 100644 index 00000000..5dc1b78a --- /dev/null +++ b/examples/batch.rs @@ -0,0 +1,36 @@ +use starknet::providers::{ + jsonrpc::{HttpTransport, JsonRpcClient}, + Provider, ProviderRequestData, ProviderResponseData, Url, +}; +use starknet_core::types::{ + requests::{BlockNumberRequest, GetBlockTransactionCountRequest}, + BlockId, +}; + +#[tokio::main] +async fn main() { + let provider = JsonRpcClient::new(HttpTransport::new( + Url::parse("https://starknet-sepolia.public.blastapi.io/rpc/v0_7").unwrap(), + )); + + let responses = provider + .batch_requests([ + ProviderRequestData::BlockNumber(BlockNumberRequest), + ProviderRequestData::GetBlockTransactionCount(GetBlockTransactionCountRequest { + block_id: BlockId::Number(100), + }), + ]) + .await + .unwrap(); + + match (&responses[0], &responses[1]) { + ( + ProviderResponseData::BlockNumber(block_number), + ProviderResponseData::GetBlockTransactionCount(count), + ) => { + println!("The latest block is #{}", block_number); + println!("Block #100 has {} transactions", count); + } + _ => panic!("unexpected response type"), + } +} diff --git a/examples/parse_jsonrpc_request.rs b/examples/parse_jsonrpc_request.rs index d325cf2c..6a69dd2a 100644 --- a/examples/parse_jsonrpc_request.rs +++ b/examples/parse_jsonrpc_request.rs @@ -1,4 +1,4 @@ -use starknet_providers::jsonrpc::{JsonRpcRequest, JsonRpcRequestData}; +use starknet_providers::{jsonrpc::JsonRpcRequest, ProviderRequestData}; fn main() { // Let's pretend this is the raw request body coming from HTTP @@ -17,7 +17,7 @@ fn main() { println!("Request received: {:#?}", parsed_request); match parsed_request.data { - JsonRpcRequestData::GetBlockTransactionCount(req) => { + ProviderRequestData::GetBlockTransactionCount(req) => { println!( "starknet_getBlockTransactionCount request received for block: {:?}", req.block_id diff --git a/starknet-providers/src/any.rs b/starknet-providers/src/any.rs index 5a10711f..3bc2b0ad 100644 --- a/starknet-providers/src/any.rs +++ b/starknet-providers/src/any.rs @@ -12,7 +12,7 @@ use starknet_core::types::{ use crate::{ jsonrpc::{HttpTransport, JsonRpcClient}, - Provider, ProviderError, SequencerGatewayProvider, + Provider, ProviderError, ProviderRequestData, ProviderResponseData, SequencerGatewayProvider, }; /// A convenient Box-able type that implements the [Provider] trait. This can be useful when you @@ -665,4 +665,21 @@ impl Provider for AnyProvider { } } } + + async fn batch_requests( + &self, + requests: R, + ) -> Result, ProviderError> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + match self { + Self::JsonRpcHttp(inner) => { + as Provider>::batch_requests(inner, requests).await + } + Self::SequencerGateway(inner) => { + ::batch_requests(inner, requests).await + } + } + } } diff --git a/starknet-providers/src/jsonrpc/mod.rs b/starknet-providers/src/jsonrpc/mod.rs index 2d6013ed..1f91e070 100644 --- a/starknet-providers/src/jsonrpc/mod.rs +++ b/starknet-providers/src/jsonrpc/mod.rs @@ -19,7 +19,9 @@ use starknet_core::{ }, }; -use crate::{provider::ProviderImplError, Provider, ProviderError}; +use crate::{ + provider::ProviderImplError, Provider, ProviderError, ProviderRequestData, ProviderResponseData, +}; mod transports; pub use transports::{HttpTransport, HttpTransportError, JsonRpcTransport}; @@ -131,70 +133,7 @@ pub struct JsonRpcRequest { /// ID of the request. Useful for identifying responses in certain transports like `WebSocket`. pub id: u64, /// Data of the requeest. - pub data: JsonRpcRequestData, -} - -/// Typed request data for Starknet JSON-RPC requests. -#[derive(Debug, Clone)] -pub enum JsonRpcRequestData { - /// Request data for `starknet_specVersion`. - SpecVersion(SpecVersionRequest), - /// Request data for `starknet_getBlockWithTxHashes`. - GetBlockWithTxHashes(GetBlockWithTxHashesRequest), - /// Request data for `starknet_getBlockWithTxs`. - GetBlockWithTxs(GetBlockWithTxsRequest), - /// Request data for `starknet_getBlockWithReceipts`. - GetBlockWithReceipts(GetBlockWithReceiptsRequest), - /// Request data for `starknet_getStateUpdate`. - GetStateUpdate(GetStateUpdateRequest), - /// Request data for `starknet_getStorageAt`. - GetStorageAt(GetStorageAtRequest), - /// Request data for `starknet_getTransactionStatus`. - GetTransactionStatus(GetTransactionStatusRequest), - /// Request data for `starknet_getTransactionByHash`. - GetTransactionByHash(GetTransactionByHashRequest), - /// Request data for `starknet_getTransactionByBlockIdAndIndex`. - GetTransactionByBlockIdAndIndex(GetTransactionByBlockIdAndIndexRequest), - /// Request data for `starknet_getTransactionReceipt`. - GetTransactionReceipt(GetTransactionReceiptRequest), - /// Request data for `starknet_getClass`. - GetClass(GetClassRequest), - /// Request data for `starknet_getClassHashAt`. - GetClassHashAt(GetClassHashAtRequest), - /// Request data for `starknet_getClassAt`. - GetClassAt(GetClassAtRequest), - /// Request data for `starknet_getBlockTransactionCount`. - GetBlockTransactionCount(GetBlockTransactionCountRequest), - /// Request data for `starknet_call`. - Call(CallRequest), - /// Request data for `starknet_estimateFee`. - EstimateFee(EstimateFeeRequest), - /// Request data for `starknet_estimateMessageFee`. - EstimateMessageFee(EstimateMessageFeeRequest), - /// Request data for `starknet_blockNumber`. - BlockNumber(BlockNumberRequest), - /// Request data for `starknet_blockHashAndNumber`. - BlockHashAndNumber(BlockHashAndNumberRequest), - /// Request data for `starknet_chainId`. - ChainId(ChainIdRequest), - /// Request data for `starknet_syncing`. - Syncing(SyncingRequest), - /// Request data for `starknet_getEvents`. - GetEvents(GetEventsRequest), - /// Request data for `starknet_getNonce`. - GetNonce(GetNonceRequest), - /// Request data for `starknet_addInvokeTransaction`. - AddInvokeTransaction(AddInvokeTransactionRequest), - /// Request data for `starknet_addDeclareTransaction`. - AddDeclareTransaction(AddDeclareTransactionRequest), - /// Request data for `starknet_addDeployAccountTransaction`. - AddDeployAccountTransaction(AddDeployAccountTransactionRequest), - /// Request data for `starknet_traceTransaction`. - TraceTransaction(TraceTransactionRequest), - /// Request data for `starknet_simulateTransactions`. - SimulateTransactions(SimulateTransactionsRequest), - /// Request data for `starknet_traceBlockTransactions`. - TraceBlockTransactions(TraceBlockTransactionsRequest), + pub data: ProviderRequestData, } /// Errors from JSON-RPC client. @@ -212,7 +151,7 @@ pub enum JsonRpcClientError { } /// An unsuccessful response returned from the server. -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct JsonRpcError { /// Error code. pub code: i64, @@ -224,7 +163,7 @@ pub struct JsonRpcError { } /// JSON-RPC response returned from a server. -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] #[serde(untagged)] pub enum JsonRpcResponse { /// Successful response. @@ -303,6 +242,199 @@ where } } } + + async fn send_requests( + &self, + requests: R, + ) -> Result, ProviderError> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + let mut results = vec![]; + + let responses = self + .transport + .send_requests(requests.as_ref().to_vec()) + .await + .map_err(JsonRpcClientError::TransportError)?; + + for (request, response) in requests.as_ref().iter().zip(responses.into_iter()) { + match response { + JsonRpcResponse::Success { result, .. } => { + let result = match request { + ProviderRequestData::SpecVersion(_) => ProviderResponseData::SpecVersion( + String::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ), + ProviderRequestData::GetBlockWithTxHashes(_) => { + ProviderResponseData::GetBlockWithTxHashes( + MaybePendingBlockWithTxHashes::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::GetBlockWithTxs(_) => { + ProviderResponseData::GetBlockWithTxs( + MaybePendingBlockWithTxs::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::GetBlockWithReceipts(_) => { + ProviderResponseData::GetBlockWithReceipts( + MaybePendingBlockWithReceipts::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::GetStateUpdate(_) => { + ProviderResponseData::GetStateUpdate( + MaybePendingStateUpdate::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::GetStorageAt(_) => ProviderResponseData::GetStorageAt( + Felt::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)? + .0, + ), + ProviderRequestData::GetTransactionStatus(_) => { + ProviderResponseData::GetTransactionStatus( + TransactionStatus::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::GetTransactionByHash(_) => { + ProviderResponseData::GetTransactionByHash( + Transaction::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::GetTransactionByBlockIdAndIndex(_) => { + ProviderResponseData::GetTransactionByBlockIdAndIndex( + Transaction::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::GetTransactionReceipt(_) => { + ProviderResponseData::GetTransactionReceipt( + TransactionReceiptWithBlockInfo::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::GetClass(_) => ProviderResponseData::GetClass( + ContractClass::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ), + ProviderRequestData::GetClassHashAt(_) => { + ProviderResponseData::GetClassHashAt( + Felt::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)? + .0, + ) + } + ProviderRequestData::GetClassAt(_) => ProviderResponseData::GetClassAt( + ContractClass::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ), + ProviderRequestData::GetBlockTransactionCount(_) => { + ProviderResponseData::GetBlockTransactionCount( + u64::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::Call(_) => ProviderResponseData::Call( + FeltArray::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)? + .0, + ), + ProviderRequestData::EstimateFee(_) => ProviderResponseData::EstimateFee( + Vec::::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ), + ProviderRequestData::EstimateMessageFee(_) => { + ProviderResponseData::EstimateMessageFee( + FeeEstimate::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::BlockNumber(_) => ProviderResponseData::BlockNumber( + u64::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ), + ProviderRequestData::BlockHashAndNumber(_) => { + ProviderResponseData::BlockHashAndNumber( + BlockHashAndNumber::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::ChainId(_) => ProviderResponseData::ChainId( + Felt::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)? + .0, + ), + ProviderRequestData::Syncing(_) => ProviderResponseData::Syncing( + SyncStatusType::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ), + ProviderRequestData::GetEvents(_) => ProviderResponseData::GetEvents( + EventsPage::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ), + ProviderRequestData::GetNonce(_) => ProviderResponseData::GetNonce( + Felt::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)? + .0, + ), + ProviderRequestData::AddInvokeTransaction(_) => { + ProviderResponseData::AddInvokeTransaction( + InvokeTransactionResult::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::AddDeclareTransaction(_) => { + ProviderResponseData::AddDeclareTransaction( + DeclareTransactionResult::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::AddDeployAccountTransaction(_) => { + ProviderResponseData::AddDeployAccountTransaction( + DeployAccountTransactionResult::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::TraceTransaction(_) => { + ProviderResponseData::TraceTransaction( + TransactionTrace::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::SimulateTransactions(_) => { + ProviderResponseData::SimulateTransactions( + Vec::::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + ProviderRequestData::TraceBlockTransactions(_) => { + ProviderResponseData::TraceBlockTransactions( + Vec::::deserialize(result) + .map_err(JsonRpcClientError::::JsonError)?, + ) + } + }; + + results.push(result); + } + // TODO: add context on index of request causing the error + JsonRpcResponse::Error { error, .. } => { + return Err(match TryInto::::try_into(&error) { + Ok(error) => ProviderError::StarknetError(error), + Err(_) => JsonRpcClientError::::JsonRpcError(error).into(), + }) + } + } + } + + Ok(results) + } } #[cfg_attr(not(target_arch = "wasm32"), async_trait)] @@ -801,6 +933,54 @@ where ) .await } + + async fn batch_requests( + &self, + requests: R, + ) -> Result, ProviderError> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + self.send_requests(requests).await + } +} + +impl ProviderRequestData { + const fn jsonrpc_method(&self) -> JsonRpcMethod { + match self { + Self::SpecVersion(_) => JsonRpcMethod::SpecVersion, + Self::GetBlockWithTxHashes(_) => JsonRpcMethod::GetBlockWithTxHashes, + Self::GetBlockWithTxs(_) => JsonRpcMethod::GetBlockWithTxs, + Self::GetBlockWithReceipts(_) => JsonRpcMethod::GetBlockWithReceipts, + Self::GetStateUpdate(_) => JsonRpcMethod::GetStateUpdate, + Self::GetStorageAt(_) => JsonRpcMethod::GetStorageAt, + Self::GetTransactionStatus(_) => JsonRpcMethod::GetTransactionStatus, + Self::GetTransactionByHash(_) => JsonRpcMethod::GetTransactionByHash, + Self::GetTransactionByBlockIdAndIndex(_) => { + JsonRpcMethod::GetTransactionByBlockIdAndIndex + } + Self::GetTransactionReceipt(_) => JsonRpcMethod::GetTransactionReceipt, + Self::GetClass(_) => JsonRpcMethod::GetClass, + Self::GetClassHashAt(_) => JsonRpcMethod::GetClassHashAt, + Self::GetClassAt(_) => JsonRpcMethod::GetClassAt, + Self::GetBlockTransactionCount(_) => JsonRpcMethod::GetBlockTransactionCount, + Self::Call(_) => JsonRpcMethod::Call, + Self::EstimateFee(_) => JsonRpcMethod::EstimateFee, + Self::EstimateMessageFee(_) => JsonRpcMethod::EstimateMessageFee, + Self::BlockNumber(_) => JsonRpcMethod::BlockNumber, + Self::BlockHashAndNumber(_) => JsonRpcMethod::BlockHashAndNumber, + Self::ChainId(_) => JsonRpcMethod::ChainId, + Self::Syncing(_) => JsonRpcMethod::Syncing, + Self::GetEvents(_) => JsonRpcMethod::GetEvents, + Self::GetNonce(_) => JsonRpcMethod::GetNonce, + Self::AddInvokeTransaction(_) => JsonRpcMethod::AddInvokeTransaction, + Self::AddDeclareTransaction(_) => JsonRpcMethod::AddDeclareTransaction, + Self::AddDeployAccountTransaction(_) => JsonRpcMethod::AddDeployAccountTransaction, + Self::TraceTransaction(_) => JsonRpcMethod::TraceTransaction, + Self::SimulateTransactions(_) => JsonRpcMethod::SimulateTransactions, + Self::TraceBlockTransactions(_) => JsonRpcMethod::TraceBlockTransactions, + } + } } impl<'de> Deserialize<'de> for JsonRpcRequest { @@ -820,128 +1000,128 @@ impl<'de> Deserialize<'de> for JsonRpcRequest { let raw_request = RawRequest::deserialize(deserializer)?; let request_data = match raw_request.method { - JsonRpcMethod::SpecVersion => JsonRpcRequestData::SpecVersion( + JsonRpcMethod::SpecVersion => ProviderRequestData::SpecVersion( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetBlockWithTxHashes => JsonRpcRequestData::GetBlockWithTxHashes( + JsonRpcMethod::GetBlockWithTxHashes => ProviderRequestData::GetBlockWithTxHashes( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetBlockWithTxs => JsonRpcRequestData::GetBlockWithTxs( + JsonRpcMethod::GetBlockWithTxs => ProviderRequestData::GetBlockWithTxs( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetBlockWithReceipts => JsonRpcRequestData::GetBlockWithReceipts( + JsonRpcMethod::GetBlockWithReceipts => ProviderRequestData::GetBlockWithReceipts( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetStateUpdate => JsonRpcRequestData::GetStateUpdate( + JsonRpcMethod::GetStateUpdate => ProviderRequestData::GetStateUpdate( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetStorageAt => JsonRpcRequestData::GetStorageAt( + JsonRpcMethod::GetStorageAt => ProviderRequestData::GetStorageAt( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetTransactionStatus => JsonRpcRequestData::GetTransactionStatus( + JsonRpcMethod::GetTransactionStatus => ProviderRequestData::GetTransactionStatus( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetTransactionByHash => JsonRpcRequestData::GetTransactionByHash( + JsonRpcMethod::GetTransactionByHash => ProviderRequestData::GetTransactionByHash( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), JsonRpcMethod::GetTransactionByBlockIdAndIndex => { - JsonRpcRequestData::GetTransactionByBlockIdAndIndex( + ProviderRequestData::GetTransactionByBlockIdAndIndex( serde_json::from_value::( raw_request.params, ) .map_err(error_mapper)?, ) } - JsonRpcMethod::GetTransactionReceipt => JsonRpcRequestData::GetTransactionReceipt( + JsonRpcMethod::GetTransactionReceipt => ProviderRequestData::GetTransactionReceipt( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetClass => JsonRpcRequestData::GetClass( + JsonRpcMethod::GetClass => ProviderRequestData::GetClass( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetClassHashAt => JsonRpcRequestData::GetClassHashAt( + JsonRpcMethod::GetClassHashAt => ProviderRequestData::GetClassHashAt( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetClassAt => JsonRpcRequestData::GetClassAt( + JsonRpcMethod::GetClassAt => ProviderRequestData::GetClassAt( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), JsonRpcMethod::GetBlockTransactionCount => { - JsonRpcRequestData::GetBlockTransactionCount( + ProviderRequestData::GetBlockTransactionCount( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ) } - JsonRpcMethod::Call => JsonRpcRequestData::Call( + JsonRpcMethod::Call => ProviderRequestData::Call( serde_json::from_value::(raw_request.params).map_err(error_mapper)?, ), - JsonRpcMethod::EstimateFee => JsonRpcRequestData::EstimateFee( + JsonRpcMethod::EstimateFee => ProviderRequestData::EstimateFee( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::EstimateMessageFee => JsonRpcRequestData::EstimateMessageFee( + JsonRpcMethod::EstimateMessageFee => ProviderRequestData::EstimateMessageFee( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::BlockNumber => JsonRpcRequestData::BlockNumber( + JsonRpcMethod::BlockNumber => ProviderRequestData::BlockNumber( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::BlockHashAndNumber => JsonRpcRequestData::BlockHashAndNumber( + JsonRpcMethod::BlockHashAndNumber => ProviderRequestData::BlockHashAndNumber( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::ChainId => JsonRpcRequestData::ChainId( + JsonRpcMethod::ChainId => ProviderRequestData::ChainId( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::Syncing => JsonRpcRequestData::Syncing( + JsonRpcMethod::Syncing => ProviderRequestData::Syncing( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetEvents => JsonRpcRequestData::GetEvents( + JsonRpcMethod::GetEvents => ProviderRequestData::GetEvents( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::GetNonce => JsonRpcRequestData::GetNonce( + JsonRpcMethod::GetNonce => ProviderRequestData::GetNonce( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::AddInvokeTransaction => JsonRpcRequestData::AddInvokeTransaction( + JsonRpcMethod::AddInvokeTransaction => ProviderRequestData::AddInvokeTransaction( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::AddDeclareTransaction => JsonRpcRequestData::AddDeclareTransaction( + JsonRpcMethod::AddDeclareTransaction => ProviderRequestData::AddDeclareTransaction( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), JsonRpcMethod::AddDeployAccountTransaction => { - JsonRpcRequestData::AddDeployAccountTransaction( + ProviderRequestData::AddDeployAccountTransaction( serde_json::from_value::( raw_request.params, ) .map_err(error_mapper)?, ) } - JsonRpcMethod::TraceTransaction => JsonRpcRequestData::TraceTransaction( + JsonRpcMethod::TraceTransaction => ProviderRequestData::TraceTransaction( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::SimulateTransactions => JsonRpcRequestData::SimulateTransactions( + JsonRpcMethod::SimulateTransactions => ProviderRequestData::SimulateTransactions( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), - JsonRpcMethod::TraceBlockTransactions => JsonRpcRequestData::TraceBlockTransactions( + JsonRpcMethod::TraceBlockTransactions => ProviderRequestData::TraceBlockTransactions( serde_json::from_value::(raw_request.params) .map_err(error_mapper)?, ), diff --git a/starknet-providers/src/jsonrpc/transports/http.rs b/starknet-providers/src/jsonrpc/transports/http.rs index 4663b859..bf85ca0f 100644 --- a/starknet-providers/src/jsonrpc/transports/http.rs +++ b/starknet-providers/src/jsonrpc/transports/http.rs @@ -3,7 +3,10 @@ use log::trace; use reqwest::{Client, Url}; use serde::{de::DeserializeOwned, Serialize}; -use crate::jsonrpc::{transports::JsonRpcTransport, JsonRpcMethod, JsonRpcResponse}; +use crate::{ + jsonrpc::{transports::JsonRpcTransport, JsonRpcMethod, JsonRpcResponse}, + ProviderRequestData, +}; /// A [`JsonRpcTransport`] implementation that uses HTTP connections. #[derive(Debug)] @@ -21,6 +24,9 @@ pub enum HttpTransportError { Reqwest(reqwest::Error), /// JSON serialization/deserialization errors. Json(serde_json::Error), + /// Unexpected response ID. + #[error("unexpected response ID: {0}")] + UnexpectedResponseId(u64), } #[derive(Debug, Serialize)] @@ -110,4 +116,67 @@ impl JsonRpcTransport for HttpTransport { Ok(parsed_response) } + + async fn send_requests( + &self, + requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + let request_bodies = requests + .as_ref() + .iter() + .enumerate() + .map(|(ind, request)| JsonRpcRequest { + id: ind as u64, + jsonrpc: "2.0", + method: request.jsonrpc_method(), + params: request, + }) + .collect::>(); + + let request_count = request_bodies.len(); + + let request_body = serde_json::to_string(&request_bodies).map_err(Self::Error::Json)?; + trace!("Sending request via JSON-RPC: {}", request_body); + + let mut request = self + .client + .post(self.url.clone()) + .body(request_body) + .header("Content-Type", "application/json"); + for (name, value) in &self.headers { + request = request.header(name, value); + } + + let response = request.send().await.map_err(Self::Error::Reqwest)?; + + let response_body = response.text().await.map_err(Self::Error::Reqwest)?; + trace!("Response from JSON-RPC: {}", response_body); + + let parsed_response: Vec> = + serde_json::from_str(&response_body).map_err(Self::Error::Json)?; + + let mut responses: Vec>> = vec![]; + responses.resize(request_bodies.len(), None); + + // Re-order the responses as servers do not maintain order. + for response_item in parsed_response { + let id = match &response_item { + JsonRpcResponse::Success { id, .. } | JsonRpcResponse::Error { id, .. } => { + *id as usize + } + }; + + if id >= request_count { + return Err(HttpTransportError::UnexpectedResponseId(id as u64)); + } + + responses[id] = Some(response_item); + } + + let responses = responses.into_iter().flatten().collect::>(); + Ok(responses) + } } diff --git a/starknet-providers/src/jsonrpc/transports/mod.rs b/starknet-providers/src/jsonrpc/transports/mod.rs index c7602c98..3c172c11 100644 --- a/starknet-providers/src/jsonrpc/transports/mod.rs +++ b/starknet-providers/src/jsonrpc/transports/mod.rs @@ -3,7 +3,10 @@ use auto_impl::auto_impl; use serde::{de::DeserializeOwned, Serialize}; use std::error::Error; -use crate::jsonrpc::{JsonRpcMethod, JsonRpcResponse}; +use crate::{ + jsonrpc::{JsonRpcMethod, JsonRpcResponse}, + ProviderRequestData, +}; mod http; pub use http::{HttpTransport, HttpTransportError}; @@ -26,4 +29,12 @@ pub trait JsonRpcTransport { where P: Serialize + Send + Sync, R: DeserializeOwned; + + /// Sends multiple JSON-RPC requests in parallel. + async fn send_requests( + &self, + requests: R, + ) -> Result>, Self::Error> + where + R: AsRef<[ProviderRequestData]> + Send + Sync; } diff --git a/starknet-providers/src/lib.rs b/starknet-providers/src/lib.rs index 0954c3d2..a8392dd3 100644 --- a/starknet-providers/src/lib.rs +++ b/starknet-providers/src/lib.rs @@ -7,7 +7,7 @@ #![deny(missing_docs)] mod provider; -pub use provider::{Provider, ProviderError}; +pub use provider::{Provider, ProviderError, ProviderRequestData, ProviderResponseData}; // Sequencer-related functionalities are mostly deprecated so we skip the docs. /// Module containing types related to the (now deprecated) sequencer gateway client. diff --git a/starknet-providers/src/provider.rs b/starknet-providers/src/provider.rs index 94d853c3..f513e52c 100644 --- a/starknet-providers/src/provider.rs +++ b/starknet-providers/src/provider.rs @@ -1,7 +1,8 @@ use async_trait::async_trait; use auto_impl::auto_impl; +use serde::Serialize; use starknet_core::types::{ - BlockHashAndNumber, BlockId, BroadcastedDeclareTransaction, + requests::*, BlockHashAndNumber, BlockId, BroadcastedDeclareTransaction, BroadcastedDeployAccountTransaction, BroadcastedInvokeTransaction, BroadcastedTransaction, ContractClass, DeclareTransactionResult, DeployAccountTransactionResult, EventFilter, EventsPage, FeeEstimate, Felt, FunctionCall, InvokeTransactionResult, @@ -260,6 +261,15 @@ pub trait Provider { where B: AsRef + Send + Sync; + /// Sends multiple requests in parallel. The function call fails if any of the requests fails. + /// Implementations must guarantee that responses follow the exact order as the requests. + async fn batch_requests( + &self, + requests: R, + ) -> Result, ProviderError> + where + R: AsRef<[ProviderRequestData]> + Send + Sync; + /// Same as [`estimate_fee`](fn.estimate_fee), but only with one estimate. async fn estimate_fee_single( &self, @@ -350,3 +360,131 @@ pub enum ProviderError { #[error("{0}")] Other(Box), } + +/// Typed request data for [`Provider`] requests. +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub enum ProviderRequestData { + /// Request data for `starknet_specVersion`. + SpecVersion(SpecVersionRequest), + /// Request data for `starknet_getBlockWithTxHashes`. + GetBlockWithTxHashes(GetBlockWithTxHashesRequest), + /// Request data for `starknet_getBlockWithTxs`. + GetBlockWithTxs(GetBlockWithTxsRequest), + /// Request data for `starknet_getBlockWithReceipts`. + GetBlockWithReceipts(GetBlockWithReceiptsRequest), + /// Request data for `starknet_getStateUpdate`. + GetStateUpdate(GetStateUpdateRequest), + /// Request data for `starknet_getStorageAt`. + GetStorageAt(GetStorageAtRequest), + /// Request data for `starknet_getTransactionStatus`. + GetTransactionStatus(GetTransactionStatusRequest), + /// Request data for `starknet_getTransactionByHash`. + GetTransactionByHash(GetTransactionByHashRequest), + /// Request data for `starknet_getTransactionByBlockIdAndIndex`. + GetTransactionByBlockIdAndIndex(GetTransactionByBlockIdAndIndexRequest), + /// Request data for `starknet_getTransactionReceipt`. + GetTransactionReceipt(GetTransactionReceiptRequest), + /// Request data for `starknet_getClass`. + GetClass(GetClassRequest), + /// Request data for `starknet_getClassHashAt`. + GetClassHashAt(GetClassHashAtRequest), + /// Request data for `starknet_getClassAt`. + GetClassAt(GetClassAtRequest), + /// Request data for `starknet_getBlockTransactionCount`. + GetBlockTransactionCount(GetBlockTransactionCountRequest), + /// Request data for `starknet_call`. + Call(CallRequest), + /// Request data for `starknet_estimateFee`. + EstimateFee(EstimateFeeRequest), + /// Request data for `starknet_estimateMessageFee`. + EstimateMessageFee(EstimateMessageFeeRequest), + /// Request data for `starknet_blockNumber`. + BlockNumber(BlockNumberRequest), + /// Request data for `starknet_blockHashAndNumber`. + BlockHashAndNumber(BlockHashAndNumberRequest), + /// Request data for `starknet_chainId`. + ChainId(ChainIdRequest), + /// Request data for `starknet_syncing`. + Syncing(SyncingRequest), + /// Request data for `starknet_getEvents`. + GetEvents(GetEventsRequest), + /// Request data for `starknet_getNonce`. + GetNonce(GetNonceRequest), + /// Request data for `starknet_addInvokeTransaction`. + AddInvokeTransaction(AddInvokeTransactionRequest), + /// Request data for `starknet_addDeclareTransaction`. + AddDeclareTransaction(AddDeclareTransactionRequest), + /// Request data for `starknet_addDeployAccountTransaction`. + AddDeployAccountTransaction(AddDeployAccountTransactionRequest), + /// Request data for `starknet_traceTransaction`. + TraceTransaction(TraceTransactionRequest), + /// Request data for `starknet_simulateTransactions`. + SimulateTransactions(SimulateTransactionsRequest), + /// Request data for `starknet_traceBlockTransactions`. + TraceBlockTransactions(TraceBlockTransactionsRequest), +} + +/// Typed response data for [`Provider`] responses. +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone)] +pub enum ProviderResponseData { + /// Response data for `starknet_specVersion`. + SpecVersion(String), + /// Response data for `starknet_getBlockWithTxHashes`. + GetBlockWithTxHashes(MaybePendingBlockWithTxHashes), + /// Response data for `starknet_getBlockWithTxs`. + GetBlockWithTxs(MaybePendingBlockWithTxs), + /// Response data for `starknet_getBlockWithReceipts`. + GetBlockWithReceipts(MaybePendingBlockWithReceipts), + /// Response data for `starknet_getStateUpdate`. + GetStateUpdate(MaybePendingStateUpdate), + /// Response data for `starknet_getStorageAt`. + GetStorageAt(Felt), + /// Response data for `starknet_getTransactionStatus`. + GetTransactionStatus(TransactionStatus), + /// Response data for `starknet_getTransactionByHash`. + GetTransactionByHash(Transaction), + /// Response data for `starknet_getTransactionByBlockIdAndIndex`. + GetTransactionByBlockIdAndIndex(Transaction), + /// Response data for `starknet_getTransactionReceipt`. + GetTransactionReceipt(TransactionReceiptWithBlockInfo), + /// Response data for `starknet_getClass`. + GetClass(ContractClass), + /// Response data for `starknet_getClassHashAt`. + GetClassHashAt(Felt), + /// Response data for `starknet_getClassAt`. + GetClassAt(ContractClass), + /// Response data for `starknet_getBlockTransactionCount`. + GetBlockTransactionCount(u64), + /// Response data for `starknet_call`. + Call(Vec), + /// Response data for `starknet_estimateFee`. + EstimateFee(Vec), + /// Response data for `starknet_estimateMessageFee`. + EstimateMessageFee(FeeEstimate), + /// Response data for `starknet_blockNumber`. + BlockNumber(u64), + /// Response data for `starknet_blockHashAndNumber`. + BlockHashAndNumber(BlockHashAndNumber), + /// Response data for `starknet_chainId`. + ChainId(Felt), + /// Response data for `starknet_syncing`. + Syncing(SyncStatusType), + /// Response data for `starknet_getEvents`. + GetEvents(EventsPage), + /// Response data for `starknet_getNonce`. + GetNonce(Felt), + /// Response data for `starknet_addInvokeTransaction`. + AddInvokeTransaction(InvokeTransactionResult), + /// Response data for `starknet_addDeclareTransaction`. + AddDeclareTransaction(DeclareTransactionResult), + /// Response data for `starknet_addDeployAccountTransaction`. + AddDeployAccountTransaction(DeployAccountTransactionResult), + /// Response data for `starknet_traceTransaction`. + TraceTransaction(TransactionTrace), + /// Response data for `starknet_simulateTransactions`. + SimulateTransactions(Vec), + /// Response data for `starknet_traceBlockTransactions`. + TraceBlockTransactions(Vec), +} diff --git a/starknet-providers/src/sequencer/provider.rs b/starknet-providers/src/sequencer/provider.rs index 313cda9c..b046838b 100644 --- a/starknet-providers/src/sequencer/provider.rs +++ b/starknet-providers/src/sequencer/provider.rs @@ -17,7 +17,7 @@ use starknet_core::types::{ use crate::{ provider::ProviderImplError, sequencer::{models::conversions::ConversionError, GatewayClientError}, - Provider, ProviderError, SequencerGatewayProvider, + Provider, ProviderError, ProviderRequestData, ProviderResponseData, SequencerGatewayProvider, }; use super::models::TransactionFinalityStatus; @@ -414,6 +414,20 @@ impl Provider for SequencerGatewayProvider { GatewayClientError::MethodNotSupported, ))) } + + async fn batch_requests( + &self, + requests: R, + ) -> Result, ProviderError> + where + R: AsRef<[ProviderRequestData]> + Send + Sync, + { + // Not implemented for now. It's technically possible to simulate this by running multiple + // requests in parallel. + Err(ProviderError::Other(Box::new( + GatewayClientError::MethodNotSupported, + ))) + } } impl ProviderImplError for GatewayClientError { diff --git a/starknet-providers/tests/jsonrpc.rs b/starknet-providers/tests/jsonrpc.rs index dac0910a..ba813fae 100644 --- a/starknet-providers/tests/jsonrpc.rs +++ b/starknet-providers/tests/jsonrpc.rs @@ -1,5 +1,6 @@ use starknet_core::{ types::{ + requests::{CallRequest, GetBlockTransactionCountRequest}, BlockId, BlockTag, BroadcastedInvokeTransaction, BroadcastedInvokeTransactionV1, BroadcastedTransaction, ContractClass, DeclareTransaction, DeployAccountTransaction, EthAddress, EventFilter, ExecuteInvocation, ExecutionResult, Felt, FunctionCall, @@ -12,7 +13,7 @@ use starknet_core::{ }; use starknet_providers::{ jsonrpc::{HttpTransport, JsonRpcClient}, - Provider, ProviderError, + Provider, ProviderError, ProviderRequestData, ProviderResponseData, }; use url::Url; @@ -873,6 +874,48 @@ async fn jsonrpc_trace_deploy_account() { } } +#[tokio::test] +async fn jsonrpc_batch() { + let rpc_client = create_jsonrpc_client(); + + let responses = rpc_client + .batch_requests([ + ProviderRequestData::GetBlockTransactionCount(GetBlockTransactionCountRequest { + block_id: BlockId::Number(20_000), + }), + ProviderRequestData::Call(CallRequest { + request: FunctionCall { + contract_address: Felt::from_hex( + "049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", + ) + .unwrap(), + entry_point_selector: get_selector_from_name("balanceOf").unwrap(), + calldata: vec![Felt::from_hex( + "03f47d3911396b6d579fd7848cf576286ab6f96dda977915d6c7b10f3dd2315b", + ) + .unwrap()], + }, + block_id: BlockId::Tag(BlockTag::Latest), + }), + ]) + .await + .unwrap(); + + match &responses[0] { + ProviderResponseData::GetBlockTransactionCount(count) => { + assert_eq!(*count, 6); + } + _ => panic!("unexpected response type"), + } + + match &responses[1] { + ProviderResponseData::Call(eth_balance) => { + assert!(eth_balance[0] > Felt::ZERO); + } + _ => panic!("unexpected response type"), + } +} + // NOTE: `addXxxxTransaction` methods are harder to test here since they require signatures. These // are integration tests anyways, so we might as well just leave the job to th tests in // `starknet-accounts`.