From 806ba18585e649e046eb9b48d8e357d785876d47 Mon Sep 17 00:00:00 2001 From: amanraj1608 Date: Thu, 4 Jul 2024 05:21:57 +0400 Subject: [PATCH] feat: add config api fix: update api interface fix: remove graceful shhutdown --- bin/reflux/src/main.rs | 64 ++---- crates/account-aggregation/src/service.rs | 201 +++++++++++------- .../account-aggregation/src/service_trait.rs | 4 +- crates/account-aggregation/src/types.rs | 32 ++- crates/api/src/service_controller.rs | 94 ++++++-- crates/routing-engine/src/engine.rs | 83 ++++---- .../routing-engine/src/source/bungee/mod.rs | 1 + 7 files changed, 284 insertions(+), 195 deletions(-) diff --git a/bin/reflux/src/main.rs b/bin/reflux/src/main.rs index e9659ef..bdb25d8 100644 --- a/bin/reflux/src/main.rs +++ b/bin/reflux/src/main.rs @@ -5,8 +5,6 @@ use std::time::Duration; use axum::http::Method; use clap::Parser; use log::{debug, error, info}; -use tokio::signal; -use tokio::sync::broadcast; use tower_http::cors::{Any, CorsLayer}; use account_aggregation::service::AccountAggregationService; @@ -117,9 +115,7 @@ async fn run_solver(config: Config) { let cache_update_topic = config.indexer_config.indexer_update_topic.clone(); let routing_engine_clone = Arc::clone(&routing_engine); - let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1); - - let cache_update_handle = tokio::spawn(async move { + tokio::task::spawn_blocking(move || { let redis_client = redis_client.clone(); if let Err(e) = redis_client.subscribe(&cache_update_topic, move |_msg| { info!("Received cache update notification"); @@ -131,12 +127,9 @@ async fn run_solver(config: Config) { }) { error!("Failed to subscribe to cache update topic: {}", e); } - - // Listen for shutdown signal - let _ = shutdown_rx.recv().await; }); - let token_chain_supported: HashMap> = config + let token_chain_map: HashMap> = config .tokens .iter() .map(|(token, token_config)| { @@ -150,14 +143,22 @@ async fn run_solver(config: Config) { .collect(); // API service controller - let service_controller = - ServiceController::new(account_service, routing_engine, token_chain_supported); + let chain_supported: Vec<(u32, String)> = + config.chains.iter().map(|(id, chain)| (*id, chain.name.clone())).collect(); + let token_supported: Vec = + config.tokens.iter().map(|(_, token_config)| token_config.symbol.clone()).collect(); + let service_controller = ServiceController::new( + account_service, + routing_engine, + token_chain_map, + chain_supported, + token_supported, + ); - let cors = CorsLayer::new().allow_origin(Any).allow_methods([ - Method::GET, - Method::POST, - Method::PATCH, - ]); + let cors = CorsLayer::new() + .allow_origin(Any) + .allow_methods([Method::GET, Method::POST, Method::PATCH]) + .allow_headers(Any); let app = service_controller.router().layer(cors); @@ -165,14 +166,7 @@ async fn run_solver(config: Config) { .await .expect("Failed to bind port"); - // todo: fix the graceful shutdown - axum::serve(listener, app.into_make_service()) - // .with_graceful_shutdown(shutdown_signal(shutdown_tx.clone())) - .await - .unwrap(); - - let _ = shutdown_tx.send(()); - let _ = cache_update_handle.abort(); + axum::serve(listener, app.into_make_service()).await.unwrap(); info!("Server stopped."); } @@ -210,25 +204,3 @@ async fn run_indexer(config: Config) { Err(e) => error!("Indexer Job Failed: {}", e), }; } - -async fn shutdown_signal(shutdown_tx: broadcast::Sender<()>) { - let ctrl_c = async { - signal::ctrl_c().await.expect("Unable to handle ctrl+c"); - }; - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("Failed to install signal handler") - .recv() - .await; - }; - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } - - info!("signal received, starting graceful shutdown"); - let _ = shutdown_tx.send(()); -} diff --git a/crates/account-aggregation/src/service.rs b/crates/account-aggregation/src/service.rs index 1aa1369..68f9be2 100644 --- a/crates/account-aggregation/src/service.rs +++ b/crates/account-aggregation/src/service.rs @@ -1,3 +1,5 @@ +use futures::future::join_all; +use log::debug; use std::sync::Arc; use thiserror::Error; @@ -10,8 +12,8 @@ use storage::mongodb_client::{DBError, MongoDBClient}; use storage::DBProvider; use crate::types::{ - Account, AddAccountPayload, ApiResponse, Balance, RegisterAccountPayload, User, - UserAccountMapping, UserAccountMappingQuery, UserQuery, + Account, AddAccountPayload, CovalentApiResponse, ExtractedBalance, RegisterAccountPayload, + User, UserAccountMapping, UserAccountMappingQuery, UserQuery, }; #[derive(Error, Debug)] @@ -36,10 +38,9 @@ pub enum AccountAggregationError { /// /// This service is responsible for managing user accounts and their balances /// It interacts with the user and account mapping databases to store and retrieve user account information - #[derive(Clone, Display, Debug)] #[display( - "AccountAggregationService {{ user_db_provider: {:?}, account_mapping_db_provider: {:?} }}", +"AccountAggregationService {{ user_db_provider: {:?}, account_mapping_db_provider: {:?} }}", user_db_provider, account_mapping_db_provider )] @@ -82,11 +83,13 @@ impl AccountAggregationService { let query = self.account_mapping_db_provider.to_document(&UserAccountMappingQuery { account })?; let user_mapping = - self.account_mapping_db_provider.read(&query).await?.ok_or_else(|| { - AccountAggregationError::CustomError("User mapping not found".to_string()) - })?; + self.account_mapping_db_provider.read(&query).await.unwrap_or(None); + if user_mapping.is_none() { + return Ok(None); + } Ok(Some( user_mapping + .unwrap() .get_str("user_id") .map_err(|e| AccountAggregationError::CustomError(e.to_string()))? .to_string(), @@ -106,18 +109,26 @@ impl AccountAggregationService { })?; let accounts = user .get_array("accounts") - .map_err(|e| AccountAggregationError::CustomError(e.to_string()))?; + .map_err(|e| AccountAggregationError::CustomError(e.to_string())); - let accounts: Vec = accounts + let accounts: Vec = accounts? .iter() .filter_map(|account| { let account = account.as_document()?; - let chain_id = account.get_str("chain_id").ok()?.to_string(); + // let chain_id = account.get_str("chain_id").ok()?.to_string(); + let address = account.get_str("address").ok()?.to_string(); let is_enabled = account.get_bool("is_enabled").ok()?; - let account_address = account.get_str("account_address").ok()?.to_string(); let account_type = account.get_str("account_type").ok()?.to_string(); - - Some(Account { chain_id, is_enabled, account_address, account_type }) + let tags = account + .get_array("tags") + .map(|tags| { + tags.iter() + .filter_map(|tag| tag.as_str().map(|tag| tag.to_string())) + .collect() + }) + .unwrap_or_default(); + + Some(Account { address, is_enabled, account_type, tags }) }) .collect(); @@ -129,33 +140,45 @@ impl AccountAggregationService { &self, account_payload: RegisterAccountPayload, ) -> Result<(), AccountAggregationError> { - let account = account_payload.account.to_lowercase(); - - if self.get_user_id(&account).await?.is_none() { - let user_id = Uuid::new_v4().to_string(); - let user_doc = self.user_db_provider.to_document(&User { - user_id: user_id.clone(), - accounts: vec![Account { - chain_id: account_payload.chain_id, - is_enabled: account_payload.is_enabled, - account_address: account.clone(), - account_type: account_payload.account_type, - }], - })?; + // Modify all accounts address to lowercase + let all_accounts: Vec<_> = account_payload + .accounts + .into_iter() + .map(|account| Account { + address: account.address.to_lowercase(), + account_type: account.account_type, + is_enabled: account.is_enabled, + tags: account.tags, + }) + .collect(); + + for account in all_accounts.iter() { + let user_id = self.get_user_id(&account.address).await; + match user_id { + Ok(Some(_)) => { + return Err(AccountAggregationError::CustomError( + "Account already mapped to a user".to_string(), + )); + } + Ok(None) => {} + Err(_e) => {} + } + } - self.user_db_provider.create(&user_doc).await?; + let user_id = Uuid::new_v4().to_string(); + let user = User { user_id: user_id.clone(), accounts: all_accounts.clone() }; + let user_doc = self.user_db_provider.to_document(&user)?; + self.user_db_provider.create(&user_doc).await?; + for account in all_accounts { let mapping_doc = self.account_mapping_db_provider.to_document(&UserAccountMapping { - account: account.clone(), user_id: user_id.clone(), + account: account.address.clone(), })?; self.account_mapping_db_provider.create(&mapping_doc).await?; - } else { - return Err(AccountAggregationError::CustomError( - "Account already mapped to a user".to_string(), - )); } + Ok(()) } @@ -164,51 +187,57 @@ impl AccountAggregationService { &self, account_payload: AddAccountPayload, ) -> Result<(), AccountAggregationError> { - let new_account = Account { - chain_id: account_payload.chain_id.clone(), - is_enabled: account_payload.is_enabled, - account_address: account_payload.account.to_lowercase(), - account_type: account_payload.account_type.clone(), - }; - - // Check if the account is already mapped to a user - if self.get_user_id(&new_account.account_address).await?.is_some() { - return Err(AccountAggregationError::CustomError( - "Account already mapped to a user".to_string(), - )); - } - // Fetch the user document let query_doc = self .user_db_provider - .to_document(&UserQuery { user_id: account_payload.user_id.clone() })?; - // Retrieve user document + .to_document(&UserQuery { user_id: account_payload.user_id.clone().unwrap() })?; let mut user_doc = self.user_db_provider.read(&query_doc).await?.ok_or_else(|| { AccountAggregationError::CustomError("User not found".to_string()) })?; - // Add the new account to the user's accounts array + // Convert all account addresses to lowercase + let mut new_accounts = vec![]; + for account in account_payload.account { + new_accounts.push(Account { + address: account.address.to_lowercase(), + account_type: account.account_type, + is_enabled: account.is_enabled, + tags: account.tags, + }); + } + + // Add the new accounts to the user's accounts array let accounts_array = user_doc.entry("accounts".to_owned()).or_insert_with(|| bson::Bson::Array(vec![])); if let bson::Bson::Array(accounts) = accounts_array { - accounts.push(bson::to_bson(&new_account)?); + for new_account in new_accounts.iter() { + if self.get_user_id(&new_account.address).await?.is_some() { + return Err(AccountAggregationError::CustomError( + "Account already mapped to a user".to_string(), + )); + } + accounts.push(bson::to_bson(new_account)?); + } } else { return Err(AccountAggregationError::CustomError( "Failed to update accounts array".to_string(), )); } - // Update the user document with the new account + // Update the user document with the new accounts self.user_db_provider.update(&query_doc, &user_doc).await?; - // Create a new mapping document for the account - let mapping_doc = self.account_mapping_db_provider.to_document(&UserAccountMapping { - account: new_account.account_address.clone(), - user_id: account_payload.user_id.clone(), - })?; - self.account_mapping_db_provider.create(&mapping_doc).await?; + // Create a new mapping document for each account + for new_account in new_accounts { + let mapping_doc = + self.account_mapping_db_provider.to_document(&UserAccountMapping { + account: new_account.address.clone(), + user_id: account_payload.user_id.clone().unwrap(), + })?; + self.account_mapping_db_provider.create(&mapping_doc).await?; + } Ok(()) } @@ -217,12 +246,12 @@ impl AccountAggregationService { pub async fn get_user_accounts_balance( &self, account: &String, - ) -> Result, AccountAggregationError> { + ) -> Result, AccountAggregationError> { let mut accounts: Vec = Vec::new(); let user_id = self.get_user_id(account).await.unwrap_or(None); if let Some(user_id) = user_id { let user_accounts = self.get_user_accounts(&user_id).await?.unwrap(); - accounts.extend(user_accounts.iter().map(|account| account.account_address.clone())); + accounts.extend(user_accounts.iter().map(|account| account.address.clone())); } else { accounts.push(account.clone()); } @@ -230,17 +259,43 @@ impl AccountAggregationService { let mut balances = Vec::new(); let networks = self.networks.clone(); - // todo: parallelize this - for user in accounts.iter() { - for network in networks.iter() { - let url = format!( - "{}/v1/{}/address/{}/balances_v2/?key={}", - self.covalent_base_url, network, user, self.covalent_api_key - ); - let response = self.client.get(&url).send().await?; - let api_response: ApiResponse = response.json().await?; - let user_balances = extract_balance_data(api_response)?; - balances.extend(user_balances); + // Prepare tasks for parallel execution + let tasks: Vec<_> = accounts + .iter() + .flat_map(|user| { + networks.iter().map(move |network| { + let url = format!( + "{}/v1/{}/address/{}/balances_v2/?key={}", + self.covalent_base_url, network, user, self.covalent_api_key + ); + debug!("Fetching balance from: {}", url); + let client = self.client.clone(); + async move { + let response = client.get(&url).send().await; + match response { + Ok(response) => { + let api_response: Result = + response.json().await; + match api_response { + Ok(api_response) => extract_balance_data(api_response), + Err(e) => Err(AccountAggregationError::ReqwestError(e)), + } + } + Err(e) => Err(AccountAggregationError::ReqwestError(e)), + } + } + }) + }) + .collect(); + + // Execute tasks concurrently + let results = join_all(tasks).await; + + // Collect results + for result in results { + match result { + Ok(user_balances) => balances.extend(user_balances), + Err(e) => debug!("Failed to fetch balance: {:?}", e), } } @@ -250,8 +305,8 @@ impl AccountAggregationService { /// Extract balance data from the API response fn extract_balance_data( - api_response: ApiResponse, -) -> Result, AccountAggregationError> { + api_response: CovalentApiResponse, +) -> Result, AccountAggregationError> { let chain_id = api_response.data.chain_id.to_string(); let results = api_response .data @@ -273,7 +328,7 @@ fn extract_balance_data( } else { let balance = balance_raw / 10f64.powf(item.contract_decimals.unwrap() as f64); - Some(Balance { + Some(ExtractedBalance { token: token.clone(), token_address: item.contract_ticker_symbol.clone().unwrap(), chain_id: chain_id.clone().parse::().unwrap(), diff --git a/crates/account-aggregation/src/service_trait.rs b/crates/account-aggregation/src/service_trait.rs index e84e3f7..199c840 100644 --- a/crates/account-aggregation/src/service_trait.rs +++ b/crates/account-aggregation/src/service_trait.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use storage::mongodb_client::MongoDBClient; -use crate::types::{Account, AddAccountPayload, Balance, RegisterAccountPayload}; +use crate::types::{Account, AddAccountPayload, ExtractedBalance, RegisterAccountPayload}; #[async_trait] pub trait AccountAggregationServiceTrait { @@ -21,5 +21,5 @@ pub trait AccountAggregationServiceTrait { account_payload: RegisterAccountPayload, ) -> Result<(), Box>; fn add_account(&self, account_payload: AddAccountPayload) -> Result<(), Box>; - fn get_user_accounts_balance(&self, account: &String) -> Vec; + fn get_user_accounts_balance(&self, account: &String) -> Vec; } diff --git a/crates/account-aggregation/src/types.rs b/crates/account-aggregation/src/types.rs index a88481d..da6871e 100644 --- a/crates/account-aggregation/src/types.rs +++ b/crates/account-aggregation/src/types.rs @@ -1,18 +1,18 @@ use serde::{Deserialize, Serialize}; #[derive(Deserialize, Serialize, Debug)] -pub struct ApiResponse { - pub data: ApiData, +pub struct CovalentApiResponse { + pub data: CovalentApiData, } #[derive(Deserialize, Serialize, Debug)] -pub struct ApiData { - pub items: Vec, +pub struct CovalentApiData { + pub items: Vec, pub chain_id: u32, } #[derive(Deserialize, Serialize, Debug)] -pub struct TokenData { +pub struct CovalentTokenData { pub contract_ticker_symbol: Option, pub balance: Option, pub quote: Option, @@ -20,7 +20,7 @@ pub struct TokenData { } #[derive(Deserialize, Serialize, Debug)] -pub struct Balance { +pub struct ExtractedBalance { pub token: String, pub token_address: String, pub chain_id: u32, @@ -35,12 +35,12 @@ pub struct User { pub accounts: Vec, } -#[derive(Deserialize, Serialize, Debug)] +#[derive(Deserialize, Serialize, Debug, Clone)] pub struct Account { - pub chain_id: String, - pub is_enabled: bool, - pub account_address: String, + pub address: String, pub account_type: String, + pub is_enabled: bool, + pub tags: Vec, } #[derive(Deserialize, Serialize, Debug)] @@ -63,19 +63,13 @@ pub struct UserAccountMappingQuery { // Register Account Payload (same as Account) #[derive(Deserialize, Serialize, Debug)] pub struct RegisterAccountPayload { - pub account: String, - pub account_type: String, - pub chain_id: String, - pub is_enabled: bool, + pub accounts: Vec, } // Add Account Payload (need to add user_id) #[derive(Deserialize, Serialize, Debug)] pub struct AddAccountPayload { - pub user_id: String, - pub account: String, - pub account_type: String, - pub chain_id: String, - pub is_enabled: bool, + pub user_id: Option, + pub account: Vec } // Path Query Model diff --git a/crates/api/src/service_controller.rs b/crates/api/src/service_controller.rs index 6ff09d2..fe31ba5 100644 --- a/crates/api/src/service_controller.rs +++ b/crates/api/src/service_controller.rs @@ -7,16 +7,26 @@ use std::{collections::HashMap, sync::Arc}; pub struct ServiceController { account_service: Arc, routing_engine: Arc, - token_supported: HashMap>, + token_chain_map: HashMap>, + chain_supported: Vec<(u32, String)>, + token_supported: Vec, } impl ServiceController { pub fn new( account_service: AccountAggregationService, routing_engine: Arc, - token_supported: HashMap>, + token_chain_map: HashMap>, + chain_supported: Vec<(u32, String)>, + token_supported: Vec, ) -> Self { - Self { account_service: Arc::new(account_service), routing_engine, token_supported } + Self { + account_service: Arc::new(account_service), + routing_engine, + token_chain_map, + chain_supported, + token_supported, + } } pub fn router(&self) -> Router { @@ -26,6 +36,10 @@ impl ServiceController { .route( "/api/account", get({ + // todo: @ankurdubey521 should we use path instead of query here? + // move |Path(account): Path| async move { + // ServiceController::get_account(account_service, account).await + // } let account_service = self.account_service.clone(); move |Query(query): Query| async move { ServiceController::get_account(account_service, query).await @@ -33,7 +47,7 @@ impl ServiceController { }), ) .route( - "/api/register_account", + "/api/account", axum::routing::post({ let account_service = self.account_service.clone(); move |Json(payload): Json| async move { @@ -42,21 +56,40 @@ impl ServiceController { }), ) .route( - "/api/add_account", - axum::routing::post({ + "/api/account", + axum::routing::patch({ let account_service = self.account_service.clone(); move |Json(payload): Json| async move { ServiceController::add_account(account_service, payload).await } }), ) + .route( + "/api/config", + get({ + let chain_supported = self.chain_supported.clone(); + let token_supported = self.token_supported.clone(); + move || async move { + ServiceController::get_config(chain_supported, token_supported) + } + }), + ) + .route( + "/api/balance", + get({ + let account_service = self.account_service.clone(); + move |Query(query): Query| async move { + ServiceController::get_balance(account_service, query).await + } + }), + ) .route( "/api/get_best_path", get({ let routing_engine = self.routing_engine.clone(); - let token_supported = self.token_supported.clone(); + let token_chain_map = self.token_chain_map.clone(); move |Query(query): Query| async move { - ServiceController::get_best_path(routing_engine, token_supported, query) + ServiceController::get_best_path(routing_engine, token_chain_map, query) .await } }), @@ -87,10 +120,7 @@ impl ServiceController { (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": err.to_string() }))) } }, - Ok(None) => ( - StatusCode::NOT_FOUND, - Json(json!({ "error": "User not found", "accounts": [query.account] })), - ), + Ok(None) => (StatusCode::NOT_FOUND, Json(json!({ "error": "User not found" }))), Err(err) => { (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": err.to_string() }))) } @@ -131,14 +161,50 @@ impl ServiceController { } } + /// Get all supported chains and tokens + pub fn get_config( + chain_supported: Vec<(u32, String)>, + token_supported: Vec, + ) -> impl IntoResponse { + let response = json!({ + "chains": chain_supported, + "tokens": token_supported + }); + (StatusCode::OK, Json(response)) + } + + /// Get user account balance + pub async fn get_balance( + account_service: Arc, + query: types::UserAccountMappingQuery, + ) -> impl IntoResponse { + println!("get_balance {:?}", query.account); + match account_service.get_user_accounts_balance(&query.account).await { + Ok(balances) => { + // for loop to add the balance in USD + let total_balance = + balances.iter().fold(0.0, |acc, balance| acc + balance.amount_in_usd); + let response = json!({ + "total_balance": total_balance, + "balances": balances + }); + (StatusCode::OK, Json(response)) + } + Err(err) => { + let response = json!({ "error": err.to_string() }); + (StatusCode::INTERNAL_SERVER_ERROR, Json(response)) + } + } + } + /// Get best cost path for asset consolidation pub async fn get_best_path( routing_engine: Arc, - token_supported: HashMap>, + token_chain_map: HashMap>, query: types::PathQuery, ) -> impl IntoResponse { // Check for the supported chain and token - match token_supported.get(&query.to_token) { + match token_chain_map.get(&query.to_token) { Some(chain_supported) => match chain_supported.get(&query.to_chain) { Some(supported) => { if !supported { diff --git a/crates/routing-engine/src/engine.rs b/crates/routing-engine/src/engine.rs index 3d7b35e..4f635cb 100644 --- a/crates/routing-engine/src/engine.rs +++ b/crates/routing-engine/src/engine.rs @@ -6,8 +6,7 @@ use log::{debug, error, info}; use thiserror::Error; use tokio::sync::RwLock; -use account_aggregation::service::AccountAggregationService; -use account_aggregation::types::Balance; +use account_aggregation::{service::AccountAggregationService, types::ExtractedBalance}; use config::{config::BucketConfig, ChainConfig, SolverConfig, TokenConfig}; use storage::{RedisClient, RedisClientError}; @@ -112,24 +111,25 @@ impl RoutingEngine { // Sort direct assets by A^x / C^y, here x=2 and y=1 let x = self.estimates.x_value; let y = self.estimates.y_value; - let mut sorted_assets: Vec<(&Balance, f64)> = stream::iter(direct_assets.into_iter()) - .then(|balance| async move { - let fee_cost = self - .get_cached_data( - balance.amount_in_usd, - PathQuery( - balance.chain_id, - to_chain, - balance.token.to_string(), - to_token.to_string(), - ), - ) - .await - .unwrap_or_default(); - (balance, fee_cost) - }) - .collect() - .await; + let mut sorted_assets: Vec<(&ExtractedBalance, f64)> = + stream::iter(direct_assets.into_iter()) + .then(|balance| async move { + let fee_cost = self + .get_cached_data( + balance.amount_in_usd, + PathQuery( + balance.chain_id, + to_chain, + balance.token.to_string(), + to_token.to_string(), + ), + ) + .await + .unwrap_or_default(); + (balance, fee_cost) + }) + .collect() + .await; sorted_assets.sort_by(|a, b| { let cost_a = (a.0.amount.powf(x)) / (a.1.powf(y)); @@ -187,26 +187,27 @@ impl RoutingEngine { // Handle swap/bridge for remaining amount if needed (non direct assets) if total_amount_needed > 0.0 { - let swap_assets: Vec<&Balance> = + let swap_assets: Vec<&ExtractedBalance> = user_balances.iter().filter(|balance| balance.token != to_token).collect(); - let mut sorted_assets: Vec<(&Balance, f64)> = stream::iter(swap_assets.into_iter()) - .then(|balance| async move { - let fee_cost = self - .get_cached_data( - balance.amount_in_usd, - PathQuery( - balance.chain_id, - to_chain, - balance.token.clone(), - to_token.to_string(), - ), - ) - .await - .unwrap_or_default(); - (balance, fee_cost) - }) - .collect() - .await; + let mut sorted_assets: Vec<(&ExtractedBalance, f64)> = + stream::iter(swap_assets.into_iter()) + .then(|balance| async move { + let fee_cost = self + .get_cached_data( + balance.amount_in_usd, + PathQuery( + balance.chain_id, + to_chain, + balance.token.clone(), + to_token.to_string(), + ), + ) + .await + .unwrap_or_default(); + (balance, fee_cost) + }) + .collect() + .await; sorted_assets.sort_by(|a, b| { let cost_a = (a.0.amount.powf(x)) / (a.1.powf(y)); @@ -316,14 +317,14 @@ impl RoutingEngine { async fn get_user_balance_from_agg_service( &self, account: &str, - ) -> Result, RoutingEngineError> { + ) -> Result, RoutingEngineError> { let balance = self .aas_client .get_user_accounts_balance(&account.to_string()) .await .map_err(|e| RoutingEngineError::UserBalanceFetchError(e.to_string()))?; - let balance = balance + let balance: Vec = balance .into_iter() .filter(|balance| { self.chain_configs.contains_key(&balance.chain_id) diff --git a/crates/routing-engine/src/source/bungee/mod.rs b/crates/routing-engine/src/source/bungee/mod.rs index 9c41413..08ae08f 100644 --- a/crates/routing-engine/src/source/bungee/mod.rs +++ b/crates/routing-engine/src/source/bungee/mod.rs @@ -367,6 +367,7 @@ mod tests { to_chain: &config.chains.get(&42161).unwrap(), from_token: &config.tokens.get(&"USDC".to_string()).unwrap(), to_token: &config.tokens.get(&"USDC".to_string()).unwrap(), + amount_in_usd: 100000000.0, is_smart_contract_deposit: false, };