From 32a6ba7f6d7773d0b9873aa116f4aab26b1bfb38 Mon Sep 17 00:00:00 2001 From: Marko Atanasievski Date: Tue, 27 Aug 2024 17:11:32 +0200 Subject: [PATCH] fix: limit number of connections (#546) * fix: limit number of connections * fix: formatting * fix: explicitelly build reqwest client * fix: nit * fix: review --- zero_bin/leader/src/client.rs | 2 +- zero_bin/rpc/src/jerigon.rs | 6 +++-- zero_bin/rpc/src/lib.rs | 4 +-- zero_bin/rpc/src/main.rs | 5 ++-- zero_bin/rpc/src/native/mod.rs | 3 ++- zero_bin/rpc/src/native/state.rs | 6 +++-- zero_bin/rpc/src/provider.rs | 44 +++++++++++++++++++++++++------- zero_bin/rpc/src/retry.rs | 24 ++++++++++++++--- 8 files changed, 71 insertions(+), 23 deletions(-) diff --git a/zero_bin/leader/src/client.rs b/zero_bin/leader/src/client.rs index 86451791f..4ef5ed3cd 100644 --- a/zero_bin/leader/src/client.rs +++ b/zero_bin/leader/src/client.rs @@ -45,7 +45,7 @@ pub(crate) async fn client_main( rpc_params.rpc_url.clone(), rpc_params.backoff, rpc_params.max_retries, - ), + )?, )); check_previous_proof_and_checkpoint( params.checkpoint_block_number, diff --git a/zero_bin/rpc/src/jerigon.rs b/zero_bin/rpc/src/jerigon.rs index 891421971..8fc4630eb 100644 --- a/zero_bin/rpc/src/jerigon.rs +++ b/zero_bin/rpc/src/jerigon.rs @@ -29,7 +29,8 @@ where { // Grab trace information let tx_results = cached_provider - .as_provider() + .get_provider() + .await? .raw_request::<_, Vec>( "debug_traceBlockByNumber".into(), (target_block_id, json!({"tracer": "zeroTracer"})), @@ -39,7 +40,8 @@ where // Grab block witness info (packed as combined trie pre-images) let block_witness = cached_provider - .as_provider() + .get_provider() + .await? .raw_request::<_, String>("eth_getWitness".into(), vec![target_block_id]) .await?; diff --git a/zero_bin/rpc/src/lib.rs b/zero_bin/rpc/src/lib.rs index c0abb29da..053380239 100644 --- a/zero_bin/rpc/src/lib.rs +++ b/zero_bin/rpc/src/lib.rs @@ -154,7 +154,7 @@ where // We use that execution not to produce a new contract bytecode - instead, we // return hashes. To look at the code use `cast disassemble `. let bytes = cached_provider - .as_provider() + .get_provider().await? .raw_request::<_, Bytes>( "eth_call".into(), (json!({"input": "0x60005B60010180430340816020025280610101116300000002576120205FF3"}), target_block_number), @@ -216,7 +216,7 @@ where .header .number .context("target block is missing field `number`")?; - let chain_id = cached_provider.as_provider().get_chain_id().await?; + let chain_id = cached_provider.get_provider().await?.get_chain_id().await?; let prev_hashes = fetch_previous_block_hashes(cached_provider, target_block_number).await?; let other_data = OtherBlockData { diff --git a/zero_bin/rpc/src/main.rs b/zero_bin/rpc/src/main.rs index ce3ea2e4f..b878b8cf1 100644 --- a/zero_bin/rpc/src/main.rs +++ b/zero_bin/rpc/src/main.rs @@ -117,7 +117,7 @@ impl Cli { self.config.rpc_url.clone(), self.config.backoff, self.config.max_retries, - ))); + )?)); match self.command { Command::Fetch { @@ -141,7 +141,8 @@ impl Cli { // Get transaction info. match cached_provider .clone() - .as_provider() + .get_provider() + .await? .get_transaction_by_hash(tx_hash) .await? { diff --git a/zero_bin/rpc/src/native/mod.rs b/zero_bin/rpc/src/native/mod.rs index 1f61d7b26..95b38d22f 100644 --- a/zero_bin/rpc/src/native/mod.rs +++ b/zero_bin/rpc/src/native/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::ops::Deref; use std::sync::Arc; use alloy::{ @@ -53,7 +54,7 @@ where .await?; let (code_db, txn_info) = - txn::process_transactions(&block, cached_provider.as_provider()).await?; + txn::process_transactions(&block, cached_provider.get_provider().await?.deref()).await?; let trie_pre_images = state::process_state_witness(cached_provider, block, &txn_info).await?; Ok(BlockTrace { diff --git a/zero_bin/rpc/src/native/state.rs b/zero_bin/rpc/src/native/state.rs index b5017b394..e41a8839f 100644 --- a/zero_bin/rpc/src/native/state.rs +++ b/zero_bin/rpc/src/native/state.rs @@ -182,7 +182,8 @@ where let provider = provider.clone(); async move { let proof = provider - .as_provider() + .get_provider() + .await? .get_proof(address, keys.into_iter().collect()) .block_id((block_number - 1).into()) .await @@ -196,7 +197,8 @@ where let provider = provider.clone(); async move { let proof = provider - .as_provider() + .get_provider() + .await? .get_proof(address, keys.into_iter().collect()) .block_id(block_number.into()) .await diff --git a/zero_bin/rpc/src/provider.rs b/zero_bin/rpc/src/provider.rs index f0c6a2691..05866168c 100644 --- a/zero_bin/rpc/src/provider.rs +++ b/zero_bin/rpc/src/provider.rs @@ -1,22 +1,48 @@ +use std::ops::{Deref, DerefMut}; use std::sync::Arc; use alloy::primitives::BlockHash; use alloy::rpc::types::{Block, BlockId, BlockTransactionsKind}; use alloy::{providers::Provider, transports::Transport}; use anyhow::Context; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, Semaphore, SemaphorePermit}; const CACHE_SIZE: usize = 1024; +const MAX_NUMBER_OF_PARALLEL_REQUESTS: usize = 64; /// Wrapper around alloy provider to cache blocks and other /// frequently used data. pub struct CachedProvider { - provider: ProviderT, + provider: Arc, + // `Alloy` provider is using `Reqwest` http client under the hood. It has an unbounded + // connection pool. We need to limit the number of parallel connections by ourselves, so we + // use semaphore to count the number of parallel RPC requests happening at any moment with + // CachedProvider. + semaphore: Arc, blocks_by_number: Arc>>, blocks_by_hash: Arc>>, _phantom: std::marker::PhantomData, } +pub struct ProviderGuard<'a, ProviderT> { + provider: Arc, + _permit: SemaphorePermit<'a>, +} + +impl<'a, ProviderT> Deref for ProviderGuard<'a, ProviderT> { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.provider + } +} + +impl DerefMut for ProviderGuard<'_, ProviderT> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.provider + } +} + impl CachedProvider where ProviderT: Provider, @@ -24,7 +50,8 @@ where { pub fn new(provider: ProviderT) -> Self { Self { - provider, + provider: provider.into(), + semaphore: Arc::new(Semaphore::new(MAX_NUMBER_OF_PARALLEL_REQUESTS)), blocks_by_number: Arc::new(Mutex::new(lru::LruCache::new( std::num::NonZero::new(CACHE_SIZE).unwrap(), ))), @@ -35,12 +62,11 @@ where } } - pub fn as_mut_provider(&mut self) -> &mut ProviderT { - &mut self.provider - } - - pub const fn as_provider(&self) -> &ProviderT { - &self.provider + pub async fn get_provider(&self) -> Result, anyhow::Error> { + Ok(ProviderGuard { + provider: self.provider.clone(), + _permit: self.semaphore.acquire().await?, + }) } /// Retrieves block by number or hash, caching it if it's not already diff --git a/zero_bin/rpc/src/retry.rs b/zero_bin/rpc/src/retry.rs index 2fe81cb60..7e1a1160d 100644 --- a/zero_bin/rpc/src/retry.rs +++ b/zero_bin/rpc/src/retry.rs @@ -1,9 +1,11 @@ +use std::time::Duration; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; +use alloy::transports::http::reqwest; use alloy::{ providers::{ProviderBuilder, RootProvider}, rpc::{ @@ -14,6 +16,9 @@ use alloy::{ }; use tower::{retry::Policy, Layer, Service}; +const HTTP_CLIENT_CONNECTION_POOL_IDLE_TIMEOUT: u64 = 90; +const HTTP_CLIENT_MAX_IDLE_CONNECTIONS_PER_HOST: usize = 64; + #[derive(Debug)] pub struct RetryPolicy { backoff: tokio::time::Duration, @@ -138,11 +143,22 @@ pub fn build_http_retry_provider( rpc_url: url::Url, backoff: u64, max_retries: u32, -) -> RootProvider> { +) -> Result>, anyhow::Error> { let retry_policy = RetryLayer::new(RetryPolicy::new( - tokio::time::Duration::from_millis(backoff), + Duration::from_millis(backoff), max_retries, )); - let client = ClientBuilder::default().layer(retry_policy).http(rpc_url); - ProviderBuilder::new().on_client(client) + let reqwest_client = reqwest::ClientBuilder::new() + .pool_max_idle_per_host(HTTP_CLIENT_MAX_IDLE_CONNECTIONS_PER_HOST) + .pool_idle_timeout(Duration::from_secs( + HTTP_CLIENT_CONNECTION_POOL_IDLE_TIMEOUT, + )) + .build()?; + + let http = alloy::transports::http::Http::with_client(reqwest_client, rpc_url); + let is_local = http.guess_local(); + let client = ClientBuilder::default() + .layer(retry_policy) + .transport(http, is_local); + Ok(ProviderBuilder::new().on_client(client)) }