Skip to content

Commit

Permalink
feat(backend): add support for arbitrary provider requests with AnyRe…
Browse files Browse the repository at this point in the history
…quest (#32)

* feat(backend): add support for arbitrary provider requests with AnyRequest

* add in-memory provider test

* chore: clean up / add relevant comment

---------

Co-authored-by: Nisheeth Barthwal <[email protected]>
  • Loading branch information
dutterbutter and nbaztec authored Nov 27, 2024
1 parent 6501345 commit d90d227
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@ url = "2"
[dev-dependencies]
alloy-rpc-client = "0.6"
alloy-transport-http = "0.6"
tiny_http = "0.11"
124 changes: 124 additions & 0 deletions src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use revm::{
};
use std::{
collections::VecDeque,
fmt,
future::IntoFuture,
marker::PhantomData,
path::Path,
Expand Down Expand Up @@ -65,13 +66,53 @@ type AddressData = AddressHashMap<AccountInfo>;
type StorageData = AddressHashMap<StorageInfo>;
type BlockHashData = HashMap<U256, B256>;

struct AnyRequestFuture<T, Err> {
sender: OneshotSender<Result<T, Err>>,
future: Pin<Box<dyn Future<Output = Result<T, Err>> + Send>>,
}

impl<T, Err> fmt::Debug for AnyRequestFuture<T, Err> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("AnyRequestFuture").field(&self.sender).finish()
}
}

trait WrappedAnyRequest: Unpin + Send + fmt::Debug {
fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<()>;
}

/// @dev Implements `WrappedAnyRequest` for `AnyRequestFuture`.
///
/// - `poll_inner` is similar to `Future` polling but intentionally consumes the Future<Output=T>
/// and return Future<Output=()>
/// - This design avoids storing `Future<Output = T>` directly, as its type may not be known at
/// compile time.
/// - Instead, the result (`Result<T, Err>`) is sent via the `sender` channel, which enforces type
/// safety.
impl<T, Err> WrappedAnyRequest for AnyRequestFuture<T, Err>
where
T: fmt::Debug + Send + 'static,
Err: fmt::Debug + Send + 'static,
{
fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<()> {
match self.future.poll_unpin(cx) {
Poll::Ready(result) => {
let _ = self.sender.send(result);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}

/// Request variants that are executed by the provider
enum ProviderRequest<Err> {
Account(AccountFuture<Err>),
Storage(StorageFuture<Err>),
BlockHash(BlockHashFuture<Err>),
FullBlock(FullBlockFuture<Err>),
Transaction(TransactionFuture<Err>),
AnyRequest(Box<dyn WrappedAnyRequest>),
}

/// The Request type the Backend listens for
Expand All @@ -96,6 +137,8 @@ enum BackendRequest {
UpdateStorage(StorageData),
/// Update Block Hashes
UpdateBlockHash(BlockHashData),
/// Any other request
AnyRequest(Box<dyn WrappedAnyRequest>),
}

/// Handles an internal provider and listens for requests.
Expand Down Expand Up @@ -210,6 +253,9 @@ where
self.db.block_hashes().write().insert(block, hash);
}
}
BackendRequest::AnyRequest(fut) => {
self.pending_requests.push(ProviderRequest::AnyRequest(fut));
}
}
}

Expand Down Expand Up @@ -505,6 +551,11 @@ where
continue;
}
}
ProviderRequest::AnyRequest(fut) => {
if fut.poll_inner(cx).is_ready() {
continue;
}
}
}
// not ready, insert and poll again
pin.pending_requests.push(request);
Expand Down Expand Up @@ -757,6 +808,23 @@ impl SharedBackend {
}
}

/// Returns any arbitrary request on the provider
pub fn do_any_request<T, F>(&mut self, fut: F) -> DatabaseResult<T>
where
F: Future<Output = Result<T, eyre::Report>> + Send + 'static,
T: fmt::Debug + Send + 'static,
{
self.blocking_mode.run(|| {
let (sender, rx) = oneshot_channel::<Result<T, eyre::Report>>();
let req = BackendRequest::AnyRequest(Box::new(AnyRequestFuture {
sender,
future: Box::pin(fut),
}));
self.backend.unbounded_send(req)?;
rx.recv()?.map_err(|err| DatabaseError::AnyRequest(Arc::new(err)))
})
}

/// Flushes the DB to disk if caching is enabled
pub fn flush_cache(&self) {
self.cache.0.flush();
Expand Down Expand Up @@ -851,7 +919,9 @@ mod tests {
use alloy_provider::{ProviderBuilder, RootProvider};
use alloy_rpc_client::ClientBuilder;
use alloy_transport_http::{Client, Http};
use serde::Deserialize;
use std::{collections::BTreeSet, fs, path::PathBuf};
use tiny_http::{Response, Server};

pub fn get_http_provider(endpoint: &str) -> RootProvider<Http<Client>, AnyNetwork> {
ProviderBuilder::new()
Expand Down Expand Up @@ -1246,4 +1316,58 @@ mod tests {
// erase the temporary file
fs::remove_file("test-data/storage-tmp.json").unwrap();
}

#[tokio::test(flavor = "multi_thread")]
async fn shared_backend_any_request() {
let expected_response_bytes: Bytes = vec![0xff, 0xee].into();
let server = Server::http("0.0.0.0:0").expect("failed starting in-memory http server");
let endpoint = format!("http://{}", server.server_addr());

// Spin an in-memory server that responds to "foo_callCustomMethod" rpc call.
let expected_bytes_innner = expected_response_bytes.clone();
let server_handle = std::thread::spawn(move || {
#[derive(Debug, Deserialize)]
struct Request {
method: String,
}
let mut request = server.recv().unwrap();
let rpc_request: Request =
serde_json::from_reader(request.as_reader()).expect("failed parsing request");

match rpc_request.method.as_str() {
"foo_callCustomMethod" => request
.respond(Response::from_string(format!(
r#"{{"result": "{}"}}"#,
alloy_primitives::hex::encode_prefixed(expected_bytes_innner),
)))
.unwrap(),
_ => request
.respond(Response::from_string(r#"{"error": "invalid request"}"#))
.unwrap(),
};
});

let provider = get_http_provider(&endpoint);
let meta = BlockchainDbMeta {
cfg_env: Default::default(),
block_env: Default::default(),
hosts: BTreeSet::from([endpoint.to_string()]),
};

let db = BlockchainDb::new(meta, None);
let provider_inner = provider.clone();
let mut backend = SharedBackend::spawn_backend(Arc::new(provider), db.clone(), None).await;

let actual_response_bytes = backend
.do_any_request(async move {
let bytes: alloy_primitives::Bytes =
provider_inner.raw_request("foo_callCustomMethod".into(), vec!["0001"]).await?;
Ok(bytes)
})
.expect("failed performing any request");

assert_eq!(actual_response_bytes, expected_response_bytes);

server_handle.join().unwrap();
}
}
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub enum DatabaseError {
BlockNotFound(BlockId),
#[error("failed to get transaction {0}: {1}")]
GetTransaction(B256, Arc<eyre::Error>),
#[error("failed to process AnyRequest: {0}")]
AnyRequest(Arc<eyre::Error>),
}

impl DatabaseError {
Expand All @@ -41,6 +43,7 @@ impl DatabaseError {
Self::GetBlockHash(_, err) => Some(err),
Self::GetFullBlock(_, err) => Some(err),
Self::GetTransaction(_, err) => Some(err),
Self::AnyRequest(err) => Some(err),
// Enumerate explicitly to make sure errors are updated if a new one is added.
Self::MissingCode(_) | Self::Recv(_) | Self::Send(_) | Self::BlockNotFound(_) => None,
}
Expand Down

0 comments on commit d90d227

Please sign in to comment.