Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support subscribing to events via ws #652

Merged
merged 10 commits into from
Nov 27, 2024
2 changes: 1 addition & 1 deletion crates/starknet-devnet-core/src/starknet/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub(crate) fn get_events(
/// * `address` - Optional. The address to filter the event by.
/// * `keys_filter` - Optional. The keys to filter the event by.
/// * `event` - The event to check if it applies to the filters.
fn check_if_filter_applies_for_event(
pub fn check_if_filter_applies_for_event(
address: &Option<ContractAddress>,
keys_filter: &Option<Vec<Vec<Felt>>>,
event: &Event,
Expand Down
13 changes: 12 additions & 1 deletion crates/starknet-devnet-core/src/starknet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ mod add_l1_handler_transaction;
mod cheats;
pub(crate) mod defaulter;
mod estimations;
mod events;
pub mod events;
mod get_class_impls;
mod predeployed;
pub mod starknet_config;
Expand Down Expand Up @@ -1034,6 +1034,17 @@ impl Starknet {
.ok_or(Error::NoTransaction)
}

pub fn get_unlimited_events(
&self,
from_block: Option<BlockId>,
to_block: Option<BlockId>,
address: Option<ContractAddress>,
keys: Option<Vec<Vec<Felt>>>,
) -> DevnetResult<Vec<EmittedEvent>> {
events::get_events(self, from_block, to_block, address, keys, 0, None)
.map(|(emitted_events, _)| emitted_events)
}

pub fn get_events(
&self,
from_block: Option<BlockId>,
Expand Down
59 changes: 52 additions & 7 deletions crates/starknet-devnet-server/src/api/json_rpc/endpoints_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use starknet_types::starknet_api::block::{BlockNumber, BlockStatus};

use super::error::ApiError;
use super::models::{
BlockInput, PendingTransactionsSubscriptionInput, SubscriptionIdInput, TransactionBlockInput,
BlockInput, EventsSubscriptionInput, PendingTransactionsSubscriptionInput, SubscriptionIdInput,
TransactionBlockInput,
};
use super::{JsonRpcHandler, JsonRpcSubscriptionRequest};
use crate::rpc_core::request::Id;
Expand All @@ -33,7 +34,9 @@ impl JsonRpcHandler {
JsonRpcSubscriptionRequest::PendingTransactions(data) => {
self.subscribe_pending_txs(data, rpc_request_id, socket_id).await
}
JsonRpcSubscriptionRequest::Events => todo!(),
JsonRpcSubscriptionRequest::Events(data) => {
self.subscribe_events(data, rpc_request_id, socket_id).await
}
JsonRpcSubscriptionRequest::Unsubscribe(SubscriptionIdInput { subscription_id }) => {
let mut sockets = self.api.sockets.lock().await;
let socket_context = sockets.get_mut(&socket_id).ok_or(
Expand All @@ -42,15 +45,14 @@ impl JsonRpcHandler {
}),
)?;

socket_context.unsubscribe(rpc_request_id, subscription_id).await?;
Ok(())
socket_context.unsubscribe(rpc_request_id, subscription_id).await
}
}
}

/// Returns (starting block number, latest block number). Returns an error in case the starting
/// block does not exist or there are too many blocks.
async fn convert_to_block_number_range(
async fn get_validated_block_number_range(
&self,
mut starting_block_id: BlockId,
) -> Result<(u64, u64), ApiError> {
Expand Down Expand Up @@ -105,7 +107,7 @@ impl JsonRpcHandler {
};

let (query_block_number, latest_block_number) =
self.convert_to_block_number_range(block_id).await?;
self.get_validated_block_number_range(block_id).await?;

// perform the actual subscription
let mut sockets = self.api.sockets.lock().await;
Expand Down Expand Up @@ -233,7 +235,7 @@ impl JsonRpcHandler {
};

let (query_block_number, latest_block_number) =
self.convert_to_block_number_range(query_block_id).await?;
self.get_validated_block_number_range(query_block_id).await?;

// perform the actual subscription
let mut sockets = self.api.sockets.lock().await;
Expand Down Expand Up @@ -280,4 +282,47 @@ impl JsonRpcHandler {

Ok(())
}

async fn subscribe_events(
&self,
maybe_subscription_input: Option<EventsSubscriptionInput>,
rpc_request_id: Id,
socket_id: SocketId,
) -> Result<(), ApiError> {
let address = maybe_subscription_input
.as_ref()
.and_then(|subscription_input| subscription_input.from_address);

let starting_block_id = maybe_subscription_input
.as_ref()
.and_then(|subscription_input| subscription_input.block.as_ref())
.map(|b| b.0)
.unwrap_or(BlockId::Tag(BlockTag::Latest));

self.get_validated_block_number_range(starting_block_id).await?;

let keys_filter =
maybe_subscription_input.and_then(|subscription_input| subscription_input.keys);

let mut sockets = self.api.sockets.lock().await;
let socket_context = sockets.get_mut(&socket_id).ok_or(ApiError::StarknetDevnetError(
Error::UnexpectedInternalError { msg: format!("Unregistered socket ID: {socket_id}") },
))?;

let subscription = Subscription::Events { address, keys_filter: keys_filter.clone() };
let subscription_id = socket_context.subscribe(rpc_request_id, subscription).await;

let events = self.api.starknet.lock().await.get_unlimited_events(
Some(starting_block_id),
Some(BlockId::Tag(BlockTag::Latest)),
address,
keys_filter,
)?;

for event in events {
socket_context.notify(subscription_id, SubscriptionNotification::Event(event)).await;
}

Ok(())
}
}
19 changes: 15 additions & 4 deletions crates/starknet-devnet-server/src/api/json_rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use models::{
BlockAndClassHashInput, BlockAndContractAddressInput, BlockAndIndexInput, BlockInput,
CallInput, EstimateFeeInput, EventsInput, GetStorageInput, L1TransactionHashInput,
PendingTransactionsSubscriptionInput, SubscriptionIdInput, TransactionBlockInput,
TransactionHashInput, TransactionHashOutput,
CallInput, EstimateFeeInput, EventsInput, EventsSubscriptionInput, GetStorageInput,
L1TransactionHashInput, PendingTransactionsSubscriptionInput, SubscriptionIdInput,
TransactionBlockInput, TransactionHashInput, TransactionHashOutput,
};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -311,6 +311,17 @@ impl JsonRpcHandler {
}),
));
}

let events = starknet.get_unlimited_events(
Some(BlockId::Tag(BlockTag::Latest)),
Some(BlockId::Tag(BlockTag::Latest)),
None,
None,
)?;

for event in events {
notifications.push(SubscriptionNotification::Event(event));
}
}

let sockets = self.api.sockets.lock().await;
Expand Down Expand Up @@ -757,7 +768,7 @@ pub enum JsonRpcSubscriptionRequest {
#[serde(rename = "starknet_subscribePendingTransactions", with = "optional_params")]
PendingTransactions(Option<PendingTransactionsSubscriptionInput>),
#[serde(rename = "starknet_subscribeEvents")]
Events,
Events(Option<EventsSubscriptionInput>),
#[serde(rename = "starknet_unsubscribe")]
Unsubscribe(SubscriptionIdInput),
}
Expand Down
12 changes: 11 additions & 1 deletion crates/starknet-devnet-server/src/api/json_rpc/models.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
use starknet_rs_core::types::{Hash256, TransactionExecutionStatus, TransactionFinalityStatus};
use starknet_rs_core::types::{
Felt, Hash256, TransactionExecutionStatus, TransactionFinalityStatus,
};
use starknet_types::contract_address::ContractAddress;
use starknet_types::felt::{BlockHash, ClassHash, TransactionHash};
use starknet_types::patricia_key::PatriciaKey;
Expand Down Expand Up @@ -205,6 +207,14 @@ pub struct PendingTransactionsSubscriptionInput {
pub sender_address: Option<Vec<ContractAddress>>,
}

#[derive(Deserialize, Clone, Debug)]
#[serde(deny_unknown_fields)]
pub struct EventsSubscriptionInput {
pub block: Option<BlockId>,
pub from_address: Option<ContractAddress>,
pub keys: Option<Vec<Vec<Felt>>>,
}

#[cfg(test)]
mod tests {
use starknet_rs_core::types::{BlockId as ImportedBlockId, BlockTag, Felt};
Expand Down
18 changes: 13 additions & 5 deletions crates/starknet-devnet-server/src/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use axum::extract::ws::{Message, WebSocket};
use futures::stream::SplitSink;
use futures::SinkExt;
use serde::{self, Serialize};
use starknet_rs_core::types::BlockTag;
use starknet_core::starknet::events::check_if_filter_applies_for_event;
use starknet_rs_core::types::{BlockTag, Felt};
use starknet_types::contract_address::ContractAddress;
use starknet_types::emitted_event::EmittedEvent;
use starknet_types::felt::TransactionHash;
use starknet_types::rpc::block::BlockHeader;
use starknet_types::rpc::transactions::{TransactionStatus, TransactionWithHash};
Expand Down Expand Up @@ -39,7 +41,7 @@ pub enum Subscription {
TransactionStatus { tag: BlockTag, transaction_hash: TransactionHash },
PendingTransactionsFull { address_filter: AddressFilter },
PendingTransactionsHash { address_filter: AddressFilter },
Events,
Events { address: Option<ContractAddress>, keys_filter: Option<Vec<Vec<Felt>>> },
}

impl Subscription {
Expand All @@ -51,7 +53,7 @@ impl Subscription {
| Subscription::PendingTransactionsHash { .. } => {
SubscriptionConfirmation::NewSubscription(id)
}
Subscription::Events => SubscriptionConfirmation::NewSubscription(id),
Subscription::Events { .. } => SubscriptionConfirmation::NewSubscription(id),
}
}

Expand Down Expand Up @@ -90,7 +92,11 @@ impl Subscription {
};
}
}
Subscription::Events => todo!(),
Subscription::Events { address, keys_filter } => {
if let SubscriptionNotification::Event(event) = notification {
return check_if_filter_applies_for_event(address, keys_filter, &event.into());
}
}
}

false
Expand Down Expand Up @@ -141,6 +147,7 @@ pub enum SubscriptionNotification {
NewHeads(Box<BlockHeader>),
TransactionStatus(NewTransactionStatus),
PendingTransaction(PendingTransactionNotification),
Event(EmittedEvent),
}

impl SubscriptionNotification {
Expand All @@ -152,7 +159,8 @@ impl SubscriptionNotification {
}
SubscriptionNotification::PendingTransaction(_) => {
"starknet_subscriptionPendingTransactions"
} // SubscriptionNotification::Events => "starknet_subscriptionEvents",
}
SubscriptionNotification::Event(_) => "starknet_subscriptionEvents",
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions crates/starknet-devnet-types/src/rpc/emitted_event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,13 @@ impl From<&blockifier::execution::call_info::OrderedEvent> for OrderedEvent {
}
}
}

impl From<&EmittedEvent> for Event {
fn from(emitted_event: &EmittedEvent) -> Self {
Self {
from_address: emitted_event.from_address,
keys: emitted_event.keys.clone(),
data: emitted_event.data.clone(),
}
}
}
13 changes: 13 additions & 0 deletions crates/starknet-devnet/tests/common/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,19 @@ pub async fn receive_rpc_via_ws(
Ok(serde_json::from_str(&msg.into_text()?)?)
}

/// Extract `result` from the notification and assert general properties
pub async fn receive_notification(
ws: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
method: &str,
expected_subscription_id: i64,
) -> Result<serde_json::Value, anyhow::Error> {
let mut notification = receive_rpc_via_ws(ws).await?;
assert_eq!(notification["jsonrpc"], "2.0");
assert_eq!(notification["method"], method);
assert_eq!(notification["params"]["subscription_id"], expected_subscription_id);
Ok(notification["params"].take()["result"].take())
}

pub async fn assert_no_notifications(ws: &mut WebSocketStream<MaybeTlsStream<TcpStream>>) {
match receive_rpc_via_ws(ws).await {
Ok(resp) => panic!("Expected no notifications; found: {resp}"),
Expand Down
Loading