Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pub-sub cache update logic #7

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions bin/reflux/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;
use std::time::Duration;

use axum::http::Method;
use log::{error, info};
use log::info;
use tokio;
use tokio::signal;
use tokio_cron_scheduler::{Job, JobScheduler};
Expand All @@ -11,11 +11,10 @@ use tower_http::cors::{Any, CorsLayer};
use account_aggregation::service::AccountAggregationService;
use api::service_controller::ServiceController;
use config::Config;
use routing_engine::{BungeeClient, CoingeckoClient, Indexer};
use routing_engine::engine::RoutingEngine;
use routing_engine::estimator::LinearRegressionEstimator;
use routing_engine::{BungeeClient, CoingeckoClient, Indexer};
use storage::mongodb_provider::MongoDBProvider;
use storage::RedisClient;
use storage::{ControlFlow, MessageQueue, RedisClient};

#[tokio::main]
async fn main() {
Expand All @@ -29,6 +28,8 @@ async fn main() {
} else {
run_server(config).await;
}

info!("Exiting Reflux");
}

async fn run_server(config: Config) {
Expand Down Expand Up @@ -70,8 +71,29 @@ async fn run_server(config: Config) {
);

// Initialize routing engine
let buckets = config.buckets;
let routing_engine = RoutingEngine::new(account_service.clone(), buckets);
let buckets = config.buckets.clone();
let redis_client = RedisClient::build(&config.infra.redis_url)
.await
.expect("Failed to instantiate redis client");
let routing_engine =
Arc::new(RoutingEngine::new(account_service.clone(), buckets, redis_client.clone()));

// Subscribe to cache update messages
let cache_update_topic = config.indexer_config.indexer_update_topic.clone();
let routing_engine_clone = Arc::clone(&routing_engine);
tokio::spawn(async move {
let redis_client = redis_client.clone();
redis_client
.subscribe(&cache_update_topic, move |_msg| {
info!("Received cache update notification");
let routing_engine_clone = Arc::clone(&routing_engine_clone);
tokio::spawn(async move {
routing_engine_clone.refresh_cache().await;
});
ControlFlow::<()>::Continue
})
.unwrap();
});

// API service controller
let service_controller = ServiceController::new(account_service, routing_engine);
Expand Down Expand Up @@ -143,10 +165,10 @@ async fn run_indexer(config: Config) {
&token_price_provider,
);

match indexer.run::<LinearRegressionEstimator>().await {
Ok(_) => info!("Indexer Job Completed"),
Err(e) => error!("Indexer Job Failed: {}", e),
};
// match indexer.run::<LinearRegressionEstimator>().await {
// Ok(_) => info!("Indexer Job Completed"),
// Err(e) => error!("Indexer Job Failed: {}", e),
// };
Comment on lines +168 to +171
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be removed


let next_tick = l.next_tick_for_job(uuid).await;
match next_tick {
Expand Down
10 changes: 5 additions & 5 deletions crates/api/src/service_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ pub struct ServiceController {
}

impl ServiceController {
pub fn new(account_service: AccountAggregationService, routing_engine: RoutingEngine) -> Self {
Self {
account_service: Arc::new(account_service),
routing_engine: Arc::new(routing_engine),
}
pub fn new(
account_service: AccountAggregationService,
routing_engine: Arc<RoutingEngine>,
) -> Self {
Self { account_service: Arc::new(account_service), routing_engine }
}

pub fn router(self) -> Router {
Expand Down
123 changes: 86 additions & 37 deletions crates/routing-engine/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ use account_aggregation::service::AccountAggregationService;
use account_aggregation::types::Balance;
use config::config::BucketConfig;
use derive_more::Display;
use futures::stream::{self, StreamExt};
use log::{error, info};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use storage::RedisClient;
use tokio::sync::RwLock;

use crate::estimator::{Estimator, LinearRegressionEstimator};

Expand All @@ -29,17 +33,42 @@ pub struct Route {
#[derive(Debug)]
struct PathQuery(u32, u32, String, String);

/// Routing Engine
/// This struct is responsible for calculating the best cost path for a user
#[derive(Debug, Clone)]
pub struct RoutingEngine {
buckets: Vec<BucketConfig>,
aas_client: AccountAggregationService,
cache: Arc<HashMap<String, String>>, // (hash(bucket), hash(estimator_value)
cache: Arc<RwLock<HashMap<String, String>>>, // (hash(bucket), hash(estimator_value)
redis_client: RedisClient,
}

impl RoutingEngine {
pub fn new(aas_client: AccountAggregationService, buckets: Vec<BucketConfig>) -> Self {
let cache = Arc::new(HashMap::new());
pub fn new(
aas_client: AccountAggregationService,
buckets: Vec<BucketConfig>,
redis_client: RedisClient,
) -> Self {
let cache = Arc::new(RwLock::new(HashMap::new()));

Self { aas_client, cache, buckets, redis_client }
}

Self { aas_client, cache, buckets }
pub async fn refresh_cache(&self) {
match self.redis_client.get_all_key_values().await {
Ok(kv_pairs) => {
info!("Refreshing cache from Redis.");
let mut cache = self.cache.write().await;
cache.clear();
for (key, value) in kv_pairs.iter() {
cache.insert(key.clone(), value.clone());
}
info!("Cache refreshed with latest data from Redis.");
}
Err(e) => {
error!("Failed to refresh cache from Redis: {}", e);
}
}
}

/// Get the best cost path for a user
Expand All @@ -66,21 +95,23 @@ impl RoutingEngine {
// Sort direct assets by A^x / C^y, here x=2 and y=1
let x = 2.0;
let y = 1.0;
let mut sorted_assets: Vec<(&&Balance, f64)> = direct_assets
.iter()
.map(|balance| {
let fee_cost = self.get_cached_data(
balance.amount_in_usd, // todo: edit
PathQuery(
balance.chain_id,
to_chain,
balance.token.to_string(),
to_token.to_string(),
),
);
let mut sorted_assets: Vec<(&&Balance, f64)> = stream::iter(direct_assets.iter())
.then(|balance| async move {
let fee_cost = self
.get_cached_data(
balance.amount_in_usd,
PathQuery(
balance.chain_id,
to_chain,
balance.token.to_string(),
to_token.to_string(),
),
)
.await;
(balance, fee_cost)
})
.collect();
.collect()
.await;

sorted_assets.sort_by(|a, b| {
let cost_a = (a.0.amount.powf(x)) / (a.1.powf(y));
Expand Down Expand Up @@ -116,21 +147,23 @@ impl RoutingEngine {
if total_amount_needed > 0.0 {
let swap_assets: Vec<&Balance> =
user_balances.iter().filter(|balance| balance.token != to_token).collect();
let mut sorted_assets: Vec<(&&Balance, f64)> = swap_assets
.iter()
.map(|balance| {
let fee_cost = self.get_cached_data(
balance.amount_in_usd,
PathQuery(
balance.chain_id,
to_chain,
balance.token.clone(),
to_token.to_string(),
),
);
let mut sorted_assets: Vec<(&&Balance, f64)> = stream::iter(swap_assets.iter())
.then(|balance| async move {
let fee_cost = self
.get_cached_data(
balance.amount_in_usd,
PathQuery(
balance.chain_id,
to_chain,
balance.token.clone(),
to_token.to_string(),
),
)
.await;
(balance, fee_cost)
})
.collect();
.collect()
.await;

sorted_assets.sort_by(|a, b| {
let cost_a = (a.0.amount.powf(x)) / (a.1.powf(y));
Expand Down Expand Up @@ -167,8 +200,7 @@ impl RoutingEngine {
selected_assets
}

fn get_cached_data(&self, target_amount: f64, path: PathQuery) -> f64 {
// filter the bucket of (chain, token) and sort with token_amount_from_usd
async fn get_cached_data(&self, target_amount: f64, path: PathQuery) -> f64 {
let mut buckets_array: Vec<BucketConfig> = self
.buckets
.clone()
Expand All @@ -190,11 +222,12 @@ impl RoutingEngine {
})
.unwrap();

// todo: should throw error if not found in bucket range or unwrap or with last bucket
let mut s = DefaultHasher::new();
bucket.hash(&mut s);
let key = s.finish().to_string();
let value = self.cache.get(&key).unwrap();

let cache = self.cache.read().await;
let value = cache.get(&key).unwrap();
let estimator: Result<LinearRegressionEstimator, serde_json::Error> =
serde_json::from_str(value);

Expand Down Expand Up @@ -227,6 +260,7 @@ mod tests {
hash::{DefaultHasher, Hash, Hasher},
};
use storage::mongodb_provider::MongoDBProvider;
use tokio::sync::RwLock;

#[tokio::test]
async fn test_get_cached_data() {
Expand Down Expand Up @@ -285,14 +319,21 @@ mod tests {
"https://api.covalent.com".to_string(),
"my-api".to_string(),
);
let routing_engine = RoutingEngine { aas_client, buckets, cache: Arc::new(cache) };
let redis_client =
storage::RedisClient::build(&"redis://localhost:6379".to_string()).await.unwrap();
let routing_engine = RoutingEngine {
aas_client,
buckets,
cache: Arc::new(RwLock::new(cache)),
redis_client,
};

// Define the target amount and path query
let target_amount = 5.0;
let path_query = PathQuery(1, 2, "USDC".to_string(), "ETH".to_string());

// Call get_cached_data and assert the result
let result = routing_engine.get_cached_data(target_amount, path_query);
let result = routing_engine.get_cached_data(target_amount, path_query).await;
assert!(result > 0.0);
assert_eq!(result, dummy_estimator.estimate(target_amount));
}
Expand Down Expand Up @@ -359,7 +400,15 @@ mod tests {
let mut cache = HashMap::new();
cache.insert(key, serialized_estimator.clone());
cache.insert(key2, serialized_estimator);
let routing_engine = RoutingEngine { aas_client, buckets, cache: Arc::new(cache) };

let redis_client =
storage::RedisClient::build(&"redis://localhost:6379".to_string()).await.unwrap();
let routing_engine = RoutingEngine {
aas_client,
buckets,
cache: Arc::new(RwLock::new(cache)),
redis_client,
};

// should have USDT in bsc-mainnet > $0.5
let dummy_user_address = "0x00000ebe3fa7cb71aE471547C836E0cE0AE758c2";
Expand Down
Loading
Loading