From 903651a287c6557b08a3a784ba3bde0eedb5dc24 Mon Sep 17 00:00:00 2001 From: Danial Mehrjerdi Date: Wed, 11 Dec 2024 11:46:53 +0100 Subject: [PATCH] Better stream handling for ws --- Cargo.lock | 1 + sdk/rust/Cargo.toml | 2 +- sdk/rust/simple-searcher/src/main.rs | 22 +++++--- sdk/rust/src/lib.rs | 80 ++++++++++++++++++++++------ 4 files changed, 80 insertions(+), 25 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 953f0d16..a7356e41 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8076,6 +8076,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/sdk/rust/Cargo.toml b/sdk/rust/Cargo.toml index 22ed2bd1..22cfa3e9 100644 --- a/sdk/rust/Cargo.toml +++ b/sdk/rust/Cargo.toml @@ -14,7 +14,7 @@ serde = { workspace = true } strum = { workspace = true } serde_json = { workspace = true } tokio-tungstenite = { version = "0.24.0", features = ["native-tls"] } -tokio-stream = { workspace = true } +tokio-stream = { workspace = true, features = ["sync"] } tokio = { workspace = true } futures-util = "0.3.31" ethers = { workspace = true } diff --git a/sdk/rust/simple-searcher/src/main.rs b/sdk/rust/simple-searcher/src/main.rs index f5a37a39..f343a541 100644 --- a/sdk/rust/simple-searcher/src/main.rs +++ b/sdk/rust/simple-searcher/src/main.rs @@ -21,10 +21,7 @@ use { WsClient, }, rand::Rng, - std::{ - collections::HashMap, - sync::Arc, - }, + std::collections::HashMap, time::{ Duration, OffsetDateTime, @@ -37,7 +34,7 @@ async fn random() -> U256 { U256::from(rng.gen::()) } -async fn handle_opportunity(ws_client: Arc, opportunity: Opportunity) -> Result<()> { +async fn handle_opportunity(ws_client: WsClient, opportunity: Opportunity) -> Result<()> { let bid = match opportunity { opportunity::Opportunity::Evm(opportunity) => { // Assess opportunity @@ -96,10 +93,10 @@ async fn main() -> Result<()> { println!("Opportunities: {:?}", opportunities.len()); - let ws_client = Arc::new(client.connect_websocket().await.map_err(|e| { + let ws_client = client.connect_websocket().await.map_err(|e| { eprintln!("Failed to connect websocket: {:?}", e); anyhow!("Failed to connect websocket") - })?); + })?; ws_client .chain_subscribe(vec![ChainId::DevelopmentEvm, ChainId::DevelopmentSvm]) @@ -109,9 +106,18 @@ async fn main() -> Result<()> { anyhow!("Failed to subscribe chains") })?; - let mut stream = ws_client.update_stream.write().await; + let mut stream = ws_client.get_update_stream(); let mut block_hash_map = HashMap::new(); while let Some(update) = stream.next().await { + let update = match update { + Ok(update) => update, + Err(e) => { + // The stream is fallen behind + eprintln!("The stream is fallen behind: {:?}", e); + continue; + } + }; + match update { ServerUpdateResponse::NewOpportunity { opportunity } => { println!("New opportunity: {:?}", opportunity); diff --git a/sdk/rust/src/lib.rs b/sdk/rust/src/lib.rs index 6021704d..9abb1503 100644 --- a/sdk/rust/src/lib.rs +++ b/sdk/rust/src/lib.rs @@ -17,6 +17,7 @@ use { }, futures_util::{ SinkExt, + Stream, StreamExt, }, reqwest::Response, @@ -27,6 +28,13 @@ use { }, std::{ collections::BTreeMap, + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{ + Context, + Poll, + }, time::Duration, }, strum::{ @@ -36,13 +44,17 @@ use { tokio::{ net::TcpStream, sync::{ + broadcast, mpsc, oneshot, RwLock, }, time::sleep, }, - tokio_stream::wrappers::UnboundedReceiverStream, + tokio_stream::wrappers::{ + errors::BroadcastStreamRecvError, + BroadcastStream, + }, tokio_tungstenite::{ connect_async, tungstenite::Message, @@ -121,13 +133,18 @@ type WsRequest = ( oneshot::Sender, ); -pub struct WsClient { +pub struct WsClientInner { #[allow(dead_code)] ws: tokio::task::JoinHandle<()>, request_sender: mpsc::UnboundedSender, request_id: RwLock, - pub update_stream: RwLock>, + update_sender: broadcast::Sender, +} + +#[derive(Clone)] +pub struct WsClient { + inner: Arc, } #[derive(Deserialize)] @@ -137,11 +154,39 @@ enum MessageType { Update(api_types::ws::ServerUpdateResponse), } +pub struct WsClientUpdateStream<'a> { + stream: BroadcastStream, + _lifetime: PhantomData<&'a ()>, +} + +impl WsClientUpdateStream<'_> { + pub fn new(stream: BroadcastStream) -> Self { + Self { + stream, + _lifetime: PhantomData, + } + } +} + +// Implementing Stream trait +impl Stream for WsClientUpdateStream<'_> { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let stream = &mut self.get_mut().stream; + stream.poll_next_unpin(cx) + } +} + impl WsClient { + pub fn get_update_stream(&self) -> WsClientUpdateStream { + WsClientUpdateStream::new(BroadcastStream::new(self.inner.update_sender.subscribe())) + } + async fn run( mut ws_stream: WebSocketStream>, mut request_receiver: mpsc::UnboundedReceiver, - update_sender: mpsc::UnboundedSender, + update_sender: broadcast::Sender, ) { let mut requests_map = BTreeMap::>::new(); loop { @@ -186,9 +231,8 @@ impl WsClient { response.id.and_then(|id| requests_map.remove(&id)).map(|sender| sender.send(response.result)); } MessageType::Update(update) => { - if update_sender.send(update).is_err() { - break; - } + _ = update_sender.send(update); + continue; } } } @@ -199,7 +243,10 @@ impl WsClient { requests_map.insert(request.id.clone(), response_sender); } } - None => break, + None => { + println!("Request receiver closed"); + break; + } } } } @@ -210,7 +257,7 @@ impl WsClient { &self, message: api_types::ws::ClientMessage, ) -> Result { - let mut write_gaurd = self.request_id.write().await; + let mut write_gaurd = self.inner.request_id.write().await; let request_id = write_gaurd.to_string(); *write_gaurd += 1; drop(write_gaurd); @@ -222,6 +269,7 @@ impl WsClient { let (response_sender, response_receiver) = oneshot::channel(); if self + .inner .request_sender .send((request, response_sender)) .is_err() @@ -354,15 +402,15 @@ impl Client { .map_err(|e| ClientError::SubscribeFailed(e.to_string()))?; let (request_sender, request_receiver) = mpsc::unbounded_channel(); - let (update_sender, update_receiver) = mpsc::unbounded_channel(); + let (update_sender, _) = broadcast::channel(1000); Ok(WsClient { - request_sender, - update_stream: RwLock::new(UnboundedReceiverStream::< - api_types::ws::ServerUpdateResponse, - >::new(update_receiver)), - request_id: RwLock::new(0), - ws: tokio::spawn(WsClient::run(ws_stream, request_receiver, update_sender)), + inner: Arc::new(WsClientInner { + request_sender, + update_sender: update_sender.clone(), + request_id: RwLock::new(0), + ws: tokio::spawn(WsClient::run(ws_stream, request_receiver, update_sender)), + }), }) }