diff --git a/Cargo.toml b/Cargo.toml index d0372ad..dd11e3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,3 +54,4 @@ url = "2" [dev-dependencies] alloy-rpc-client = "0.6" alloy-transport-http = "0.6" +tiny_http = "0.11" diff --git a/src/backend.rs b/src/backend.rs index 0250e3b..b920d89 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -28,6 +28,7 @@ use revm::{ }; use std::{ collections::VecDeque, + fmt, future::IntoFuture, marker::PhantomData, path::Path, @@ -65,6 +66,45 @@ type AddressData = AddressHashMap; type StorageData = AddressHashMap; type BlockHashData = HashMap; +struct AnyRequestFuture { + sender: OneshotSender>, + future: Pin> + Send>>, +} + +impl fmt::Debug for AnyRequestFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("AnyRequestFuture").field(&self.sender).finish() + } +} + +trait WrappedAnyRequest: Unpin + Send + fmt::Debug { + fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<()>; +} + +/// @dev Implements `WrappedAnyRequest` for `AnyRequestFuture`. +/// +/// - `poll_inner` is similar to `Future` polling but intentionally consumes the Future +/// and return Future +/// - This design avoids storing `Future` directly, as its type may not be known at +/// compile time. +/// - Instead, the result (`Result`) is sent via the `sender` channel, which enforces type +/// safety. +impl WrappedAnyRequest for AnyRequestFuture +where + T: fmt::Debug + Send + 'static, + Err: fmt::Debug + Send + 'static, +{ + fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<()> { + match self.future.poll_unpin(cx) { + Poll::Ready(result) => { + let _ = self.sender.send(result); + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + } + } +} + /// Request variants that are executed by the provider enum ProviderRequest { Account(AccountFuture), @@ -72,6 +112,7 @@ enum ProviderRequest { BlockHash(BlockHashFuture), FullBlock(FullBlockFuture), Transaction(TransactionFuture), + AnyRequest(Box), } /// The Request type the Backend listens for @@ -96,6 +137,8 @@ enum BackendRequest { UpdateStorage(StorageData), /// Update Block Hashes UpdateBlockHash(BlockHashData), + /// Any other request + AnyRequest(Box), } /// Handles an internal provider and listens for requests. @@ -210,6 +253,9 @@ where self.db.block_hashes().write().insert(block, hash); } } + BackendRequest::AnyRequest(fut) => { + self.pending_requests.push(ProviderRequest::AnyRequest(fut)); + } } } @@ -505,6 +551,11 @@ where continue; } } + ProviderRequest::AnyRequest(fut) => { + if fut.poll_inner(cx).is_ready() { + continue; + } + } } // not ready, insert and poll again pin.pending_requests.push(request); @@ -757,6 +808,23 @@ impl SharedBackend { } } + /// Returns any arbitrary request on the provider + pub fn do_any_request(&mut self, fut: F) -> DatabaseResult + where + F: Future> + Send + 'static, + T: fmt::Debug + Send + 'static, + { + self.blocking_mode.run(|| { + let (sender, rx) = oneshot_channel::>(); + let req = BackendRequest::AnyRequest(Box::new(AnyRequestFuture { + sender, + future: Box::pin(fut), + })); + self.backend.unbounded_send(req)?; + rx.recv()?.map_err(|err| DatabaseError::AnyRequest(Arc::new(err))) + }) + } + /// Flushes the DB to disk if caching is enabled pub fn flush_cache(&self) { self.cache.0.flush(); @@ -851,7 +919,9 @@ mod tests { use alloy_provider::{ProviderBuilder, RootProvider}; use alloy_rpc_client::ClientBuilder; use alloy_transport_http::{Client, Http}; + use serde::Deserialize; use std::{collections::BTreeSet, fs, path::PathBuf}; + use tiny_http::{Response, Server}; pub fn get_http_provider(endpoint: &str) -> RootProvider, AnyNetwork> { ProviderBuilder::new() @@ -1246,4 +1316,58 @@ mod tests { // erase the temporary file fs::remove_file("test-data/storage-tmp.json").unwrap(); } + + #[tokio::test(flavor = "multi_thread")] + async fn shared_backend_any_request() { + let expected_response_bytes: Bytes = vec![0xff, 0xee].into(); + let server = Server::http("0.0.0.0:0").expect("failed starting in-memory http server"); + let endpoint = format!("http://{}", server.server_addr()); + + // Spin an in-memory server that responds to "foo_callCustomMethod" rpc call. + let expected_bytes_innner = expected_response_bytes.clone(); + let server_handle = std::thread::spawn(move || { + #[derive(Debug, Deserialize)] + struct Request { + method: String, + } + let mut request = server.recv().unwrap(); + let rpc_request: Request = + serde_json::from_reader(request.as_reader()).expect("failed parsing request"); + + match rpc_request.method.as_str() { + "foo_callCustomMethod" => request + .respond(Response::from_string(format!( + r#"{{"result": "{}"}}"#, + alloy_primitives::hex::encode_prefixed(expected_bytes_innner), + ))) + .unwrap(), + _ => request + .respond(Response::from_string(r#"{"error": "invalid request"}"#)) + .unwrap(), + }; + }); + + let provider = get_http_provider(&endpoint); + let meta = BlockchainDbMeta { + cfg_env: Default::default(), + block_env: Default::default(), + hosts: BTreeSet::from([endpoint.to_string()]), + }; + + let db = BlockchainDb::new(meta, None); + let provider_inner = provider.clone(); + let mut backend = SharedBackend::spawn_backend(Arc::new(provider), db.clone(), None).await; + + let actual_response_bytes = backend + .do_any_request(async move { + let bytes: alloy_primitives::Bytes = + provider_inner.raw_request("foo_callCustomMethod".into(), vec!["0001"]).await?; + Ok(bytes) + }) + .expect("failed performing any request"); + + assert_eq!(actual_response_bytes, expected_response_bytes); + + server_handle.join().unwrap(); + } } diff --git a/src/error.rs b/src/error.rs index 9170b6e..691c16a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -31,6 +31,8 @@ pub enum DatabaseError { BlockNotFound(BlockId), #[error("failed to get transaction {0}: {1}")] GetTransaction(B256, Arc), + #[error("failed to process AnyRequest: {0}")] + AnyRequest(Arc), } impl DatabaseError { @@ -41,6 +43,7 @@ impl DatabaseError { Self::GetBlockHash(_, err) => Some(err), Self::GetFullBlock(_, err) => Some(err), Self::GetTransaction(_, err) => Some(err), + Self::AnyRequest(err) => Some(err), // Enumerate explicitly to make sure errors are updated if a new one is added. Self::MissingCode(_) | Self::Recv(_) | Self::Send(_) | Self::BlockNotFound(_) => None, }