diff --git a/Cargo.toml b/Cargo.toml index 5178027..57e6079 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] +resolver = "2" members = [ "crates/config", @@ -15,3 +16,4 @@ config = { path = "crates/config" } storage = { path = "crates/storage" } routing-engine = { path = "crates/routing-engine" } account-aggregation = { path = "crates/account-aggregation" } + diff --git a/Dockerfile b/Dockerfile index 7b3570b..2349bac 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,4 +8,3 @@ RUN apt-get update RUN apt-get upgrade -y RUN apt-get install -y libssl-dev ca-certificates COPY --from=builder /usr/local/cargo/bin/reflux /app/reflux - diff --git a/README.md b/README.md index dd63647..fd5f27b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ ## reflux -Backend of solver which helps in seamless cross-chain asset consolidation. It aggregates user balances, automates routing, and suggests optimal transactions. +Backend of solver which helps in seamless cross-chain asset consolidation. It aggregates user balances, automates +routing, and suggests optimal transactions. #### Installation @@ -18,4 +19,4 @@ Once build is copleted, just run the server and test with the endpoints ### Dependencies graph -![image](./graph.png) +![image](./assets/dependency-graph.png) diff --git a/graph.png b/assets/dependency-graph.png similarity index 100% rename from graph.png rename to assets/dependency-graph.png diff --git a/bin/reflux/src/main.rs b/bin/reflux/src/main.rs index 2292eda..d1f8cb4 100644 --- a/bin/reflux/src/main.rs +++ b/bin/reflux/src/main.rs @@ -1,21 +1,21 @@ +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use axum::http::Method; use clap::Parser; use log::{debug, error, info}; -use tokio::signal; -use tokio::sync::broadcast; use tower_http::cors::{Any, CorsLayer}; use account_aggregation::service::AccountAggregationService; use api::service_controller::ServiceController; use config::Config; -use routing_engine::engine::RoutingEngine; -use routing_engine::estimator::LinearRegressionEstimator; use routing_engine::{BungeeClient, CoingeckoClient, Indexer}; -use storage::mongodb_client::MongoDBClient; +use routing_engine::estimator::LinearRegressionEstimator; +use routing_engine::routing_engine::RoutingEngine; +use routing_engine::settlement_engine::{generate_erc20_instance_map, SettlementEngine}; use storage::{ControlFlow, MessageQueue, RedisClient}; +use storage::mongodb_client::MongoDBClient; #[derive(Parser, Debug)] struct Args { @@ -49,7 +49,7 @@ async fn main() { } // Load configuration from yaml - let config = Config::from_file(&args.config).expect("Failed to load config file"); + let config = Arc::new(Config::from_file(&args.config).expect("Failed to load config file")); if args.indexer { run_indexer(config).await; @@ -58,7 +58,7 @@ async fn main() { } } -async fn run_solver(config: Config) { +async fn run_solver(config: Arc) { info!("Starting Reflux Server"); let (app_host, app_port) = (config.server.host.clone(), config.server.port.clone()); @@ -88,16 +88,18 @@ async fn run_solver(config: Config) { config.chains.iter().map(|(_, chain)| chain.covalent_name.clone()).collect(); // Initialize account aggregation service for api - let account_service = AccountAggregationService::new( + let account_service = Arc::new(AccountAggregationService::new( user_db_provider.clone(), account_mapping_db_provider.clone(), networks, covalent_base_url, covalent_api_key, - ); + )); // Initialize routing engine let buckets = config.buckets.clone(); + let chain_configs = config.chains.clone(); + let token_configs = config.tokens.clone(); let redis_client = RedisClient::build(&config.infra.redis_url) .await .expect("Failed to instantiate redis client"); @@ -105,16 +107,33 @@ async fn run_solver(config: Config) { account_service.clone(), buckets, redis_client.clone(), - config.solver_config, + config.solver_config.clone(), + chain_configs, + token_configs, + )); + + // Initialize Settlement Engine and Dependencies + let erc20_instance_map = generate_erc20_instance_map(&config).unwrap(); + let bungee_client = BungeeClient::new(&config.bungee.base_url, &config.bungee.api_key) + .expect("Failed to Instantiate Bungee Client"); + let token_price_provider = CoingeckoClient::new( + config.coingecko.base_url.clone(), + config.coingecko.api_key.clone(), + redis_client.clone(), + Duration::from_secs(config.coingecko.expiry_sec), + ); + let settlement_engine = Arc::new(SettlementEngine::new( + Arc::clone(&config), + bungee_client, + token_price_provider, + erc20_instance_map, )); // Subscribe to cache update messages let cache_update_topic = config.indexer_config.indexer_update_topic.clone(); let routing_engine_clone = Arc::clone(&routing_engine); - let (shutdown_tx, mut shutdown_rx) = broadcast::channel(1); - - let cache_update_handle = tokio::spawn(async move { + tokio::task::spawn_blocking(move || { let redis_client = redis_client.clone(); if let Err(e) = redis_client.subscribe(&cache_update_topic, move |_msg| { info!("Received cache update notification"); @@ -126,19 +145,39 @@ async fn run_solver(config: Config) { }) { error!("Failed to subscribe to cache update topic: {}", e); } - - // Listen for shutdown signal - let _ = shutdown_rx.recv().await; }); + let token_chain_map: HashMap> = config + .tokens + .iter() + .map(|(token, token_config)| { + let chain_supported = token_config + .by_chain + .iter() + .map(|(chain_id, chain_config)| (*chain_id, chain_config.is_enabled)) + .collect(); + (token.clone(), chain_supported) + }) + .collect(); + // API service controller - let service_controller = ServiceController::new(account_service, routing_engine); + let chain_supported: Vec<(u32, String)> = + config.chains.iter().map(|(id, chain)| (*id, chain.name.clone())).collect(); + let token_supported: Vec = + config.tokens.iter().map(|(_, token_config)| token_config.symbol.clone()).collect(); + let service_controller = ServiceController::new( + account_service, + routing_engine, + settlement_engine, + token_chain_map, + chain_supported, + token_supported, + ); - let cors = CorsLayer::new().allow_origin(Any).allow_methods([ - Method::GET, - Method::POST, - Method::PATCH, - ]); + let cors = CorsLayer::new() + .allow_origin(Any) + .allow_methods([Method::GET, Method::POST, Method::PATCH]) + .allow_headers(Any); let app = service_controller.router().layer(cors); @@ -146,18 +185,12 @@ async fn run_solver(config: Config) { .await .expect("Failed to bind port"); - axum::serve(listener, app.into_make_service()) - .with_graceful_shutdown(shutdown_signal(shutdown_tx.clone())) - .await - .unwrap(); - - let _ = shutdown_tx.send(()); - let _ = cache_update_handle.abort(); + axum::serve(listener, app.into_make_service()).await.unwrap(); info!("Server stopped."); } -async fn run_indexer(config: Config) { +async fn run_indexer(config: Arc) { info!("Configuring Indexer"); let config = config; @@ -170,19 +203,19 @@ async fn run_indexer(config: Config) { .expect("Failed to Instantiate Bungee Client"); let token_price_provider = CoingeckoClient::new( - &config.coingecko.base_url, - &config.coingecko.api_key, - &redis_provider, + config.coingecko.base_url.clone(), + config.coingecko.api_key.clone(), + redis_provider.clone(), Duration::from_secs(config.coingecko.expiry_sec), ); let indexer: Indexer> = Indexer::new( - &config, - &bungee_client, - &redis_provider, - &redis_provider, - &token_price_provider, + config, + bungee_client, + redis_provider.clone(), + redis_provider.clone(), + token_price_provider, ); match indexer.run::().await { @@ -190,25 +223,3 @@ async fn run_indexer(config: Config) { Err(e) => error!("Indexer Job Failed: {}", e), }; } - -async fn shutdown_signal(shutdown_tx: broadcast::Sender<()>) { - let ctrl_c = async { - signal::ctrl_c().await.expect("Unable to handle ctrl+c"); - }; - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("Failed to install signal handler") - .recv() - .await; - }; - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } - - info!("signal received, starting graceful shutdown"); - let _ = shutdown_tx.send(()); -} diff --git a/config.yaml.example b/config.yaml.example index 7705e42..2608e2b 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -3,10 +3,12 @@ chains: name: Ethereum covalent_name: eth-mainnet is_enabled: true + rpc_url: https://rpc.ankr.com/eth - id: 42161 name: Arbitrum is_enabled: true - covalent_name: bsc-mainnet + covalent_name: arbitrum-mainnet + rpc_url: https://arbitrum.llamarpc.com tokens: - symbol: USDC is_enabled: true @@ -16,6 +18,18 @@ tokens: is_enabled: true decimals: 6 address: '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48' + 42161: + is_enabled: true + decimals: 6 + address: '0xaf88d065e77c8cC2239327C5EDb3A432268e5831' + - symbol: USDT + is_enabled: true + coingecko_symbol: tether + by_chain: + 1: + is_enabled: true + decimals: 6 + address: '0xdac17f958d2ee523a2206206994597c13d831ec7' 42161: is_enabled: true decimals: 6 diff --git a/crates/account-aggregation/src/service.rs b/crates/account-aggregation/src/service.rs index 1aa1369..b7c47c1 100644 --- a/crates/account-aggregation/src/service.rs +++ b/crates/account-aggregation/src/service.rs @@ -1,17 +1,19 @@ +use futures::future::join_all; +use log::debug; use std::sync::Arc; -use thiserror::Error; use derive_more::Display; use mongodb::bson; use reqwest::Client as ReqwestClient; +use thiserror::Error; use uuid::Uuid; -use storage::mongodb_client::{DBError, MongoDBClient}; use storage::DBProvider; +use storage::mongodb_client::{DBError, MongoDBClient}; use crate::types::{ - Account, AddAccountPayload, ApiResponse, Balance, RegisterAccountPayload, User, - UserAccountMapping, UserAccountMappingQuery, UserQuery, + Account, AddAccountPayload, CovalentApiResponse, TokenWithBalance, RegisterAccountPayload, + User, UserAccountMapping, UserAccountMappingQuery, UserQuery, }; #[derive(Error, Debug)] @@ -36,10 +38,9 @@ pub enum AccountAggregationError { /// /// This service is responsible for managing user accounts and their balances /// It interacts with the user and account mapping databases to store and retrieve user account information - #[derive(Clone, Display, Debug)] #[display( - "AccountAggregationService {{ user_db_provider: {:?}, account_mapping_db_provider: {:?} }}", +"AccountAggregationService {{ user_db_provider: {:?}, account_mapping_db_provider: {:?} }}", user_db_provider, account_mapping_db_provider )] @@ -82,11 +83,13 @@ impl AccountAggregationService { let query = self.account_mapping_db_provider.to_document(&UserAccountMappingQuery { account })?; let user_mapping = - self.account_mapping_db_provider.read(&query).await?.ok_or_else(|| { - AccountAggregationError::CustomError("User mapping not found".to_string()) - })?; + self.account_mapping_db_provider.read(&query).await.unwrap_or(None); + if user_mapping.is_none() { + return Ok(None); + } Ok(Some( user_mapping + .unwrap() .get_str("user_id") .map_err(|e| AccountAggregationError::CustomError(e.to_string()))? .to_string(), @@ -106,18 +109,26 @@ impl AccountAggregationService { })?; let accounts = user .get_array("accounts") - .map_err(|e| AccountAggregationError::CustomError(e.to_string()))?; + .map_err(|e| AccountAggregationError::CustomError(e.to_string())); - let accounts: Vec = accounts + let accounts: Vec = accounts? .iter() .filter_map(|account| { let account = account.as_document()?; - let chain_id = account.get_str("chain_id").ok()?.to_string(); + // let chain_id = account.get_str("chain_id").ok()?.to_string(); + let address = account.get_str("address").ok()?.to_string(); let is_enabled = account.get_bool("is_enabled").ok()?; - let account_address = account.get_str("account_address").ok()?.to_string(); let account_type = account.get_str("account_type").ok()?.to_string(); - - Some(Account { chain_id, is_enabled, account_address, account_type }) + let tags = account + .get_array("tags") + .map(|tags| { + tags.iter() + .filter_map(|tag| tag.as_str().map(|tag| tag.to_string())) + .collect() + }) + .unwrap_or_default(); + + Some(Account { address, is_enabled, account_type, tags }) }) .collect(); @@ -129,33 +140,45 @@ impl AccountAggregationService { &self, account_payload: RegisterAccountPayload, ) -> Result<(), AccountAggregationError> { - let account = account_payload.account.to_lowercase(); - - if self.get_user_id(&account).await?.is_none() { - let user_id = Uuid::new_v4().to_string(); - let user_doc = self.user_db_provider.to_document(&User { - user_id: user_id.clone(), - accounts: vec![Account { - chain_id: account_payload.chain_id, - is_enabled: account_payload.is_enabled, - account_address: account.clone(), - account_type: account_payload.account_type, - }], - })?; + // Modify all accounts address to lowercase + let all_accounts: Vec<_> = account_payload + .accounts + .into_iter() + .map(|account| Account { + address: account.address.to_lowercase(), + account_type: account.account_type, + is_enabled: account.is_enabled, + tags: account.tags, + }) + .collect(); + + for account in all_accounts.iter() { + let user_id = self.get_user_id(&account.address).await; + match user_id { + Ok(Some(_)) => { + return Err(AccountAggregationError::CustomError( + "Account already mapped to a user".to_string(), + )); + } + Ok(None) => {} + Err(_e) => {} + } + } - self.user_db_provider.create(&user_doc).await?; + let user_id = Uuid::new_v4().to_string(); + let user = User { user_id: user_id.clone(), accounts: all_accounts.clone() }; + let user_doc = self.user_db_provider.to_document(&user)?; + self.user_db_provider.create(&user_doc).await?; + for account in all_accounts { let mapping_doc = self.account_mapping_db_provider.to_document(&UserAccountMapping { - account: account.clone(), user_id: user_id.clone(), + account: account.address.clone(), })?; self.account_mapping_db_provider.create(&mapping_doc).await?; - } else { - return Err(AccountAggregationError::CustomError( - "Account already mapped to a user".to_string(), - )); } + Ok(()) } @@ -164,51 +187,57 @@ impl AccountAggregationService { &self, account_payload: AddAccountPayload, ) -> Result<(), AccountAggregationError> { - let new_account = Account { - chain_id: account_payload.chain_id.clone(), - is_enabled: account_payload.is_enabled, - account_address: account_payload.account.to_lowercase(), - account_type: account_payload.account_type.clone(), - }; - - // Check if the account is already mapped to a user - if self.get_user_id(&new_account.account_address).await?.is_some() { - return Err(AccountAggregationError::CustomError( - "Account already mapped to a user".to_string(), - )); - } - // Fetch the user document let query_doc = self .user_db_provider - .to_document(&UserQuery { user_id: account_payload.user_id.clone() })?; - // Retrieve user document + .to_document(&UserQuery { user_id: account_payload.user_id.clone().unwrap() })?; let mut user_doc = self.user_db_provider.read(&query_doc).await?.ok_or_else(|| { AccountAggregationError::CustomError("User not found".to_string()) })?; - // Add the new account to the user's accounts array + // Convert all account addresses to lowercase + let mut new_accounts = vec![]; + for account in account_payload.account { + new_accounts.push(Account { + address: account.address.to_lowercase(), + account_type: account.account_type, + is_enabled: account.is_enabled, + tags: account.tags, + }); + } + + // Add the new accounts to the user's accounts array let accounts_array = user_doc.entry("accounts".to_owned()).or_insert_with(|| bson::Bson::Array(vec![])); if let bson::Bson::Array(accounts) = accounts_array { - accounts.push(bson::to_bson(&new_account)?); + for new_account in new_accounts.iter() { + if self.get_user_id(&new_account.address).await?.is_some() { + return Err(AccountAggregationError::CustomError( + "Account already mapped to a user".to_string(), + )); + } + accounts.push(bson::to_bson(new_account)?); + } } else { return Err(AccountAggregationError::CustomError( "Failed to update accounts array".to_string(), )); } - // Update the user document with the new account + // Update the user document with the new accounts self.user_db_provider.update(&query_doc, &user_doc).await?; - // Create a new mapping document for the account - let mapping_doc = self.account_mapping_db_provider.to_document(&UserAccountMapping { - account: new_account.account_address.clone(), - user_id: account_payload.user_id.clone(), - })?; - self.account_mapping_db_provider.create(&mapping_doc).await?; + // Create a new mapping document for each account + for new_account in new_accounts { + let mapping_doc = + self.account_mapping_db_provider.to_document(&UserAccountMapping { + account: new_account.address.clone(), + user_id: account_payload.user_id.clone().unwrap(), + })?; + self.account_mapping_db_provider.create(&mapping_doc).await?; + } Ok(()) } @@ -217,12 +246,12 @@ impl AccountAggregationService { pub async fn get_user_accounts_balance( &self, account: &String, - ) -> Result, AccountAggregationError> { + ) -> Result, AccountAggregationError> { let mut accounts: Vec = Vec::new(); let user_id = self.get_user_id(account).await.unwrap_or(None); if let Some(user_id) = user_id { let user_accounts = self.get_user_accounts(&user_id).await?.unwrap(); - accounts.extend(user_accounts.iter().map(|account| account.account_address.clone())); + accounts.extend(user_accounts.iter().map(|account| account.address.clone())); } else { accounts.push(account.clone()); } @@ -230,17 +259,42 @@ impl AccountAggregationService { let mut balances = Vec::new(); let networks = self.networks.clone(); - // todo: parallelize this - for user in accounts.iter() { - for network in networks.iter() { - let url = format!( - "{}/v1/{}/address/{}/balances_v2/?key={}", - self.covalent_base_url, network, user, self.covalent_api_key - ); - let response = self.client.get(&url).send().await?; - let api_response: ApiResponse = response.json().await?; - let user_balances = extract_balance_data(api_response)?; - balances.extend(user_balances); + // Prepare tasks for parallel execution + let tasks: Vec<_> = accounts + .iter() + .flat_map(|user| { + networks.iter().map(move |network| { + let url = format!( + "{}/v1/{}/address/{}/balances_v2/?key={}", + self.covalent_base_url, network, user, self.covalent_api_key + ); + let client = self.client.clone(); + async move { + let response = client.get(&url).send().await; + match response { + Ok(response) => { + let api_response: Result = + response.json().await; + match api_response { + Ok(api_response) => extract_balance_data(api_response, user), + Err(e) => Err(AccountAggregationError::ReqwestError(e)), + } + } + Err(e) => Err(AccountAggregationError::ReqwestError(e)), + } + } + }) + }) + .collect(); + + // Execute tasks concurrently + let results = join_all(tasks).await; + + // Collect results + for result in results { + match result { + Ok(user_balances) => balances.extend(user_balances), + Err(e) => debug!("Failed to fetch balance: {:?}", e), } } @@ -250,8 +304,9 @@ impl AccountAggregationService { /// Extract balance data from the API response fn extract_balance_data( - api_response: ApiResponse, -) -> Result, AccountAggregationError> { + api_response: CovalentApiResponse, + user: &String, +) -> Result, AccountAggregationError> { let chain_id = api_response.data.chain_id.to_string(); let results = api_response .data @@ -273,7 +328,8 @@ fn extract_balance_data( } else { let balance = balance_raw / 10f64.powf(item.contract_decimals.unwrap() as f64); - Some(Balance { + Some(TokenWithBalance { + address: user.clone(), token: token.clone(), token_address: item.contract_ticker_symbol.clone().unwrap(), chain_id: chain_id.clone().parse::().unwrap(), diff --git a/crates/account-aggregation/src/service_trait.rs b/crates/account-aggregation/src/service_trait.rs index e84e3f7..7073b65 100644 --- a/crates/account-aggregation/src/service_trait.rs +++ b/crates/account-aggregation/src/service_trait.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use storage::mongodb_client::MongoDBClient; -use crate::types::{Account, AddAccountPayload, Balance, RegisterAccountPayload}; +use crate::types::{Account, AddAccountPayload, RegisterAccountPayload, TokenWithBalance}; #[async_trait] pub trait AccountAggregationServiceTrait { @@ -21,5 +21,5 @@ pub trait AccountAggregationServiceTrait { account_payload: RegisterAccountPayload, ) -> Result<(), Box>; fn add_account(&self, account_payload: AddAccountPayload) -> Result<(), Box>; - fn get_user_accounts_balance(&self, account: &String) -> Vec; + fn get_user_accounts_balance(&self, account: &String) -> Vec; } diff --git a/crates/account-aggregation/src/types.rs b/crates/account-aggregation/src/types.rs index a88481d..08d4e6a 100644 --- a/crates/account-aggregation/src/types.rs +++ b/crates/account-aggregation/src/types.rs @@ -1,18 +1,18 @@ use serde::{Deserialize, Serialize}; #[derive(Deserialize, Serialize, Debug)] -pub struct ApiResponse { - pub data: ApiData, +pub struct CovalentApiResponse { + pub data: CovalentApiData, } #[derive(Deserialize, Serialize, Debug)] -pub struct ApiData { - pub items: Vec, +pub struct CovalentApiData { + pub items: Vec, pub chain_id: u32, } #[derive(Deserialize, Serialize, Debug)] -pub struct TokenData { +pub struct CovalentTokenData { pub contract_ticker_symbol: Option, pub balance: Option, pub quote: Option, @@ -20,7 +20,8 @@ pub struct TokenData { } #[derive(Deserialize, Serialize, Debug)] -pub struct Balance { +pub struct TokenWithBalance { + pub address: String, pub token: String, pub token_address: String, pub chain_id: u32, @@ -35,12 +36,12 @@ pub struct User { pub accounts: Vec, } -#[derive(Deserialize, Serialize, Debug)] +#[derive(Deserialize, Serialize, Debug, Clone)] pub struct Account { - pub chain_id: String, - pub is_enabled: bool, - pub account_address: String, + pub address: String, pub account_type: String, + pub is_enabled: bool, + pub tags: Vec, } #[derive(Deserialize, Serialize, Debug)] @@ -63,19 +64,13 @@ pub struct UserAccountMappingQuery { // Register Account Payload (same as Account) #[derive(Deserialize, Serialize, Debug)] pub struct RegisterAccountPayload { - pub account: String, - pub account_type: String, - pub chain_id: String, - pub is_enabled: bool, + pub accounts: Vec, } // Add Account Payload (need to add user_id) #[derive(Deserialize, Serialize, Debug)] pub struct AddAccountPayload { - pub user_id: String, - pub account: String, - pub account_type: String, - pub chain_id: String, - pub is_enabled: bool, + pub user_id: Option, + pub account: Vec } // Path Query Model diff --git a/crates/api/src/service_controller.rs b/crates/api/src/service_controller.rs index 6927af0..5372455 100644 --- a/crates/api/src/service_controller.rs +++ b/crates/api/src/service_controller.rs @@ -1,63 +1,112 @@ -use account_aggregation::{service::AccountAggregationService, types}; -use axum::{extract::Query, http::StatusCode, response::IntoResponse, routing::get, Json, Router}; -use routing_engine::engine::RoutingEngine; -use serde_json::json; +use std::collections::HashMap; use std::sync::Arc; -pub struct ServiceController { +use axum::{extract::Query, http::StatusCode, Json, response::IntoResponse, Router, routing::get}; +use serde_json::json; + +use account_aggregation::{service::AccountAggregationService, types}; +use routing_engine::routing_engine::RoutingEngine; +use routing_engine::settlement_engine::SettlementEngine; +use routing_engine::source::RouteSource; +use routing_engine::token_price::TokenPriceProvider; + +pub struct ServiceController { account_service: Arc, routing_engine: Arc, + settlement_engine: Arc>, + token_chain_map: HashMap>, + chain_supported: Vec<(u32, String)>, + token_supported: Vec, } -impl ServiceController { +impl + ServiceController +{ pub fn new( - account_service: AccountAggregationService, + account_service: Arc, routing_engine: Arc, + settlement_engine: Arc>, + token_chain_map: HashMap>, + chain_supported: Vec<(u32, String)>, + token_supported: Vec, ) -> Self { - Self { account_service: Arc::new(account_service), routing_engine } + Self { + account_service, + routing_engine, + settlement_engine, + token_chain_map, + chain_supported, + token_supported, + } } - pub fn router(self) -> Router { - let account_service = self.account_service.clone(); - let routing_engine = self.routing_engine.clone(); - + pub fn router(&self) -> Router { Router::new() - .route("/", get(ServiceController::status)) - .route("/api/health", get(ServiceController::status)) + .route("/", get(|| async { Self::status().await })) + .route("/api/health", get(|| async { Self::status().await })) .route( "/api/account", get({ - let account_service = account_service.clone(); + // todo: @ankurdubey521 should we use path instead of query here? + // move |Path(account): Path| async move { + // ServiceController::get_account(account_service, account).await + // } + let account_service = self.account_service.clone(); move |Query(query): Query| async move { - ServiceController::get_account(account_service.clone(), query).await + Self::get_account(account_service, query).await } }), ) .route( - "/api/register_account", + "/api/account", axum::routing::post({ - let account_service = account_service.clone(); + let account_service = self.account_service.clone(); move |Json(payload): Json| async move { - ServiceController::register_user_account(account_service.clone(), payload) - .await + Self::register_user_account(account_service, payload).await } }), ) .route( - "/api/add_account", - axum::routing::post({ - let account_service = account_service.clone(); + "/api/account", + axum::routing::patch({ + let account_service = self.account_service.clone(); move |Json(payload): Json| async move { - ServiceController::add_account(account_service.clone(), payload).await + Self::add_account(account_service, payload).await + } + }), + ) + .route( + "/api/config", + get({ + let chain_supported = self.chain_supported.clone(); + let token_supported = self.token_supported.clone(); + move || async move { Self::get_config(chain_supported, token_supported) } + }), + ) + .route( + "/api/balance", + get({ + let account_service = self.account_service.clone(); + move |Query(query): Query| async move { + Self::get_balance(account_service, query).await } }), ) .route( "/api/get_best_path", get({ - let routing_engine = routing_engine.clone(); + let routing_engine = self.routing_engine.clone(); + let settlement_engine = self.settlement_engine.clone(); + let token_chain_map = self.token_chain_map.clone(); + move |Query(query): Query| async move { - ServiceController::get_best_path(routing_engine.clone(), query).await + Self::get_best_path( + routing_engine, + settlement_engine, + token_chain_map, + query, + ) + .await } }), ) @@ -87,10 +136,7 @@ impl ServiceController { (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": err.to_string() }))) } }, - Ok(None) => ( - StatusCode::NOT_FOUND, - Json(json!({ "error": "User not found", "accounts": [query.account] })), - ), + Ok(None) => (StatusCode::NOT_FOUND, Json(json!({ "error": "User not found" }))), Err(err) => { (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": err.to_string() }))) } @@ -131,17 +177,32 @@ impl ServiceController { } } - /// Get best cost path for asset consolidation - pub async fn get_best_path( - routing_engine: Arc, - query: types::PathQuery, + /// Get all supported chains and tokens + pub fn get_config( + chain_supported: Vec<(u32, String)>, + token_supported: Vec, + ) -> impl IntoResponse { + let response = json!({ + "chains": chain_supported, + "tokens": token_supported + }); + (StatusCode::OK, Json(response)) + } + + /// Get user account balance + pub async fn get_balance( + account_service: Arc, + query: types::UserAccountMappingQuery, ) -> impl IntoResponse { - match routing_engine - .get_best_cost_path(&query.account, query.to_chain, &query.to_token, query.to_value) - .await - { - Ok(routes) => { - let response = json!({ "routes": routes }); + match account_service.get_user_accounts_balance(&query.account).await { + Ok(balances) => { + // for loop to add the balance in USD + let total_balance = + balances.iter().fold(0.0, |acc, balance| acc + balance.amount_in_usd); + let response = json!({ + "total_balance": total_balance, + "balances": balances + }); (StatusCode::OK, Json(response)) } Err(err) => { @@ -150,4 +211,52 @@ impl ServiceController { } } } + + /// Get best cost path for asset consolidation + pub async fn get_best_path( + routing_engine: Arc, + settlement_engine: Arc>, + token_chain_map: HashMap>, + query: types::PathQuery, + ) -> impl IntoResponse { + // Check for the supported chain and token + match token_chain_map.get(&query.to_token) { + Some(chain_supported) => match chain_supported.get(&query.to_chain) { + Some(supported) => { + if !supported { + let response = json!({ "error": "Token not supported on chain" }); + return (StatusCode::BAD_REQUEST, Json(response)); + } + } + None => { + let response = json!({ "error": "Chain not supported for token" }); + return (StatusCode::BAD_REQUEST, Json(response)); + } + }, + None => { + let response = json!({ "error": "Token not supported" }); + return (StatusCode::BAD_REQUEST, Json(response)); + } + } + + let routes_result = routing_engine + .get_best_cost_paths(&query.account, query.to_chain, &query.to_token, query.to_value) + .await; + + if let Err(err) = routes_result { + let response = json!({ "error": err.to_string() }); + return (StatusCode::INTERNAL_SERVER_ERROR, Json(response)); + } + + let transactions_result = + settlement_engine.generate_transactions(routes_result.unwrap()).await; + + if let Err(err) = transactions_result { + let response = json!({ "error": err.to_string() }); + return (StatusCode::INTERNAL_SERVER_ERROR, Json(response)); + } + + let response = json!({ "routes": transactions_result.unwrap() }); + (StatusCode::OK, Json(response)) + } } diff --git a/crates/config/src/config.rs b/crates/config/src/config.rs index 7d2dafb..929eb6c 100644 --- a/crates/config/src/config.rs +++ b/crates/config/src/config.rs @@ -1,36 +1,37 @@ use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::Deref; +use std::sync::Arc; use derive_more::{Display, From, Into}; use serde::Deserialize; -use serde_valid::yaml::FromYamlStr; use serde_valid::{UniqueItemsError, Validate, ValidateUniqueItems}; +use serde_valid::yaml::FromYamlStr; // Config Type #[derive(Debug)] pub struct Config { // A bucket is defined for a pair of (source chain, source token) and (destination chain, destination token) // in which the estimation algorithm will be applied. - pub buckets: Vec, + pub buckets: Vec>, // List of all chains and their configurations. - pub chains: HashMap, + pub chains: HashMap>, // List of all tokens and their configurations. - pub tokens: HashMap, + pub tokens: HashMap>, // Bungee API configuration - pub bungee: BungeeConfig, + pub bungee: Arc, // CoinGecko API configuration - pub coingecko: CoinGeckoConfig, + pub coingecko: Arc, // Covalent API configuration - pub covalent: CovalentConfig, + pub covalent: Arc, // Infra Dependencies - pub infra: InfraConfig, + pub infra: Arc, // API Server Configuration - pub server: ServerConfig, + pub server: Arc, // Configuration for the indexer - pub indexer_config: IndexerConfig, + pub indexer_config: Arc, // Configuration for the solver - pub solver_config: SolverConfig, + pub solver_config: Arc, } impl Config { @@ -43,14 +44,14 @@ impl Config { let raw_config = RawConfig::from_yaml_str(s)?; let mut chains = HashMap::new(); for chain in raw_config.chains.0 { - chains.insert(chain.id, chain); + chains.insert(chain.id, Arc::new(chain)); } let mut tokens = HashMap::new(); fn verify_chain( chain_id: u32, - chains: &HashMap, + chains: &HashMap>, ) -> Result<(), ConfigError> { if let Some(chain) = chains.get(&chain_id) { if !chain.is_enabled { @@ -65,7 +66,7 @@ impl Config { fn verify_token( token_symbol: &str, chain_id: u32, - tokens: &HashMap, + tokens: &HashMap>, ) -> Result<(), ConfigError> { if let Some(token) = tokens.get(token_symbol) { if !token.is_enabled { @@ -99,7 +100,7 @@ impl Config { } } - tokens.insert(token.symbol.clone(), token); + tokens.insert(token.symbol.clone(), Arc::new(token)); } // Validate chains and tokens in the bucket configuration @@ -121,14 +122,14 @@ impl Config { Ok(Config { chains, tokens, - buckets: raw_config.buckets.0, - covalent: raw_config.covalent, - bungee: raw_config.bungee, - coingecko: raw_config.coingecko, - infra: raw_config.infra, - server: raw_config.server, - indexer_config: raw_config.indexer_config, - solver_config: raw_config.solver_config, + buckets: raw_config.buckets.0.into_iter().map(Arc::new).collect(), + covalent: Arc::new(raw_config.covalent), + bungee: Arc::new(raw_config.bungee), + coingecko: Arc::new(raw_config.coingecko), + infra: Arc::new(raw_config.infra), + server: Arc::new(raw_config.server), + indexer_config: Arc::new(raw_config.indexer_config), + solver_config: Arc::new(raw_config.solver_config), }) } } @@ -248,7 +249,7 @@ pub struct BucketConfig { // The destination token #[validate(min_length = 1)] pub to_token: String, - // Whether the bucket should only index routes that support smart contracts or just EOAs + // Whether the bucket should only index routes that support smart blockchain or just EOAs pub is_smart_contract_deposit_supported: bool, // Lower bound of the token amount to be transferred from the source chain to the destination chain #[validate(minimum = 1.0)] @@ -298,7 +299,7 @@ impl PartialEq for BucketConfig { impl Eq for BucketConfig {} -#[derive(Debug, Deserialize, Validate)] +#[derive(Debug, Deserialize, Validate, Clone)] pub struct ChainConfig { // The chain id #[validate(minimum = 1)] @@ -311,9 +312,14 @@ pub struct ChainConfig { // The name of the chain in Covalent API #[validate(min_length = 1)] pub covalent_name: String, + // The RPC URL of the chain + #[validate( + pattern = r"https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)" + )] + pub rpc_url: String, } -#[derive(Debug, Deserialize, Validate)] +#[derive(Debug, Deserialize, Validate, Clone)] pub struct TokenConfig { // The token symbol #[validate(min_length = 1)] @@ -328,8 +334,8 @@ pub struct TokenConfig { pub by_chain: TokenConfigByChainConfigs, } -#[derive(Debug, Deserialize, Validate, Into, From)] -pub struct TokenConfigByChainConfigs(HashMap); +#[derive(Debug, Deserialize, Validate, Into, From, Clone)] +pub struct TokenConfigByChainConfigs(pub HashMap); impl ValidateUniqueItems for TokenConfigByChainConfigs { fn validate_unique_items(&self) -> Result<(), UniqueItemsError> { @@ -345,7 +351,7 @@ impl Deref for TokenConfigByChainConfigs { } } -#[derive(Debug, Deserialize, Validate)] +#[derive(Debug, Deserialize, Validate, Clone)] pub struct ChainSpecificTokenConfig { // The number of decimals the token has #[validate(minimum = 1)] @@ -433,7 +439,7 @@ pub struct IndexerConfig { pub points_per_bucket: u64, } -#[derive(Debug, Deserialize, Validate)] +#[derive(Debug, Deserialize, Validate, Clone)] pub struct SolverConfig { #[validate(minimum = 1.0)] pub x_value: f64, @@ -463,10 +469,12 @@ chains: name: Ethereum is_enabled: true covalent_name: eth-mainnet + rpc_url: 'https://mainnet.infura.io/v3/1234567890' - id: 1 name: Ethereum is_enabled: true covalent_name: eth-mainnet + rpc_url: 'https://mainnet.infura.io/v3/1234567890' tokens: buckets: bungee: diff --git a/crates/routing-engine/Cargo.toml b/crates/routing-engine/Cargo.toml index 0e7b475..67dbf6b 100644 --- a/crates/routing-engine/Cargo.toml +++ b/crates/routing-engine/Cargo.toml @@ -20,6 +20,10 @@ log = "0.4.21" account-aggregation = { workspace = true } storage = { workspace = true } config = { workspace = true } +alloy = { version = "0.1.4", features = ["full"] } [dev-dependencies] mockall = "0.12.1" + +[lib] +doctest = false diff --git a/crates/routing-engine/src/blockchain/erc20.rs b/crates/routing-engine/src/blockchain/erc20.rs new file mode 100644 index 0000000..41ce94c --- /dev/null +++ b/crates/routing-engine/src/blockchain/erc20.rs @@ -0,0 +1,91 @@ +use alloy::sol; + +#[cfg(not(doctest))] +sol! { + // SPDX-License-Identifier: MIT + pragma solidity ^0.8.20; + + // node_modules/@openzeppelin/blockchain/token/ERC20/IERC20.sol + + // OpenZeppelin Contracts (last updated v5.0.0) (token/ERC20/IERC20.sol) + + /** + * @dev Interface of the ERC20 standard as defined in the EIP. + */ + #[sol(rpc)] + interface IERC20 { + /** + * @dev Emitted when `value` tokens are moved from one account (`from`) to + * another (`to`). + * + * Note that `value` may be zero. + */ + event Transfer(address indexed from, address indexed to, uint256 value); + + /** + * @dev Emitted when the allowance of a `spender` for an `owner` is set by + * a call to {approve}. `value` is the new allowance. + */ + event Approval(address indexed owner, address indexed spender, uint256 value); + + /** + * @dev Returns the value of tokens in existence. + */ + function totalSupply() external view returns (uint256); + + /** + * @dev Returns the value of tokens owned by `account`. + */ + function balanceOf(address account) external view returns (uint256); + + /** + * @dev Moves a `value` amount of tokens from the caller's account to `to`. + * + * Returns a boolean value indicating whether the operation succeeded. + * + * Emits a {Transfer} event. + */ + function transfer(address to, uint256 value) external returns (bool); + + /** + * @dev Returns the remaining number of tokens that `spender` will be + * allowed to spend on behalf of `owner` through {transferFrom}. This is + * zero by default. + * + * This value changes when {approve} or {transferFrom} are called. + */ + function allowance(address owner, address spender) external view returns (uint256 allowance); + + /** + * @dev Sets a `value` amount of tokens as the allowance of `spender` over the + * caller's tokens. + * + * Returns a boolean value indicating whether the operation succeeded. + * + * IMPORTANT: Beware that changing an allowance with this method brings the risk + * that someone may use both the old and the new allowance by unfortunate + * transaction ordering. One possible solution to mitigate this race + * condition is to first reduce the spender's allowance to 0 and set the + * desired value afterwards: + * https://github.com/ethereum/EIPs/issues/20#issuecomment-263524729 + * + * Emits an {Approval} event. + */ + function approve(address spender, uint256 value) external returns (bool); + + /** + * @dev Moves a `value` amount of tokens from `from` to `to` using the + * allowance mechanism. `value` is then deducted from the caller's + * allowance. + * + * Returns a boolean value indicating whether the operation succeeded. + * + * Emits a {Transfer} event. + */ + function transferFrom(address from, address to, uint256 value) external returns (bool); + } + + // node_modules/@openzeppelin/blockchain/interfaces/IERC20.sol + + // OpenZeppelin Contracts (last updated v5.0.0) (interfaces/IERC20.sol) +} diff --git a/crates/routing-engine/src/blockchain/mod.rs b/crates/routing-engine/src/blockchain/mod.rs new file mode 100644 index 0000000..d2df2a2 --- /dev/null +++ b/crates/routing-engine/src/blockchain/mod.rs @@ -0,0 +1,7 @@ +use alloy::providers::RootProvider; +use alloy::transports::http::Http; +use reqwest::Client; + +pub type ERC20Contract = erc20::IERC20::IERC20Instance, RootProvider>>; + +pub mod erc20; diff --git a/crates/routing-engine/src/indexer.rs b/crates/routing-engine/src/indexer.rs index b30fe23..57277bd 100644 --- a/crates/routing-engine/src/indexer.rs +++ b/crates/routing-engine/src/indexer.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; use futures::stream::StreamExt; use log::{error, info}; @@ -13,38 +14,36 @@ const SOURCE_FETCH_PER_BUCKET_RATE_LIMIT: usize = 10; const BUCKET_PROCESSING_RATE_LIMIT: usize = 5; pub struct Indexer< - 'config, Source: source::RouteSource, ModelStore: storage::KeyValueStore, Producer: storage::MessageQueue, TokenPriceProvider: token_price::TokenPriceProvider, > { - config: &'config config::Config, - source: &'config Source, - model_store: &'config ModelStore, - message_producer: &'config Producer, - token_price_provider: &'config TokenPriceProvider, + config: Arc, + source: Source, + model_store: ModelStore, + message_producer: Producer, + token_price_provider: TokenPriceProvider, } impl< - 'config, RouteSource: source::RouteSource, ModelStore: storage::KeyValueStore, Producer: storage::MessageQueue, TokenPriceProvider: token_price::TokenPriceProvider, - > Indexer<'config, RouteSource, ModelStore, Producer, TokenPriceProvider> + > Indexer { pub fn new( - config: &'config config::Config, - source: &'config RouteSource, - model_store: &'config ModelStore, - message_producer: &'config Producer, - token_price_provider: &'config TokenPriceProvider, + config: Arc, + source: RouteSource, + model_store: ModelStore, + message_producer: Producer, + token_price_provider: TokenPriceProvider, ) -> Self { Indexer { config, source, model_store, message_producer, token_price_provider } } - fn generate_bucket_observation_points(&self, bucket: &BucketConfig) -> Vec { + fn generate_bucket_observation_points(&self, bucket: &Arc) -> Vec { let points_per_bucket = self.config.indexer_config.points_per_bucket; (0..points_per_bucket) .into_iter() @@ -58,9 +57,9 @@ impl< async fn build_estimator<'est_de, Estimator: estimator::Estimator<'est_de, f64, f64>>( &self, - bucket: &'config BucketConfig, + bucket: &Arc, cost_type: &CostType, - ) -> Result> { + ) -> Result> { let bucket_id = bucket.get_hash(); info!("Building estimator for bucket: {:?} with ID: {}", bucket, bucket_id); @@ -85,7 +84,7 @@ impl< let from_token_amount_in_wei = token_price::utils::get_token_amount_from_value_in_usd( &self.config, - self.token_price_provider, + &self.token_price_provider, &bucket.from_token, bucket.from_chain_id, &input_value_in_usd, @@ -94,11 +93,17 @@ impl< .map_err(|err| IndexerErrors::TokenPriceProviderError(err))?; // Get the fee in usd from source - let route = Route::build(bucket, self.config) + let route = Route::build_from_bucket(bucket, &self.config) .map_err(|err| IndexerErrors::RouteBuildError(err))?; - let fee_in_usd = self + let (_, fee_in_usd) = self .source - .fetch_least_route_cost_in_usd(&route, from_token_amount_in_wei, cost_type) + .fetch_least_cost_route_and_cost_in_usd( + &route, + &from_token_amount_in_wei, + None, + None, + cost_type, + ) .await .map_err(|err| IndexerErrors::RouteSourceError(err))?; @@ -156,13 +161,13 @@ impl< if data_points.is_empty() { error!("BucketID-{}: No data points were built", bucket_id); - return Err(BuildEstimatorError::NoDataPoints(bucket)); + return Err(BuildEstimatorError::NoDataPoints(bucket.clone())); } // Build the Estimator info!("BucketID-{}:All data points fetched, building estimator...", bucket_id); let estimator = Estimator::build(data_points) - .map_err(|e| BuildEstimatorError::EstimatorBuildError(bucket, e))?; + .map_err(|e| BuildEstimatorError::EstimatorBuildError(bucket.clone(), e))?; Ok(estimator) } @@ -202,28 +207,27 @@ impl< pub async fn run<'est_de, Estimator: estimator::Estimator<'est_de, f64, f64>>( &self, ) -> Result< - HashMap<&'config BucketConfig, Estimator>, + HashMap<&BucketConfig, Estimator>, IndexerErrors<'est_de, TokenPriceProvider, RouteSource, ModelStore, Producer, Estimator>, > { info!("Running Indexer"); // Build Estimators - let (estimators, failed_estimators): (Vec<_>, Vec<_>) = futures::stream::iter( - self.config.buckets.iter(), - ) - .map(|bucket: &_| async { - // Build the Estimator - let estimator = self.build_estimator(bucket, &CostType::Fee).await?; - - Ok::<(&BucketConfig, Estimator), BuildEstimatorError<'config, 'est_de, Estimator>>(( - bucket, estimator, - )) - }) - .buffer_unordered(BUCKET_PROCESSING_RATE_LIMIT) - .collect::>() - .await - .into_iter() - .partition(|r| r.is_ok()); + let (estimators, failed_estimators): (Vec<_>, Vec<_>) = + futures::stream::iter(self.config.buckets.iter()) + .map(|bucket: &_| async { + // Build the Estimator + let estimator = self.build_estimator(bucket, &CostType::Fee).await?; + + Ok::<(&BucketConfig, Estimator), BuildEstimatorError<'est_de, Estimator>>(( + bucket, estimator, + )) + }) + .buffer_unordered(BUCKET_PROCESSING_RATE_LIMIT) + .collect::>() + .await + .into_iter() + .partition(|r| r.is_ok()); let estimator_map: HashMap<&BucketConfig, Estimator> = estimators.into_iter().map(|r| r.unwrap()).collect(); @@ -288,25 +292,28 @@ pub enum IndexerErrors< } #[derive(Debug, Error)] -pub enum BuildEstimatorError<'config, 'est_de, Estimator: estimator::Estimator<'est_de, f64, f64>> { +pub enum BuildEstimatorError<'est_de, Estimator: estimator::Estimator<'est_de, f64, f64>> { #[error("No data points found while building estimator for {:?}", _0)] - NoDataPoints(&'config BucketConfig), + NoDataPoints(Arc), #[error("Estimator build error: {} for bucket {:?}", _1, _0)] - EstimatorBuildError(&'config BucketConfig, Estimator::Error), + EstimatorBuildError(Arc, Estimator::Error), } #[cfg(test)] mod tests { + use std::collections::HashMap; use std::env; use std::fmt::Error; + use std::sync::Arc; use std::time::Duration; + use async_trait::async_trait; use derive_more::Display; use thiserror::Error; use config::{Config, get_sample_config}; - use storage::{ControlFlow, KeyValueStore, MessageQueue, Msg}; + use storage::{ControlFlow, KeyValueStore, MessageQueue, Msg, RedisClientError}; use crate::{BungeeClient, CostType}; use crate::estimator::{Estimator, LinearRegressionEstimator}; @@ -318,6 +325,8 @@ mod tests { #[derive(Debug)] struct ModelStoreStub; + + #[async_trait] impl KeyValueStore for ModelStoreStub { type Error = Err; @@ -330,16 +339,26 @@ mod tests { } async fn set(&self, _: &String, _: &String, _: Duration) -> Result<(), Self::Error> { - Ok(()) + todo!() } async fn set_multiple(&self, _: &Vec<(String, String)>) -> Result<(), Self::Error> { - Ok(()) + todo!() + } + + async fn get_all_keys(&self) -> Result, RedisClientError> { + todo!() + } + + async fn get_all_key_values(&self) -> Result, RedisClientError> { + todo!() } } #[derive(Debug)] struct ProducerStub; + + #[async_trait] impl MessageQueue for ProducerStub { type Error = Err; @@ -358,6 +377,8 @@ mod tests { #[derive(Debug)] struct TokenPriceProviderStub; + + #[async_trait] impl TokenPriceProvider for TokenPriceProviderStub { type Error = Error; @@ -369,7 +390,7 @@ mod tests { fn setup<'a>() -> (Config, BungeeClient, ModelStoreStub, ProducerStub, TokenPriceProviderStub) { let mut config = get_sample_config(); config.buckets = vec![ - config::BucketConfig { + Arc::new(config::BucketConfig { from_chain_id: 1, to_chain_id: 42161, from_token: "USDC".to_string(), @@ -377,8 +398,8 @@ mod tests { is_smart_contract_deposit_supported: false, token_amount_from_usd: 10.0, token_amount_to_usd: 100.0, - }, - config::BucketConfig { + }), + Arc::new(config::BucketConfig { from_chain_id: 1, to_chain_id: 42161, from_token: "USDC".to_string(), @@ -386,10 +407,13 @@ mod tests { is_smart_contract_deposit_supported: false, token_amount_from_usd: 100.0, token_amount_to_usd: 1000.0, - }, + }), ]; - config.bungee.api_key = env::var("BUNGEE_API_KEY").unwrap(); + config.bungee = Arc::new(config::BungeeConfig { + base_url: config.bungee.base_url.clone(), + api_key: env::var("BUNGEE_API_KEY").unwrap(), + }); let bungee_client = BungeeClient::new(&config.bungee.base_url, &config.bungee.api_key).unwrap(); @@ -402,19 +426,14 @@ mod tests { #[tokio::test] async fn test_build_estimator() { - let ( - config, - bungee_client, - mut model_store, - mut message_producer, - mut token_price_provider, - ) = setup(); + let (config, bungee_client, model_store, message_producer, token_price_provider) = setup(); + let config = Arc::new(config); let indexer = Indexer::new( - &config, - &bungee_client, - &mut model_store, - &mut message_producer, - &mut token_price_provider, + Arc::clone(&config), + bungee_client, + model_store, + message_producer, + token_price_provider, ); let estimator = indexer.build_estimator(&config.buckets[0], &CostType::Fee).await; diff --git a/crates/routing-engine/src/lib.rs b/crates/routing-engine/src/lib.rs index 960010c..9316496 100644 --- a/crates/routing-engine/src/lib.rs +++ b/crates/routing-engine/src/lib.rs @@ -1,17 +1,24 @@ +use std::sync::Arc; + +pub use alloy::providers::Provider; +pub use alloy::transports::Transport; use derive_more::Display; use thiserror::Error; -use config::config::{BucketConfig, ChainConfig, Config, TokenConfig}; +use config::{ChainConfig, TokenConfig}; +use config::config::{BucketConfig, Config}; pub use indexer::Indexer; pub use source::bungee::BungeeClient; pub use token_price::CoingeckoClient; -pub mod engine; +pub mod routing_engine; pub mod token_price; +pub mod blockchain; pub mod estimator; pub mod indexer; -mod source; +pub mod settlement_engine; +pub mod source; #[derive(Debug, Error, Display)] pub enum CostType { @@ -20,44 +27,62 @@ pub enum CostType { } #[derive(Debug)] -pub struct Route<'a> { - from_chain: &'a ChainConfig, - to_chain: &'a ChainConfig, - from_token: &'a TokenConfig, - to_token: &'a TokenConfig, +pub struct Route { + from_chain: Arc, + to_chain: Arc, + from_token: Arc, + to_token: Arc, is_smart_contract_deposit: bool, } -impl<'a> Route<'a> { - pub fn build(bucket: &'a BucketConfig, config: &'a Config) -> Result, RouteError> { - let from_chain = config.chains.get(&bucket.from_chain_id); +impl Route { + pub fn build( + config: &Config, + from_chain_id: &u32, + to_chain_id: &u32, + from_token_id: &String, + to_token_id: &String, + is_smart_contract_deposit: bool, + ) -> Result { + let from_chain = config.chains.get(from_chain_id); if from_chain.is_none() { - return Err(RouteError::ChainNotFoundError(bucket.from_chain_id)); + return Err(RouteError::ChainNotFoundError(*from_chain_id)); } - let to_chain = config.chains.get(&bucket.to_chain_id); + let to_chain = config.chains.get(to_chain_id); if to_chain.is_none() { - return Err(RouteError::ChainNotFoundError(bucket.to_chain_id)); + return Err(RouteError::ChainNotFoundError(*to_chain_id)); } - let from_token = config.tokens.get(&bucket.from_token); + let from_token = config.tokens.get(from_token_id); if from_token.is_none() { - return Err(RouteError::TokenNotFoundError(bucket.from_token.clone())); + return Err(RouteError::TokenNotFoundError(from_token_id.clone())); } - let to_token = config.tokens.get(&bucket.to_token); + let to_token = config.tokens.get(to_token_id); if to_token.is_none() { - return Err(RouteError::TokenNotFoundError(bucket.to_token.clone())); + return Err(RouteError::TokenNotFoundError(to_token_id.clone())); } Ok(Route { - from_chain: from_chain.unwrap(), - to_chain: to_chain.unwrap(), - from_token: from_token.unwrap(), - to_token: to_token.unwrap(), - is_smart_contract_deposit: bucket.is_smart_contract_deposit_supported, + from_chain: Arc::clone(from_chain.unwrap()), + to_chain: Arc::clone(to_chain.unwrap()), + from_token: Arc::clone(from_token.unwrap()), + to_token: Arc::clone(to_token.unwrap()), + is_smart_contract_deposit, }) } + + pub fn build_from_bucket(bucket: &BucketConfig, config: &Config) -> Result { + Self::build( + config, + &bucket.from_chain_id, + &bucket.to_chain_id, + &bucket.from_token, + &bucket.to_token, + bucket.is_smart_contract_deposit_supported, + ) + } } #[derive(Debug, Error)] @@ -68,3 +93,81 @@ pub enum RouteError { #[error("Token not found while building route: {}", _0)] TokenNotFoundError(String), } + +#[derive(Debug)] +pub struct BridgeResult { + route: Route, + source_amount_in_usd: f64, + from_address: String, + to_address: String, +} + +impl BridgeResult { + pub fn build( + config: &Config, + from_chain_id: &u32, + to_chain_id: &u32, + from_token_id: &String, + to_token_id: &String, + is_smart_contract_deposit: bool, + source_amount_in_usd: f64, + from_address: String, + to_address: String, + ) -> Result { + Ok(BridgeResult { + route: Route::build( + config, + from_chain_id, + to_chain_id, + from_token_id, + to_token_id, + is_smart_contract_deposit, + )?, + source_amount_in_usd, + from_address, + to_address, + }) + } +} + +#[cfg(test)] +mod test { + use config::get_sample_config; + + fn assert_is_send(_: impl Send) {} + + #[test] + fn test_route_must_be_send() { + let config = get_sample_config(); + let route = super::Route::build( + &config, + &1, + &42161, + &"USDC".to_string(), + &"USDT".to_string(), + false, + ) + .unwrap(); + + assert_is_send(route); + } + + #[test] + fn test_bridge_result_must_be_send() { + let config = get_sample_config(); + let bridge_result = super::BridgeResult::build( + &config, + &1, + &42161, + &"USDC".to_string(), + &"USDT".to_string(), + false, + 100.0, + "0x123".to_string(), + "0x456".to_string(), + ) + .unwrap(); + + assert_is_send(bridge_result); + } +} diff --git a/crates/routing-engine/src/engine.rs b/crates/routing-engine/src/routing_engine.rs similarity index 54% rename from crates/routing-engine/src/engine.rs rename to crates/routing-engine/src/routing_engine.rs index 72ab416..08ae098 100644 --- a/crates/routing-engine/src/engine.rs +++ b/crates/routing-engine/src/routing_engine.rs @@ -1,34 +1,19 @@ use std::collections::HashMap; use std::sync::Arc; -use derive_more::Display; use futures::stream::{self, StreamExt}; use log::{debug, error, info}; -use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::sync::RwLock; -use account_aggregation::service::AccountAggregationService; -use account_aggregation::types::Balance; -use config::{config::BucketConfig, SolverConfig}; -use storage::{RedisClient, RedisClientError}; - -use crate::estimator::{Estimator, LinearRegressionEstimator}; - -#[derive(Serialize, Deserialize, Debug, Display, PartialEq, Clone)] -#[display( - "Route: from_chain: {}, to_chain: {}, token: {}, amount: {}", - from_chain, - to_chain, - token, - amount -)] -pub struct Route { - pub from_chain: u32, - pub to_chain: u32, - pub token: String, - pub amount: f64, -} +use account_aggregation::{service::AccountAggregationService, types::TokenWithBalance}; +use config::{ChainConfig, config::BucketConfig, SolverConfig, TokenConfig}; +use storage::{KeyValueStore, RedisClient, RedisClientError}; + +use crate::{ + BridgeResult, + estimator::{Estimator, LinearRegressionEstimator}, Route, +}; /// (from_chain, to_chain, from_token, to_token) #[derive(Debug)] @@ -53,25 +38,38 @@ pub enum RoutingEngineError { /// This struct is responsible for calculating the best cost path for a user #[derive(Debug)] pub struct RoutingEngine { - buckets: Vec, - aas_client: AccountAggregationService, + buckets: Vec>, + aas_client: Arc, cache: Arc>>, // (hash(bucket), hash(estimator_value) redis_client: RedisClient, - estimates: SolverConfig, + estimates: Arc, + chain_configs: HashMap>, + token_configs: HashMap>, } impl RoutingEngine { pub fn new( - aas_client: AccountAggregationService, - buckets: Vec, + aas_client: Arc, + buckets: Vec>, redis_client: RedisClient, - solver_config: SolverConfig, + solver_config: Arc, + chain_configs: HashMap>, + token_configs: HashMap>, ) -> Self { let cache = Arc::new(RwLock::new(HashMap::new())); - Self { aas_client, cache, buckets, redis_client, estimates: solver_config } + Self { + aas_client, + cache, + buckets, + redis_client, + estimates: solver_config, + chain_configs, + token_configs, + } } + /// Refresh the cache from Redis pub async fn refresh_cache(&self) { match self.redis_client.get_all_key_values().await { Ok(kv_pairs) => { @@ -90,15 +88,15 @@ impl RoutingEngine { } } - /// Get the best cost path for a user + /// Get the best cost path for a user. /// This function will get the user balances from the aas and then calculate the best cost path for the user - pub async fn get_best_cost_path( + pub async fn get_best_cost_paths( &self, account: &str, to_chain: u32, to_token: &str, to_value: f64, - ) -> Result, RoutingEngineError> { + ) -> Result, RoutingEngineError> { debug!( "Getting best cost path for user: {}, to_chain: {}, to_token: {}, to_value: {}", account, to_chain, to_token, to_value @@ -107,43 +105,91 @@ impl RoutingEngine { debug!("User balances: {:?}", user_balances); // todo: for account aggregation, transfer same chain same asset first - let direct_assets: Vec<_> = - user_balances.iter().filter(|balance| balance.token == to_token).collect(); + let (direct_assets, non_direct_assets): (Vec<_>, _) = + user_balances.into_iter().partition(|balance| balance.token == to_token); debug!("Direct assets: {:?}", direct_assets); + debug!("Non-direct assets: {:?}", non_direct_assets); + + let (mut selected_routes, total_amount_needed, mut total_cost) = self + .generate_optimal_routes(direct_assets, to_chain, to_token, to_value, account) + .await?; + + // Handle swap/bridge for remaining amount if needed (non-direct assets) + if total_amount_needed > 0.0 { + let (swap_routes, _, swap_total_cost) = self + .generate_optimal_routes( + non_direct_assets, + to_chain, + to_token, + total_amount_needed, + account, + ) + .await?; + + selected_routes.extend(swap_routes); + total_cost += swap_total_cost; + } - // Sort direct assets by A^x / C^y, here x=2 and y=1 + debug!("Selected assets: {:?}", selected_routes); + info!( + "Total cost for user: {} on chain {} to token {} is {}", + account, to_chain, to_token, total_cost + ); + + Ok(selected_routes) + } + + async fn generate_optimal_routes( + &self, + assets: Vec, + to_chain: u32, + to_token: &str, + to_value_usd: f64, + to_address: &str, + ) -> Result<(Vec, f64, f64), RoutingEngineError> { + // Sort direct assets by Balance^x / Fee_Cost^y, here x=2 and y=1 let x = self.estimates.x_value; let y = self.estimates.y_value; - let mut sorted_assets: Vec<(&Balance, f64)> = stream::iter(direct_assets.into_iter()) - .then(|balance| async move { - let fee_cost = self - .get_cached_data( - balance.amount_in_usd, - PathQuery( - balance.chain_id, - to_chain, - balance.token.to_string(), - to_token.to_string(), - ), - ) - .await - .unwrap_or_default(); - (balance, fee_cost) - }) - .collect() - .await; + let mut assets_sorted_by_bridging_cost: Vec<(TokenWithBalance, f64)> = + stream::iter(assets.into_iter()) + .then(|balance| async move { + let fee_cost = self + .estimate_bridging_cost( + balance.amount_in_usd, + PathQuery( + balance.chain_id, + to_chain, + balance.token.to_string(), + to_token.to_string(), + ), + ) + .await; + (balance, fee_cost) + }) + .collect::>() + .await + .into_iter() + .filter_map(|(balance, cost)| match cost { + Ok(cost) => Some((balance, cost)), + Err(e) => { + error!("Failed to estimate bridging cost for balance {:?}: {}", balance, e); + None + } + }) + .collect(); - sorted_assets.sort_by(|a, b| { + // Greedily select bridging routes that + assets_sorted_by_bridging_cost.sort_by(|a, b| { let cost_a = (a.0.amount.powf(x)) / (a.1.powf(y)); let cost_b = (b.0.amount.powf(x)) / (b.1.powf(y)); cost_a.partial_cmp(&cost_b).unwrap() }); let mut total_cost = 0.0; - let mut total_amount_needed = to_value; - let mut selected_assets: Vec = Vec::new(); + let mut total_amount_needed = to_value_usd; + let mut selected_routes: Vec = Vec::new(); - for (balance, fee) in sorted_assets { + for (balance, fee) in assets_sorted_by_bridging_cost { if total_amount_needed <= 0.0 { break; } @@ -155,98 +201,40 @@ impl RoutingEngine { total_amount_needed -= amount_to_take; total_cost += fee; - selected_assets.push(Route { - from_chain: balance.chain_id, + selected_routes.push(self.build_bridging_route( + balance.chain_id, to_chain, - token: balance.token.clone(), - amount: amount_to_take, - }); + &balance.token, + to_token, + amount_to_take, + false, + &balance.address, + &to_address, + )?); } - // Handle swap/bridge for remaining amount if needed (non direct assets) - if total_amount_needed > 0.0 { - let swap_assets: Vec<&Balance> = - user_balances.iter().filter(|balance| balance.token != to_token).collect(); - let mut sorted_assets: Vec<(&Balance, f64)> = stream::iter(swap_assets.into_iter()) - .then(|balance| async move { - let fee_cost = self - .get_cached_data( - balance.amount_in_usd, - PathQuery( - balance.chain_id, - to_chain, - balance.token.clone(), - to_token.to_string(), - ), - ) - .await - .unwrap_or_default(); - (balance, fee_cost) - }) - .collect() - .await; - - sorted_assets.sort_by(|a, b| { - let cost_a = (a.0.amount.powf(x)) / (a.1.powf(y)); - let cost_b = (b.0.amount.powf(x)) / (b.1.powf(y)); - cost_a.partial_cmp(&cost_b).unwrap() - }); - - for (balance, fee_cost) in sorted_assets { - if total_amount_needed <= 0.0 { - break; - } - - let amount_to_take = if balance.amount_in_usd >= total_amount_needed { - total_amount_needed - } else { - balance.amount_in_usd - }; - - total_amount_needed -= amount_to_take; - total_cost += fee_cost; - - selected_assets.push(Route { - from_chain: balance.chain_id, - to_chain, - token: balance.token.clone(), - amount: amount_to_take, - }); - } - } - - debug!("Selected assets: {:?}", selected_assets); - info!( - "Total cost for user: {} on chain {} to token {} is {}", - account, to_chain, to_token, total_cost - ); - - Ok(selected_assets) + Ok((selected_routes, total_amount_needed, total_cost)) } - async fn get_cached_data( + async fn estimate_bridging_cost( &self, - target_amount: f64, + target_amount_in_usd: f64, path: PathQuery, ) -> Result { - let mut buckets_array: Vec = self + // TODO: Maintain sorted list cache in cache, binary search + let bucket = self .buckets - .clone() - .into_iter() - .filter(|bucket| { - bucket.from_chain_id == path.0 + .iter() + .find(|&bucket| { + let matches_path = bucket.from_chain_id == path.0 && bucket.to_chain_id == path.1 && bucket.from_token == path.2 - && bucket.to_token == path.3 - }) - .collect(); - buckets_array.sort(); + && bucket.to_token == path.3; - let bucket = buckets_array - .iter() - .find(|window| { - target_amount >= window.token_amount_from_usd - && target_amount <= window.token_amount_to_usd + let matches_amount = target_amount_in_usd >= bucket.token_amount_from_usd + && target_amount_in_usd <= bucket.token_amount_to_usd; + + matches_path && matches_amount }) .ok_or_else(|| { RoutingEngineError::CacheError("No matching bucket found".to_string()) @@ -260,19 +248,65 @@ impl RoutingEngine { .ok_or_else(|| RoutingEngineError::CacheError("No cached value found".to_string()))?; let estimator: LinearRegressionEstimator = serde_json::from_str(value)?; - Ok(estimator.estimate(target_amount)) + Ok(estimator.estimate(target_amount_in_usd)) } /// Get user balance from account aggregation service async fn get_user_balance_from_agg_service( &self, account: &str, - ) -> Result, RoutingEngineError> { - // Note: aas should always return vec of balances - self.aas_client + ) -> Result, RoutingEngineError> { + let balance = self + .aas_client .get_user_accounts_balance(&account.to_string()) .await - .map_err(|e| RoutingEngineError::UserBalanceFetchError(e.to_string())) + .map_err(|e| RoutingEngineError::UserBalanceFetchError(e.to_string()))?; + + let balance: Vec<_> = balance + .into_iter() + .filter(|balance| { + self.chain_configs.contains_key(&balance.chain_id) + && self.token_configs.contains_key(&balance.token) + }) + .collect(); + + debug!("User balance: {:?}", balance); + Ok(balance) + } + + fn build_bridging_route( + &self, + from_chain_id: u32, + to_chain_id: u32, + from_token_id: &str, + to_token_id: &str, + token_amount_in_usd: f64, + is_smart_contract_deposit: bool, + from_address: &str, + to_address: &str, + ) -> Result { + let from_chain = Arc::clone(self.chain_configs.get(&from_chain_id).ok_or_else(|| { + RoutingEngineError::CacheError(format!( + "Chain config not found for ID {}", + from_chain_id + )) + })?); + let to_chain = Arc::clone(self.chain_configs.get(&to_chain_id).ok_or_else(|| { + RoutingEngineError::CacheError(format!("Chain config not found for ID {}", to_chain_id)) + })?); + let from_token = Arc::clone(self.token_configs.get(from_token_id).ok_or_else(|| { + RoutingEngineError::CacheError(format!("Token config not found for {}", from_token_id)) + })?); + let to_token = Arc::clone(self.token_configs.get(to_token_id).ok_or_else(|| { + RoutingEngineError::CacheError(format!("Token config not found for {}", to_token_id)) + })?); + + Ok(BridgeResult { + route: Route { from_chain, to_chain, from_token, to_token, is_smart_contract_deposit }, + source_amount_in_usd: token_amount_in_usd, + from_address: from_address.to_string(), + to_address: to_address.to_string(), + }) } } @@ -285,21 +319,21 @@ mod tests { use tokio::sync::RwLock; use account_aggregation::service::AccountAggregationService; - use config::{BucketConfig, SolverConfig}; + use config::{BucketConfig, ChainConfig, SolverConfig, TokenConfig, TokenConfigByChainConfigs}; use storage::mongodb_client::MongoDBClient; - use crate::engine::PathQuery; - use crate::estimator::Estimator; use crate::{ - engine::{RoutingEngine, RoutingEngineError}, estimator::{DataPoint, LinearRegressionEstimator}, + routing_engine::{RoutingEngine, RoutingEngineError}, }; + use crate::estimator::Estimator; + use crate::routing_engine::PathQuery; #[tokio::test] async fn test_get_cached_data() -> Result<(), RoutingEngineError> { // Create dummy buckets let buckets = vec![ - BucketConfig { + Arc::new(BucketConfig { from_chain_id: 1, to_chain_id: 2, from_token: "USDC".to_string(), @@ -307,8 +341,8 @@ mod tests { is_smart_contract_deposit_supported: false, token_amount_from_usd: 1.0, token_amount_to_usd: 10.0, - }, - BucketConfig { + }), + Arc::new(BucketConfig { from_chain_id: 1, to_chain_id: 2, from_token: "USDC".to_string(), @@ -316,7 +350,7 @@ mod tests { is_smart_contract_deposit_supported: false, token_amount_from_usd: 10.0, token_amount_to_usd: 100.0, - }, + }), ]; // Create a dummy estimator and serialize it @@ -343,25 +377,26 @@ mod tests { .await .unwrap(); - let aas_client = AccountAggregationService::new( + let aas_client = Arc::new(AccountAggregationService::new( user_db_provider.clone(), user_db_provider.clone(), vec!["eth-mainnet".to_string()], "https://api.covalent.com".to_string(), "my-api".to_string(), - ); + )); let redis_client = storage::RedisClient::build(&"redis://localhost:6379".to_string()).await.unwrap(); - let estimates = SolverConfig { - x_value: 2.0, - y_value: 1.0, - }; + let estimates = Arc::new(SolverConfig { x_value: 2.0, y_value: 1.0 }); + let chain_configs = HashMap::new(); + let token_configs = HashMap::new(); let routing_engine = RoutingEngine { aas_client, buckets, cache: Arc::new(RwLock::new(cache)), redis_client, estimates, + chain_configs, + token_configs, }; // Define the target amount and path query @@ -369,7 +404,7 @@ mod tests { let path_query = PathQuery(1, 2, "USDC".to_string(), "ETH".to_string()); // Call get_cached_data and assert the result - let result = routing_engine.get_cached_data(target_amount, path_query).await?; + let result = routing_engine.estimate_bridging_cost(target_amount, path_query).await?; assert!(result > 0.0); assert_eq!(result, dummy_estimator.estimate(target_amount)); Ok(()) @@ -391,16 +426,16 @@ mod tests { ) .await .unwrap(); - let aas_client = AccountAggregationService::new( + let aas_client = Arc::new(AccountAggregationService::new( user_db_provider.clone(), user_db_provider.clone(), vec!["bsc-mainnet".to_string()], "https://api.covalenthq.com".to_string(), api_key, - ); + )); let buckets = vec![ - BucketConfig { + Arc::new(BucketConfig { from_chain_id: 56, to_chain_id: 2, from_token: "USDT".to_string(), @@ -408,8 +443,8 @@ mod tests { is_smart_contract_deposit_supported: false, token_amount_from_usd: 0.0, token_amount_to_usd: 5.0, - }, - BucketConfig { + }), + Arc::new(BucketConfig { from_chain_id: 56, to_chain_id: 2, from_token: "USDT".to_string(), @@ -417,7 +452,7 @@ mod tests { is_smart_contract_deposit_supported: false, token_amount_from_usd: 5.0, token_amount_to_usd: 100.0, - }, + }), ]; // Create a dummy estimator and serialize it let dummy_estimator = LinearRegressionEstimator::build(vec![ @@ -436,22 +471,50 @@ mod tests { let redis_client = storage::RedisClient::build(&"redis://localhost:6379".to_string()).await.unwrap(); - let estimates = SolverConfig { - x_value: 2.0, - y_value: 1.0, - }; + let estimates = Arc::new(SolverConfig { x_value: 2.0, y_value: 1.0 }); + let chain_config1 = Arc::new(ChainConfig { + id: 56, + name: "bsc-mainnet".to_string(), + is_enabled: true, + covalent_name: "bsc-mainnet".to_string(), + rpc_url: "https://bsc-dataseed.binance.org".to_string(), + }); + let chain_config2 = Arc::new(ChainConfig { + id: 2, + name: "eth-mainnet".to_string(), + is_enabled: true, + covalent_name: "ethereum".to_string(), + rpc_url: "https://mainnet.infura.io/v3/".to_string(), + }); + let mut chain_configs = HashMap::new(); + chain_configs.insert(56, chain_config1); + chain_configs.insert(2, chain_config2); + + let token_config = Arc::new(TokenConfig { + symbol: "USDT".to_string(), + coingecko_symbol: "USDT".to_string(), + is_enabled: true, + by_chain: TokenConfigByChainConfigs(HashMap::new()), + }); + let mut token_configs = HashMap::new(); + token_configs.insert("USDT".to_string(), token_config); + let routing_engine = RoutingEngine { aas_client, buckets, cache: Arc::new(RwLock::new(cache)), redis_client, estimates, + chain_configs, + token_configs, }; // should have USDT in bsc-mainnet > $0.5 let dummy_user_address = "0x00000ebe3fa7cb71aE471547C836E0cE0AE758c2"; - let result = routing_engine.get_best_cost_path(dummy_user_address, 2, "USDT", 0.5).await?; + let result = routing_engine.get_best_cost_paths(dummy_user_address, 2, "USDT", 0.5).await?; assert_eq!(result.len(), 1); + assert!(result[0].source_amount_in_usd >= 0.5); + assert!(result[0].from_address == dummy_user_address); Ok(()) } } diff --git a/crates/routing-engine/src/settlement_engine.rs b/crates/routing-engine/src/settlement_engine.rs new file mode 100644 index 0000000..88eb17f --- /dev/null +++ b/crates/routing-engine/src/settlement_engine.rs @@ -0,0 +1,759 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use alloy::hex::FromHexError; +use alloy::providers::{ProviderBuilder, RootProvider}; +use alloy::transports::http::Http; +use futures::StreamExt; +use log::{error, info}; +use reqwest::{Client, Url}; +use ruint::Uint; +use serde::Serialize; +use thiserror::Error; + +use config::Config; + +use crate::{blockchain, BridgeResult}; +use crate::blockchain::erc20::IERC20::IERC20Instance; +use crate::source::{EthereumTransaction, RequiredApprovalDetails, RouteSource}; +use crate::token_price::TokenPriceProvider; +use crate::token_price::utils::{Errors, get_token_amount_from_value_in_usd}; + +pub struct SettlementEngine { + source: Source, + config: Arc, + price_provider: PriceProvider, + // (chain_id, token_address) -> contract + erc20_instance_map: HashMap<(u32, String), blockchain::ERC20Contract>, +} + +#[derive(Debug, PartialEq, Serialize)] +pub enum TransactionType { + Approval, + Bungee, +} + +#[derive(Debug, Serialize)] +pub struct TransactionWithType { + transaction: EthereumTransaction, + transaction_type: TransactionType, +} + +const GENERATE_TRANSACTIONS_CONCURRENCY: usize = 10; + +impl + SettlementEngine +{ + pub fn new( + config: Arc, + source: Source, + price_provider: PriceProvider, + erc20_instance_map: HashMap<(u32, String), blockchain::ERC20Contract>, + ) -> Self { + SettlementEngine { source, config, price_provider, erc20_instance_map } + } + + pub async fn generate_transactions( + &self, + routes: Vec, + ) -> Result, SettlementEngineErrors> { + info!("Generating transactions for routes: {:?}", routes); + + let (results, failed): ( + Vec< + Result< + (Vec, Vec), + SettlementEngineErrors, + >, + >, + _, + ) = futures::stream::iter(routes.into_iter()) + .map(|route| async move { + info!("Generating transactions for route: {:?}", route.route); + + let token_amount = get_token_amount_from_value_in_usd( + &self.config, + &self.price_provider, + &route.route.from_token.symbol, + route.route.from_chain.id, + &route.source_amount_in_usd, + ) + .await + .map_err(|err| SettlementEngineErrors::GetTokenAmountFromValueInUsdError(err))?; + + info!("Token amount: {:?} for route {:?}", token_amount, route); + + let (ethereum_transactions, required_approval_details) = self + .source + .generate_route_transactions( + &route.route, + &token_amount, + &route.from_address, + &route.to_address, + ) + .await + .map_err(|err| SettlementEngineErrors::GenerateTransactionsError(err))?; + + info!("Generated transactions: {:?} for route {:?}", ethereum_transactions, route); + + Ok::<_, SettlementEngineErrors<_, _>>(( + ethereum_transactions, + required_approval_details, + )) + }) + .buffer_unordered(GENERATE_TRANSACTIONS_CONCURRENCY) + .collect::>() + .await + .into_iter() + .partition(Result::is_ok); + + let failed: Vec<_> = failed.into_iter().map(Result::unwrap_err).collect(); + if !failed.is_empty() { + error!("Failed to generate transactions: {:?}", failed); + } + + if results.is_empty() { + error!("No transactions generated"); + return Err(SettlementEngineErrors::NoTransactionsGenerated); + } + + let (bridge_transactions, required_approval_details): (Vec>, Vec>) = + results.into_iter().map(Result::unwrap).unzip(); + + let bridge_transactions: Vec<_> = bridge_transactions + .into_iter() + .flatten() + .map(|t| TransactionWithType { + transaction: t, + transaction_type: TransactionType::Bungee, + }) + .collect(); + + let required_approval_details: Vec<_> = + required_approval_details.into_iter().flatten().collect(); + let required_approval_transactions = + self.generate_transactions_for_approvals(&required_approval_details).await?; + + info!("Generated transactions: {:?}", bridge_transactions); + info!("Required approvals: {:?}", required_approval_details); + + let final_transactions = vec![required_approval_transactions, bridge_transactions] + .into_iter() + .flatten() + .collect::>(); + + info!("Final Transactions: {:?}", final_transactions); + + Ok(final_transactions) + } + + async fn generate_transaction_for_approval( + &self, + required_approval_details: &RequiredApprovalDetails, + ) -> Result, SettlementEngineErrors> { + info!("Generating transaction for approval: {:?}", required_approval_details); + + let token_instance = self.erc20_instance_map.get(&( + required_approval_details.chain_id, + required_approval_details.token_address.clone(), + )); + + if token_instance.is_none() { + error!( + "ERC20 Utils not found for chain_id: {} and token_address: {}", + required_approval_details.chain_id, required_approval_details.token_address + ); + return Err(SettlementEngineErrors::ERC20UtilsNotFound( + required_approval_details.chain_id.clone(), + required_approval_details.token_address.clone(), + )); + } + let token_instance = token_instance.unwrap(); + + let owner = (&required_approval_details.owner) + .parse() + .map_err(SettlementEngineErrors::InvalidAddressError)?; + let spender = (&required_approval_details.target) + .parse() + .map_err(SettlementEngineErrors::InvalidAddressError)?; + + let current_approval = token_instance.allowance(owner, spender).call().await?.allowance; + + info!( + "Current approval: {} on chain against requirement: {:?}", + current_approval, required_approval_details + ); + + if current_approval >= required_approval_details.amount { + info!("Sufficient Approval already exists for: {:?}", required_approval_details); + return Ok(None); + } + + let required_approval = required_approval_details.amount - current_approval; + info!( + "Required Approval: {:?} against requirement: {:?}", + required_approval, required_approval_details + ); + + let calldata = token_instance.approve(spender, required_approval).calldata().to_string(); + + Ok(Some(TransactionWithType { + transaction: EthereumTransaction { + from: required_approval_details.owner.clone(), + to: token_instance.address().to_string(), + value: Uint::ZERO, + calldata, + }, + transaction_type: TransactionType::Approval, + })) + } + + async fn generate_transactions_for_approvals( + &self, + approvals: &Vec, + ) -> Result, SettlementEngineErrors> { + info!("Generating transactions for approvals: {:?}", approvals); + + // Group the approvals and combine them based on chain_id, token_address, spender and target + let mut approvals_grouped = + HashMap::<(u32, &String, &String, &String), Vec<&RequiredApprovalDetails>>::new(); + for approval in approvals { + let key = + (approval.chain_id, &approval.token_address, &approval.owner, &approval.target); + let arr = approvals_grouped.get_mut(&key); + if arr.is_none() { + approvals_grouped.insert(key, vec![&approval]); + } else { + arr.unwrap().push(&approval); + } + } + + // Merge the approvals with the same key + let merged_approvals: Vec = approvals_grouped + .into_iter() + .map(|(_, approvals)| { + // If there's only one approval in this group, return it + if approvals.len() == 1 { + return approvals[0].clone(); + } + + let mut amount = + approvals.iter().map(|approval| approval.amount).reduce(|a, b| (a + b)); + + if amount.is_none() { + error!( + "Failed to merge approvals due to error in amount reduction: {:?}", + approvals + ); + + // Set 0 approval if there's an error + amount = Some(Uint::ZERO); + } + + let amount = amount.unwrap(); + + RequiredApprovalDetails { + chain_id: approvals[0].chain_id, + token_address: approvals[0].token_address.clone(), + owner: approvals[0].owner.clone(), + target: approvals[0].target.clone(), + amount, + } + }) + .collect(); + + // Generate Transactions for the merged approvals + let (approval_transactions, failed): (Vec<_>, _) = + futures::stream::iter(merged_approvals.into_iter()) + .map(|approval| async move { + Ok::<_, SettlementEngineErrors<_, _>>( + self.generate_transaction_for_approval(&approval).await?, + ) + }) + .buffer_unordered(GENERATE_TRANSACTIONS_CONCURRENCY) + .collect::>() + .await + .into_iter() + .partition(Result::is_ok); + + if !failed.is_empty() { + error!("Failed to generate approval transactions: {:?}", failed); + } + + if approval_transactions.is_empty() { + info!("No Approval Transactions Required"); + return Ok(Vec::new()); + } + + Ok(approval_transactions + .into_iter() + .map(Result::unwrap) + .filter(Option::is_some) + .map(Option::unwrap) + .collect()) + } +} + +#[derive(Error, Debug)] +pub enum SettlementEngineErrors { + #[error("Error generating transactions: {0}")] + GenerateTransactionsError(Source::GenerateRouteTransactionsError), + + #[error("Error getting token amount from value in USD: {0}")] + GetTokenAmountFromValueInUsdError(Errors), + + #[error("No transactions generated")] + NoTransactionsGenerated, + + #[error("ERC20 Utils not found for chain_id: {0} and token_address: {1}")] + ERC20UtilsNotFound(u32, String), + + #[error("Error parsing address: {0}")] + InvalidAddressError(FromHexError), + + #[error("Error while calling ERC20 Contract: {0}")] + AlloyError(#[from] alloy::contract::Error), +} + +pub fn generate_erc20_instance_map( + config: &Config, +) -> Result, Vec> +{ + let (result, failed): (Vec<_>, _) = config + .tokens + .iter() + .flat_map(|(_, token)| { + token + .by_chain + .iter() + .map( + |(chain_id, chain_specific_config)| -> Result< + ((u32, String), IERC20Instance, RootProvider>>), + GenerateERC20InstanceMapErrors, + > { + let rpc_url = &config + .chains + .iter() + .find(|(&id, _)| id == *chain_id) + .ok_or(GenerateERC20InstanceMapErrors::ChainNoFound(*chain_id))? + .1 + .rpc_url; + + let provider = + ProviderBuilder::new().on_http(Url::parse(rpc_url.as_str()).map_err( + |_| GenerateERC20InstanceMapErrors::InvalidRPCUrl(rpc_url.clone()), + )?); + + let token_address = + chain_specific_config.address.clone().parse().map_err(|err| { + GenerateERC20InstanceMapErrors::InvalidAddressError( + chain_specific_config.address.clone(), + err, + ) + })?; + + let token = blockchain::ERC20Contract::new(token_address, provider); + + Ok(((*chain_id, chain_specific_config.address.clone()), token)) + }, + ) + .collect::>() + }) + .partition(Result::is_ok); + + if failed.len() != 0 { + return Err(failed.into_iter().map(Result::unwrap_err).collect()); + } + + Ok(result.into_iter().map(Result::unwrap).collect()) +} + +#[derive(Error, Debug)] +pub enum GenerateERC20InstanceMapErrors { + #[error("Chain not found for id: {0}")] + ChainNoFound(u32), + + #[error("Error parsing RPC URL: {0}")] + InvalidRPCUrl(String), + + #[error("Error parsing address: {0}")] + InvalidAddressError(String, FromHexError), +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::env; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + + use alloy::primitives::U256; + use async_trait::async_trait; + use derive_more::Display; + use thiserror::Error; + + use config::{Config, get_sample_config}; + use storage::{KeyValueStore, RedisClientError}; + + use crate::{BridgeResult, BungeeClient, CoingeckoClient}; + use crate::settlement_engine::{ + generate_erc20_instance_map, SettlementEngine, SettlementEngineErrors, TransactionType, + }; + use crate::source::{EthereumTransaction, RequiredApprovalDetails}; + + #[derive(Error, Debug, Display)] + struct Err; + + #[derive(Default, Debug)] + struct KVStore { + map: Mutex>, + } + + #[async_trait] + impl KeyValueStore for KVStore { + type Error = Err; + + async fn get(&self, k: &String) -> Result { + match self.map.lock().unwrap().get(k) { + Some(v) => Ok(v.clone()), + None => Result::Err(Err), + } + } + + async fn get_multiple(&self, _: &Vec) -> Result, Self::Error> { + unimplemented!() + } + + async fn set(&self, k: &String, v: &String, _: Duration) -> Result<(), Self::Error> { + self.map + .lock() + .unwrap() + .insert((*k.clone()).parse().unwrap(), (*v.clone()).parse().unwrap()); + Ok(()) + } + + async fn set_multiple(&self, _: &Vec<(String, String)>) -> Result<(), Self::Error> { + unimplemented!() + } + + async fn get_all_keys(&self) -> Result, RedisClientError> { + unimplemented!() + } + + async fn get_all_key_values(&self) -> Result, RedisClientError> { + unimplemented!() + } + } + + fn setup_config<'a>() -> Config { + get_sample_config() + } + + fn setup(config: &Arc) -> SettlementEngine> { + let bungee_client = BungeeClient::new( + &"https://api.socket.tech/v2".to_string(), + &env::var("BUNGEE_API_KEY").unwrap().to_string(), + ) + .unwrap(); + + let client = CoingeckoClient::new( + config.coingecko.base_url.clone(), + env::var("COINGECKO_API_KEY").unwrap(), + KVStore::default(), + Duration::from_secs(config.coingecko.expiry_sec), + ); + + let erc20_instance_map = generate_erc20_instance_map(&config).unwrap(); + + let settlement_engine = + SettlementEngine::new(Arc::clone(config), bungee_client, client, erc20_instance_map); + + return settlement_engine; + } + + const TEST_OWNER_WALLET: &str = "0xe0E67a6F478D7ED604Cf528bDE6C3f5aB034b59D"; + const TOKEN_ADDRESS_USDC_42161: &str = "0xaf88d065e77c8cC2239327C5EDb3A432268e5831"; + // Target has currently 0 approval for token on arbitrum mainnet for USDC token + const TARGET_NO_APPROVAL_BY_OWNER_ON_42161_FOR_USDC: &str = + "0x22f966A213288B29bB1F650a923E8f70dAd2515A"; + // Target has 100 approval for token on arbitrum mainnet for USDC token + const TARGET_100_USDC_APPROVAL_BY_OWNER_ON_42161_FOR_USDC: &str = + "0xE4ec34b790e2fCabF37Cc9938A34327ddEadDc78"; + + #[tokio::test] + async fn test_should_generate_required_approval_transaction() { + let config = Arc::new(setup_config()); + let engine = setup(&config); + + let required_approval_data = RequiredApprovalDetails { + chain_id: 42161, + token_address: TOKEN_ADDRESS_USDC_42161.to_string(), + owner: TEST_OWNER_WALLET.to_string(), + target: TARGET_NO_APPROVAL_BY_OWNER_ON_42161_FOR_USDC.to_string(), + amount: U256::from(100), + }; + + let transaction = engine + .generate_transaction_for_approval(&required_approval_data) + .await + .unwrap() + .unwrap(); + + assert_eq!(transaction.transaction_type, TransactionType::Approval); + assert_eq!(transaction.transaction.to, TOKEN_ADDRESS_USDC_42161); + assert_eq!(transaction.transaction.value, U256::ZERO); + assert_eq!( + transaction.transaction.calldata, + engine + .erc20_instance_map + .get(&(required_approval_data.chain_id, required_approval_data.token_address)) + .expect("ERC20 Utils not found") + .approve( + TARGET_NO_APPROVAL_BY_OWNER_ON_42161_FOR_USDC + .to_string() + .parse() + .expect("Invalid address"), + U256::from(100) + ) + .calldata() + .to_string() + ) + } + + #[tokio::test] + async fn test_should_take_existing_approval_into_consideration_while_building_approval_data() { + let config = Arc::new(setup_config()); + let engine = setup(&config); + + let required_approval_data = RequiredApprovalDetails { + chain_id: 42161, + token_address: TOKEN_ADDRESS_USDC_42161.to_string(), + owner: TEST_OWNER_WALLET.to_string(), + target: TARGET_100_USDC_APPROVAL_BY_OWNER_ON_42161_FOR_USDC.to_string(), + amount: U256::from(150), + }; + + let transaction = engine + .generate_transaction_for_approval(&required_approval_data) + .await + .unwrap() + .unwrap(); + + assert_eq!(transaction.transaction_type, TransactionType::Approval); + assert_eq!(transaction.transaction.to, TOKEN_ADDRESS_USDC_42161); + assert_eq!(transaction.transaction.value, U256::ZERO); + assert_eq!( + transaction.transaction.calldata, + engine + .erc20_instance_map + .get(&(required_approval_data.chain_id, required_approval_data.token_address)) + .expect("ERC20 Utils not found") + .approve( + TARGET_100_USDC_APPROVAL_BY_OWNER_ON_42161_FOR_USDC + .to_string() + .parse() + .expect("Invalid address"), + U256::from(50) + ) + .calldata() + .to_string() + ) + } + + #[tokio::test] + async fn test_should_generate_approvals_for_multiple_required_transactions() { + let config = Arc::new(setup_config()); + let engine = setup(&config); + + let required_approval_datas = vec![ + RequiredApprovalDetails { + chain_id: 42161, + token_address: TOKEN_ADDRESS_USDC_42161.to_string(), + owner: TEST_OWNER_WALLET.to_string(), + target: TARGET_NO_APPROVAL_BY_OWNER_ON_42161_FOR_USDC.to_string(), + amount: U256::from(100), + }, + RequiredApprovalDetails { + chain_id: 42161, + token_address: TOKEN_ADDRESS_USDC_42161.to_string(), + owner: TEST_OWNER_WALLET.to_string(), + target: TARGET_100_USDC_APPROVAL_BY_OWNER_ON_42161_FOR_USDC.to_string(), + amount: U256::from(150), + }, + ]; + + let mut transactions = + engine.generate_transactions_for_approvals(&required_approval_datas).await.unwrap(); + transactions.sort_by(|a, b| a.transaction.calldata.cmp(&b.transaction.calldata)); + + assert_eq!(transactions[0].transaction_type, TransactionType::Approval); + assert_eq!(transactions[0].transaction.to, TOKEN_ADDRESS_USDC_42161); + assert_eq!(transactions[0].transaction.value, U256::ZERO); + assert_eq!( + transactions[0].transaction.calldata, + engine + .erc20_instance_map + .get(&( + required_approval_datas[0].chain_id, + required_approval_datas[0].token_address.clone() + )) + .expect("ERC20 Utils not found") + .approve( + TARGET_NO_APPROVAL_BY_OWNER_ON_42161_FOR_USDC + .to_string() + .parse() + .expect("Invalid address"), + U256::from(100) + ) + .calldata() + .to_string() + ); + + assert_eq!(transactions[1].transaction_type, TransactionType::Approval); + assert_eq!(transactions[1].transaction.to, TOKEN_ADDRESS_USDC_42161); + assert_eq!(transactions[1].transaction.value, U256::ZERO); + assert_eq!( + transactions[1].transaction.calldata, + engine + .erc20_instance_map + .get(&( + required_approval_datas[1].chain_id, + required_approval_datas[1].token_address.clone() + )) + .expect("ERC20 Utils not found") + .approve( + TARGET_100_USDC_APPROVAL_BY_OWNER_ON_42161_FOR_USDC + .to_string() + .parse() + .expect("Invalid address"), + U256::from(50) + ) + .calldata() + .to_string() + ); + } + + #[tokio::test] + async fn test_should_merge_approvals_from_same_owner_to_same_spender_for_same_token_and_chain() + { + let config = Arc::new(setup_config()); + let engine = setup(&config); + + let required_approval_datas = vec![ + RequiredApprovalDetails { + chain_id: 42161, + token_address: TOKEN_ADDRESS_USDC_42161.to_string(), + owner: TEST_OWNER_WALLET.to_string(), + target: TARGET_NO_APPROVAL_BY_OWNER_ON_42161_FOR_USDC.to_string(), + amount: U256::from(100), + }, + RequiredApprovalDetails { + chain_id: 42161, + token_address: TOKEN_ADDRESS_USDC_42161.to_string(), + owner: TEST_OWNER_WALLET.to_string(), + target: TARGET_NO_APPROVAL_BY_OWNER_ON_42161_FOR_USDC.to_string(), + amount: U256::from(150), + }, + ]; + + let transactions = + engine.generate_transactions_for_approvals(&required_approval_datas).await.unwrap(); + + assert_eq!(transactions.len(), 1); + assert_eq!(transactions[0].transaction_type, TransactionType::Approval); + assert_eq!(transactions[0].transaction.to, TOKEN_ADDRESS_USDC_42161); + assert_eq!(transactions[0].transaction.value, U256::ZERO); + assert_eq!( + transactions[0].transaction.calldata, + engine + .erc20_instance_map + .get(&( + required_approval_datas[0].chain_id, + required_approval_datas[0].token_address.clone() + )) + .expect("ERC20 Utils not found") + .approve( + TARGET_NO_APPROVAL_BY_OWNER_ON_42161_FOR_USDC + .to_string() + .parse() + .expect("Invalid address"), + U256::from(250) + ) + .calldata() + .to_string() + ); + } + + #[tokio::test] + async fn test_should_generate_transactions_for_bridging_routes() { + let config = Arc::new(setup_config()); + let engine = setup(&config); + + let bridge_result = BridgeResult::build( + &config, + &1, + &42161, + &"USDC".to_string(), + &"USDC".to_string(), + false, + 100.0, + TEST_OWNER_WALLET.to_string(), + TEST_OWNER_WALLET.to_string(), + ) + .expect("Failed to build bridge result"); + + let transactions = engine + .generate_transactions(vec![bridge_result]) + .await + .expect("Failed to generate transactions"); + + assert_eq!(transactions.len(), 2); + assert_eq!(transactions[0].transaction_type, TransactionType::Approval); + assert_eq!(transactions[1].transaction_type, TransactionType::Bungee); + } + + #[tokio::test] + async fn test_should_generate_transactions_for_swaps() { + let config = Arc::new(setup_config()); + let engine = setup(&config); + + let bridge_result = BridgeResult::build( + &config, + &1, + &1, + &"USDC".to_string(), + &"USDT".to_string(), + false, + 100.0, + TEST_OWNER_WALLET.to_string(), + TEST_OWNER_WALLET.to_string(), + ) + .expect("Failed to build bridge result"); + + let transactions = engine + .generate_transactions(vec![bridge_result]) + .await + .expect("Failed to generate transactions"); + + assert_eq!(transactions.len(), 2); + assert_eq!(transactions[0].transaction_type, TransactionType::Approval); + assert_eq!(transactions[1].transaction_type, TransactionType::Bungee); + } + + fn assert_is_send() {} + + #[test] + fn test_custom_types_must_be_send() { + assert_is_send::>>(); + assert_is_send::(); + assert_is_send::(); + assert_is_send::>>(); + assert_is_send::(); + assert_is_send::< + Result< + (Vec, Vec), + SettlementEngineErrors>, + >, + >() + } +} diff --git a/crates/routing-engine/src/source/bungee/mod.rs b/crates/routing-engine/src/source/bungee/mod.rs index 29bc309..0b73a20 100644 --- a/crates/routing-engine/src/source/bungee/mod.rs +++ b/crates/routing-engine/src/source/bungee/mod.rs @@ -1,14 +1,17 @@ -use derive_more::Display; +use std::str::FromStr; + +use async_trait::async_trait; use log::{error, info}; use reqwest; use reqwest::header; use ruint::aliases::U256; +use ruint::Uint; use thiserror::Error; use types::*; -use crate::source::{Calldata, RouteSource}; use crate::{CostType, Route}; +use crate::source::{EthereumTransaction, RequiredApprovalDetails, RouteSource}; mod types; @@ -19,10 +22,7 @@ pub struct BungeeClient { } impl BungeeClient { - pub fn new<'config>( - base_url: &'config String, - api_key: &'config String, - ) -> Result { + pub fn new(base_url: &String, api_key: &String) -> Result { let mut headers = header::HeaderMap::new(); headers.insert("API-KEY", header::HeaderValue::from_str(api_key)?); @@ -45,6 +45,34 @@ impl BungeeClient { serde_json::from_str(&raw_text) .map_err(|err| BungeeClientError::DeserializationError(raw_text, err)) } + + async fn build_tx( + &self, + params: BuildTxRequest, + ) -> Result, BungeeClientError> { + let response = + self.client.post(self.base_url.to_owned() + "/build-tx").json(¶ms).send().await?; + let raw_text = response.text().await?; + + serde_json::from_str(&raw_text) + .map_err(|err| BungeeClientError::DeserializationError(raw_text, err)) + } +} + +// Errors +#[derive(Debug, Error)] +pub enum BungeeClientError { + #[error("Error while deserializing response: {0}")] + DeserializationError(String, serde_json::Error), + + #[error("Error while Serializing Body: {0:?} with error {1:?}")] + BodySerializationError(BuildTxRequest, serde_json::Error), + + #[error("Error while making request: Request error: {}", _0)] + RequestError(#[from] reqwest::Error), + + #[error("No route returned by Bungee API")] + NoRouteError, } const ADDRESS_ZERO: &'static str = "0x0000000000000000000000000000000000000000"; @@ -67,20 +95,34 @@ pub enum BungeeFetchRouteCostError { EstimationTypeNotImplementedError(#[from] CostType), } -#[derive(Error, Debug, Display)] -pub struct GenerateRouteCalldataError; +#[derive(Error, Debug)] +pub enum GenerateRouteTransactionsError { + #[error("Error while fetching least route and cost in USD: {0}")] + FetchRouteCostError(#[from] BungeeFetchRouteCostError), + + #[error("Error while calling build transaction api: {0}")] + BungeeClientError(#[from] BungeeClientError), + #[error("Error while parsing U256: {0}")] + InvalidU256Error(String), +} + +#[async_trait] impl RouteSource for BungeeClient { type FetchRouteCostError = BungeeFetchRouteCostError; - type GenerateRouteCalldataError = GenerateRouteCalldataError; + type GenerateRouteTransactionsError = GenerateRouteTransactionsError; + + type BaseRouteType = serde_json::Value; - async fn fetch_least_route_cost_in_usd( + async fn fetch_least_cost_route_and_cost_in_usd( &self, - route: &Route<'_>, - from_token_amount: U256, + route: &Route, + from_token_amount: &U256, + sender_address: Option<&String>, + recipient_address: Option<&String>, estimation_type: &CostType, - ) -> Result { + ) -> Result<(Self::BaseRouteType, f64), Self::FetchRouteCostError> { info!("Fetching least route cost in USD for route {:?} with token amount {} and estimation type {}", route, from_token_amount, estimation_type); // Build GetQuoteRequest @@ -111,8 +153,8 @@ impl RouteSource for BungeeClient { to_chain_id: route.to_chain.id, to_token_address: to_token.address.clone(), from_amount: from_token_amount.to_string(), - user_address: ADDRESS_ZERO.to_string(), - recipient: ADDRESS_ZERO.to_string(), + user_address: sender_address.unwrap_or(&ADDRESS_ZERO.to_string()).clone(), + recipient: recipient_address.unwrap_or(&ADDRESS_ZERO.to_string()).clone(), unique_routes_per_bridge: false, is_contract_call: route.is_smart_contract_deposit, }; @@ -123,36 +165,123 @@ impl RouteSource for BungeeClient { return Err(BungeeFetchRouteCostError::FailureIndicatedInResponseError()); } + let get_route_fee = |route: &serde_json::Value| { + let total_gas_fees_in_usd = route.get("totalGasFeesInUsd")?.as_f64()?; + let input_value_in_usd = route.get("inputValueInUsd")?.as_f64()?; + let output_value_in_usd = route.get("outputValueInUsd")?.as_f64()?; + + match estimation_type { + CostType::Fee => { + Some(total_gas_fees_in_usd + input_value_in_usd - output_value_in_usd) + } + } + }; + // Find the minimum cost across all routes - let route_costs_in_usd: Vec = response + let (route_costs_in_usd, failed): (Vec<(serde_json::Value, Option)>, _) = response .result .routes - .iter() - .map(|route| match estimation_type { - CostType::Fee => Some( - route.total_gas_fees_in_usd + route.input_value_in_usd? - - route.output_value_in_usd?, - ), + .into_iter() + .map(|route| { + let fee = get_route_fee(&route); + (route, fee) }) - .filter(|cost| cost.is_some()) - .map(|cost| cost.unwrap()) - .collect(); + .partition(|(_, r)| r.is_some()); + + let route_costs_in_usd: Vec<(serde_json::Value, f64)> = + route_costs_in_usd.into_iter().map(|(route, r)| (route, r.unwrap())).collect(); + + if failed.len() > 0 { + let failed: Vec = + failed.into_iter().map(|(route, _)| route).collect(); + error!("Error while calculating route costs in usd: {:?}", failed); + } if route_costs_in_usd.len() == 0 { error!("No valid routes returned by Bungee API for route {:?}", route); return Err(BungeeFetchRouteCostError::NoValidRouteError()); } - info!("Route costs in USD: {:?} for route {:?}", route_costs_in_usd, route); + info!("Route costs in USD: {:?}", route_costs_in_usd.len()); + + let min_cost_route_and_fee = + route_costs_in_usd.into_iter().min_by(|a, b| a.1.total_cmp(&b.1)); - Ok(route_costs_in_usd.into_iter().min_by(|a, b| a.total_cmp(b)).unwrap()) + return if let Some(min_cost_route_and_fee) = min_cost_route_and_fee { + Ok(min_cost_route_and_fee) + } else { + error!("No valid routes returned after applying min function"); + Err(BungeeFetchRouteCostError::NoValidRouteError()) + }; } - async fn generate_route_calldata( + async fn generate_route_transactions( &self, - _: &Route<'_>, - ) -> Result { - todo!() + route: &Route, + amount: &U256, + sender_address: &String, + recipient_address: &String, + ) -> Result< + (Vec, Vec), + Self::GenerateRouteTransactionsError, + > { + info!( + "Generating cheapest route transactions for route {:?} with amount {}", + route, amount + ); + + let (bungee_route, _) = self + .fetch_least_cost_route_and_cost_in_usd( + route, + amount, + Some(sender_address), + Some(recipient_address), + &CostType::Fee, + ) + .await?; + + info!("Retrieved bungee route {:?} for route {:?}", bungee_route, route); + + let tx = self.build_tx(BuildTxRequest { route: bungee_route }).await?; + if !tx.success { + error!("Failure indicated in bungee response"); + + return Err(GenerateRouteTransactionsError::FetchRouteCostError( + BungeeFetchRouteCostError::FailureIndicatedInResponseError(), + )); + } + + let tx = tx.result; + info!("Returned transaction from bungee {:?}", tx); + + let transactions = vec![EthereumTransaction { + from: sender_address.clone(), + to: tx.tx_target, + value: Uint::from_str(&tx.value).map_err(|err| { + error!("Error while parsing tx data: {}", err); + GenerateRouteTransactionsError::InvalidU256Error(tx.value) + })?, + calldata: tx.tx_data, + }]; + + info!("Generated transactions {:?}", transactions); + + let approvals = vec![RequiredApprovalDetails { + chain_id: tx.chain_id, + token_address: tx.approval_data.approval_token_address, + owner: tx.approval_data.owner, + target: tx.approval_data.allowance_target, + amount: Uint::from_str(&tx.approval_data.minimum_approval_amount).map_err(|err| { + error!("Error while parsing approval data: {}", err); + GenerateRouteTransactionsError::InvalidU256Error( + tx.approval_data.minimum_approval_amount, + ) + })?, + }]; + + info!("Generated approvals {:?}", approvals); + + Ok((transactions, approvals)) } } @@ -162,12 +291,12 @@ mod tests { use ruint::Uint; - use config::get_sample_config; use config::Config; + use config::get_sample_config; + use crate::{BungeeClient, CostType, Route}; use crate::source::bungee::types::GetQuoteRequest; use crate::source::RouteSource; - use crate::{BungeeClient, CostType, Route}; fn setup() -> (Config, BungeeClient) { let config = get_sample_config(); @@ -208,18 +337,41 @@ mod tests { async fn test_fetch_least_cost_route() { let (config, client) = setup(); - let route = Route { - from_chain: &config.chains.get(&1).unwrap(), - to_chain: &config.chains.get(&42161).unwrap(), - from_token: &config.tokens.get(&"USDC".to_string()).unwrap(), - to_token: &config.tokens.get(&"USDC".to_string()).unwrap(), - is_smart_contract_deposit: false, - }; - let least_route_cost = client - .fetch_least_route_cost_in_usd(&route, Uint::from(100000000), &CostType::Fee) + let route = + Route::build(&config, &1, &42161, &"USDC".to_string(), &"USDC".to_string(), false) + .unwrap(); + let (_, least_route_cost) = client + .fetch_least_cost_route_and_cost_in_usd( + &route, + &Uint::from(100000000), + None, + None, + &CostType::Fee, + ) .await .unwrap(); assert_eq!(least_route_cost > 0.0, true); } + + #[tokio::test] + async fn test_generate_route_transactions() { + let (config, client) = setup(); + + let route = + Route::build(&config, &1, &42161, &"USDC".to_string(), &"USDC".to_string(), false) + .unwrap(); + + let address = "0x90f05C1E52FAfB4577A4f5F869b804318d56A1ee".to_string(); + + let token_amount = Uint::from(100000000); + let result = + client.generate_route_transactions(&route, &token_amount, &address, &address).await; + + assert!(result.is_ok()); + + let (transactions, approvals) = result.unwrap(); + assert_eq!(transactions.len(), 1); + assert_eq!(approvals.len(), 1); + } } diff --git a/crates/routing-engine/src/source/bungee/types.rs b/crates/routing-engine/src/source/bungee/types.rs index df70da4..e4a71ca 100644 --- a/crates/routing-engine/src/source/bungee/types.rs +++ b/crates/routing-engine/src/source/bungee/types.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use thiserror::Error; #[derive(Deserialize, Debug)] pub struct BungeeResponse { @@ -38,7 +37,7 @@ pub struct BungeeTokenResponse { #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] pub struct GetQuoteResponse { - pub routes: Vec, + pub routes: Vec, pub from_chain_id: Option, pub from_asset: Option, pub to_chain_id: Option, @@ -46,37 +45,6 @@ pub struct GetQuoteResponse { pub refuel: Option, } -#[derive(Deserialize, Serialize, Debug, Clone, Default)] -#[serde(rename_all = "camelCase")] -pub struct GetQuoteResponseRoute { - pub route_id: String, - pub is_only_swap_route: bool, - pub from_amount: String, - pub to_amount: String, - #[serde(skip_serializing)] - pub used_bridge_names: Vec, - pub total_user_tx: u32, - pub total_gas_fees_in_usd: f64, - pub recipient: String, - pub sender: String, - pub received_value_in_usd: Option, - pub input_value_in_usd: Option, - pub output_value_in_usd: Option, - pub service_time: u32, - pub max_service_time: u32, - #[serde(skip_serializing)] - pub integrator_fee: GetQuoteResponseRouteIntegratorFee, - pub t2b_receiver_address: Option, -} - -#[derive(Deserialize, Serialize, Debug, Clone, Default)] -#[serde(rename_all = "camelCase")] -pub struct GetQuoteResponseRouteIntegratorFee { - pub fee_taker_address: Option, - pub amount: Option, - pub asset: BungeeTokenResponse, -} - #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] pub struct GetQuoteResponseRefuel { @@ -100,15 +68,31 @@ pub struct GetQuoteResponseRefuelGasFee { pub gas_amount: String, } -// Errors -#[derive(Debug, Error)] -pub enum BungeeClientError { - #[error("Error while deserializing response: {0}")] - DeserializationError(String, serde_json::Error), +// POST /build-tx +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct BuildTxRequest { + pub(crate) route: serde_json::Value, +} - #[error("Error while making request: Request error: {}", _0)] - RequestError(#[from] reqwest::Error), +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct BuildTxResponse { + pub user_tx_type: String, + pub tx_target: String, + pub chain_id: u32, + pub tx_data: String, + pub tx_type: String, + pub value: String, + pub total_user_tx: Option, + pub approval_data: BuildTxResponseApprovalData, +} - #[error("No route returned by Bungee API")] - NoRouteError, +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct BuildTxResponseApprovalData { + pub minimum_approval_amount: String, + pub approval_token_address: String, + pub allowance_target: String, + pub owner: String, } diff --git a/crates/routing-engine/src/source/mod.rs b/crates/routing-engine/src/source/mod.rs index 0df6090..2bfeead 100644 --- a/crates/routing-engine/src/source/mod.rs +++ b/crates/routing-engine/src/source/mod.rs @@ -1,27 +1,54 @@ use std::error::Error; use std::fmt::Debug; +use async_trait::async_trait; use ruint::aliases::U256; +use serde::Serialize; use crate::{CostType, Route}; pub mod bungee; -type Calldata = String; +#[derive(Debug, Serialize)] +pub struct EthereumTransaction { + pub from: String, + pub to: String, + pub value: U256, + pub calldata: String, +} + +#[derive(Debug, Clone)] +pub struct RequiredApprovalDetails { + pub chain_id: u32, + pub token_address: String, + pub owner: String, + pub target: String, + pub amount: U256, +} -pub trait RouteSource: Debug { - type FetchRouteCostError: Debug + Error; - type GenerateRouteCalldataError: Debug + Error; +#[async_trait] +pub trait RouteSource: Debug + Send + Sync { + type FetchRouteCostError: Debug + Error + Send + Sync; + type GenerateRouteTransactionsError: Debug + Error + Send + Sync; + type BaseRouteType: Debug + Send + Sync; - fn fetch_least_route_cost_in_usd( + async fn fetch_least_cost_route_and_cost_in_usd( &self, route: &Route, - from_token_amount: U256, + from_token_amount: &U256, + sender_address: Option<&String>, + recipient_address: Option<&String>, estimation_type: &CostType, - ) -> impl futures::Future>; + ) -> Result<(Self::BaseRouteType, f64), Self::FetchRouteCostError>; - fn generate_route_calldata( + async fn generate_route_transactions( &self, route: &Route, - ) -> impl futures::Future>; + amount: &U256, + sender_address: &String, + recipient_address: &String, + ) -> Result< + (Vec, Vec), + Self::GenerateRouteTransactionsError, + >; } diff --git a/crates/routing-engine/src/token_price/coingecko.rs b/crates/routing-engine/src/token_price/coingecko.rs index 23f4779..896c93b 100644 --- a/crates/routing-engine/src/token_price/coingecko.rs +++ b/crates/routing-engine/src/token_price/coingecko.rs @@ -2,6 +2,7 @@ use std::fmt::Debug; use std::num::ParseFloatError; use std::time::Duration; +use async_trait::async_trait; use derive_more::Display; use log::{error, info}; use reqwest::{header, StatusCode}; @@ -14,24 +15,24 @@ use crate::token_price::coingecko::CoingeckoClientError::RequestFailed; use crate::token_price::TokenPriceProvider; #[derive(Debug)] -pub struct CoingeckoClient<'config, KVStore: KeyValueStore> { - base_url: &'config String, +pub struct CoingeckoClient { + base_url: String, client: reqwest::Client, - cache: &'config KVStore, + cache: KVStore, key_expiry: Duration, } -impl<'config, KVStore: KeyValueStore> CoingeckoClient<'config, KVStore> { +impl CoingeckoClient { pub fn new( - base_url: &'config String, - api_key: &'config String, - cache: &'config KVStore, + base_url: String, + api_key: String, + cache: KVStore, key_expiry: Duration, - ) -> CoingeckoClient<'config, KVStore> { + ) -> CoingeckoClient { let mut headers = header::HeaderMap::new(); headers.insert( "x-cg-pro-api-key", - header::HeaderValue::from_str(api_key) + header::HeaderValue::from_str(&api_key) .expect("Error while building header value Invalid CoinGecko API Key"), ); @@ -70,7 +71,8 @@ impl<'config, KVStore: KeyValueStore> CoingeckoClient<'config, KVStore> { } } -impl<'config, KVStore: KeyValueStore> TokenPriceProvider for CoingeckoClient<'config, KVStore> { +#[async_trait] +impl TokenPriceProvider for CoingeckoClient { type Error = CoingeckoClientError; async fn get_token_price(&self, token_symbol: &String) -> Result { @@ -134,17 +136,18 @@ struct CoinsIdResponseMarketDataCurrentPrice { #[cfg(test)] mod tests { - use std::cell::RefCell; use std::collections::HashMap; use std::env; use std::fmt::Debug; + use std::sync::Mutex; use std::time::Duration; + use async_trait::async_trait; use derive_more::Display; use thiserror::Error; use config::{Config, get_sample_config}; - use storage::KeyValueStore; + use storage::{KeyValueStore, RedisClientError}; use crate::CoingeckoClient; use crate::token_price::TokenPriceProvider; @@ -154,32 +157,42 @@ mod tests { #[derive(Default, Debug)] struct KVStore { - map: RefCell>, + map: Mutex>, } + #[async_trait] impl KeyValueStore for KVStore { type Error = Err; async fn get(&self, k: &String) -> Result { - match self.map.borrow().get(k) { + match self.map.lock().unwrap().get(k) { Some(v) => Ok(v.clone()), None => Result::Err(Err), } } async fn get_multiple(&self, _: &Vec) -> Result, Self::Error> { - todo!() + unimplemented!() } async fn set(&self, k: &String, v: &String, _: Duration) -> Result<(), Self::Error> { self.map - .borrow_mut() + .lock() + .unwrap() .insert((*k.clone()).parse().unwrap(), (*v.clone()).parse().unwrap()); Ok(()) } async fn set_multiple(&self, _: &Vec<(String, String)>) -> Result<(), Self::Error> { - todo!() + unimplemented!() + } + + async fn get_all_keys(&self) -> Result, RedisClientError> { + unimplemented!() + } + + async fn get_all_key_values(&self) -> Result, RedisClientError> { + unimplemented!() } } @@ -196,9 +209,9 @@ mod tests { let store = KVStore::default(); let client = CoingeckoClient::new( - &config.coingecko.base_url, - &api_key, - &store, + config.coingecko.base_url.clone(), + api_key, + store, Duration::from_secs(config.coingecko.expiry_sec), ); let price = client.get_fresh_token_price(&"usd-coin".to_string()).await.unwrap(); @@ -215,21 +228,21 @@ mod tests { let store = KVStore::default(); let client = CoingeckoClient::new( - &config.coingecko.base_url, - &api_key, - &store, + config.coingecko.base_url.clone(), + api_key, + store, Duration::from_secs(config.coingecko.expiry_sec), ); let price = client.get_token_price(&"usd-coin".to_string()).await.unwrap(); assert!(price > 0.0); let key = "usd-coin_price".to_string(); - assert_eq!(store.get(&key).await.unwrap().parse::().unwrap(), price); + assert_eq!(client.cache.get(&key).await.unwrap().parse::().unwrap(), price); let price2 = client.get_token_price(&"usd-coin".to_string()).await.unwrap(); assert_eq!(price, price2); - store.set(&key, &"1.1".to_string(), Duration::from_secs(10)).await.unwrap(); + client.cache.set(&key, &"1.1".to_string(), Duration::from_secs(10)).await.unwrap(); let price = client.get_token_price(&"usd-coin".to_string()).await.unwrap(); assert_eq!(price, 1.1); diff --git a/crates/routing-engine/src/token_price/mod.rs b/crates/routing-engine/src/token_price/mod.rs index f72b224..8cc774c 100644 --- a/crates/routing-engine/src/token_price/mod.rs +++ b/crates/routing-engine/src/token_price/mod.rs @@ -1,17 +1,16 @@ use std::error::Error; use std::fmt::Debug; +use async_trait::async_trait; + pub use coingecko::CoingeckoClient; mod coingecko; pub mod utils; -pub trait TokenPriceProvider: Debug { - type Error: Error + Debug; +#[async_trait] +pub trait TokenPriceProvider: Debug + Send + Sync { + type Error: Error + Debug + Send + Sync; - fn get_token_price( - &self, - token_symbol: &String, - ) -> impl futures::Future>; + async fn get_token_price(&self, token_symbol: &String) -> Result; } - diff --git a/crates/routing-engine/src/token_price/utils.rs b/crates/routing-engine/src/token_price/utils.rs index 8deb9db..03e9131 100644 --- a/crates/routing-engine/src/token_price/utils.rs +++ b/crates/routing-engine/src/token_price/utils.rs @@ -1,4 +1,5 @@ -use derive_more::Display; +use std::fmt::Debug; + use ruint; use ruint::aliases::U256; use ruint::Uint; @@ -6,24 +7,21 @@ use thiserror::Error; use crate::token_price::TokenPriceProvider; -pub async fn get_token_amount_from_value_in_usd<'config, T: TokenPriceProvider>( - config: &'config config::Config, - token_price_provider: &'config T, - token_symbol: &'config String, +pub async fn get_token_amount_from_value_in_usd( + config: &config::Config, + token_price_provider: &T, + token_symbol: &String, chain_id: u32, value_in_usd: &f64, ) -> Result> { + let token_price = get_token_price(config, token_price_provider, token_symbol).await?; + let token_config = config.tokens.get(token_symbol); if token_config.is_none() { return Err(Errors::TokenConfigurationNotFound(token_symbol.clone())); } let token_config = token_config.unwrap(); - let token_price = token_price_provider - .get_token_price(&token_config.coingecko_symbol) - .await - .map_err(Errors::::TokenPriceProviderError)?; - let token_config_by_chain = token_config.by_chain.get(&chain_id); if token_config_by_chain.is_none() { return Err(Errors::TokenConfigurationNotFoundForChain(token_symbol.clone(), chain_id)); @@ -38,9 +36,28 @@ pub async fn get_token_amount_from_value_in_usd<'config, T: TokenPriceProvider>( Ok(token_amount_in_wei) } +pub async fn get_token_price( + config: &config::Config, + token_price_provider: &T, + token_symbol: &String, +) -> Result> { + let token_config = config.tokens.get(token_symbol); + if token_config.is_none() { + return Err(Errors::TokenConfigurationNotFound(token_symbol.clone())); + } + let token_config = token_config.unwrap(); + + let token_price = token_price_provider + .get_token_price(&token_config.coingecko_symbol) + .await + .map_err(Errors::::TokenPriceProviderError)?; + + return Ok(token_price); +} + #[derive(Debug, Error)] -pub enum Errors { - #[error("Token price provider error: {}", _0)] +pub enum Errors { + #[error("Token price provider error: {:?}", _0)] TokenPriceProviderError(#[from] T), #[error("Could not find token configuration for {}", _0)] @@ -54,6 +71,7 @@ pub enum Errors { mod tests { use std::fmt::Error; + use async_trait::async_trait; use ruint::Uint; use config::{Config, get_sample_config}; @@ -67,6 +85,7 @@ mod tests { #[derive(Debug)] struct TokenPriceProviderStub; + #[async_trait] impl TokenPriceProvider for TokenPriceProviderStub { type Error = Error; diff --git a/crates/storage/Cargo.toml b/crates/storage/Cargo.toml index ee2fa9b..9752993 100644 --- a/crates/storage/Cargo.toml +++ b/crates/storage/Cargo.toml @@ -21,4 +21,3 @@ log = "0.4.21" [dev-dependencies] serial_test = "3.1.1" uuid = "1.8.0" - diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index f5b26b3..0aea2b7 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -1,9 +1,10 @@ +use std::collections::HashMap; use std::error::Error; use std::fmt::Debug; -use std::future; use std::time::Duration; pub use ::redis::{ControlFlow, Msg}; +use async_trait::async_trait; use mongodb::bson::Document; pub use redis_client::{RedisClient, RedisClientError}; @@ -12,37 +13,28 @@ pub mod mongodb_client; mod redis_client; -pub trait KeyValueStore: Debug { - type Error: Error + Debug; +#[async_trait] +pub trait KeyValueStore: Debug + Send + Sync { + type Error: Error + Debug + Send + Sync; - fn get(&self, k: &String) -> impl future::Future>; + async fn get(&self, k: &String) -> Result; - fn get_multiple( - &self, - k: &Vec, - ) -> impl future::Future, Self::Error>>; + async fn get_multiple(&self, k: &Vec) -> Result, Self::Error>; - fn set( - &self, - k: &String, - v: &String, - expiry: Duration, - ) -> impl future::Future>; + async fn set(&self, k: &String, v: &String, expiry: Duration) -> Result<(), Self::Error>; - fn set_multiple( - &self, - kv: &Vec<(String, String)>, - ) -> impl future::Future>; + async fn set_multiple(&self, kv: &Vec<(String, String)>) -> Result<(), Self::Error>; + + async fn get_all_keys(&self) -> Result, RedisClientError>; + + async fn get_all_key_values(&self) -> Result, RedisClientError>; } -pub trait MessageQueue: Debug { - type Error: Error + Debug; +#[async_trait] +pub trait MessageQueue: Debug + Send + Sync { + type Error: Error + Debug + Send + Sync; - fn publish( - &self, - topic: &str, - message: &str, - ) -> impl future::Future>; + async fn publish(&self, topic: &str, message: &str) -> Result<(), Self::Error>; fn subscribe( &self, @@ -51,21 +43,15 @@ pub trait MessageQueue: Debug { ) -> Result<(), Self::Error>; } -pub trait DBProvider: Debug { - type Error: Error + Debug; +#[async_trait] +pub trait DBProvider: Debug + Send + Sync { + type Error: Error + Debug + Send + Sync; - fn create(&self, item: &Document) -> impl future::Future>; + async fn create(&self, item: &Document) -> Result<(), Self::Error>; - fn read( - &self, - query: &Document, - ) -> impl future::Future, Self::Error>>; + async fn read(&self, query: &Document) -> Result, Self::Error>; - fn update( - &self, - query: &Document, - update: &Document, - ) -> impl future::Future>; + async fn update(&self, query: &Document, update: &Document) -> Result<(), Self::Error>; - fn delete(&self, query: &Document) -> impl future::Future>; + async fn delete(&self, query: &Document) -> Result<(), Self::Error>; } diff --git a/crates/storage/src/mongodb_client.rs b/crates/storage/src/mongodb_client.rs index 1b7bcbe..0d03768 100644 --- a/crates/storage/src/mongodb_client.rs +++ b/crates/storage/src/mongodb_client.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use derive_more::Display; use mongodb::{ bson::{self, doc, Document}, @@ -65,6 +66,7 @@ impl MongoDBClient { } } +#[async_trait] impl DBProvider for MongoDBClient { type Error = DBError; diff --git a/crates/storage/src/redis_client.rs b/crates/storage/src/redis_client.rs index fde605e..4b77c48 100644 --- a/crates/storage/src/redis_client.rs +++ b/crates/storage/src/redis_client.rs @@ -1,11 +1,14 @@ -use crate::{KeyValueStore, MessageQueue}; -use log::info; -use redis::RedisError; -use redis::{self, aio, AsyncCommands, ControlFlow, Msg, PubSubCommands}; use std::collections::HashMap; use std::time::Duration; + +use async_trait::async_trait; +use log::info; +use redis::{self, aio, AsyncCommands, ControlFlow, Msg, PubSubCommands}; +use redis::RedisError; use thiserror::Error; +use crate::{KeyValueStore, MessageQueue}; + #[derive(Debug, Clone)] pub struct RedisClient { client: redis::Client, @@ -18,22 +21,9 @@ impl RedisClient { let connection = client.get_multiplexed_async_connection().await?; Ok(RedisClient { client, connection }) } - - pub async fn get_all_keys(&self) -> Result, RedisClientError> { - info!("Fetching all keys"); - let keys: Vec = self.connection.clone().keys("*").await?; - Ok(keys) - } - - pub async fn get_all_key_values(&self) -> Result, RedisClientError> { - info!("Fetching all key-value pairs"); - let keys = self.get_all_keys().await?; - let values: Vec = self.connection.clone().mget(&keys).await?; - let kv_pairs = keys.into_iter().zip(values.into_iter()).collect(); - Ok(kv_pairs) - } } +#[async_trait] impl KeyValueStore for RedisClient { type Error = RedisClientError; @@ -61,8 +51,23 @@ impl KeyValueStore for RedisClient { info!("Setting keys: {:?}", kv); self.connection.clone().mset(kv).await.map_err(RedisClientError::RedisLibraryError) } + + async fn get_all_keys(&self) -> Result, RedisClientError> { + info!("Fetching all keys"); + let keys: Vec = self.connection.clone().keys("*").await?; + Ok(keys) + } + + async fn get_all_key_values(&self) -> Result, RedisClientError> { + info!("Fetching all key-value pairs"); + let keys = self.get_all_keys().await?; + let values: Vec = self.connection.clone().mget(&keys).await?; + let kv_pairs = keys.into_iter().zip(values.into_iter()).collect(); + Ok(kv_pairs) + } } +#[async_trait] impl MessageQueue for RedisClient { type Error = RedisClientError; diff --git a/docker-compose.yml b/docker-compose.yml index ca61c49..b3c6029 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -44,4 +44,3 @@ services: volumes: redis_data: mongo_data: -