diff --git a/auction-server/Cargo.lock b/auction-server/Cargo.lock index a814317c..103dc091 100644 --- a/auction-server/Cargo.lock +++ b/auction-server/Cargo.lock @@ -170,6 +170,7 @@ dependencies = [ "axum-macros", "axum-streams", "clap", + "dashmap", "ethers", "futures", "serde", @@ -739,6 +740,19 @@ dependencies = [ "cipher", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.5.0" diff --git a/auction-server/Cargo.toml b/auction-server/Cargo.toml index ffd304f8..3aa89106 100644 --- a/auction-server/Cargo.toml +++ b/auction-server/Cargo.toml @@ -25,3 +25,4 @@ utoipa-swagger-ui = { version = "3.1.4", features = ["axum"] } serde_yaml = "0.9.25" ethers = "2.0.10" axum-macros = "0.4.0" +dashmap = { version = "5.4.0" } diff --git a/auction-server/src/api.rs b/auction-server/src/api.rs index b3a5f26d..48c79217 100644 --- a/auction-server/src/api.rs +++ b/auction-server/src/api.rs @@ -7,7 +7,7 @@ use { }, liquidation::{ OpportunityBid, - OpportunityParamsWithId, + OpportunityParamsWithMetadata, }, }, auction::run_submission_loop, @@ -44,6 +44,7 @@ use { Router, }, clap::crate_version, + dashmap::DashMap, ethers::{ providers::{ Http, @@ -63,6 +64,7 @@ use { sync::{ atomic::{ AtomicBool, + AtomicUsize, Ordering, }, Arc, @@ -92,6 +94,7 @@ async fn root() -> String { mod bid; pub(crate) mod liquidation; +pub(crate) mod ws; pub enum RestError { /// The request contained invalid parameters @@ -177,12 +180,12 @@ pub async fn start_server(run_options: RunOptions) -> Result<()> { schemas(OpportunityParamsV1), schemas(OpportunityBid), schemas(OpportunityParams), - schemas(OpportunityParamsWithId), + schemas(OpportunityParamsWithMetadata), schemas(TokenQty), schemas(BidResult), schemas(ErrorBodyResponse), responses(ErrorBodyResponse), - responses(OpportunityParamsWithId), + responses(OpportunityParamsWithMetadata), responses(BidResult) ), tags( @@ -235,6 +238,10 @@ pub async fn start_server(run_options: RunOptions) -> Result<()> { chains: chain_store?, liquidation_store: LiquidationStore::default(), per_operator: wallet, + ws: ws::WsState { + subscriber_counter: AtomicUsize::new(0), + subscribers: DashMap::new(), + }, }); let server_store = store.clone(); @@ -258,6 +265,7 @@ pub async fn start_server(run_options: RunOptions) -> Result<()> { "/v1/liquidation/opportunities/:opportunity_id/bids", post(liquidation::post_bid), ) + .route("/v1/ws", get(ws::ws_route_handler)) .route("/live", get(live)) .layer(CorsLayer::permissive()) .with_state(server_store); diff --git a/auction-server/src/api/liquidation.rs b/auction-server/src/api/liquidation.rs index f8945e29..b3e0fe1e 100644 --- a/auction-server/src/api/liquidation.rs +++ b/auction-server/src/api/liquidation.rs @@ -5,6 +5,10 @@ use { handle_bid, BidResult, }, + ws::{ + notify_updates, + UpdateEvent::NewOpportunity, + }, ErrorBodyResponse, RestError, }, @@ -60,19 +64,31 @@ use { /// Similar to OpportunityParams, but with the opportunity id included. #[derive(Serialize, Deserialize, ToSchema, Clone, ToResponse)] -pub struct OpportunityParamsWithId { +pub struct OpportunityParamsWithMetadata { /// The opportunity unique id #[schema(example = "f47ac10b-58cc-4372-a567-0e02b2c3d479", value_type=String)] opportunity_id: Uuid, + /// Creation time of the opportunity + #[schema(example = "1700000000")] + creation_time: UnixTimestamp, /// opportunity data #[serde(flatten)] params: OpportunityParams, } -impl Into for LiquidationOpportunity { - fn into(self) -> OpportunityParamsWithId { - OpportunityParamsWithId { +impl OpportunityParamsWithMetadata { + pub fn get_chain_id(&self) -> &ChainId { + match &self.params { + OpportunityParams::V1(params) => ¶ms.chain_id, + } + } +} + +impl Into for LiquidationOpportunity { + fn into(self) -> OpportunityParamsWithMetadata { + OpportunityParamsWithMetadata { opportunity_id: self.id, + creation_time: self.creation_time, params: self.params, } } @@ -90,7 +106,7 @@ impl Into for LiquidationOpportunity { pub async fn post_opportunity( State(store): State>, Json(versioned_params): Json, -) -> Result, RestError> { +) -> Result, RestError> { let params = match versioned_params.clone() { OpportunityParams::V1(params) => params, }; @@ -127,22 +143,22 @@ pub async fn post_opportunity( } } - opportunities_existing.push(opportunity); + opportunities_existing.push(opportunity.clone()); } else { - write_lock.insert(params.permission_key.clone(), vec![opportunity]); + write_lock.insert(params.permission_key.clone(), vec![opportunity.clone()]); } + notify_updates(&store.ws, NewOpportunity(opportunity.clone().into())).await; + tracing::debug!("number of permission keys: {}", write_lock.len()); tracing::debug!( "number of opportunities for key: {}", write_lock[¶ms.permission_key].len() ); - Ok(OpportunityParamsWithId { - opportunity_id: id, - params: versioned_params, - } - .into()) + let opportunity_with_metadata: OpportunityParamsWithMetadata = opportunity.into(); + + Ok(opportunity_with_metadata.into()) } @@ -162,8 +178,8 @@ params(ChainIdQueryParams))] pub async fn get_opportunities( State(store): State>, query_params: Query, -) -> Result>, RestError> { - let opportunities: Vec = store +) -> Result>, RestError> { + let opportunities: Vec = store .liquidation_store .opportunities .read() @@ -177,7 +193,7 @@ pub async fn get_opportunities( .clone() .into() }) - .filter(|params_with_id: &OpportunityParamsWithId| { + .filter(|params_with_id: &OpportunityParamsWithMetadata| { let params = match ¶ms_with_id.params { OpportunityParams::V1(params) => params, }; diff --git a/auction-server/src/api/ws.rs b/auction-server/src/api/ws.rs new file mode 100644 index 00000000..b1f57e77 --- /dev/null +++ b/auction-server/src/api/ws.rs @@ -0,0 +1,331 @@ +use { + crate::{ + api::{ + liquidation::OpportunityParamsWithMetadata, + SHOULD_EXIT, + }, + config::ChainId, + state::{ + LiquidationOpportunity, + Store, + }, + }, + anyhow::{ + anyhow, + Result, + }, + axum::{ + extract::{ + ws::{ + Message, + WebSocket, + }, + State, + WebSocketUpgrade, + }, + http::HeaderMap, + response::IntoResponse, + }, + dashmap::DashMap, + ethers::types::Chain, + futures::{ + future::join_all, + stream::{ + SplitSink, + SplitStream, + }, + SinkExt, + StreamExt, + }, + serde::{ + Deserialize, + Serialize, + }, + std::{ + collections::HashSet, + sync::{ + atomic::{ + AtomicUsize, + Ordering, + }, + Arc, + }, + time::Duration, + }, + tokio::sync::mpsc, +}; + +pub struct WsState { + pub subscriber_counter: AtomicUsize, + pub subscribers: DashMap>, +} + +#[derive(Deserialize, Debug, Clone)] +#[serde(tag = "type")] +enum ClientMessage { + #[serde(rename = "subscribe")] + Subscribe { chain_ids: Vec }, + #[serde(rename = "unsubscribe")] + Unsubscribe { chain_ids: Vec }, +} + +#[derive(Serialize, Clone)] +#[serde(tag = "type")] +enum ServerMessage { + #[serde(rename = "response")] + Response(ServerResponseMessage), + #[serde(rename = "new_opportunity")] + NewOpportunity { + opportunity: OpportunityParamsWithMetadata, + }, +} + +#[derive(Serialize, Debug, Clone)] +#[serde(tag = "status")] +enum ServerResponseMessage { + #[serde(rename = "success")] + Success, + #[serde(rename = "error")] + Err { error: String }, +} + +pub async fn ws_route_handler( + ws: WebSocketUpgrade, + State(store): State>, +) -> impl IntoResponse { + ws.on_upgrade(move |socket| websocket_handler(socket, store)) +} + +async fn websocket_handler(stream: WebSocket, state: Arc) { + let ws_state = &state.ws; + let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst); + let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN); + let (sender, receiver) = stream.split(); + ws_state.subscribers.insert(id, notify_sender); + let mut subscriber = Subscriber::new(id, state, notify_receiver, receiver, sender); + subscriber.run().await; +} + +#[derive(Clone)] +pub enum UpdateEvent { + NewOpportunity(OpportunityParamsWithMetadata), +} + +pub type SubscriberId = usize; + +/// Subscriber is an actor that handles a single websocket connection. +/// It listens to the store for updates and sends them to the client. +pub struct Subscriber { + id: SubscriberId, + closed: bool, + store: Arc, + notify_receiver: mpsc::Receiver, + receiver: SplitStream, + sender: SplitSink, + chain_ids: HashSet, + ping_interval: tokio::time::Interval, + exit_check_interval: tokio::time::Interval, + responded_to_ping: bool, +} + +const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30); +const NOTIFICATIONS_CHAN_LEN: usize = 1000; + +impl Subscriber { + pub fn new( + id: SubscriberId, + store: Arc, + notify_receiver: mpsc::Receiver, + receiver: SplitStream, + sender: SplitSink, + ) -> Self { + Self { + id, + closed: false, + store, + notify_receiver, + receiver, + sender, + chain_ids: HashSet::new(), + ping_interval: tokio::time::interval(PING_INTERVAL_DURATION), + exit_check_interval: tokio::time::interval(Duration::from_secs(5)), + responded_to_ping: true, // We start with true so we don't close the connection immediately + } + } + + #[tracing::instrument(skip(self))] + pub async fn run(&mut self) { + while !self.closed { + if let Err(e) = self.handle_next().await { + tracing::debug!(subscriber = self.id, error = ?e, "Error Handling Subscriber Message."); + break; + } + } + } + + async fn handle_next(&mut self) -> Result<()> { + tokio::select! { + maybe_update_event = self.notify_receiver.recv() => { + match maybe_update_event { + Some(event) => self.handle_update(event).await, + None => Err(anyhow!("Update channel closed. This should never happen. Closing connection.")) + } + }, + maybe_message_or_err = self.receiver.next() => { + self.handle_client_message( + maybe_message_or_err.ok_or(anyhow!("Client channel is closed"))?? + ).await + }, + _ = self.ping_interval.tick() => { + if !self.responded_to_ping { + return Err(anyhow!("Subscriber did not respond to ping. Closing connection.")); + } + self.responded_to_ping = false; + self.sender.send(Message::Ping(vec![])).await?; + Ok(()) + }, + _ = self.exit_check_interval.tick() => { + if SHOULD_EXIT.load(Ordering::Acquire) { + self.sender.close().await?; + self.closed = true; + return Err(anyhow!("Application is shutting down. Closing connection.")); + } + Ok(()) + } + } + } + + async fn handle_update(&mut self, event: UpdateEvent) -> Result<()> { + match event.clone() { + UpdateEvent::NewOpportunity(opportunity) => { + if !self.chain_ids.contains(opportunity.get_chain_id()) { + // Irrelevant update + return Ok(()); + } + let message = + serde_json::to_string(&ServerMessage::NewOpportunity { opportunity })?; + self.sender.send(message.into()).await?; + } + } + + Ok(()) + } + + #[tracing::instrument(skip(self, message))] + async fn handle_client_message(&mut self, message: Message) -> Result<()> { + let maybe_client_message = match message { + Message::Close(_) => { + // Closing the connection. We don't remove it from the subscribers + // list, instead when the Subscriber struct is dropped the channel + // to subscribers list will be closed and it will eventually get + // removed. + tracing::trace!(id = self.id, "Subscriber Closed Connection."); + + // Send the close message to gracefully shut down the connection + // Otherwise the client might get an abnormal Websocket closure + // error. + self.sender.close().await?; + self.closed = true; + return Ok(()); + } + Message::Text(text) => serde_json::from_str::(&text), + Message::Binary(data) => serde_json::from_slice::(&data), + Message::Ping(_) => { + // Axum will send Pong automatically + return Ok(()); + } + Message::Pong(_) => { + self.responded_to_ping = true; + return Ok(()); + } + }; + + match maybe_client_message { + Err(e) => { + self.sender + .send( + serde_json::to_string(&ServerMessage::Response( + ServerResponseMessage::Err { + error: e.to_string(), + }, + ))? + .into(), + ) + .await?; + return Ok(()); + } + + Ok(ClientMessage::Subscribe { chain_ids }) => { + let available_chain_ids: Vec<&ChainId> = self.store.chains.keys().collect(); + + let not_found_chain_ids: Vec<&ChainId> = chain_ids + .iter() + .filter(|chain_id| !available_chain_ids.contains(chain_id)) + .collect(); + + // If there is a single chain id that is not found, we don't subscribe to any of the + // asked correct chain ids and return an error to be more explicit and clear. + if !not_found_chain_ids.is_empty() { + self.sender + .send( + serde_json::to_string(&ServerMessage::Response( + ServerResponseMessage::Err { + error: format!( + "Chain id(s) with id(s) {:?} not found", + not_found_chain_ids + ), + }, + ))? + .into(), + ) + .await?; + return Ok(()); + } else { + self.chain_ids.extend(chain_ids.into_iter()); + } + } + Ok(ClientMessage::Unsubscribe { chain_ids }) => { + self.chain_ids + .retain(|chain_id| !chain_ids.contains(chain_id)); + } + } + + + self.sender + .send( + serde_json::to_string(&ServerMessage::Response(ServerResponseMessage::Success))? + .into(), + ) + .await?; + + Ok(()) + } +} + + +pub async fn notify_updates(ws_state: &WsState, event: UpdateEvent) { + let closed_subscribers: Vec> = + join_all(ws_state.subscribers.iter_mut().map(|subscriber| { + let event = event.clone(); + async move { + match subscriber.send(event).await { + Ok(_) => None, + Err(_) => { + // An error here indicates the channel is closed (which may happen either when the + // client has sent Message::Close or some other abrupt disconnection). We remove + // subscribers only when send fails so we can handle closure only once when we are + // able to see send() fail. + Some(*subscriber.key()) + } + } + } + })) + .await; + + // Remove closed_subscribers from ws_state + closed_subscribers.into_iter().for_each(|id| { + if let Some(id) = id { + ws_state.subscribers.remove(&id); + } + }); +} diff --git a/auction-server/src/state.rs b/auction-server/src/state.rs index 90ca2a20..475b1153 100644 --- a/auction-server/src/state.rs +++ b/auction-server/src/state.rs @@ -1,7 +1,10 @@ use { - crate::config::{ - ChainId, - EthereumConfig, + crate::{ + api::ws::WsState, + config::{ + ChainId, + EthereumConfig, + }, }, ethers::{ providers::{ @@ -119,4 +122,5 @@ pub struct Store { pub chains: HashMap, pub liquidation_store: LiquidationStore, pub per_operator: LocalWallet, + pub ws: WsState, }