From 5cef969f453248dd8633b4e2aa20453db3877388 Mon Sep 17 00:00:00 2001 From: gustavodiasag Date: Thu, 20 Jun 2024 18:44:36 -0300 Subject: [PATCH] refac(balaner): define balancer endpoint and round-robin routing --- Cargo.lock | 1 + ctl/Cargo.toml | 1 + ctl/src/balancer/mod.rs | 128 +++++++++++++++++++++++++++++++--------- ctl/src/main.rs | 14 ++++- 4 files changed, 113 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 13d1484..3b59a98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -387,6 +387,7 @@ dependencies = [ "chrono", "clap", "eyre", + "hyper-util", "proto", "tokio", "tracing", diff --git a/ctl/Cargo.toml b/ctl/Cargo.toml index b969738..e336a9b 100644 --- a/ctl/Cargo.toml +++ b/ctl/Cargo.toml @@ -14,6 +14,7 @@ axum.workspace = true chrono.workspace = true clap.workspace = true eyre.workspace = true +hyper-util.workspace = true tokio.workspace = true tracing.workspace = true uuid.workspace = true diff --git a/ctl/src/balancer/mod.rs b/ctl/src/balancer/mod.rs index 7ec5fce..e2452d0 100644 --- a/ctl/src/balancer/mod.rs +++ b/ctl/src/balancer/mod.rs @@ -1,44 +1,116 @@ -#![allow(dead_code)] - use std::{ collections::HashMap, - net::SocketAddr + net::IpAddr, + str::FromStr as _, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, + time::Duration, }; -use axum::extract::Request; -use proto::common::instance::InstanceId; -use proto::common::service::ServiceId; +use axum::{ + body::Body, + extract::{Request, State}, + http::{ + uri::{Authority, Scheme}, + HeaderValue, StatusCode, Uri, + }, + response::IntoResponse, +}; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; +use proto::{ + common::{instance::InstanceId, service::ServiceId}, + well_known::PROXY_INSTANCE_HEADER_NAME, +}; +use utils::http::{self, OptionExt as _, ResultExt as _}; -struct Balancer { - strategy: S, - addrs: HashMap> +pub struct InstanceBag { + pub instances: Vec<(InstanceId, IpAddr)>, + pub count: AtomicUsize, } -trait Strategy { - async fn get_server(&self, _req: &Request) -> (InstanceId, SocketAddr); +#[derive(Clone)] +pub struct Balancer { + pub addrs: Arc>>, } -impl Balancer -where - S: Strategy -{ - pub async fn run() { - todo!(); +impl Balancer { + pub fn new() -> Self { + Balancer { + addrs: Arc::new(Mutex::new(HashMap::default())), + } } - async fn next_server(&self, _req: &Request) -> (InstanceId, SocketAddr) { - todo!(); + pub fn next(&self, service: &ServiceId) -> (InstanceId, IpAddr) { + let map = self.addrs.lock().unwrap(); + let bag = map.get(service).unwrap(); + let count = bag.count.fetch_add(1, Ordering::Relaxed); + bag.instances[count % bag.instances.len()] } +} - pub async fn drop_instance(&self, _id: InstanceId) { - todo!(); - } - - pub async fn add_instance(&self, _id: InstanceId, _at: SocketAddr) { - todo!(); - } +#[derive(Clone)] +pub struct BalancerState { + pub balancer: Balancer, + pub client: Client, +} - pub async fn swap_instance(_old_id: InstanceId, _new_id: InstanceId, _new_at: SocketAddr) { - todo!(); +impl BalancerState { + #[must_use] + pub fn new() -> Self { + BalancerState { + balancer: Balancer::new(), + client: { + let mut connector = HttpConnector::new(); + connector.set_keepalive(Some(Duration::from_secs(60))); + connector.set_nodelay(true); + Client::builder(TokioExecutor::new()).build::<_, Body>(connector) + }, + } } } + +#[axum::debug_handler] +pub async fn proxy( + State(state): State, + mut req: Request, +) -> http::Result { + let service = extract_service_id(&mut req)?; + + let (instance, server) = state.balancer.next(&service); + + *req.uri_mut() = { + let uri = req.uri(); + let mut parts = uri.clone().into_parts(); + parts.authority = Authority::from_str(&format!("{server}")).ok(); + parts.scheme = Some(Scheme::HTTP); + Uri::from_parts(parts).unwrap() + }; + + req.headers_mut().insert( + PROXY_INSTANCE_HEADER_NAME, + HeaderValue::from_str(&format!("{instance}")).unwrap(), + ); + + state + .client + .request(req) + .await + .http_error(StatusCode::BAD_GATEWAY, "bad gateway") +} + +fn extract_service_id(req: &mut Request) -> http::Result { + let inner = req + .headers() + .get("host") + .unwrap() + .to_str() + .ok() + .and_then(|s| s.parse().ok()) + .or_http_error(StatusCode::BAD_REQUEST, "invalid service name")?; + Ok(ServiceId(inner)) +} diff --git a/ctl/src/main.rs b/ctl/src/main.rs index fb2a223..049d985 100644 --- a/ctl/src/main.rs +++ b/ctl/src/main.rs @@ -3,18 +3,19 @@ use std::{ sync::Arc, }; +use axum::handler::Handler; use clap::Parser; use proto::well_known::{CTL_BALANCER_PORT, CTL_HTTP_PORT}; use tokio::task::JoinSet; use tracing::info; use utils::server::mk_listener; -use crate::{args::CtlArgs, discovery::Discovery, http::HttpState}; +use crate::{args::CtlArgs, balancer::BalancerState, discovery::Discovery, http::HttpState}; mod args; +mod balancer; mod discovery; mod http; -mod balancer; const ANY_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); @@ -25,7 +26,7 @@ async fn main() -> eyre::Result<()> { let args = Arc::new(CtlArgs::parse()); info!(?args, "started ctl"); - let _balancer_listener = mk_listener(ANY_IP, CTL_BALANCER_PORT).await?; + let balancer_listener = mk_listener(ANY_IP, CTL_BALANCER_PORT).await?; let http_listener = mk_listener(ANY_IP, CTL_HTTP_PORT).await?; let mut bag = JoinSet::new(); @@ -35,6 +36,13 @@ async fn main() -> eyre::Result<()> { discovery.run().await; }); + let balancer_state = BalancerState::new(); + bag.spawn(async move { + let app = balancer::proxy.with_state(balancer_state); + info!("balancer http listening at {ANY_IP}:{CTL_BALANCER_PORT}"); + axum::serve(balancer_listener, app).await.unwrap(); + }); + bag.spawn(async move { let state = HttpState { discovery: discovery_handle.clone(),