Skip to content

Commit

Permalink
feat: adding server for active components
Browse files Browse the repository at this point in the history
commit-id:fa8f56cc
  • Loading branch information
lev-starkware committed Jul 28, 2024
1 parent 93de0bd commit 93f4f2a
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 7 deletions.
20 changes: 20 additions & 0 deletions crates/mempool_infra/src/component_server/definitions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use async_trait::async_trait;
use tokio::sync::mpsc::Receiver;
use tracing::{error, info};

use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler};
use crate::component_runner::ComponentStarter;

#[async_trait]
Expand All @@ -20,3 +22,21 @@ where
info!("ComponentServer::start() completed.");
true
}

pub async fn request_response_loop<Request, Response, Component>(
rx: &mut Receiver<ComponentRequestAndResponseSender<Request, Response>>,
component: &mut Component,
) where
Component: ComponentRequestHandler<Request, Response> + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
while let Some(request_and_res_tx) = rx.recv().await {
let request = request_and_res_tx.request;
let tx = request_and_res_tx.tx;

let res = component.handle_request(request).await;

tx.send(res).await.expect("Response connection should be open.");
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use async_trait::async_trait;
use tokio::sync::mpsc::Receiver;
use tracing::error;

use super::definitions::{start_component, ComponentServerStarter};
use super::definitions::{request_response_loop, start_component, ComponentServerStarter};
use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler};
use crate::component_runner::ComponentStarter;

Expand Down Expand Up @@ -137,14 +138,56 @@ where
{
async fn start(&mut self) {
if start_component(&mut self.component).await {
while let Some(request_and_res_tx) = self.rx.recv().await {
let request = request_and_res_tx.request;
let tx = request_and_res_tx.tx;
request_response_loop(&mut self.rx, &mut self.component).await;
}
}
}

pub struct LocalActiveComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + ComponentStarter + Clone + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
component: Component,
rx: Receiver<ComponentRequestAndResponseSender<Request, Response>>,
}

impl<Component, Request, Response> LocalActiveComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + ComponentStarter + Clone + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
pub fn new(
component: Component,
rx: Receiver<ComponentRequestAndResponseSender<Request, Response>>,
) -> Self {
Self { component, rx }
}
}

let res = self.component.handle_request(request).await;
#[async_trait]
impl<Component, Request, Response> ComponentServerStarter
for LocalActiveComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + ComponentStarter + Clone + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
async fn start(&mut self) {
let mut component = self.component.clone();
let component_future = async move { component.start().await };
let request_response_future = request_response_loop(&mut self.rx, &mut self.component);

tx.send(res).await.expect("Response connection should be open.");
tokio::select! {
_res = component_future => {
error!("Component stopped.");
}
}
_res = request_response_future => {
error!("Server stopped.");
}
};
error!("Server ended with unexpected Ok.");
}
}
242 changes: 242 additions & 0 deletions crates/mempool_infra/tests/active_component_server_client_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
use std::future::pending;
use std::sync::Arc;

use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use starknet_mempool_infra::component_client::definitions::{ClientError, ClientResult};
use starknet_mempool_infra::component_client::local_component_client::LocalComponentClient;
use starknet_mempool_infra::component_definitions::{
ComponentRequestAndResponseSender, ComponentRequestHandler,
};
use starknet_mempool_infra::component_runner::{ComponentStartError, ComponentStarter};
use starknet_mempool_infra::component_server::definitions::ComponentServerStarter;
use starknet_mempool_infra::component_server::empty_component_server::EmptyServer;
use starknet_mempool_infra::component_server::local_component_server::LocalActiveComponentServer;
use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::Mutex;
use tokio::task;
use tokio::time::{sleep, Duration};

type ValueC = i64;
type ResultC = ClientResult<ValueC>;

#[derive(Debug, Clone)]
struct ComponentC {
test_counter: Arc<Mutex<ValueC>>,
max_iterations: ValueC,
c_test_ended: Arc<Mutex<bool>>,
d_test_ended: Arc<Mutex<bool>>,
}

impl ComponentC {
pub fn new(init_value: ValueC, max_iterations: ValueC) -> Self {
Self {
test_counter: Arc::new(Mutex::new(init_value)),
max_iterations,
c_test_ended: Arc::new(Mutex::new(false)),
d_test_ended: Arc::new(Mutex::new(false)),
}
}

pub async fn c_get_test_counter(&self) -> ValueC {
*self.test_counter.lock().await
}

pub async fn c_increment_test_counter(&self) {
*self.test_counter.lock().await += 1;
}

pub async fn c_set_c_test_end(&self) {
*self.c_test_ended.lock().await = true;
}

pub async fn c_set_d_test_end(&self) {
*self.d_test_ended.lock().await = true;
}

pub async fn c_check_test_ended(&self) -> bool {
let c_test_ended = *self.c_test_ended.lock().await;
c_test_ended && *self.d_test_ended.lock().await
}
}

#[async_trait]
impl ComponentStarter for ComponentC {
async fn start(&mut self) -> Result<(), ComponentStartError> {
for _ in 0..self.max_iterations {
self.c_increment_test_counter().await;
}
let val = self.c_get_test_counter().await;
assert!(val >= self.max_iterations);
self.c_set_c_test_end().await;

// Mimicing real start function that should not return.
let () = pending().await;
Ok(())
}
}

#[derive(Serialize, Deserialize, Debug)]
pub enum ComponentCRequest {
CIncValue,
CGetValue,
CSetDTestEnd,
CTestEndCheck,
}

#[derive(Serialize, Deserialize, Debug)]
pub enum ComponentCResponse {
CIncValue,
CGetValue(ValueC),
CSetDTestEnd,
CTestEndCheck(bool),
}

#[async_trait]
trait ComponentCClientTrait: Send + Sync {
async fn c_inc_value(&self) -> ClientResult<()>;
async fn c_get_value(&self) -> ResultC;
async fn c_set_d_test_end(&self) -> ClientResult<()>;
async fn c_test_end_check(&self) -> ClientResult<bool>;
}

struct ComponentD {
c: Box<dyn ComponentCClientTrait>,
max_iterations: ValueC,
}

impl ComponentD {
pub fn new(c: Box<dyn ComponentCClientTrait>, max_iterations: ValueC) -> Self {
Self { c, max_iterations }
}

pub async fn d_increment_value(&self) {
self.c.c_inc_value().await.unwrap()
}

pub async fn d_get_value(&self) -> ValueC {
self.c.c_get_value().await.unwrap()
}

pub async fn d_send_test_end(&self) {
self.c.c_set_d_test_end().await.unwrap()
}
}

#[async_trait]
impl ComponentStarter for ComponentD {
async fn start(&mut self) -> Result<(), ComponentStartError> {
for _ in 0..self.max_iterations {
self.d_increment_value().await;
}
let val = self.d_get_value().await;
assert!(val >= self.max_iterations);
self.d_send_test_end().await;

// Mimicing real start function that should not return.
let () = pending().await;
Ok(())
}
}

#[async_trait]
impl ComponentCClientTrait for LocalComponentClient<ComponentCRequest, ComponentCResponse> {
async fn c_inc_value(&self) -> ClientResult<()> {
let res = self.send(ComponentCRequest::CIncValue).await;
match res {
ComponentCResponse::CIncValue => Ok(()),
_ => Err(ClientError::UnexpectedResponse),
}
}

async fn c_get_value(&self) -> ResultC {
let res = self.send(ComponentCRequest::CGetValue).await;
match res {
ComponentCResponse::CGetValue(value) => Ok(value),
_ => Err(ClientError::UnexpectedResponse),
}
}

async fn c_set_d_test_end(&self) -> ClientResult<()> {
let res = self.send(ComponentCRequest::CSetDTestEnd).await;
match res {
ComponentCResponse::CSetDTestEnd => Ok(()),
_ => Err(ClientError::UnexpectedResponse),
}
}

async fn c_test_end_check(&self) -> ClientResult<bool> {
let res = self.send(ComponentCRequest::CTestEndCheck).await;
match res {
ComponentCResponse::CTestEndCheck(value) => Ok(value),
_ => Err(ClientError::UnexpectedResponse),
}
}
}

#[async_trait]
impl ComponentRequestHandler<ComponentCRequest, ComponentCResponse> for ComponentC {
async fn handle_request(&mut self, request: ComponentCRequest) -> ComponentCResponse {
match request {
ComponentCRequest::CGetValue => {
ComponentCResponse::CGetValue(self.c_get_test_counter().await)
}
ComponentCRequest::CIncValue => {
self.c_increment_test_counter().await;
ComponentCResponse::CIncValue
}
ComponentCRequest::CSetDTestEnd => {
self.c_set_d_test_end().await;
ComponentCResponse::CSetDTestEnd
}
ComponentCRequest::CTestEndCheck => {
ComponentCResponse::CTestEndCheck(self.c_check_test_ended().await)
}
}
}
}

async fn wait_and_verify_response(
tx_c: Sender<ComponentRequestAndResponseSender<ComponentCRequest, ComponentCResponse>>,
expected_value: ValueC,
) {
let c_client = LocalComponentClient::new(tx_c);

let delay = Duration::from_millis(1);
let mut test_ended = false;

while !test_ended {
test_ended = c_client.c_test_end_check().await.unwrap();
sleep(delay).await; // Lower CPU usage.
}
assert_eq!(c_client.c_get_value().await.unwrap(), expected_value);
}

#[tokio::test]
async fn test_setup_c_d() {
let setup_value: ValueC = 0;
let max_iterations: ValueC = 1024;
let expected_value = max_iterations * 2;

let (tx_c, rx_c) =
channel::<ComponentRequestAndResponseSender<ComponentCRequest, ComponentCResponse>>(32);

let c_client = LocalComponentClient::new(tx_c.clone());

let component_c = ComponentC::new(setup_value, max_iterations);
let component_d = ComponentD::new(Box::new(c_client), max_iterations);

let mut component_c_server = LocalActiveComponentServer::new(component_c, rx_c);
let mut component_d_server = EmptyServer::new(component_d);

task::spawn(async move {
component_c_server.start().await;
});

task::spawn(async move {
component_d_server.start().await;
});

// Wait for the components to finish incrementing of the ComponentC::value and verify it.
wait_and_verify_response(tx_c.clone(), expected_value).await;
}

0 comments on commit 93f4f2a

Please sign in to comment.