Skip to content

Commit

Permalink
feat: add error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
AmanRaj1608 committed Jun 27, 2024
1 parent d937d7c commit 69da0df
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 122 deletions.
44 changes: 28 additions & 16 deletions bin/reflux/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ use std::time::Duration;
use axum::http::Method;
use clap::Parser;
use log::{debug, error, info};
use tokio;
use tokio::signal;
use tokio::sync::broadcast;
use tower_http::cors::{Any, CorsLayer};

use account_aggregation::service::AccountAggregationService;
use api::service_controller::ServiceController;
use config::Config;
use routing_engine::{BungeeClient, CoingeckoClient, Indexer};
use routing_engine::engine::RoutingEngine;
use routing_engine::estimator::LinearRegressionEstimator;
use storage::{ControlFlow, MessageQueue, RedisClient};
use routing_engine::{BungeeClient, CoingeckoClient, Indexer};
use storage::mongodb_client::MongoDBClient;
use storage::{ControlFlow, MessageQueue, RedisClient};

#[derive(Parser, Debug)]
struct Args {
Expand Down Expand Up @@ -107,18 +107,24 @@ async fn run_solver(config: Config) {
// Subscribe to cache update messages
let cache_update_topic = config.indexer_config.indexer_update_topic.clone();
let routing_engine_clone = Arc::clone(&routing_engine);
tokio::spawn(async move {

let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1);

let cache_update_handle = tokio::spawn(async move {
let redis_client = redis_client.clone();
redis_client
.subscribe(&cache_update_topic, move |_msg| {
info!("Received cache update notification");
let routing_engine_clone = Arc::clone(&routing_engine_clone);
tokio::spawn(async move {
routing_engine_clone.refresh_cache().await;
});
ControlFlow::<()>::Continue
})
.unwrap();
if let Err(e) = redis_client.subscribe(&cache_update_topic, move |_msg| {
info!("Received cache update notification");
let routing_engine_clone = Arc::clone(&routing_engine_clone);
tokio::spawn(async move {
routing_engine_clone.refresh_cache().await;
});
ControlFlow::<()>::Continue
}) {
error!("Failed to subscribe to cache update topic: {}", e);
}

// Listen for shutdown signal
let _ = shutdown_rx.recv().await;
});

// API service controller
Expand All @@ -135,11 +141,15 @@ async fn run_solver(config: Config) {
let listener = tokio::net::TcpListener::bind(format!("{}:{}", app_host, app_port))
.await
.expect("Failed to bind port");

axum::serve(listener, app.into_make_service())
.with_graceful_shutdown(shutdown_signal())
.with_graceful_shutdown(shutdown_signal(shutdown_tx.clone()))
.await
.unwrap();

let _ = shutdown_tx.send(());
let _ = cache_update_handle.abort();

info!("Server stopped.");
}

Expand Down Expand Up @@ -177,7 +187,7 @@ async fn run_indexer(config: Config) {
};
}

async fn shutdown_signal() {
async fn shutdown_signal(shutdown_tx: broadcast::Sender<()>) {
let ctrl_c = async {
signal::ctrl_c().await.expect("Unable to handle ctrl+c");
};
Expand All @@ -194,5 +204,7 @@ async fn shutdown_signal() {
_ = ctrl_c => {},
_ = terminate => {},
}

info!("signal received, starting graceful shutdown");
let _ = shutdown_tx.send(());
}
1 change: 1 addition & 0 deletions crates/account-aggregation/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ serde = { version = "1.0.203", features = ["derive"] }
futures = "0.3.30"
log = "0.4.21"
derive_more = { version = "1.0.0-beta.6", features = ["from", "into", "display"] }
thiserror = "1.0.61"

# workspace dependencies
storage = { workspace = true }
Expand Down
167 changes: 104 additions & 63 deletions crates/account-aggregation/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,38 @@
use std::error::Error;
use log::debug;
use std::sync::Arc;
use thiserror::Error;

use derive_more::Display;
use mongodb::bson;
use reqwest::Client as ReqwestClient;
use uuid::Uuid;

use storage::mongodb_client::{DBError, MongoDBClient};
use storage::DBProvider;
use storage::mongodb_client::MongoDBClient;

use crate::types::{
Account, AddAccountPayload, ApiResponse, Balance, RegisterAccountPayload, User,
UserAccountMapping, UserAccountMappingQuery, UserQuery,
};

#[derive(Error, Debug)]
pub enum AccountAggregationError {
#[error("Database error: {0}")]
DatabaseError(#[from] DBError),

#[error("Reqwest error: {0}")]
ReqwestError(#[from] reqwest::Error),

#[error("Serialization error: {0}")]
SerializationError(#[from] bson::ser::Error),

#[error("Deserialization error: {0}")]
DeserializationError(#[from] bson::de::Error),

#[error("Custom error: {0}")]
CustomError(String),
}

/// Account Aggregation Service
///
/// This service is responsible for managing user accounts and their balances
Expand Down Expand Up @@ -56,23 +75,39 @@ impl AccountAggregationService {
}

/// Get the user_id associated with an account
pub async fn get_user_id(&self, account: &String) -> Option<String> {
pub async fn get_user_id(
&self,
account: &String,
) -> Result<Option<String>, AccountAggregationError> {
let account = account.to_lowercase();
let query = self
.account_mapping_db_provider
.to_document(&UserAccountMappingQuery { account })
.ok()?;
let user_mapping = self.account_mapping_db_provider.read(&query).await.ok()??;
Some(user_mapping.get_str("user_id").ok()?.to_string())
let query =
self.account_mapping_db_provider.to_document(&UserAccountMappingQuery { account })?;
let user_mapping =
self.account_mapping_db_provider.read(&query).await?.ok_or_else(|| {
AccountAggregationError::CustomError("User mapping not found".to_string())
})?;
Ok(Some(
user_mapping
.get_str("user_id")
.map_err(|e| AccountAggregationError::CustomError(e.to_string()))?
.to_string(),
))
}

/// Get the accounts associated with a user_id
pub async fn get_user_accounts(&self, user_id: &String) -> Option<Vec<Account>> {
let query =
self.user_db_provider.to_document(&UserQuery { user_id: user_id.clone() }).ok()?;

let user = self.user_db_provider.read(&query).await.ok()??;
let accounts = user.get_array("accounts").ok()?;
pub async fn get_user_accounts(
&self,
user_id: &String,
) -> Result<Option<Vec<Account>>, AccountAggregationError> {
let query = self.user_db_provider.to_document(&UserQuery { user_id: user_id.clone() })?;

let user =
self.user_db_provider.read(&query).await?.ok_or_else(|| {
AccountAggregationError::CustomError("User not found".to_string())
})?;
let accounts = user
.get_array("accounts")
.map_err(|e| AccountAggregationError::CustomError(e.to_string()))?;

let accounts: Vec<Account> = accounts
.iter()
Expand All @@ -87,43 +122,40 @@ impl AccountAggregationService {
})
.collect();

Some(accounts)
Ok(Some(accounts))
}

/// Register a new user account
pub async fn register_user_account(
&self,
account_payload: RegisterAccountPayload,
) -> Result<(), Box<dyn Error>> {
) -> Result<(), AccountAggregationError> {
let account = account_payload.account.to_lowercase();

if self.get_user_id(&account).await.is_none() {
if self.get_user_id(&account).await?.is_none() {
let user_id = Uuid::new_v4().to_string();
let user_doc = self
.user_db_provider
.to_document(&User {
user_id: user_id.clone(),
accounts: vec![Account {
chain_id: account_payload.chain_id,
is_enabled: account_payload.is_enabled,
account_address: account.clone(),
account_type: account_payload.account_type,
}],
})
.unwrap();
let user_doc = self.user_db_provider.to_document(&User {
user_id: user_id.clone(),
accounts: vec![Account {
chain_id: account_payload.chain_id,
is_enabled: account_payload.is_enabled,
account_address: account.clone(),
account_type: account_payload.account_type,
}],
})?;

self.user_db_provider.create(&user_doc).await?;

let mapping_doc = self
.account_mapping_db_provider
.to_document(&UserAccountMapping {
let mapping_doc =
self.account_mapping_db_provider.to_document(&UserAccountMapping {
account: account.clone(),
user_id: user_id.clone(),
})
.unwrap();
})?;
self.account_mapping_db_provider.create(&mapping_doc).await?;
} else {
return Err("Account already mapped to a user".into());
return Err(AccountAggregationError::CustomError(
"Account already mapped to a user".to_string(),
));
}
Ok(())
}
Expand All @@ -132,7 +164,7 @@ impl AccountAggregationService {
pub async fn add_account(
&self,
account_payload: AddAccountPayload,
) -> Result<(), Box<dyn Error>> {
) -> Result<(), AccountAggregationError> {
let new_account = Account {
chain_id: account_payload.chain_id.clone(),
is_enabled: account_payload.is_enabled,
Expand All @@ -141,17 +173,21 @@ impl AccountAggregationService {
};

// Check if the account is already mapped to a user
if self.get_user_id(&new_account.account_address).await.is_some() {
return Err("Account already mapped to a user".into());
if self.get_user_id(&new_account.account_address).await?.is_some() {
return Err(AccountAggregationError::CustomError(
"Account already mapped to a user".to_string(),
));
}

// Fetch the user document
let query_doc = self
.user_db_provider
.to_document(&UserQuery { user_id: account_payload.user_id.clone() })
.unwrap();
.to_document(&UserQuery { user_id: account_payload.user_id.clone() })?;
// Retrieve user document
let mut user_doc = self.user_db_provider.read(&query_doc).await?.ok_or("User not found")?;
let mut user_doc =
self.user_db_provider.read(&query_doc).await?.ok_or_else(|| {
AccountAggregationError::CustomError("User not found".to_string())
})?;

// Add the new account to the user's accounts array
let accounts_array =
Expand All @@ -160,40 +196,41 @@ impl AccountAggregationService {
if let bson::Bson::Array(accounts) = accounts_array {
accounts.push(bson::to_bson(&new_account)?);
} else {
return Err("Failed to update accounts array".into());
return Err(AccountAggregationError::CustomError(
"Failed to update accounts array".to_string(),
));
}

// Update the user document with the new account
self.user_db_provider.update(&query_doc, &user_doc).await?;

// Create a new mapping document for the account
let mapping_doc = self
.account_mapping_db_provider
.to_document(&UserAccountMapping {
account: new_account.account_address.clone(),
user_id: account_payload.user_id.clone(),
})
.unwrap();
let mapping_doc = self.account_mapping_db_provider.to_document(&UserAccountMapping {
account: new_account.account_address.clone(),
user_id: account_payload.user_id.clone(),
})?;
self.account_mapping_db_provider.create(&mapping_doc).await?;

Ok(())
}

/// Get the balance of a user's accounts
pub async fn get_user_accounts_balance(&self, account: &String) -> Vec<Balance> {
// find if the account is mapped to a user
let user_id = self.get_user_id(account).await;
pub async fn get_user_accounts_balance(
&self,
account: &String,
) -> Result<Vec<Balance>, AccountAggregationError> {
let mut accounts: Vec<String> = Vec::new();
if user_id.is_some() {
let user_id = user_id.unwrap();
let user_accounts = self.get_user_accounts(&user_id).await.unwrap();
let user_id = self.get_user_id(account).await.unwrap_or(None);
if let Some(user_id) = user_id {
let user_accounts = self.get_user_accounts(&user_id).await?.unwrap();
accounts.extend(user_accounts.iter().map(|account| account.account_address.clone()));
} else {
accounts.push(account.clone());
}

let mut balances = Vec::new();
let networks = self.networks.clone();
debug!("Networks: {:?}", networks);

// todo: parallelize this
for user in accounts.iter() {
Expand All @@ -202,19 +239,23 @@ impl AccountAggregationService {
"{}/v1/{}/address/{}/balances_v2/?key={}",
self.covalent_base_url, network, user, self.covalent_api_key
);
let response = self.client.get(&url).send().await.unwrap();
let api_response: ApiResponse = response.json().await.unwrap();
let user_balances = extract_balance_data(api_response).unwrap();
debug!("Requesting: {}", url);
let response = self.client.get(&url).send().await?;
let api_response: ApiResponse = response.json().await?;
let user_balances = extract_balance_data(api_response)?;
balances.extend(user_balances);
}
}
println!("{:?}", balances);

balances
Ok(balances)
}
}

/// Extract balance data from the API response
fn extract_balance_data(api_response: ApiResponse) -> Option<Vec<Balance>> {
fn extract_balance_data(
api_response: ApiResponse,
) -> Result<Vec<Balance>, AccountAggregationError> {
let chain_id = api_response.data.chain_id.to_string();
let results = api_response
.data
Expand All @@ -227,7 +268,7 @@ fn extract_balance_data(api_response: ApiResponse) -> Option<Vec<Balance>> {
.clone()
.unwrap_or("0".to_string())
.parse::<f64>()
.map_err(Box::<dyn Error>::from)
.map_err(|e| AccountAggregationError::CustomError(e.to_string()))
.ok()?;
let quote = item.quote;

Expand All @@ -247,5 +288,5 @@ fn extract_balance_data(api_response: ApiResponse) -> Option<Vec<Balance>> {
})
.collect();

Some(results)
Ok(results)
}
Loading

0 comments on commit 69da0df

Please sign in to comment.