From 68b86372a7c90748561542e0898e8f5e1d4967fd Mon Sep 17 00:00:00 2001 From: amanraj1608 Date: Tue, 25 Jun 2024 15:39:10 +0400 Subject: [PATCH] feat: pub-sub cache update logic --- bin/reflux/src/main.rs | 42 ++++++--- crates/api/src/service_controller.rs | 10 +-- crates/routing-engine/src/engine.rs | 123 +++++++++++++++++++-------- crates/storage/src/redis.rs | 61 ++++++++++--- 4 files changed, 171 insertions(+), 65 deletions(-) diff --git a/bin/reflux/src/main.rs b/bin/reflux/src/main.rs index e69bc98..3ab1724 100644 --- a/bin/reflux/src/main.rs +++ b/bin/reflux/src/main.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use std::time::Duration; use axum::http::Method; -use log::{error, info}; +use log::info; use tokio; use tokio::signal; use tokio_cron_scheduler::{Job, JobScheduler}; @@ -11,11 +11,10 @@ use tower_http::cors::{Any, CorsLayer}; use account_aggregation::service::AccountAggregationService; use api::service_controller::ServiceController; use config::Config; -use routing_engine::{BungeeClient, CoingeckoClient, Indexer}; use routing_engine::engine::RoutingEngine; -use routing_engine::estimator::LinearRegressionEstimator; +use routing_engine::{BungeeClient, CoingeckoClient, Indexer}; use storage::mongodb_provider::MongoDBProvider; -use storage::RedisClient; +use storage::{ControlFlow, MessageQueue, RedisClient}; #[tokio::main] async fn main() { @@ -29,6 +28,8 @@ async fn main() { } else { run_server(config).await; } + + info!("Exiting Reflux"); } async fn run_server(config: Config) { @@ -70,8 +71,29 @@ async fn run_server(config: Config) { ); // Initialize routing engine - let buckets = config.buckets; - let routing_engine = RoutingEngine::new(account_service.clone(), buckets); + let buckets = config.buckets.clone(); + let redis_client = RedisClient::build(&config.infra.redis_url) + .await + .expect("Failed to instantiate redis client"); + let routing_engine = + Arc::new(RoutingEngine::new(account_service.clone(), buckets, redis_client.clone())); + + // Subscribe to cache update messages + let cache_update_topic = config.indexer_config.indexer_update_topic.clone(); + let routing_engine_clone = Arc::clone(&routing_engine); + tokio::spawn(async move { + let redis_client = redis_client.clone(); + redis_client + .subscribe(&cache_update_topic, move |_msg| { + info!("Received cache update notification"); + let routing_engine_clone = Arc::clone(&routing_engine_clone); + tokio::spawn(async move { + routing_engine_clone.refresh_cache().await; + }); + ControlFlow::<()>::Continue + }) + .unwrap(); + }); // API service controller let service_controller = ServiceController::new(account_service, routing_engine); @@ -143,10 +165,10 @@ async fn run_indexer(config: Config) { &token_price_provider, ); - match indexer.run::().await { - Ok(_) => info!("Indexer Job Completed"), - Err(e) => error!("Indexer Job Failed: {}", e), - }; + // match indexer.run::().await { + // Ok(_) => info!("Indexer Job Completed"), + // Err(e) => error!("Indexer Job Failed: {}", e), + // }; let next_tick = l.next_tick_for_job(uuid).await; match next_tick { diff --git a/crates/api/src/service_controller.rs b/crates/api/src/service_controller.rs index 2d057fe..369f927 100644 --- a/crates/api/src/service_controller.rs +++ b/crates/api/src/service_controller.rs @@ -10,11 +10,11 @@ pub struct ServiceController { } impl ServiceController { - pub fn new(account_service: AccountAggregationService, routing_engine: RoutingEngine) -> Self { - Self { - account_service: Arc::new(account_service), - routing_engine: Arc::new(routing_engine), - } + pub fn new( + account_service: AccountAggregationService, + routing_engine: Arc, + ) -> Self { + Self { account_service: Arc::new(account_service), routing_engine } } pub fn router(self) -> Router { diff --git a/crates/routing-engine/src/engine.rs b/crates/routing-engine/src/engine.rs index 170b13f..c36ce08 100644 --- a/crates/routing-engine/src/engine.rs +++ b/crates/routing-engine/src/engine.rs @@ -2,11 +2,15 @@ use account_aggregation::service::AccountAggregationService; use account_aggregation::types::Balance; use config::config::BucketConfig; use derive_more::Display; +use futures::stream::{self, StreamExt}; +use log::{error, info}; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; +use storage::RedisClient; +use tokio::sync::RwLock; use crate::estimator::{Estimator, LinearRegressionEstimator}; @@ -29,17 +33,42 @@ pub struct Route { #[derive(Debug)] struct PathQuery(u32, u32, String, String); +/// Routing Engine +/// This struct is responsible for calculating the best cost path for a user +#[derive(Debug, Clone)] pub struct RoutingEngine { buckets: Vec, aas_client: AccountAggregationService, - cache: Arc>, // (hash(bucket), hash(estimator_value) + cache: Arc>>, // (hash(bucket), hash(estimator_value) + redis_client: RedisClient, } impl RoutingEngine { - pub fn new(aas_client: AccountAggregationService, buckets: Vec) -> Self { - let cache = Arc::new(HashMap::new()); + pub fn new( + aas_client: AccountAggregationService, + buckets: Vec, + redis_client: RedisClient, + ) -> Self { + let cache = Arc::new(RwLock::new(HashMap::new())); + + Self { aas_client, cache, buckets, redis_client } + } - Self { aas_client, cache, buckets } + pub async fn refresh_cache(&self) { + match self.redis_client.get_all_key_values().await { + Ok(kv_pairs) => { + info!("Refreshing cache from Redis."); + let mut cache = self.cache.write().await; + cache.clear(); + for (key, value) in kv_pairs.iter() { + cache.insert(key.clone(), value.clone()); + } + info!("Cache refreshed with latest data from Redis."); + } + Err(e) => { + error!("Failed to refresh cache from Redis: {}", e); + } + } } /// Get the best cost path for a user @@ -66,21 +95,23 @@ impl RoutingEngine { // Sort direct assets by A^x / C^y, here x=2 and y=1 let x = 2.0; let y = 1.0; - let mut sorted_assets: Vec<(&&Balance, f64)> = direct_assets - .iter() - .map(|balance| { - let fee_cost = self.get_cached_data( - balance.amount_in_usd, // todo: edit - PathQuery( - balance.chain_id, - to_chain, - balance.token.to_string(), - to_token.to_string(), - ), - ); + let mut sorted_assets: Vec<(&&Balance, f64)> = stream::iter(direct_assets.iter()) + .then(|balance| async move { + let fee_cost = self + .get_cached_data( + balance.amount_in_usd, + PathQuery( + balance.chain_id, + to_chain, + balance.token.to_string(), + to_token.to_string(), + ), + ) + .await; (balance, fee_cost) }) - .collect(); + .collect() + .await; sorted_assets.sort_by(|a, b| { let cost_a = (a.0.amount.powf(x)) / (a.1.powf(y)); @@ -116,21 +147,23 @@ impl RoutingEngine { if total_amount_needed > 0.0 { let swap_assets: Vec<&Balance> = user_balances.iter().filter(|balance| balance.token != to_token).collect(); - let mut sorted_assets: Vec<(&&Balance, f64)> = swap_assets - .iter() - .map(|balance| { - let fee_cost = self.get_cached_data( - balance.amount_in_usd, - PathQuery( - balance.chain_id, - to_chain, - balance.token.clone(), - to_token.to_string(), - ), - ); + let mut sorted_assets: Vec<(&&Balance, f64)> = stream::iter(swap_assets.iter()) + .then(|balance| async move { + let fee_cost = self + .get_cached_data( + balance.amount_in_usd, + PathQuery( + balance.chain_id, + to_chain, + balance.token.clone(), + to_token.to_string(), + ), + ) + .await; (balance, fee_cost) }) - .collect(); + .collect() + .await; sorted_assets.sort_by(|a, b| { let cost_a = (a.0.amount.powf(x)) / (a.1.powf(y)); @@ -167,8 +200,7 @@ impl RoutingEngine { selected_assets } - fn get_cached_data(&self, target_amount: f64, path: PathQuery) -> f64 { - // filter the bucket of (chain, token) and sort with token_amount_from_usd + async fn get_cached_data(&self, target_amount: f64, path: PathQuery) -> f64 { let mut buckets_array: Vec = self .buckets .clone() @@ -190,11 +222,12 @@ impl RoutingEngine { }) .unwrap(); - // todo: should throw error if not found in bucket range or unwrap or with last bucket let mut s = DefaultHasher::new(); bucket.hash(&mut s); let key = s.finish().to_string(); - let value = self.cache.get(&key).unwrap(); + + let cache = self.cache.read().await; + let value = cache.get(&key).unwrap(); let estimator: Result = serde_json::from_str(value); @@ -227,6 +260,7 @@ mod tests { hash::{DefaultHasher, Hash, Hasher}, }; use storage::mongodb_provider::MongoDBProvider; + use tokio::sync::RwLock; #[tokio::test] async fn test_get_cached_data() { @@ -285,14 +319,21 @@ mod tests { "https://api.covalent.com".to_string(), "my-api".to_string(), ); - let routing_engine = RoutingEngine { aas_client, buckets, cache: Arc::new(cache) }; + let redis_client = + storage::RedisClient::build(&"redis://localhost:6379".to_string()).await.unwrap(); + let routing_engine = RoutingEngine { + aas_client, + buckets, + cache: Arc::new(RwLock::new(cache)), + redis_client, + }; // Define the target amount and path query let target_amount = 5.0; let path_query = PathQuery(1, 2, "USDC".to_string(), "ETH".to_string()); // Call get_cached_data and assert the result - let result = routing_engine.get_cached_data(target_amount, path_query); + let result = routing_engine.get_cached_data(target_amount, path_query).await; assert!(result > 0.0); assert_eq!(result, dummy_estimator.estimate(target_amount)); } @@ -359,7 +400,15 @@ mod tests { let mut cache = HashMap::new(); cache.insert(key, serialized_estimator.clone()); cache.insert(key2, serialized_estimator); - let routing_engine = RoutingEngine { aas_client, buckets, cache: Arc::new(cache) }; + + let redis_client = + storage::RedisClient::build(&"redis://localhost:6379".to_string()).await.unwrap(); + let routing_engine = RoutingEngine { + aas_client, + buckets, + cache: Arc::new(RwLock::new(cache)), + redis_client, + }; // should have USDT in bsc-mainnet > $0.5 let dummy_user_address = "0x00000ebe3fa7cb71aE471547C836E0cE0AE758c2"; diff --git a/crates/storage/src/redis.rs b/crates/storage/src/redis.rs index 2caa66f..d404ff7 100644 --- a/crates/storage/src/redis.rs +++ b/crates/storage/src/redis.rs @@ -1,16 +1,12 @@ -use std::time::Duration; - +use crate::{KeyValueStore, MessageQueue}; use log::info; -use redis; -use redis::{aio, AsyncCommands, Commands, ControlFlow, Msg, PubSubCommands}; use redis::RedisError; +use redis::{self, aio, AsyncCommands, ControlFlow, Msg, PubSubCommands}; +use std::collections::HashMap; +use std::time::Duration; use thiserror::Error; -use config; - -use crate::{KeyValueStore, MessageQueue}; - -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct RedisClient { client: redis::Client, connection: aio::MultiplexedConnection, @@ -22,12 +18,25 @@ impl RedisClient { let connection = client.get_multiplexed_async_connection().await?; Ok(RedisClient { client, connection }) } + + pub async fn get_all_keys(&self) -> Result, RedisClientError> { + info!("Fetching all keys"); + let keys: Vec = self.connection.clone().keys("*").await?; + Ok(keys) + } + + pub async fn get_all_key_values(&self) -> Result, RedisClientError> { + info!("Fetching all key-value pairs"); + let keys = self.get_all_keys().await?; + let values: Vec = self.connection.clone().mget(&keys).await?; + let kv_pairs = keys.into_iter().zip(values.into_iter()).collect(); + Ok(kv_pairs) + } } impl KeyValueStore for RedisClient { type Error = RedisClientError; - // Todo: This should return an option async fn get(&self, k: &String) -> Result { info!("Getting key: {}", k); self.connection.clone().get(k).await.map_err(RedisClientError::RedisLibraryError) @@ -90,7 +99,7 @@ pub enum RedisClientError { mod tests { use std::sync::mpsc::channel; - use tokio; + use tokio::{self}; use super::*; @@ -133,9 +142,35 @@ mod tests { assert_eq!(values, vec!["test_value1".to_string(), "test_value2".to_string()]); } + #[tokio::test] + async fn test_get_all_key_values() { + let client = setup().await; + + // Set some keys + client + .set(&"key1".to_string(), &"value1".to_string(), Duration::from_secs(60)) + .await + .unwrap(); + client + .set(&"key2".to_string(), &"value2".to_string(), Duration::from_secs(60)) + .await + .unwrap(); + client + .set(&"key3".to_string(), &"value3".to_string(), Duration::from_secs(60)) + .await + .unwrap(); + + // Fetch all key-values + let key_values = client.get_all_key_values().await.unwrap(); + + assert_eq!(key_values.get("key1").unwrap(), "value1"); + assert_eq!(key_values.get("key2").unwrap(), "value2"); + assert_eq!(key_values.get("key3").unwrap(), "value3"); + } + #[tokio::test] async fn test_pub_sub() { - let (tx, mut rx) = channel::(); + let (tx, rx) = channel::(); let client = setup().await; tokio::task::spawn_blocking(move || { @@ -149,7 +184,7 @@ mod tests { .unwrap(); }); - let mut client = setup().await; + let client = setup().await; client.publish("TOPIC", "HELLO").await.unwrap(); loop {