diff --git a/bin/reflux/src/main.rs b/bin/reflux/src/main.rs index 83bb0e2..a3dbd5d 100644 --- a/bin/reflux/src/main.rs +++ b/bin/reflux/src/main.rs @@ -4,18 +4,18 @@ use std::time::Duration; use axum::http::Method; use clap::Parser; use log::{debug, error, info}; -use tokio; use tokio::signal; +use tokio::sync::broadcast; use tower_http::cors::{Any, CorsLayer}; use account_aggregation::service::AccountAggregationService; use api::service_controller::ServiceController; use config::Config; -use routing_engine::{BungeeClient, CoingeckoClient, Indexer}; use routing_engine::engine::RoutingEngine; use routing_engine::estimator::LinearRegressionEstimator; -use storage::{ControlFlow, MessageQueue, RedisClient}; +use routing_engine::{BungeeClient, CoingeckoClient, Indexer}; use storage::mongodb_client::MongoDBClient; +use storage::{ControlFlow, MessageQueue, RedisClient}; #[derive(Parser, Debug)] struct Args { @@ -107,18 +107,24 @@ async fn run_solver(config: Config) { // Subscribe to cache update messages let cache_update_topic = config.indexer_config.indexer_update_topic.clone(); let routing_engine_clone = Arc::clone(&routing_engine); - tokio::spawn(async move { + + let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1); + + let cache_update_handle = tokio::spawn(async move { let redis_client = redis_client.clone(); - redis_client - .subscribe(&cache_update_topic, move |_msg| { - info!("Received cache update notification"); - let routing_engine_clone = Arc::clone(&routing_engine_clone); - tokio::spawn(async move { - routing_engine_clone.refresh_cache().await; - }); - ControlFlow::<()>::Continue - }) - .unwrap(); + if let Err(e) = redis_client.subscribe(&cache_update_topic, move |_msg| { + info!("Received cache update notification"); + let routing_engine_clone = Arc::clone(&routing_engine_clone); + tokio::spawn(async move { + routing_engine_clone.refresh_cache().await; + }); + ControlFlow::<()>::Continue + }) { + error!("Failed to subscribe to cache update topic: {}", e); + } + + // Listen for shutdown signal + let _ = shutdown_rx.recv().await; }); // API service controller @@ -135,11 +141,15 @@ async fn run_solver(config: Config) { let listener = tokio::net::TcpListener::bind(format!("{}:{}", app_host, app_port)) .await .expect("Failed to bind port"); + axum::serve(listener, app.into_make_service()) - .with_graceful_shutdown(shutdown_signal()) + .with_graceful_shutdown(shutdown_signal(shutdown_tx.clone())) .await .unwrap(); + let _ = shutdown_tx.send(()); + let _ = cache_update_handle.abort(); + info!("Server stopped."); } @@ -177,7 +187,7 @@ async fn run_indexer(config: Config) { }; } -async fn shutdown_signal() { +async fn shutdown_signal(shutdown_tx: broadcast::Sender<()>) { let ctrl_c = async { signal::ctrl_c().await.expect("Unable to handle ctrl+c"); }; @@ -194,5 +204,7 @@ async fn shutdown_signal() { _ = ctrl_c => {}, _ = terminate => {}, } + info!("signal received, starting graceful shutdown"); + let _ = shutdown_tx.send(()); } diff --git a/crates/account-aggregation/Cargo.toml b/crates/account-aggregation/Cargo.toml index c1513c1..1469eca 100644 --- a/crates/account-aggregation/Cargo.toml +++ b/crates/account-aggregation/Cargo.toml @@ -12,6 +12,7 @@ serde = { version = "1.0.203", features = ["derive"] } futures = "0.3.30" log = "0.4.21" derive_more = { version = "1.0.0-beta.6", features = ["from", "into", "display"] } +thiserror = "1.0.61" # workspace dependencies storage = { workspace = true } diff --git a/crates/account-aggregation/src/service.rs b/crates/account-aggregation/src/service.rs index 318754a..c777911 100644 --- a/crates/account-aggregation/src/service.rs +++ b/crates/account-aggregation/src/service.rs @@ -1,19 +1,38 @@ -use std::error::Error; +use log::debug; use std::sync::Arc; +use thiserror::Error; use derive_more::Display; use mongodb::bson; use reqwest::Client as ReqwestClient; use uuid::Uuid; +use storage::mongodb_client::{DBError, MongoDBClient}; use storage::DBProvider; -use storage::mongodb_client::MongoDBClient; use crate::types::{ Account, AddAccountPayload, ApiResponse, Balance, RegisterAccountPayload, User, UserAccountMapping, UserAccountMappingQuery, UserQuery, }; +#[derive(Error, Debug)] +pub enum AccountAggregationError { + #[error("Database error: {0}")] + DatabaseError(#[from] DBError), + + #[error("Reqwest error: {0}")] + ReqwestError(#[from] reqwest::Error), + + #[error("Serialization error: {0}")] + SerializationError(#[from] bson::ser::Error), + + #[error("Deserialization error: {0}")] + DeserializationError(#[from] bson::de::Error), + + #[error("Custom error: {0}")] + CustomError(String), +} + /// Account Aggregation Service /// /// This service is responsible for managing user accounts and their balances @@ -56,23 +75,39 @@ impl AccountAggregationService { } /// Get the user_id associated with an account - pub async fn get_user_id(&self, account: &String) -> Option { + pub async fn get_user_id( + &self, + account: &String, + ) -> Result, AccountAggregationError> { let account = account.to_lowercase(); - let query = self - .account_mapping_db_provider - .to_document(&UserAccountMappingQuery { account }) - .ok()?; - let user_mapping = self.account_mapping_db_provider.read(&query).await.ok()??; - Some(user_mapping.get_str("user_id").ok()?.to_string()) + let query = + self.account_mapping_db_provider.to_document(&UserAccountMappingQuery { account })?; + let user_mapping = + self.account_mapping_db_provider.read(&query).await?.ok_or_else(|| { + AccountAggregationError::CustomError("User mapping not found".to_string()) + })?; + Ok(Some( + user_mapping + .get_str("user_id") + .map_err(|e| AccountAggregationError::CustomError(e.to_string()))? + .to_string(), + )) } /// Get the accounts associated with a user_id - pub async fn get_user_accounts(&self, user_id: &String) -> Option> { - let query = - self.user_db_provider.to_document(&UserQuery { user_id: user_id.clone() }).ok()?; - - let user = self.user_db_provider.read(&query).await.ok()??; - let accounts = user.get_array("accounts").ok()?; + pub async fn get_user_accounts( + &self, + user_id: &String, + ) -> Result>, AccountAggregationError> { + let query = self.user_db_provider.to_document(&UserQuery { user_id: user_id.clone() })?; + + let user = + self.user_db_provider.read(&query).await?.ok_or_else(|| { + AccountAggregationError::CustomError("User not found".to_string()) + })?; + let accounts = user + .get_array("accounts") + .map_err(|e| AccountAggregationError::CustomError(e.to_string()))?; let accounts: Vec = accounts .iter() @@ -87,43 +122,40 @@ impl AccountAggregationService { }) .collect(); - Some(accounts) + Ok(Some(accounts)) } /// Register a new user account pub async fn register_user_account( &self, account_payload: RegisterAccountPayload, - ) -> Result<(), Box> { + ) -> Result<(), AccountAggregationError> { let account = account_payload.account.to_lowercase(); - if self.get_user_id(&account).await.is_none() { + if self.get_user_id(&account).await?.is_none() { let user_id = Uuid::new_v4().to_string(); - let user_doc = self - .user_db_provider - .to_document(&User { - user_id: user_id.clone(), - accounts: vec![Account { - chain_id: account_payload.chain_id, - is_enabled: account_payload.is_enabled, - account_address: account.clone(), - account_type: account_payload.account_type, - }], - }) - .unwrap(); + let user_doc = self.user_db_provider.to_document(&User { + user_id: user_id.clone(), + accounts: vec![Account { + chain_id: account_payload.chain_id, + is_enabled: account_payload.is_enabled, + account_address: account.clone(), + account_type: account_payload.account_type, + }], + })?; self.user_db_provider.create(&user_doc).await?; - let mapping_doc = self - .account_mapping_db_provider - .to_document(&UserAccountMapping { + let mapping_doc = + self.account_mapping_db_provider.to_document(&UserAccountMapping { account: account.clone(), user_id: user_id.clone(), - }) - .unwrap(); + })?; self.account_mapping_db_provider.create(&mapping_doc).await?; } else { - return Err("Account already mapped to a user".into()); + return Err(AccountAggregationError::CustomError( + "Account already mapped to a user".to_string(), + )); } Ok(()) } @@ -132,7 +164,7 @@ impl AccountAggregationService { pub async fn add_account( &self, account_payload: AddAccountPayload, - ) -> Result<(), Box> { + ) -> Result<(), AccountAggregationError> { let new_account = Account { chain_id: account_payload.chain_id.clone(), is_enabled: account_payload.is_enabled, @@ -141,17 +173,21 @@ impl AccountAggregationService { }; // Check if the account is already mapped to a user - if self.get_user_id(&new_account.account_address).await.is_some() { - return Err("Account already mapped to a user".into()); + if self.get_user_id(&new_account.account_address).await?.is_some() { + return Err(AccountAggregationError::CustomError( + "Account already mapped to a user".to_string(), + )); } // Fetch the user document let query_doc = self .user_db_provider - .to_document(&UserQuery { user_id: account_payload.user_id.clone() }) - .unwrap(); + .to_document(&UserQuery { user_id: account_payload.user_id.clone() })?; // Retrieve user document - let mut user_doc = self.user_db_provider.read(&query_doc).await?.ok_or("User not found")?; + let mut user_doc = + self.user_db_provider.read(&query_doc).await?.ok_or_else(|| { + AccountAggregationError::CustomError("User not found".to_string()) + })?; // Add the new account to the user's accounts array let accounts_array = @@ -160,33 +196,33 @@ impl AccountAggregationService { if let bson::Bson::Array(accounts) = accounts_array { accounts.push(bson::to_bson(&new_account)?); } else { - return Err("Failed to update accounts array".into()); + return Err(AccountAggregationError::CustomError( + "Failed to update accounts array".to_string(), + )); } // Update the user document with the new account self.user_db_provider.update(&query_doc, &user_doc).await?; // Create a new mapping document for the account - let mapping_doc = self - .account_mapping_db_provider - .to_document(&UserAccountMapping { - account: new_account.account_address.clone(), - user_id: account_payload.user_id.clone(), - }) - .unwrap(); + let mapping_doc = self.account_mapping_db_provider.to_document(&UserAccountMapping { + account: new_account.account_address.clone(), + user_id: account_payload.user_id.clone(), + })?; self.account_mapping_db_provider.create(&mapping_doc).await?; Ok(()) } /// Get the balance of a user's accounts - pub async fn get_user_accounts_balance(&self, account: &String) -> Vec { - // find if the account is mapped to a user - let user_id = self.get_user_id(account).await; + pub async fn get_user_accounts_balance( + &self, + account: &String, + ) -> Result, AccountAggregationError> { let mut accounts: Vec = Vec::new(); - if user_id.is_some() { - let user_id = user_id.unwrap(); - let user_accounts = self.get_user_accounts(&user_id).await.unwrap(); + let user_id = self.get_user_id(account).await.unwrap_or(None); + if let Some(user_id) = user_id { + let user_accounts = self.get_user_accounts(&user_id).await?.unwrap(); accounts.extend(user_accounts.iter().map(|account| account.account_address.clone())); } else { accounts.push(account.clone()); @@ -194,6 +230,7 @@ impl AccountAggregationService { let mut balances = Vec::new(); let networks = self.networks.clone(); + debug!("Networks: {:?}", networks); // todo: parallelize this for user in accounts.iter() { @@ -202,19 +239,23 @@ impl AccountAggregationService { "{}/v1/{}/address/{}/balances_v2/?key={}", self.covalent_base_url, network, user, self.covalent_api_key ); - let response = self.client.get(&url).send().await.unwrap(); - let api_response: ApiResponse = response.json().await.unwrap(); - let user_balances = extract_balance_data(api_response).unwrap(); + debug!("Requesting: {}", url); + let response = self.client.get(&url).send().await?; + let api_response: ApiResponse = response.json().await?; + let user_balances = extract_balance_data(api_response)?; balances.extend(user_balances); } } + println!("{:?}", balances); - balances + Ok(balances) } } /// Extract balance data from the API response -fn extract_balance_data(api_response: ApiResponse) -> Option> { +fn extract_balance_data( + api_response: ApiResponse, +) -> Result, AccountAggregationError> { let chain_id = api_response.data.chain_id.to_string(); let results = api_response .data @@ -227,7 +268,7 @@ fn extract_balance_data(api_response: ApiResponse) -> Option> { .clone() .unwrap_or("0".to_string()) .parse::() - .map_err(Box::::from) + .map_err(|e| AccountAggregationError::CustomError(e.to_string())) .ok()?; let quote = item.quote; @@ -247,5 +288,5 @@ fn extract_balance_data(api_response: ApiResponse) -> Option> { }) .collect(); - Some(results) + Ok(results) } diff --git a/crates/api/src/service_controller.rs b/crates/api/src/service_controller.rs index 3d01f43..6927af0 100644 --- a/crates/api/src/service_controller.rs +++ b/crates/api/src/service_controller.rs @@ -77,17 +77,24 @@ impl ServiceController { account_service: Arc, query: types::UserAccountMappingQuery, ) -> impl IntoResponse { - let user_id = account_service.get_user_id(&query.account).await; - - let response = match user_id { - Some(user_id) => { - let accounts = account_service.get_user_accounts(&user_id).await; - json!({ "user_id": user_id, "accounts": accounts }) + match account_service.get_user_id(&query.account).await { + Ok(Some(user_id)) => match account_service.get_user_accounts(&user_id).await { + Ok(Some(accounts)) => { + (StatusCode::OK, Json(json!({ "user_id": user_id, "accounts": accounts }))) + } + Ok(None) => (StatusCode::NOT_FOUND, Json(json!({ "error": "Accounts not found" }))), + Err(err) => { + (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": err.to_string() }))) + } + }, + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(json!({ "error": "User not found", "accounts": [query.account] })), + ), + Err(err) => { + (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": err.to_string() }))) } - None => json!({ "error": "User not found", "accounts": [query.account] }), - }; - - (StatusCode::OK, Json(response)) + } } /// Register user account @@ -129,12 +136,18 @@ impl ServiceController { routing_engine: Arc, query: types::PathQuery, ) -> impl IntoResponse { - // todo: decide if we want to find with user_id in db or generic account also - let routes = routing_engine + match routing_engine .get_best_cost_path(&query.account, query.to_chain, &query.to_token, query.to_value) - .await; - let response = json!({ "routes": routes }); - - (StatusCode::OK, Json(response)) + .await + { + Ok(routes) => { + let response = json!({ "routes": routes }); + (StatusCode::OK, Json(response)) + } + Err(err) => { + let response = json!({ "error": err.to_string() }); + (StatusCode::INTERNAL_SERVER_ERROR, Json(response)) + } + } } } diff --git a/crates/routing-engine/src/engine.rs b/crates/routing-engine/src/engine.rs index 5c5f325..bdf3eff 100644 --- a/crates/routing-engine/src/engine.rs +++ b/crates/routing-engine/src/engine.rs @@ -1,4 +1,3 @@ -use std::borrow::Borrow; use std::collections::HashMap; use std::sync::Arc; @@ -6,12 +5,13 @@ use derive_more::Display; use futures::stream::{self, StreamExt}; use log::{debug, error, info}; use serde::{Deserialize, Serialize}; +use thiserror::Error; use tokio::sync::RwLock; use account_aggregation::service::AccountAggregationService; use account_aggregation::types::Balance; use config::config::BucketConfig; -use storage::RedisClient; +use storage::{RedisClient, RedisClientError}; use crate::estimator::{Estimator, LinearRegressionEstimator}; @@ -34,6 +34,21 @@ pub struct Route { #[derive(Debug)] struct PathQuery(u32, u32, String, String); +#[derive(Error, Debug)] +pub enum RoutingEngineError { + #[error("Redis error: {0}")] + RedisError(#[from] RedisClientError), + + #[error("Estimator error: {0}")] + EstimatorError(#[from] serde_json::Error), + + #[error("Cache error: {0}")] + CacheError(String), + + #[error("User balance fetch error: {0}")] + UserBalanceFetchError(String), +} + /// Routing Engine /// This struct is responsible for calculating the best cost path for a user #[derive(Debug, Clone)] @@ -65,6 +80,7 @@ impl RoutingEngine { cache.insert(key.clone(), value.clone()); } info!("Cache refreshed with latest data from Redis."); + debug!("Cache: {:?}", cache); } Err(e) => { error!("Failed to refresh cache from Redis: {}", e); @@ -80,12 +96,12 @@ impl RoutingEngine { to_chain: u32, to_token: &str, to_value: f64, - ) -> Vec { + ) -> Result, RoutingEngineError> { debug!( "Getting best cost path for user: {}, to_chain: {}, to_token: {}, to_value: {}", account, to_chain, to_token, to_value ); - let user_balances = self.get_user_balance_from_agg_service(&account).await; + let user_balances = self.get_user_balance_from_agg_service(&account).await?; debug!("User balances: {:?}", user_balances); // todo: for account aggregation, transfer same chain same asset first @@ -108,7 +124,8 @@ impl RoutingEngine { to_token.to_string(), ), ) - .await; + .await + .unwrap_or_default(); (balance, fee_cost) }) .collect() @@ -160,7 +177,8 @@ impl RoutingEngine { to_token.to_string(), ), ) - .await; + .await + .unwrap_or_default(); (balance, fee_cost) }) .collect() @@ -201,10 +219,14 @@ impl RoutingEngine { account, to_chain, to_token, total_cost ); - selected_assets + Ok(selected_assets) } - async fn get_cached_data(&self, target_amount: f64, path: PathQuery) -> f64 { + async fn get_cached_data( + &self, + target_amount: f64, + path: PathQuery, + ) -> Result { let mut buckets_array: Vec = self .buckets .clone() @@ -224,24 +246,31 @@ impl RoutingEngine { target_amount >= window.token_amount_from_usd && target_amount <= window.token_amount_to_usd }) - .unwrap(); + .ok_or_else(|| { + RoutingEngineError::CacheError("No matching bucket found".to_string()) + })?; let key = bucket.get_hash().to_string(); let cache = self.cache.read().await; - let value = cache.get(&key).unwrap(); - let estimator: Result = - serde_json::from_str(value); - - let cost = estimator.unwrap().borrow().estimate(target_amount); + let value = cache + .get(&key) + .ok_or_else(|| RoutingEngineError::CacheError("No cached value found".to_string()))?; + let estimator: LinearRegressionEstimator = serde_json::from_str(value)?; - cost + Ok(estimator.estimate(target_amount)) } /// Get user balance from account aggregation service - async fn get_user_balance_from_agg_service(&self, account: &str) -> Vec { + async fn get_user_balance_from_agg_service( + &self, + account: &str, + ) -> Result, RoutingEngineError> { // Note: aas should always return vec of balances - self.aas_client.get_user_accounts_balance(&account.to_string()).await + self.aas_client + .get_user_accounts_balance(&account.to_string()) + .await + .map_err(|e| RoutingEngineError::UserBalanceFetchError(e.to_string())) } } @@ -257,15 +286,15 @@ mod tests { use config::BucketConfig; use storage::mongodb_client::MongoDBClient; + use crate::engine::PathQuery; + use crate::estimator::Estimator; use crate::{ - engine::RoutingEngine, + engine::{RoutingEngine, RoutingEngineError}, estimator::{DataPoint, LinearRegressionEstimator}, }; - use crate::engine::PathQuery; - use crate::estimator::Estimator; #[tokio::test] - async fn test_get_cached_data() { + async fn test_get_cached_data() -> Result<(), RoutingEngineError> { // Create dummy buckets let buckets = vec![ BucketConfig { @@ -295,7 +324,7 @@ mod tests { DataPoint { x: 2.0, y: 2.0 }, ]) .unwrap(); - let serialized_estimator = serde_json::to_string(&dummy_estimator).unwrap(); + let serialized_estimator = serde_json::to_string(&dummy_estimator)?; // Create a cache with a dummy bucket let key = buckets[0].get_hash().to_string(); @@ -333,13 +362,14 @@ mod tests { let path_query = PathQuery(1, 2, "USDC".to_string(), "ETH".to_string()); // Call get_cached_data and assert the result - let result = routing_engine.get_cached_data(target_amount, path_query).await; + let result = routing_engine.get_cached_data(target_amount, path_query).await?; assert!(result > 0.0); assert_eq!(result, dummy_estimator.estimate(target_amount)); + Ok(()) } #[tokio::test] - async fn test_get_best_cost_path() { + async fn test_get_best_cost_path() -> Result<(), RoutingEngineError> { let api_key = env::var("COVALENT_API_KEY"); if api_key.is_err() { panic!("COVALENT_API_KEY is not set"); @@ -389,7 +419,7 @@ mod tests { DataPoint { x: 2.0, y: 2.0 }, ]) .unwrap(); - let serialized_estimator = serde_json::to_string(&dummy_estimator).unwrap(); + let serialized_estimator = serde_json::to_string(&dummy_estimator)?; // Create a cache with a dummy bucket let key1 = buckets[0].get_hash().to_string(); let key2 = buckets[1].get_hash().to_string(); @@ -408,7 +438,8 @@ mod tests { // should have USDT in bsc-mainnet > $0.5 let dummy_user_address = "0x00000ebe3fa7cb71aE471547C836E0cE0AE758c2"; - let result = routing_engine.get_best_cost_path(dummy_user_address, 2, "USDT", 0.5).await; + let result = routing_engine.get_best_cost_path(dummy_user_address, 2, "USDT", 0.5).await?; assert_eq!(result.len(), 1); + Ok(()) } } diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index f0a65ae..f5b26b3 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -6,7 +6,7 @@ use std::time::Duration; pub use ::redis::{ControlFlow, Msg}; use mongodb::bson::Document; -pub use redis_client::RedisClient; +pub use redis_client::{RedisClient, RedisClientError}; pub mod mongodb_client;