Skip to content

Commit

Permalink
feat: pub-sub cache update logic
Browse files Browse the repository at this point in the history
  • Loading branch information
AmanRaj1608 committed Jun 25, 2024
1 parent f8d1f6f commit 68b8637
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 65 deletions.
42 changes: 32 additions & 10 deletions bin/reflux/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;
use std::time::Duration;

use axum::http::Method;
use log::{error, info};
use log::info;
use tokio;
use tokio::signal;
use tokio_cron_scheduler::{Job, JobScheduler};
Expand All @@ -11,11 +11,10 @@ use tower_http::cors::{Any, CorsLayer};
use account_aggregation::service::AccountAggregationService;
use api::service_controller::ServiceController;
use config::Config;
use routing_engine::{BungeeClient, CoingeckoClient, Indexer};
use routing_engine::engine::RoutingEngine;
use routing_engine::estimator::LinearRegressionEstimator;
use routing_engine::{BungeeClient, CoingeckoClient, Indexer};
use storage::mongodb_provider::MongoDBProvider;
use storage::RedisClient;
use storage::{ControlFlow, MessageQueue, RedisClient};

#[tokio::main]
async fn main() {
Expand All @@ -29,6 +28,8 @@ async fn main() {
} else {
run_server(config).await;
}

info!("Exiting Reflux");
}

async fn run_server(config: Config) {
Expand Down Expand Up @@ -70,8 +71,29 @@ async fn run_server(config: Config) {
);

// Initialize routing engine
let buckets = config.buckets;
let routing_engine = RoutingEngine::new(account_service.clone(), buckets);
let buckets = config.buckets.clone();
let redis_client = RedisClient::build(&config.infra.redis_url)
.await
.expect("Failed to instantiate redis client");
let routing_engine =
Arc::new(RoutingEngine::new(account_service.clone(), buckets, redis_client.clone()));

// Subscribe to cache update messages
let cache_update_topic = config.indexer_config.indexer_update_topic.clone();
let routing_engine_clone = Arc::clone(&routing_engine);
tokio::spawn(async move {
let redis_client = redis_client.clone();
redis_client
.subscribe(&cache_update_topic, move |_msg| {
info!("Received cache update notification");
let routing_engine_clone = Arc::clone(&routing_engine_clone);
tokio::spawn(async move {
routing_engine_clone.refresh_cache().await;
});
ControlFlow::<()>::Continue
})
.unwrap();
});

// API service controller
let service_controller = ServiceController::new(account_service, routing_engine);
Expand Down Expand Up @@ -143,10 +165,10 @@ async fn run_indexer(config: Config) {
&token_price_provider,
);

match indexer.run::<LinearRegressionEstimator>().await {
Ok(_) => info!("Indexer Job Completed"),
Err(e) => error!("Indexer Job Failed: {}", e),
};
// match indexer.run::<LinearRegressionEstimator>().await {
// Ok(_) => info!("Indexer Job Completed"),
// Err(e) => error!("Indexer Job Failed: {}", e),
// };

let next_tick = l.next_tick_for_job(uuid).await;
match next_tick {
Expand Down
10 changes: 5 additions & 5 deletions crates/api/src/service_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ pub struct ServiceController {
}

impl ServiceController {
pub fn new(account_service: AccountAggregationService, routing_engine: RoutingEngine) -> Self {
Self {
account_service: Arc::new(account_service),
routing_engine: Arc::new(routing_engine),
}
pub fn new(
account_service: AccountAggregationService,
routing_engine: Arc<RoutingEngine>,
) -> Self {
Self { account_service: Arc::new(account_service), routing_engine }
}

pub fn router(self) -> Router {
Expand Down
123 changes: 86 additions & 37 deletions crates/routing-engine/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ use account_aggregation::service::AccountAggregationService;
use account_aggregation::types::Balance;
use config::config::BucketConfig;
use derive_more::Display;
use futures::stream::{self, StreamExt};
use log::{error, info};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use storage::RedisClient;
use tokio::sync::RwLock;

use crate::estimator::{Estimator, LinearRegressionEstimator};

Expand All @@ -29,17 +33,42 @@ pub struct Route {
#[derive(Debug)]
struct PathQuery(u32, u32, String, String);

/// Routing Engine
/// This struct is responsible for calculating the best cost path for a user
#[derive(Debug, Clone)]
pub struct RoutingEngine {
buckets: Vec<BucketConfig>,
aas_client: AccountAggregationService,
cache: Arc<HashMap<String, String>>, // (hash(bucket), hash(estimator_value)
cache: Arc<RwLock<HashMap<String, String>>>, // (hash(bucket), hash(estimator_value)
redis_client: RedisClient,
}

impl RoutingEngine {
pub fn new(aas_client: AccountAggregationService, buckets: Vec<BucketConfig>) -> Self {
let cache = Arc::new(HashMap::new());
pub fn new(
aas_client: AccountAggregationService,
buckets: Vec<BucketConfig>,
redis_client: RedisClient,
) -> Self {
let cache = Arc::new(RwLock::new(HashMap::new()));

Self { aas_client, cache, buckets, redis_client }
}

Self { aas_client, cache, buckets }
pub async fn refresh_cache(&self) {
match self.redis_client.get_all_key_values().await {
Ok(kv_pairs) => {
info!("Refreshing cache from Redis.");
let mut cache = self.cache.write().await;
cache.clear();
for (key, value) in kv_pairs.iter() {
cache.insert(key.clone(), value.clone());
}
info!("Cache refreshed with latest data from Redis.");
}
Err(e) => {
error!("Failed to refresh cache from Redis: {}", e);
}
}
}

/// Get the best cost path for a user
Expand All @@ -66,21 +95,23 @@ impl RoutingEngine {
// Sort direct assets by A^x / C^y, here x=2 and y=1
let x = 2.0;
let y = 1.0;
let mut sorted_assets: Vec<(&&Balance, f64)> = direct_assets
.iter()
.map(|balance| {
let fee_cost = self.get_cached_data(
balance.amount_in_usd, // todo: edit
PathQuery(
balance.chain_id,
to_chain,
balance.token.to_string(),
to_token.to_string(),
),
);
let mut sorted_assets: Vec<(&&Balance, f64)> = stream::iter(direct_assets.iter())
.then(|balance| async move {
let fee_cost = self
.get_cached_data(
balance.amount_in_usd,
PathQuery(
balance.chain_id,
to_chain,
balance.token.to_string(),
to_token.to_string(),
),
)
.await;
(balance, fee_cost)
})
.collect();
.collect()
.await;

sorted_assets.sort_by(|a, b| {
let cost_a = (a.0.amount.powf(x)) / (a.1.powf(y));
Expand Down Expand Up @@ -116,21 +147,23 @@ impl RoutingEngine {
if total_amount_needed > 0.0 {
let swap_assets: Vec<&Balance> =
user_balances.iter().filter(|balance| balance.token != to_token).collect();
let mut sorted_assets: Vec<(&&Balance, f64)> = swap_assets
.iter()
.map(|balance| {
let fee_cost = self.get_cached_data(
balance.amount_in_usd,
PathQuery(
balance.chain_id,
to_chain,
balance.token.clone(),
to_token.to_string(),
),
);
let mut sorted_assets: Vec<(&&Balance, f64)> = stream::iter(swap_assets.iter())
.then(|balance| async move {
let fee_cost = self
.get_cached_data(
balance.amount_in_usd,
PathQuery(
balance.chain_id,
to_chain,
balance.token.clone(),
to_token.to_string(),
),
)
.await;
(balance, fee_cost)
})
.collect();
.collect()
.await;

sorted_assets.sort_by(|a, b| {
let cost_a = (a.0.amount.powf(x)) / (a.1.powf(y));
Expand Down Expand Up @@ -167,8 +200,7 @@ impl RoutingEngine {
selected_assets
}

fn get_cached_data(&self, target_amount: f64, path: PathQuery) -> f64 {
// filter the bucket of (chain, token) and sort with token_amount_from_usd
async fn get_cached_data(&self, target_amount: f64, path: PathQuery) -> f64 {
let mut buckets_array: Vec<BucketConfig> = self
.buckets
.clone()
Expand All @@ -190,11 +222,12 @@ impl RoutingEngine {
})
.unwrap();

// todo: should throw error if not found in bucket range or unwrap or with last bucket
let mut s = DefaultHasher::new();
bucket.hash(&mut s);
let key = s.finish().to_string();
let value = self.cache.get(&key).unwrap();

let cache = self.cache.read().await;
let value = cache.get(&key).unwrap();
let estimator: Result<LinearRegressionEstimator, serde_json::Error> =
serde_json::from_str(value);

Expand Down Expand Up @@ -227,6 +260,7 @@ mod tests {
hash::{DefaultHasher, Hash, Hasher},
};
use storage::mongodb_provider::MongoDBProvider;
use tokio::sync::RwLock;

#[tokio::test]
async fn test_get_cached_data() {
Expand Down Expand Up @@ -285,14 +319,21 @@ mod tests {
"https://api.covalent.com".to_string(),
"my-api".to_string(),
);
let routing_engine = RoutingEngine { aas_client, buckets, cache: Arc::new(cache) };
let redis_client =
storage::RedisClient::build(&"redis://localhost:6379".to_string()).await.unwrap();
let routing_engine = RoutingEngine {
aas_client,
buckets,
cache: Arc::new(RwLock::new(cache)),
redis_client,
};

// Define the target amount and path query
let target_amount = 5.0;
let path_query = PathQuery(1, 2, "USDC".to_string(), "ETH".to_string());

// Call get_cached_data and assert the result
let result = routing_engine.get_cached_data(target_amount, path_query);
let result = routing_engine.get_cached_data(target_amount, path_query).await;
assert!(result > 0.0);
assert_eq!(result, dummy_estimator.estimate(target_amount));
}
Expand Down Expand Up @@ -359,7 +400,15 @@ mod tests {
let mut cache = HashMap::new();
cache.insert(key, serialized_estimator.clone());
cache.insert(key2, serialized_estimator);
let routing_engine = RoutingEngine { aas_client, buckets, cache: Arc::new(cache) };

let redis_client =
storage::RedisClient::build(&"redis://localhost:6379".to_string()).await.unwrap();
let routing_engine = RoutingEngine {
aas_client,
buckets,
cache: Arc::new(RwLock::new(cache)),
redis_client,
};

// should have USDT in bsc-mainnet > $0.5
let dummy_user_address = "0x00000ebe3fa7cb71aE471547C836E0cE0AE758c2";
Expand Down
Loading

0 comments on commit 68b8637

Please sign in to comment.