diff --git a/src/bin/roughenough-server.rs b/src/bin/roughenough-server.rs index a6d0cf5..15a3d9e 100644 --- a/src/bin/roughenough-server.rs +++ b/src/bin/roughenough-server.rs @@ -71,12 +71,11 @@ fn set_ctrlc_handler() { fn display_config(server: &Server, cfg: &dyn ServerConfig) { info!("Processing thread : {}", server.thread_name()); + info!("Number of workers : {}", cfg.num_workers()); info!("Long-term public key : {}", server.get_public_key()); info!("Max response batch size : {}", cfg.batch_size()); - info!( - "Status updates every : {} seconds", - cfg.status_interval().as_secs() - ); + info!("Status updates every : {} seconds", cfg.status_interval().as_secs()); + info!( "Server listening on : {}:{}", cfg.interface(), @@ -104,6 +103,7 @@ fn display_config(server: &Server, cfg: &dyn ServerConfig) { } else { info!("Deliberate response errors : disabled"); } + } pub fn main() { @@ -131,18 +131,18 @@ pub fn main() { Ok(cfg) => Arc::new(Mutex::new(cfg)), }; - let sock_addr = config.lock().unwrap().udp_socket_addr().expect("udp sock addr"); let socket = { + let sock_addr = config.lock().unwrap().udp_socket_addr().expect("udp sock addr"); let sock = UdpSocket::bind(&sock_addr).expect("failed to bind to socket"); Arc::new(sock) }; set_ctrlc_handler(); - // TODO(stuart) pull TCP healthcheck out of worker threads + // TODO(stuart) move TCP healthcheck out of worker threads as it currently conflicts let mut thread_handles = Vec::new(); - for i in 0 .. 4 { + for i in 0 .. config.lock().unwrap().num_workers() { let cfg = config.clone(); let sock = socket.try_clone().unwrap(); let thrd = thread::Builder::new() diff --git a/src/config/environment.rs b/src/config/environment.rs index 908e447..fed17fd 100644 --- a/src/config/environment.rs +++ b/src/config/environment.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::env; +use std::{env, thread}; use std::time::Duration; use data_encoding::{Encoding, HEXLOWER_PERMISSIVE}; @@ -39,6 +39,7 @@ const HEX: Encoding = HEXLOWER_PERMISSIVE; /// health_check_port | `ROUGHENOUGH_HEALTH_CHECK_PORT` /// client_stats | `ROUGHENOUGH_CLIENT_STATS` /// fault_percentage | `ROUGHENOUGH_FAULT_PERCENTAGE` +/// num_workers | `ROUGHENOUGH_NUM_WORKERS` /// pub struct EnvironmentConfig { port: u16, @@ -50,6 +51,7 @@ pub struct EnvironmentConfig { health_check_port: Option, client_stats: bool, fault_percentage: u8, + num_workers: usize, } const ROUGHENOUGH_PORT: &str = "ROUGHENOUGH_PORT"; @@ -61,6 +63,7 @@ const ROUGHENOUGH_KMS_PROTECTION: &str = "ROUGHENOUGH_KMS_PROTECTION"; const ROUGHENOUGH_HEALTH_CHECK_PORT: &str = "ROUGHENOUGH_HEALTH_CHECK_PORT"; const ROUGHENOUGH_CLIENT_STATS: &str = "ROUGHENOUGH_CLIENT_STATS"; const ROUGHENOUGH_FAULT_PERCENTAGE: &str = "ROUGHENOUGH_FAULT_PERCENTAGE"; +const ROUGHENOUGH_NUM_WORKERS: &str = "ROUGHENOUGH_NUM_WORKERS:"; impl EnvironmentConfig { pub fn new() -> Result { @@ -74,6 +77,7 @@ impl EnvironmentConfig { health_check_port: None, client_stats: false, fault_percentage: 0, + num_workers: thread::available_parallelism().unwrap().get(), }; if let Ok(port) = env::var(ROUGHENOUGH_PORT) { @@ -132,6 +136,12 @@ impl EnvironmentConfig { .unwrap_or_else(|_| panic!("invalid fault_percentage: {}", fault_percentage)); }; + if let Ok(num_workers) = env::var(ROUGHENOUGH_NUM_WORKERS) { + cfg.num_workers = num_workers + .parse() + .unwrap_or_else(|_| panic!("invalid num_workers: {}", num_workers)); + }; + Ok(cfg) } } @@ -172,4 +182,8 @@ impl ServerConfig for EnvironmentConfig { fn fault_percentage(&self) -> u8 { self.fault_percentage } + + fn num_workers(&self) -> usize { + self.num_workers + } } diff --git a/src/config/file.rs b/src/config/file.rs index d0f388b..5e31902 100644 --- a/src/config/file.rs +++ b/src/config/file.rs @@ -14,6 +14,7 @@ use std::fs::File; use std::io::Read; +use std::thread; use std::time::Duration; use data_encoding::{Encoding, HEXLOWER_PERMISSIVE}; @@ -48,6 +49,7 @@ pub struct FileConfig { health_check_port: Option, client_stats: bool, fault_percentage: u8, + num_workers: usize, } impl FileConfig { @@ -80,6 +82,7 @@ impl FileConfig { health_check_port: None, client_stats: false, fault_percentage: 0, + num_workers: thread::available_parallelism().unwrap().get(), }; for (key, value) in cfg[0].as_hash().unwrap() { @@ -116,6 +119,10 @@ impl FileConfig { let val = value.as_i64().unwrap() as u8; config.fault_percentage = val; } + "num_workers" => { + let val = value.as_i64().unwrap() as usize; + config.num_workers = val; + } unknown => { return Err(Error::InvalidConfiguration(format!( "unknown config key: {}", @@ -165,4 +172,8 @@ impl ServerConfig for FileConfig { fn fault_percentage(&self) -> u8 { self.fault_percentage } + + fn num_workers(&self) -> usize { + self.num_workers + } } diff --git a/src/config/memory.rs b/src/config/memory.rs index 88815fb..272e92e 100644 --- a/src/config/memory.rs +++ b/src/config/memory.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::thread; use std::time::Duration; use data_encoding::{Encoding, HEXLOWER_PERMISSIVE}; @@ -35,6 +36,7 @@ pub struct MemoryConfig { pub health_check_port: Option, pub client_stats: bool, pub fault_percentage: u8, + pub num_workers: usize, } impl MemoryConfig { @@ -50,6 +52,7 @@ impl MemoryConfig { health_check_port: None, client_stats: false, fault_percentage: 0, + num_workers: thread::available_parallelism().unwrap().get(), } } } @@ -90,4 +93,8 @@ impl ServerConfig for MemoryConfig { fn fault_percentage(&self) -> u8 { self.fault_percentage } + + fn num_workers(&self) -> usize { + self.num_workers + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index a72664d..a4c0916 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -60,6 +60,7 @@ pub const DEFAULT_STATUS_INTERVAL: Duration = Duration::from_secs(600); /// `kms_protection` | `ROUGHENOUGH_KMS_PROTECTION` | Optional | If compiled with KMS support, the ID of the KMS key used to protect the long-term identity. /// `client_stats` | `ROUGHENOUGH_CLIENT_STATS` | Optional | A value of `on` or `yes` will enable tracking of per-client request statistics that will be output each time server status is logged. Default is `off` (disabled). /// `fault_percentage` | `ROUGHENOUGH_FAULT_PERCENTAGE` | Optional | Likelihood (as a percentage) that the server will intentionally return an invalid client response. An integer range from `0` (disabled, all responses valid) to `50` (50% of responses will be invalid). Default is `0` (disabled). +/// `num_workers` | `ROUGHENOUGH_NUM_WORKERS` | Optional | Number of worker threads created to process requests. Defaults to `thread::available_parallelism()` /// /// Implementations of this trait obtain a valid configuration from different back-end /// sources. See: @@ -109,6 +110,10 @@ pub trait ServerConfig : Send { /// for background and rationale. fn fault_percentage(&self) -> u8; + /// [Optional] The number of worker threads to start. Defaults to the value returned by + /// Rust's `thread::available_parallelism()`. + fn num_workers(&self) -> usize; + /// Convenience function to create a `SocketAddr` from the provided `interface` and `port` fn udp_socket_addr(&self) -> Result { let addr = format!("{}:{}", self.interface(), self.port()); @@ -190,6 +195,11 @@ pub fn is_valid_config(cfg: &dyn ServerConfig) -> bool { is_valid = false; } + if cfg.num_workers() == 0 { + error!("num_workers must be > 0"); + is_valid = false; + } + if is_valid { if let Err(e) = cfg.udp_socket_addr() { error!(