diff --git a/ctl/src/balancer/mod.rs b/ctl/src/balancer/mod.rs index 66d06ff..8d9d68c 100644 --- a/ctl/src/balancer/mod.rs +++ b/ctl/src/balancer/mod.rs @@ -28,21 +28,34 @@ use proto::{ }; use utils::http::{self, OptionExt as _, ResultExt as _}; +#[derive(Default)] pub struct InstanceBag { pub instances: Vec<(InstanceId, IpAddr)>, pub count: AtomicUsize, } #[derive(Clone)] -pub struct Balancer { +pub struct BalancerState { pub addrs: Arc>>, + pub client: Client, } -impl Balancer { - pub fn new() -> Self { - Balancer { - addrs: Arc::new(Mutex::new(HashMap::default())), - } +impl BalancerState { + #[must_use] + pub fn new() -> (Self, BalancerHandle) { + let map = Arc::new(Mutex::new(HashMap::default())); + ( + BalancerState { + addrs: map.clone(), + client: { + let mut connector = HttpConnector::new(); + connector.set_keepalive(Some(Duration::from_secs(60))); + connector.set_nodelay(true); + Client::builder(TokioExecutor::new()).build::<_, Body>(connector) + }, + }, + BalancerHandle { addrs: map }, + ) } pub fn next(&self, service: &ServiceId) -> (InstanceId, IpAddr) { @@ -53,35 +66,37 @@ impl Balancer { } } -#[derive(Clone)] -pub struct BalancerState { - pub balancer: Balancer, - pub client: Client, +pub struct BalancerHandle { + pub addrs: Arc>>, } -impl BalancerState { - #[must_use] - pub fn new() -> Self { - BalancerState { - balancer: Balancer::new(), - client: { - let mut connector = HttpConnector::new(); - connector.set_keepalive(Some(Duration::from_secs(60))); - connector.set_nodelay(true); - Client::builder(TokioExecutor::new()).build::<_, Body>(connector) - }, - } +impl BalancerHandle { + #[allow(dead_code)] + pub fn add_instance(&mut self, id: ServiceId, at: (InstanceId, IpAddr)) { + let mut map = self.addrs.lock().unwrap(); + let bag = map.entry(id).or_default(); + bag.instances.push(at); + } + + #[allow(dead_code)] + pub fn drop_instance(&mut self, id: &ServiceId, at: (InstanceId, IpAddr)) { + let mut map = self.addrs.lock().unwrap(); + let Some(bag) = map.get_mut(id) else { + return; + }; + bag.instances + .retain(|(inst, addr)| inst == &at.0 && addr == &at.1); } } pub async fn proxy( ConnectInfo(addr): ConnectInfo, - State(state): State, + State(balancer): State, mut req: Request, ) -> http::Result { let service = extract_service_id(&mut req)?; - let (instance, server_addr) = state.balancer.next(&service); + let (instance, server_addr) = balancer.next(&service); *req.uri_mut() = { let uri = req.uri(); @@ -100,7 +115,7 @@ pub async fn proxy( HeaderValue::from_str(&addr.ip().to_string()).unwrap(), ); - state + balancer .client .request(req) .await diff --git a/ctl/src/main.rs b/ctl/src/main.rs index 915bf95..6bd0962 100644 --- a/ctl/src/main.rs +++ b/ctl/src/main.rs @@ -45,10 +45,10 @@ async fn main() -> eyre::Result<()> { worker_mgr.run().await; }); - let balancer_state = BalancerState::new(); + let (balancer, _balancer_handle) = BalancerState::new(); bag.spawn(async move { let app = balancer::proxy - .with_state(balancer_state) + .with_state(balancer) .into_make_service_with_connect_info::(); info!("balancer http listening at {ANY_IP}:{CTL_BALANCER_PORT}"); axum::serve(balancer_listener, app).await.unwrap();