Skip to content

Commit

Permalink
Start starknet_unsubscribe [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
FabijanC committed Oct 31, 2024
1 parent 72ea6d2 commit cc05e4f
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 21 deletions.
14 changes: 12 additions & 2 deletions crates/starknet-devnet-server/src/api/json_rpc/endpoints_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use starknet_core::error::Error;
use starknet_rs_core::types::{BlockId, BlockTag};

use super::error::ApiError;
use super::models::BlockIdInput;
use super::models::{BlockIdInput, SubscriptionIdInput};
use super::{JsonRpcHandler, JsonRpcSubscriptionRequest};
use crate::rpc_core::request::Id;
use crate::subscribe::{SocketId, SubscriptionNotification};
Expand All @@ -22,7 +22,17 @@ impl JsonRpcHandler {
JsonRpcSubscriptionRequest::TransactionStatus => todo!(),
JsonRpcSubscriptionRequest::PendingTransactions => todo!(),
JsonRpcSubscriptionRequest::Events => todo!(),
JsonRpcSubscriptionRequest::Unsubscribe => todo!(),
JsonRpcSubscriptionRequest::Unsubscribe(SubscriptionIdInput { subscription_id }) => {
let mut sockets = self.api.sockets.lock().await;
let socket_context = sockets.get_mut(&socket_id).ok_or(
ApiError::StarknetDevnetError(Error::UnexpectedInternalError {
msg: format!("Missing socket ID: {socket_id}"),
}),
)?;

socket_context.unsubscribe(subscription_id).await;
Ok(())
}
}
}

Expand Down
10 changes: 5 additions & 5 deletions crates/starknet-devnet-server/src/api/json_rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use models::{
BlockAndClassHashInput, BlockAndContractAddressInput, BlockAndIndexInput, CallInput,
EstimateFeeInput, EventsInput, GetStorageInput, L1TransactionHashInput, TransactionHashInput,
TransactionHashOutput,
EstimateFeeInput, EventsInput, GetStorageInput, L1TransactionHashInput, SubscriptionIdInput,
TransactionHashInput, TransactionHashOutput,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
Expand Down Expand Up @@ -714,7 +714,7 @@ pub enum JsonRpcSubscriptionRequest {
#[serde(rename = "starknet_subscribeEvents")]
Events,
#[serde(rename = "starknet_unsubscribe")]
Unsubscribe,
Unsubscribe(SubscriptionIdInput),
}

impl std::fmt::Display for JsonRpcRequest {
Expand Down Expand Up @@ -1423,8 +1423,8 @@ mod response_tests {
use crate::api::json_rpc::ToRpcResponseResult;

#[test]
fn serializing_starknet_response_empty_variant_has_to_produce_empty_json_object_when_converted_to_rpc_result()
{
fn serializing_starknet_response_empty_variant_has_to_produce_empty_json_object_when_converted_to_rpc_result(
) {
assert_eq!(
r#"{"result":{}}"#,
serde_json::to_string(
Expand Down
6 changes: 6 additions & 0 deletions crates/starknet-devnet-server/src/api/json_rpc/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ pub struct L1TransactionHashInput {
pub transaction_hash: Hash256,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(deny_unknown_fields)]
pub struct SubscriptionIdInput {
pub subscription_id: i64,
}

#[cfg(test)]
mod tests {
use starknet_rs_core::types::{BlockId as ImportedBlockId, BlockTag, Felt};
Expand Down
34 changes: 20 additions & 14 deletions crates/starknet-devnet-server/src/subscribe.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use axum::extract::ws::{Message, WebSocket};
use futures::stream::SplitSink;
Expand All @@ -11,17 +11,16 @@ use crate::rpc_core::request::Id;

pub type SocketId = u64;

type SubscriptionId = i64;

#[derive(Debug)]
pub enum Subscription {
NewHeads(Id),
NewHeads,
TransactionStatus,
PendingTransactions,
Events,
Reorg,
}

type SubscriptionId = Id;

#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum SubscriptionConfirmation {
Expand Down Expand Up @@ -59,7 +58,7 @@ impl SubscriptionNotification {
#[derive(Debug, Clone)]
pub enum SubscriptionResponse {
Confirmation { rpc_request_id: Id, result: SubscriptionConfirmation },
Notification { subscription_id: Id, data: SubscriptionNotification },
Notification { subscription_id: SubscriptionId, data: SubscriptionNotification },
}

impl SubscriptionResponse {
Expand Down Expand Up @@ -90,12 +89,12 @@ impl SubscriptionResponse {
pub struct SocketContext {
/// The sender part of the socket's own channel
sender: Arc<Mutex<SplitSink<WebSocket, Message>>>,
subscriptions: Vec<Subscription>,
subscriptions: HashMap<SubscriptionId, Subscription>,
}

impl SocketContext {
pub fn from_sender(sender: Arc<Mutex<SplitSink<WebSocket, Message>>>) -> Self {
Self { sender, subscriptions: vec![] }
Self { sender, subscriptions: HashMap::new() }
}

async fn send(&self, subscription_response: SubscriptionResponse) {
Expand All @@ -107,28 +106,35 @@ impl SocketContext {
}

pub async fn subscribe(&mut self, rpc_request_id: Id) -> SubscriptionId {
let subscription_id = Id::Number(rand::random()); // TODO safe? negative?
self.subscriptions.push(Subscription::NewHeads(subscription_id.clone()));
let subscription_id = rand::random(); // TODO safe? negative?
self.subscriptions.insert(subscription_id, Subscription::NewHeads);

self.send(SubscriptionResponse::Confirmation {
rpc_request_id,
result: SubscriptionConfirmation::NewHeadsConfirmation(subscription_id.clone()),
result: SubscriptionConfirmation::NewHeadsConfirmation(subscription_id),
})
.await;

subscription_id
}

pub async fn unsubscribe(&mut self, subscription_id: SubscriptionId) {
match self.subscriptions.remove(&subscription_id) {
Some(_) => todo!("return true"),
None => todo!("return INVALID_SUBSCRIPTION_ID"),
}
}

pub async fn notify(&self, subscription_id: SubscriptionId, data: SubscriptionNotification) {
self.send(SubscriptionResponse::Notification { subscription_id, data }).await;
}

pub async fn notify_subscribers(&self, data: SubscriptionNotification) {
for subscription in self.subscriptions.iter() {
for (subscription_id, subscription) in self.subscriptions.iter() {
match subscription {
Subscription::NewHeads(subscription_id) => {
Subscription::NewHeads => {
if let SubscriptionNotification::NewHeadsNotification(_) = data {
self.notify(subscription_id.clone(), data.clone()).await;
self.notify(*subscription_id, data.clone()).await;
}
}
other => println!("DEBUG unsupported subscription: {other:?}"),
Expand Down

0 comments on commit cc05e4f

Please sign in to comment.