diff --git a/src/client.rs b/src/client.rs index 1149bec0..45d3d859 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,5 +1,5 @@ use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType}; -use crate::config_watcher::ServiceChange; +use crate::config_watcher::{ClientServiceChange, ConfigChange}; use crate::helper::udp_connect; use crate::protocol::Hello::{self, *}; use crate::protocol::{ @@ -31,7 +31,7 @@ use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE pub async fn run_client( config: Config, shutdown_rx: broadcast::Receiver, - service_rx: mpsc::Receiver, + update_rx: mpsc::Receiver, ) -> Result<()> { let config = config.client.ok_or_else(|| { anyhow!( @@ -42,13 +42,13 @@ pub async fn run_client( match config.transport.transport_type { TransportType::Tcp => { let mut client = Client::::from(config).await?; - client.run(shutdown_rx, service_rx).await + client.run(shutdown_rx, update_rx).await } TransportType::Tls => { #[cfg(feature = "tls")] { let mut client = Client::::from(config).await?; - client.run(shutdown_rx, service_rx).await + client.run(shutdown_rx, update_rx).await } #[cfg(not(feature = "tls"))] crate::helper::feature_not_compile("tls") @@ -57,7 +57,7 @@ pub async fn run_client( #[cfg(feature = "noise")] { let mut client = Client::::from(config).await?; - client.run(shutdown_rx, service_rx).await + client.run(shutdown_rx, update_rx).await } #[cfg(not(feature = "noise"))] crate::helper::feature_not_compile("noise") @@ -91,7 +91,7 @@ impl Client { async fn run( &mut self, mut shutdown_rx: broadcast::Receiver, - mut service_rx: mpsc::Receiver, + mut update_rx: mpsc::Receiver, ) -> Result<()> { for (name, config) in &self.config.services { // Create a control channel for each service defined @@ -116,24 +116,9 @@ impl Client { } break; }, - e = service_rx.recv() => { + e = update_rx.recv() => { if let Some(e) = e { - match e { - ServiceChange::ClientAdd(s)=> { - let name = s.name.clone(); - let handle = ControlChannelHandle::new( - s, - self.config.remote_addr.clone(), - self.transport.clone(), - self.config.heartbeat_timeout - ); - let _ = self.service_handles.insert(name, handle); - }, - ServiceChange::ClientDelete(s)=> { - let _ = self.service_handles.remove(&s); - }, - _ => () - } + self.handle_hot_reload(e).await; } } } @@ -146,6 +131,27 @@ impl Client { Ok(()) } + + async fn handle_hot_reload(&mut self, e: ConfigChange) { + match e { + ConfigChange::ClientChange(client_change) => match client_change { + ClientServiceChange::Add(cfg) => { + let name = cfg.name.clone(); + let handle = ControlChannelHandle::new( + cfg, + self.config.remote_addr.clone(), + self.transport.clone(), + self.config.heartbeat_timeout, + ); + let _ = self.service_handles.insert(name, handle); + } + ClientServiceChange::Delete(s) => { + let _ = self.service_handles.remove(&s); + } + }, + ignored => warn!("Ignored {:?} since running as a client", ignored), + } + } } struct RunDataChannelArgs { diff --git a/src/config_watcher.rs b/src/config_watcher.rs index 25423cff..993fdcce 100644 --- a/src/config_watcher.rs +++ b/src/config_watcher.rs @@ -14,36 +14,30 @@ use tracing::{error, info, instrument}; #[cfg(feature = "notify")] use notify::{EventKind, RecursiveMode, Watcher}; -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum ConfigChange { General(Box), // Trigger a full restart - ServiceChange(ServiceChange), + ServerChange(ServerServiceChange), + ClientChange(ClientServiceChange), } -#[derive(Debug, PartialEq, Eq)] -pub enum ServiceChange { - ClientAdd(ClientServiceConfig), - ClientDelete(String), - ServerAdd(ServerServiceConfig), - ServerDelete(String), +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum ClientServiceChange { + Add(ClientServiceConfig), + Delete(String), } -impl From for ServiceChange { - fn from(c: ClientServiceConfig) -> Self { - ServiceChange::ClientAdd(c) - } -} - -impl From for ServiceChange { - fn from(c: ServerServiceConfig) -> Self { - ServiceChange::ServerAdd(c) - } +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum ServerServiceChange { + Add(ServerServiceConfig), + Delete(String), } trait InstanceConfig: Clone { - type ServiceConfig: Into + PartialEq + Clone; + type ServiceConfig: PartialEq + Eq + Clone; fn equal_without_service(&self, rhs: &Self) -> bool; - fn to_service_change_delete(s: String) -> ServiceChange; + fn service_delete_change(s: String) -> ConfigChange; + fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange; fn get_services(&self) -> &HashMap; } @@ -62,8 +56,11 @@ impl InstanceConfig for ServerConfig { left == right } - fn to_service_change_delete(s: String) -> ServiceChange { - ServiceChange::ServerDelete(s) + fn service_delete_change(s: String) -> ConfigChange { + ConfigChange::ServerChange(ServerServiceChange::Delete(s)) + } + fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange { + ConfigChange::ServerChange(ServerServiceChange::Add(cfg)) } fn get_services(&self) -> &HashMap { &self.services @@ -85,8 +82,11 @@ impl InstanceConfig for ClientConfig { left == right } - fn to_service_change_delete(s: String) -> ServiceChange { - ServiceChange::ClientDelete(s) + fn service_delete_change(s: String) -> ConfigChange { + ConfigChange::ClientChange(ClientServiceChange::Delete(s)) + } + fn service_add_change(cfg: Self::ServiceConfig) -> ConfigChange { + ConfigChange::ClientChange(ClientServiceChange::Add(cfg)) } fn get_services(&self) -> &HashMap { &self.services @@ -180,8 +180,9 @@ async fn config_watcher( } }; - for event in calculate_events(&old, &new) { - event_tx.send(event)?; + let events = calculate_events(&old, &new).into_iter().flatten(); + for event in events { + event_tx.send(event)?; } old = new; @@ -198,42 +199,40 @@ async fn config_watcher( Ok(()) } -fn calculate_events(old: &Config, new: &Config) -> Vec { +fn calculate_events(old: &Config, new: &Config) -> Option> { if old == new { - return vec![]; + return None; + } + + if (old.server.is_some() != new.server.is_some()) + || (old.client.is_some() != new.client.is_some()) + { + return Some(vec![ConfigChange::General(Box::new(new.clone()))]); } let mut ret = vec![]; if old.server != new.server { - if old.server.is_some() != new.server.is_some() { - return vec![ConfigChange::General(Box::new(new.clone()))]; - } else { - match calculate_instance_config_events( - old.server.as_ref().unwrap(), - new.server.as_ref().unwrap(), - ) { - Some(mut v) => ret.append(&mut v), - None => return vec![ConfigChange::General(Box::new(new.clone()))], - } + match calculate_instance_config_events( + old.server.as_ref().unwrap(), + new.server.as_ref().unwrap(), + ) { + Some(mut v) => ret.append(&mut v), + None => return Some(vec![ConfigChange::General(Box::new(new.clone()))]), } } if old.client != new.client { - if old.client.is_some() != new.client.is_some() { - return vec![ConfigChange::General(Box::new(new.clone()))]; - } else { - match calculate_instance_config_events( - old.client.as_ref().unwrap(), - new.client.as_ref().unwrap(), - ) { - Some(mut v) => ret.append(&mut v), - None => return vec![ConfigChange::General(Box::new(new.clone()))], - } + match calculate_instance_config_events( + old.client.as_ref().unwrap(), + new.client.as_ref().unwrap(), + ) { + Some(mut v) => ret.append(&mut v), + None => return Some(vec![ConfigChange::General(Box::new(new.clone()))]), } } - ret + Some(ret) } // None indicates a General change needed @@ -248,31 +247,17 @@ fn calculate_instance_config_events( let old = old.get_services(); let new = new.get_services(); - let mut v = vec![]; - v.append(&mut calculate_service_delete_events::(old, new)); - v.append(&mut calculate_service_add_events(old, new)); - - Some(v.into_iter().map(ConfigChange::ServiceChange).collect()) -} - -fn calculate_service_delete_events( - old: &HashMap, - new: &HashMap, -) -> Vec { - old.keys() + let deletions = old + .keys() .filter(|&name| new.get(name).is_none()) - .map(|x| T::to_service_change_delete(x.to_owned())) - .collect() -} + .map(|x| T::service_delete_change(x.to_owned())); -fn calculate_service_add_events>( - old: &HashMap, - new: &HashMap, -) -> Vec { - new.iter() + let addition = new + .iter() .filter(|(name, c)| old.get(*name) != Some(*c)) - .map(|(_, c)| c.clone().into()) - .collect() + .map(|(_, c)| T::service_add_change(c.clone())); + + Some(deletions.chain(addition).collect()) } #[cfg(test)] @@ -378,23 +363,23 @@ mod test { let mut expected = [ vec![ConfigChange::General(Box::new(tests[0].new.clone()))], vec![ConfigChange::General(Box::new(tests[1].new.clone()))], - vec![ConfigChange::ServiceChange(ServiceChange::ServerAdd( + vec![ConfigChange::ServerChange(ServerServiceChange::Add( Default::default(), ))], - vec![ConfigChange::ServiceChange(ServiceChange::ServerDelete( + vec![ConfigChange::ServerChange(ServerServiceChange::Delete( String::from("foo"), ))], vec![ - ConfigChange::ServiceChange(ServiceChange::ServerDelete(String::from("foo1"))), - ConfigChange::ServiceChange(ServiceChange::ServerAdd( + ConfigChange::ServerChange(ServerServiceChange::Delete(String::from("foo1"))), + ConfigChange::ServerChange(ServerServiceChange::Add( tests[4].new.server.as_ref().unwrap().services["bar1"].clone(), )), - ConfigChange::ServiceChange(ServiceChange::ClientDelete(String::from("foo1"))), - ConfigChange::ServiceChange(ServiceChange::ClientDelete(String::from("foo2"))), - ConfigChange::ServiceChange(ServiceChange::ClientAdd( + ConfigChange::ClientChange(ClientServiceChange::Delete(String::from("foo1"))), + ConfigChange::ClientChange(ClientServiceChange::Delete(String::from("foo2"))), + ConfigChange::ClientChange(ClientServiceChange::Add( tests[4].new.client.as_ref().unwrap().services["bar1"].clone(), )), - ConfigChange::ServiceChange(ServiceChange::ClientAdd( + ConfigChange::ClientChange(ClientServiceChange::Add( tests[4].new.client.as_ref().unwrap().services["bar2"].clone(), )), ], @@ -403,16 +388,18 @@ mod test { assert_eq!(tests.len(), expected.len()); for i in 0..tests.len() { - let mut actual = calculate_events(&tests[i].old, &tests[i].new); + let mut actual = calculate_events(&tests[i].old, &tests[i].new).unwrap(); let get_key = |x: &ConfigChange| -> String { match x { ConfigChange::General(_) => String::from("g"), - ConfigChange::ServiceChange(sc) => match sc { - ServiceChange::ClientAdd(c) => "c_add_".to_owned() + &c.name, - ServiceChange::ClientDelete(s) => "c_del_".to_owned() + s, - ServiceChange::ServerAdd(c) => "s_add_".to_owned() + &c.name, - ServiceChange::ServerDelete(s) => "s_del_".to_owned() + s, + ConfigChange::ServerChange(sc) => match sc { + ServerServiceChange::Add(c) => "s_add_".to_owned() + &c.name, + ServerServiceChange::Delete(s) => "s_del_".to_owned() + s, + }, + ConfigChange::ClientChange(sc) => match sc { + ClientServiceChange::Add(c) => "c_add_".to_owned() + &c.name, + ClientServiceChange::Delete(s) => "c_del_".to_owned() + s, }, } }; @@ -422,5 +409,20 @@ mod test { assert_eq!(actual, expected[i]); } + + // No changes + assert_eq!( + calculate_events( + &Config { + server: Default::default(), + client: None, + }, + &Config { + server: Default::default(), + client: None, + }, + ), + None + ); } } diff --git a/src/lib.rs b/src/lib.rs index c31da23b..7fb2fa6d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,6 @@ mod transport; pub use cli::Cli; use cli::KeypairType; pub use config::Config; -use config_watcher::ServiceChange; pub use constants::UDP_BUFFER_SIZE; use anyhow::Result; @@ -76,7 +75,7 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() let (shutdown_tx, _) = broadcast::channel(1); // (The join handle of the last instance, The service update channel sender) - let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender)> = None; + let mut last_instance: Option<(tokio::task::JoinHandle<_>, mpsc::Sender)> = None; while let Some(e) = cfg_watcher.event_rx.recv().await { match e { @@ -101,10 +100,10 @@ pub async fn run(args: Cli, shutdown_rx: broadcast::Receiver) -> Result<() service_update_tx, )); } - ConfigChange::ServiceChange(service_event) => { - info!("Service change detcted. {:?}", service_event); + ev => { + info!("Service change detected. {:?}", ev); if let Some((_, service_update_tx)) = &last_instance { - let _ = service_update_tx.send(service_event).await; + let _ = service_update_tx.send(ev).await; } } } @@ -119,7 +118,7 @@ async fn run_instance( config: Config, args: Cli, shutdown_rx: broadcast::Receiver, - service_update: mpsc::Receiver, + service_update: mpsc::Receiver, ) { let ret: Result<()> = match determine_run_mode(&config, &args) { RunMode::Undetermine => panic!("Cannot determine running as a server or a client"), diff --git a/src/server.rs b/src/server.rs index 4c086402..6ad91ee6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,5 @@ use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType}; -use crate::config_watcher::ServiceChange; +use crate::config_watcher::{ConfigChange, ServerServiceChange}; use crate::constants::{listen_backoff, UDP_BUFFER_SIZE}; use crate::helper::retry_notify_with_deadline; use crate::multi_map::MultiMap; @@ -40,7 +40,7 @@ const HANDSHAKE_TIMEOUT: u64 = 5; // Timeout for transport handshake pub async fn run_server( config: Config, shutdown_rx: broadcast::Receiver, - service_rx: mpsc::Receiver, + update_rx: mpsc::Receiver, ) -> Result<()> { let config = match config.server { Some(config) => config, @@ -52,13 +52,13 @@ pub async fn run_server( match config.transport.transport_type { TransportType::Tcp => { let mut server = Server::::from(config).await?; - server.run(shutdown_rx, service_rx).await?; + server.run(shutdown_rx, update_rx).await?; } TransportType::Tls => { #[cfg(feature = "tls")] { let mut server = Server::::from(config).await?; - server.run(shutdown_rx, service_rx).await?; + server.run(shutdown_rx, update_rx).await?; } #[cfg(not(feature = "tls"))] crate::helper::feature_not_compile("tls") @@ -67,7 +67,7 @@ pub async fn run_server( #[cfg(feature = "noise")] { let mut server = Server::::from(config).await?; - server.run(shutdown_rx, service_rx).await?; + server.run(shutdown_rx, update_rx).await?; } #[cfg(not(feature = "noise"))] crate::helper::feature_not_compile("noise") @@ -124,7 +124,7 @@ impl Server { pub async fn run( &mut self, mut shutdown_rx: broadcast::Receiver, - mut service_rx: mpsc::Receiver, + mut update_rx: mpsc::Receiver, ) -> Result<()> { // Listen at `server.bind_addr` let l = self @@ -198,7 +198,7 @@ impl Server { info!("Shuting down gracefully..."); break; }, - e = service_rx.recv() => { + e = update_rx.recv() => { if let Some(e) = e { self.handle_hot_reload(e).await; } @@ -211,24 +211,26 @@ impl Server { Ok(()) } - async fn handle_hot_reload(&mut self, e: ServiceChange) { + async fn handle_hot_reload(&mut self, e: ConfigChange) { match e { - ServiceChange::ServerAdd(s) => { - let hash = protocol::digest(s.name.as_bytes()); - let mut wg = self.services.write().await; - let _ = wg.insert(hash, s); - - let mut wg = self.control_channels.write().await; - let _ = wg.remove1(&hash); - } - ServiceChange::ServerDelete(s) => { - let hash = protocol::digest(s.as_bytes()); - let _ = self.services.write().await.remove(&hash); + ConfigChange::ServerChange(server_change) => match server_change { + ServerServiceChange::Add(cfg) => { + let hash = protocol::digest(cfg.name.as_bytes()); + let mut wg = self.services.write().await; + let _ = wg.insert(hash, cfg); + + let mut wg = self.control_channels.write().await; + let _ = wg.remove1(&hash); + } + ServerServiceChange::Delete(s) => { + let hash = protocol::digest(s.as_bytes()); + let _ = self.services.write().await.remove(&hash); - let mut wg = self.control_channels.write().await; - let _ = wg.remove1(&hash); - } - _ => (), + let mut wg = self.control_channels.write().await; + let _ = wg.remove1(&hash); + } + }, + ignored => warn!("Ignored {:?} since running as a server", ignored), } } }