From e22637aa1912931749aa58c964bee5d88b63207c Mon Sep 17 00:00:00 2001 From: Yingjie Shang Date: Wed, 9 Oct 2024 14:58:02 +0800 Subject: [PATCH] Instrument RustyVault with Prometheus (#76) * Instrument RustyVault with Prometheus * Network data stay zero all the time. Comment out temporarily, fix it later. * Add test cases to verify Prometheus instrumentation. * Since load_avg captured by sysinfo is not available Windows, skip it on Windows platform. --- Cargo.toml | 2 + src/cli/command/server.rs | 28 +- src/cli/config.rs | 6 + src/http/metrics.rs | 24 ++ src/http/mod.rs | 2 + src/lib.rs | 1 + src/metrics/http_metrics.rs | 192 +++++++++++++ src/metrics/manager.rs | 21 ++ src/metrics/middleware.rs | 123 +++++++++ src/metrics/mod.rs | 75 +++++ src/metrics/system_metrics.rs | 193 +++++++++++++ src/test_utils.rs | 498 ++++++++++++++++++++++++++-------- 12 files changed, 1053 insertions(+), 112 deletions(-) create mode 100644 src/http/metrics.rs create mode 100644 src/metrics/http_metrics.rs create mode 100644 src/metrics/manager.rs create mode 100644 src/metrics/middleware.rs create mode 100644 src/metrics/mod.rs create mode 100644 src/metrics/system_metrics.rs diff --git a/Cargo.toml b/Cargo.toml index 818914f..d6e5ac7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,8 @@ dashmap = "5.5" tokio = { version = "1.40", features = ["rt-multi-thread", "macros"] } ctor = "0.2.8" better_default = "1.0.5" +prometheus-client = "0.22.3" +sysinfo = "0.31.4" # optional dependencies openssl = { version = "0.10.64", optional = true } diff --git a/src/cli/command/server.rs b/src/cli/command/server.rs index ddb8d3d..85d7630 100644 --- a/src/cli/command/server.rs +++ b/src/cli/command/server.rs @@ -7,7 +7,10 @@ use std::{ sync::{Arc, RwLock}, }; -use actix_web::{middleware, web, App, HttpResponse, HttpServer}; +use actix_web::{ + middleware::{self, from_fn}, + web, App, HttpResponse, HttpServer, +}; use anyhow::format_err; use clap::ArgMatches; use openssl::{ @@ -17,8 +20,9 @@ use openssl::{ use sysexits::ExitCode; use crate::{ - cli::config, core::Core, errors::RvError, http, storage, EXIT_CODE_INSUFFICIENT_PARAMS, - EXIT_CODE_LOAD_CONFIG_FAILURE, EXIT_CODE_OK, + cli::config, core::Core, errors::RvError, http, storage, + EXIT_CODE_INSUFFICIENT_PARAMS, EXIT_CODE_LOAD_CONFIG_FAILURE, EXIT_CODE_OK, + metrics::{manager::MetricsManager, middleware::metrics_midleware}, }; pub const WORK_DIR_PATH_DEFAULT: &str = "/tmp/rusty_vault"; @@ -109,7 +113,14 @@ pub fn main(config_path: &str) -> Result<(), RvError> { let barrier = storage::barrier_aes_gcm::AESGCMBarrier::new(Arc::clone(&backend)); - let core = Arc::new(RwLock::new(Core { physical: backend, barrier: Arc::new(barrier), ..Default::default() })); + let metrics_manager = Arc::new(RwLock::new(MetricsManager::new(config.collection_interval))); + let system_metrics = Arc::clone(&metrics_manager.read().unwrap().system_metrics); + + let core = Arc::new(RwLock::new(Core { + physical: backend, + barrier: Arc::new(barrier), + ..Default::default() + })); { let mut c = core.write()?; @@ -119,7 +130,9 @@ pub fn main(config_path: &str) -> Result<(), RvError> { let mut http_server = HttpServer::new(move || { App::new() .wrap(middleware::Logger::default()) + .wrap(from_fn(metrics_midleware)) .app_data(web::Data::new(Arc::clone(&core))) + .app_data(web::Data::new(Arc::clone(&metrics_manager))) .configure(http::init_service) .default_service(web::to(|| HttpResponse::NotFound())) }) @@ -182,7 +195,12 @@ pub fn main(config_path: &str) -> Result<(), RvError> { log::info!("rusty_vault server starts, waiting for request..."); - server.block_on(async { http_server.run().await })?; + server.block_on(async { + tokio::spawn(async { + system_metrics.start_collecting().await; + }); + http_server.run().await + })?; let _ = server.run(); Ok(()) diff --git a/src/cli/config.rs b/src/cli/config.rs index 2b14b22..54c3547 100644 --- a/src/cli/config.rs +++ b/src/cli/config.rs @@ -36,6 +36,12 @@ pub struct Config { pub daemon_user: String, #[serde(default)] pub daemon_group: String, + #[serde(default = "default_collection_interval")] + pub collection_interval: u64, +} + +fn default_collection_interval() -> u64 { + 15 } /// A struct that contains several configurable options for networking stuffs diff --git a/src/http/metrics.rs b/src/http/metrics.rs new file mode 100644 index 0000000..c556d66 --- /dev/null +++ b/src/http/metrics.rs @@ -0,0 +1,24 @@ +use std::sync::{Arc, RwLock}; + +use actix_web::{web, HttpResponse}; +use prometheus_client::encoding::text::encode; +use crate::metrics::manager::MetricsManager; + +pub async fn metrics_handler(metrics_manager: web::Data>>) -> HttpResponse { + let m = metrics_manager.read().unwrap(); + let registry = m.registry.lock().unwrap(); + + let mut buffer = String::new(); + if let Err(e) = encode(&mut buffer, ®istry) { + log::error!("Failed to encode metrics: {}", e); + return HttpResponse::InternalServerError().finish(); + } + + HttpResponse::Ok() + .content_type("text/plain; version=0.0.4") + .body(buffer) +} + +pub fn init_metrics_service(cfg: &mut web::ServiceConfig){ + cfg.service(web::resource("/metrics").route(web::get().to(metrics_handler))); +} diff --git a/src/http/mod.rs b/src/http/mod.rs index a7c3750..9f1a600 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -21,6 +21,7 @@ use crate::{core::Core, errors::RvError, logical::Request}; pub mod logical; pub mod sys; +pub mod metrics; pub const AUTH_COOKIE_NAME: &str = "token"; pub const AUTH_HEADER_NAME: &str = "X-RustyVault-Token"; @@ -109,6 +110,7 @@ pub fn request_on_connect_handler(conn: &dyn Any, ext: &mut Extensions) { pub fn init_service(cfg: &mut web::ServiceConfig) { sys::init_sys_service(cfg); logical::init_logical_service(cfg); + metrics::init_metrics_service(cfg); } impl ResponseError for RvError { diff --git a/src/lib.rs b/src/lib.rs index ddc6210..38e29e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,6 +40,7 @@ pub mod schema; pub mod shamir; pub mod storage; pub mod utils; +pub mod metrics; #[cfg(test)] pub mod test_utils; diff --git a/src/metrics/http_metrics.rs b/src/metrics/http_metrics.rs new file mode 100644 index 0000000..e99345b --- /dev/null +++ b/src/metrics/http_metrics.rs @@ -0,0 +1,192 @@ +//! Define and implement HTTP metrics and corresponding methods. +use std::fmt::Write; + +use prometheus_client::encoding::{EncodeLabelSet, EncodeLabelValue, LabelValueEncoder}; +use prometheus_client::metrics::counter::Counter; +use prometheus_client::metrics::family::Family; +use prometheus_client::metrics::histogram::{linear_buckets, Histogram}; +use prometheus_client::registry::Registry; + +pub const HTTP_REQUEST_COUNT: &str = "http_request_count"; +pub const HTTP_REQUEST_COUNT_HELP: &str = "Number of HTTP requests received, labeled by method and status"; +pub const HTTP_REQUEST_DURATION_SECONDS: &str = "http_request_duration_seconds"; +pub const HTTP_REQUEST_DURATION_SECONDS_HELP: &str = "Duration of HTTP requests, labeled by method and status"; + +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub enum MetricsMethod { + GET, + POST, + PUT, + DELETE, + LIST, + OTHER, +} + +impl EncodeLabelValue for MetricsMethod { + fn encode(&self, writer: &mut LabelValueEncoder<'_>) -> Result<(), std::fmt::Error> { + match self { + MetricsMethod::GET => writer.write_str("get"), + MetricsMethod::POST => writer.write_str("post"), + MetricsMethod::PUT => writer.write_str("put"), + MetricsMethod::DELETE => writer.write_str("delete"), + MetricsMethod::LIST => writer.write_str("list"), + MetricsMethod::OTHER => writer.write_str("other"), + } + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct HttpLabel { + pub path: String, + pub method: MetricsMethod, + pub status: u16, +} + +#[derive(Clone)] +pub struct HttpMetrics { + requests: Family, + histogram: Family, +} + +impl HttpMetrics { + pub fn new(registry: &mut Registry) -> Self { + let requests = Family::::default(); + let histogram = + Family::::new_with_constructor(|| Histogram::new(linear_buckets(0.1, 0.1, 10))); + + registry.register(HTTP_REQUEST_COUNT, HTTP_REQUEST_COUNT_HELP, requests.clone()); + + registry.register(HTTP_REQUEST_DURATION_SECONDS, HTTP_REQUEST_DURATION_SECONDS_HELP, histogram.clone()); + + Self { requests, histogram } + } + + pub fn increment_request_count(&self, label: &HttpLabel) { + self.requests.get_or_create(label).inc(); + } + + pub fn observe_duration(&self, label: &HttpLabel, duration: f64) { + self.histogram.get_or_create(label).observe(duration); + } +} + +#[cfg(test)] +mod tests { + use rand::Rng; + use regex::Regex; + use ureq::json; + + use crate::test_utils::TestHttpServer; + use std::collections::HashMap; + + const PATH: &str = "path"; + const METHOD: &str = "method"; + + const GET: &str = "GET"; + const LIST: &str = "LIST"; + const POST: &str = "POST"; + const PUT: &str = "PUT"; + const DELETE: &str = "DELETE"; + + fn parse_counter(raw: &str) -> HashMap> { + let lines: Vec<&str> = raw.split('\n').collect(); + let mut i = 0; + let mut counter_map: HashMap> = HashMap::new(); + let name_label_re = + Regex::new(r#"\bpath="(?P[^"]+)",method="(?P[^"]+)",status="(?P[^"]+)""#).unwrap(); + + while i < lines.len() { + let line = lines[i]; + if line.ends_with("counter") { + // move to next line, which is counter + i += 1; + let parts: Vec<&str> = lines[i].split("{").collect(); + let metric_name = parts[0]; + + // capture following counter lines + while lines[i].starts_with(metric_name) { + let parts: Vec<&str> = lines[i].split(" ").collect(); + let name_label = parts[0]; + let value: u32 = parts[1].parse().unwrap(); + + if let Some(caps) = name_label_re.captures(name_label) { + let path = caps[PATH].to_string(); + let method = caps[METHOD].to_string().to_uppercase(); + if let Some(req) = counter_map.get_mut(&path) { + req.insert(method, value); + } else { + let mut req: HashMap = HashMap::new(); + req.insert(method, value); + println!("path:{}", &path); + counter_map.insert(path, req); + } + } + + i += 1; + } + } + i += 1; + } + counter_map + } + + #[test] + fn test_http_request() { + let server = TestHttpServer::new_with_prometheus("test_http_request", false); + let root_token = &server.root_token; + + let path = ["v1/secret/password-0", "v1/secret/password-1", "v1/secret/password-2", "v1/secret"]; + let mock = [ + vec![(DELETE, 2)], + vec![(POST, 3), (GET, 5), (PUT, 7), (DELETE, 9)], + vec![(POST, 2), (GET, 8), (PUT, 12), (DELETE, 16)], + vec![(LIST, 1)], + ]; + let mut mock_map: HashMap<&str, Vec<(&str, u32)>> = HashMap::new(); + for (p, m) in path.iter().zip(mock.iter()) { + mock_map.insert(p, m.to_vec()); + } + + for (path, mock) in &mock_map { + for request in mock { + let method = request.0; + let count = request.1; + for _ in 0..count { + if method == "POST" || method == "PUT" { + let random_number: u32 = rand::thread_rng().gen_range(0..10000); + let data = json!({ + "password": random_number, + }) + .as_object() + .unwrap() + .clone(); + let (_, _) = server.request(method, path, Some(data), Some(&root_token), None).unwrap(); + } else { + let (_, _) = server.request(method, path, None, Some(&root_token), None).unwrap(); + } + } + } + } + + let (status, resp) = server.request_prometheus("GET", "metrics", None, Some(&root_token), None).unwrap(); + assert_eq!(status, 200); + + let counter_map = parse_counter(resp["metrics"].as_str().unwrap()); + println!("counter map len={}", counter_map.len()); + + for (path, mock) in &mock_map { + for mock_req in mock { + let method = mock_req.0; + let count = mock_req.1; + let path = format!("/{}", path); + assert!(counter_map.contains_key(&path)); + + let prom = counter_map.get(&path).unwrap(); + assert!(prom.contains_key(method)); + + let value = *prom.get(method).unwrap(); + assert_eq!(count, value); + } + } + } +} diff --git a/src/metrics/manager.rs b/src/metrics/manager.rs new file mode 100644 index 0000000..8be3e86 --- /dev/null +++ b/src/metrics/manager.rs @@ -0,0 +1,21 @@ +//! `MetricManager` holds the Prometheus registry and metrics. +use crate::metrics::http_metrics::HttpMetrics; +use crate::metrics::system_metrics::SystemMetrics; +use prometheus_client::registry::Registry; +use std::sync::{Arc, Mutex}; + +#[derive(Clone)] +pub struct MetricsManager { + pub registry: Arc>, + pub system_metrics: Arc, + pub http_metrics: Arc, +} + +impl MetricsManager { + pub fn new(collection_interval: u64) -> Self { + let registry = Arc::new(Mutex::new(Registry::default())); + let system_metrics = Arc::new(SystemMetrics::new(&mut registry.lock().unwrap(), collection_interval)); + let http_metrics = Arc::new(HttpMetrics::new(&mut registry.lock().unwrap())); + MetricsManager { registry, system_metrics, http_metrics } + } +} diff --git a/src/metrics/middleware.rs b/src/metrics/middleware.rs new file mode 100644 index 0000000..caa5f14 --- /dev/null +++ b/src/metrics/middleware.rs @@ -0,0 +1,123 @@ +//! Actix-web middleware function, captures and monitors HTTP requests. +//! +//! # Usage +//! The actix-web middleware function could be used as following: +//! +//! ```text +//! let mut http_server = HttpServer::new(move || { +//! App::new() +//! //skip +//! .wrap(from_fn(metrics_midleware)) +//! //skip +//! }) +//! ``` +use std::{ + sync::{Arc, RwLock}, + time::Instant, +}; + +use crate::metrics::http_metrics::HttpLabel; +use actix_web::{ + body::MessageBody, + dev::{ServiceRequest, ServiceResponse}, + http::Method, + middleware::Next, + web::Data, + Error, +}; + +use super::{http_metrics::MetricsMethod, manager::MetricsManager}; + +pub async fn metrics_midleware( + req: ServiceRequest, + next: Next, +) -> Result, Error> { + let start_time = Instant::now(); + let path = req.path().to_string(); + let method = match *req.method() { + Method::GET => MetricsMethod::GET, + _ if req.method().to_string() == "LIST" => MetricsMethod::LIST, + Method::POST => MetricsMethod::POST, + Method::PUT => MetricsMethod::PUT, + Method::DELETE => MetricsMethod::DELETE, + _ => MetricsMethod::OTHER, + }; + + let res = next.call(req).await?; + + let status = res.status().as_u16(); + let label = HttpLabel { path, method, status }; + if let Some(m) = res.request().app_data::>>>() { + let metrics_manager = m.read().unwrap(); + metrics_manager.http_metrics.increment_request_count(&label); + let duration = start_time.elapsed().as_secs_f64(); + metrics_manager.http_metrics.observe_duration(&label, duration); + } + + Ok(res) +} + +#[cfg(test)] +mod tests { + use crate::metrics::http_metrics::*; + use crate::metrics::system_metrics::*; + use crate::test_utils::TestHttpServer; + use std::collections::HashMap; + + static SYS_METRICS_MAP: &[(&str, &str)] = &[ + (CPU_USAGE_PERCENT, CPU_USAGE_PERCENT_HELP), + (TOTAL_MEMORY, TOTAL_MEMORY_HELP), + (USED_MEMORY, USED_MEMORY_HELP), + (FREE_MEMORY, FREE_MEMORY_HELP), + (TOTAL_DISK_SPACE, TOTAL_DISK_SPACE_HELP), + (TOTAL_DISK_AVAILABLE, TOTAL_DISK_AVAILABLE_HELP), + // (NETWORK_IN, NETWORK_IN_HELP), + // (NETWORK_OUT, NETWORK_OUT_HELP), + (LOAD_AVERAGE, LOAD_AVERAGE_HELP), + ]; + + static HTTP_METRICS_MAP: &[(&str, &str)] = &[ + (HTTP_REQUEST_COUNT, HTTP_REQUEST_COUNT_HELP), + (HTTP_REQUEST_DURATION_SECONDS, HTTP_REQUEST_DURATION_SECONDS_HELP), + ]; + + fn parse_metrics_name_help(raw: &str) -> HashMap { + let mut metrics_map = HashMap::new(); + for line in raw.split('\n') { + if line.starts_with("# HELP") { + let line = line.trim_end_matches("."); + // # PROPERTY METRIC_NAME METRIC_HELP + // # HELP cpu_usage_percent CPU usage percent. + let parts: Vec<&str> = line.split(" ").collect(); + let metric_name = parts[2].to_string(); + let metric_help = parts[3..].join(" "); + metrics_map.insert(metric_name, metric_help); + } + } + metrics_map + } + + #[test] + fn test_metrics_name_and_help_info() { + let sys_metrics_map: HashMap<&str, &str> = SYS_METRICS_MAP.iter().cloned().collect(); + let http_metrics_map: HashMap<&str, &str> = HTTP_METRICS_MAP.iter().cloned().collect(); + + let server = TestHttpServer::new_with_prometheus("test_metrics_name_and_help_info", false); + let root_token = &server.root_token; + let (status, resp) = server.request_prometheus("GET", "metrics", None, Some(&root_token), None).unwrap(); + assert_eq!(status, 200); + + let metrics_map = parse_metrics_name_help(resp["metrics"].as_str().unwrap()); + assert_eq!(sys_metrics_map.len() + http_metrics_map.len(), metrics_map.len()); + + for (metric_name, metric_help) in &metrics_map { + let name = metric_name.as_str(); + let help = metric_help.as_str(); + if sys_metrics_map.contains_key(name) { + assert_eq!(sys_metrics_map.get(name), Some(&help)); + } else if http_metrics_map.contains_key(name) { + assert_eq!(http_metrics_map.get(name), Some(&help)); + } + } + } +} diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs new file mode 100644 index 0000000..6058905 --- /dev/null +++ b/src/metrics/mod.rs @@ -0,0 +1,75 @@ +//! The `rusty_vault::metrics` module instruments RustyVault with Prometheus, allowing it to capture performance metrics. +//! +//! # Methodology +//! +//! From a monitoring perspective, [Prometheus](https://prometheus.io/docs/practices/instrumentation/#the-three-types-of-services) categorizes services into three types: online services, offline processing, and batch jobs. As a modern key management system, RustyVault provides a set of RESTful APIs, so it is classified as an online service. +//! +//! In online service systems, the key metrics include the number of executed queries, error rates, and latency. In this project, the monitored content is divided into two parts: the target operating system and the target application service. +//! +//! Based on the [USE (Utilization, Saturation, and Errors) method](https://www.brendangregg.com/usemethod.html), system performance metrics such as CPU, memory, disk, network, and load are monitored. For the target service, the [RED (Rate, Errors, and Duration)](https://grafana.com/blog/2018/08/02/the-red-method-how-to-instrument-your-services/) method is used to monitor the number of requests, request outcomes, and the time taken to process each request. +//! +//! # Dependency +//! +//! This implementation utilizes the [prometheus-client](https://docs.rs/prometheus-client/latest/prometheus_client/) and [sysinfo](https://docs.rs/sysinfo/latest/sysinfo/) libraries to gather system performance data. +//! +//! # How to Create and Using New Metric +//! +//! 1. **Define and Implement Metrics** +//! +//! Define your metrics under `src/metrics/` and register them with the `Registry` like this: +//! +//! ```text +//! pub const HTTP_REQUEST_COUNT: &str = "http_request_count"; +//! pub const HTTP_REQUEST_COUNT_HELP: &str = "Number of HTTP requests received, labeled by method and status"; +//! +//! pub struct HttpMetrics { +//! requests: Family, +//! } +//! +//! impl HttpMetrics { +//! pub fn new(registry: &mut Registry) -> Self { +//! let requests = Family::::default(); +//! registry.register(HTTP_REQUEST_COUNT, HTTP_REQUEST_COUNT_HELP, requests.clone()); +//! Self { requests } +//! } +//! +//! pub fn increment_request_count(&self, label: &HttpLabel) { +//! self.requests.get_or_create(label).inc(); +//! } +//! } +//! ``` +//! +//! 2. **Add Metrics to `MetricsManager`** +//! +//! Register the metrics within the `MetricsManager` struct: +//! +//! ```text +//! pub struct MetricsManager { +//! pub registry: Arc>, +//! pub http_metrics: Arc, +//! // Other fields... +//! } +//! +//! impl MetricsManager { +//! pub fn new(collection_interval: u64) -> Self { +//! let registry = Arc::new(Mutex::new(Registry::default())); +//! let http_metrics = Arc::new(HttpMetrics::new(&mut registry.lock().unwrap())); +//! MetricsManager { registry, http_metrics } +//! } +//! } +//! ``` +//! +//! 3. **Update Metrics Based on Events** +//! +//! Invoke methods to update metrics where relevant events occur. In this example, retrieve `MetricsManager` from the `app_data` in the Actix Web application: +//! +//! ```text +//! if let Some(m) = res.request().app_data::>>>() { +//! let metrics_manager = m.read().unwrap(); +//! metrics_manager.http_metrics.increment_request_count(&label); +//! } +//! ``` +pub mod middleware; +pub mod manager; +pub mod system_metrics; +pub mod http_metrics; \ No newline at end of file diff --git a/src/metrics/system_metrics.rs b/src/metrics/system_metrics.rs new file mode 100644 index 0000000..ff4c53b --- /dev/null +++ b/src/metrics/system_metrics.rs @@ -0,0 +1,193 @@ +//! Define and implement operating system metrics, using [sysinfo](https://docs.rs/sysinfo/latest/sysinfo/) to capture. +use prometheus_client::metrics::gauge::Gauge; +use prometheus_client::registry::Registry; +use std::sync::{atomic::AtomicU64, Arc, Mutex}; +use sysinfo::{Disks, System}; +use tokio::time::{self, Duration}; + +pub const CPU_USAGE_PERCENT: &str = "cpu_usage_percent"; +pub const CPU_USAGE_PERCENT_HELP: &str = "CPU usage percent"; +pub const TOTAL_MEMORY: &str = "total_memory"; +pub const TOTAL_MEMORY_HELP: &str = "Total memory"; +pub const USED_MEMORY: &str = "used_memory"; +pub const USED_MEMORY_HELP: &str = "Used memory"; +pub const FREE_MEMORY: &str = "free_memory"; +pub const FREE_MEMORY_HELP: &str = "Free memory"; +pub const TOTAL_DISK_SPACE: &str = "total_disk_space"; +pub const TOTAL_DISK_SPACE_HELP: &str = "Total disk space"; +pub const TOTAL_DISK_AVAILABLE: &str = "total_disk_available"; +pub const TOTAL_DISK_AVAILABLE_HELP: &str = "Total disk available"; +// pub const NETWORK_IN: &str = "network_in"; +// pub const NETWORK_IN_HELP: &str = "Network in"; +// pub const NETWORK_OUT: &str = "network_out"; +// pub const NETWORK_OUT_HELP: &str = "Network out"; +pub const LOAD_AVERAGE: &str = "load_average"; +pub const LOAD_AVERAGE_HELP: &str = "System load average"; + +pub struct SystemMetrics { + system: Arc>, + collection_interval: u64, + cpu_usage: Gauge, + total_memory: Gauge, + used_memory: Gauge, + free_memory: Gauge, + total_disk_available: Gauge, + total_disk_space: Gauge, + // network_in: Gauge, + // network_out: Gauge, + load_avg: Gauge, +} + +impl SystemMetrics { + pub fn new(registry: &mut Registry, collection_interval: u64) -> Self { + let cpu_usage = Gauge::::default(); + + let total_memory = Gauge::::default(); + let used_memory = Gauge::::default(); + let free_memory = Gauge::::default(); + + let total_disk_space = Gauge::::default(); + let total_disk_available = Gauge::::default(); + + // let network_in = Gauge::::default(); + // let network_out = Gauge::::default(); + let load_avg = Gauge::::default(); + + registry.register(CPU_USAGE_PERCENT, CPU_USAGE_PERCENT_HELP, cpu_usage.clone()); + + registry.register(TOTAL_MEMORY, TOTAL_MEMORY_HELP, total_memory.clone()); + registry.register(USED_MEMORY, USED_MEMORY_HELP, used_memory.clone()); + registry.register(FREE_MEMORY, FREE_MEMORY_HELP, free_memory.clone()); + + registry.register(TOTAL_DISK_SPACE, TOTAL_DISK_SPACE_HELP, total_disk_space.clone()); + registry.register(TOTAL_DISK_AVAILABLE, TOTAL_DISK_AVAILABLE_HELP, total_disk_available.clone()); + + // registry.register(NETWORK_IN, NETWORK_IN_HELP, network_in.clone()); + // registry.register(NETWORK_OUT, NETWORK_OUT_HELP, network_out.clone()); + + registry.register(LOAD_AVERAGE, LOAD_AVERAGE_HELP, load_avg.clone()); + + let system = Arc::new(Mutex::new(System::new_all())); + + Self { + system, + collection_interval, + cpu_usage, + total_memory, + used_memory, + free_memory, + total_disk_available, + total_disk_space, + // network_in, + // network_out, + load_avg, + } + } + + pub async fn start_collecting(self: Arc) { + let mut interval = time::interval(Duration::from_secs(self.collection_interval)); + + loop { + interval.tick().await; + self.collect_metrics(); + } + } + + fn collect_metrics(&self) { + let mut sys = self.system.lock().unwrap(); + sys.refresh_all(); + + self.cpu_usage.set(sys.global_cpu_usage() as f64); + + self.total_memory.set(sys.total_memory() as f64); + self.used_memory.set(sys.used_memory() as f64); + self.free_memory.set(sys.free_memory() as f64); + + let mut total_available_space = 0; + let mut total_disk_space = 0; + + for disk in Disks::new_with_refreshed_list().list() { + total_available_space += disk.available_space(); + total_disk_space += disk.total_space(); + } + self.total_disk_available.set(total_available_space as f64); + self.total_disk_space.set(total_disk_space as f64); + + // let mut total_network_in = 0; + // let mut total_network_out = 0; + + // TODO: network data stays at zero all the time + // for (_, n) in Networks::new_with_refreshed_list().list() { + // total_network_in += n.received(); + // total_network_out += n.transmitted(); + // } + + // self.network_in.set(total_network_in as f64); + // self.network_out.set(total_network_out as f64); + + self.load_avg.set(System::load_average().one as f64); + } +} + +#[cfg(test)] +mod tests { + use crate::metrics::system_metrics::*; + use crate::test_utils::TestHttpServer; + use std::collections::HashMap; + use std::thread; + use std::time::Duration; + + static SYS_METRICS_MAP: &[(&str, &str)] = &[ + (CPU_USAGE_PERCENT, CPU_USAGE_PERCENT_HELP), + (TOTAL_MEMORY, TOTAL_MEMORY_HELP), + (USED_MEMORY, USED_MEMORY_HELP), + (FREE_MEMORY, FREE_MEMORY_HELP), + (TOTAL_DISK_SPACE, TOTAL_DISK_SPACE_HELP), + (TOTAL_DISK_AVAILABLE, TOTAL_DISK_AVAILABLE_HELP), + // (NETWORK_IN, NETWORK_IN_HELP), + // (NETWORK_OUT, NETWORK_OUT_HELP), + (LOAD_AVERAGE, LOAD_AVERAGE_HELP), + ]; + + fn parse_gauge(raw: &str) -> HashMap { + let mut gauge_map = HashMap::new(); + let lines: Vec<&str> = raw.split('\n').collect(); + let mut i = 0; + + while i < lines.len() { + let line = lines[i]; + if line.ends_with("gauge") { + let parts: Vec<&str> = lines[i + 1].split(" ").collect(); + // println!("in parse_gauge {}:{}", parts[0], parts[1]); + let metric_name = parts[0].to_string(); + let value: f64 = parts[1].parse().unwrap(); + gauge_map.insert(metric_name, value); + } + i += 1; + } + gauge_map + } + + #[test] + fn test_sys_metrics() { + let server = TestHttpServer::new_with_prometheus("test_sys_metrics", false); + let root_token = &server.root_token; + thread::sleep(Duration::from_secs(20)); + + let (status, resp) = server.request_prometheus("GET", "metrics", None, Some(&root_token), None).unwrap(); + assert_eq!(status, 200); + + let mut gauge_map = parse_gauge(resp["metrics"].as_str().unwrap()); + assert_eq!(SYS_METRICS_MAP.len(), gauge_map.len()); + + // load average is not available on Windows + if cfg!(target_os = "windows") { + gauge_map.remove("load_average"); + } + + for (metric, value) in gauge_map { + println!("{}:{}", metric, value); + assert!(value != 0.0); + } + } +} diff --git a/src/test_utils.rs b/src/test_utils.rs index fd76070..ecacca1 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,34 +1,30 @@ +use libc::c_int; use std::{ collections::HashMap, default::Default, - env, fs, thread, + env, fs, io::prelude::*, path::Path, - sync::{Arc, RwLock, Barrier}, + sync::{Arc, Barrier, RwLock}, + thread, time::{Duration, SystemTime, UNIX_EPOCH}, }; -use libc::c_int; +use actix_web::{ + dev::Server, + middleware::{self, from_fn}, + web, App, HttpResponse, HttpServer, +}; +use anyhow::format_err; use foreign_types::ForeignType; use humantime::parse_duration; use lazy_static::lazy_static; -use serde_json::{json, Map, Value}; -use actix_web::{middleware, web, dev::Server, App, HttpResponse, HttpServer}; -use ureq::AgentBuilder; -use rustls::{ - ClientConfig, - RootCertStore, - pki_types::{ - CertificateDer, - PrivateKeyDer, - } -}; -use tokio::sync::oneshot; -use anyhow::format_err; use openssl::{ - rsa::Rsa, - pkey::{PKey, Private}, + asn1::{Asn1Object, Asn1OctetString, Asn1Time}, hash::MessageDigest, + pkey::{PKey, Private}, + rsa::Rsa, + ssl::{SslAcceptor, SslFiletype, SslMethod, SslVerifyMode, SslVersion}, x509::{ extension::{ AuthorityKeyIdentifier, BasicConstraints, ExtendedKeyUsage, KeyUsage, SubjectAlternativeName, @@ -36,18 +32,24 @@ use openssl::{ }, X509Extension, X509NameBuilder, X509Ref, X509, }, - ssl::{SslAcceptor, SslFiletype, SslMethod, SslVersion, SslVerifyMode}, - asn1::{Asn1Time, Asn1Object, Asn1OctetString}, }; +use rustls::{ + pki_types::{CertificateDer, PrivateKeyDer}, + ClientConfig, RootCertStore, +}; +use serde_json::{json, Map, Value}; +use tokio::sync::oneshot; +use ureq::AgentBuilder; use crate::{ - http, core::{Core, SealConfig}, errors::RvError, + http, logical::{Operation, Request, Response}, + metrics::{manager::MetricsManager, middleware::metrics_midleware, system_metrics::SystemMetrics}, + rv_error_response, storage::{self, Backend}, utils::cert::Certificate, - rv_error_response, }; lazy_static! { @@ -97,9 +99,21 @@ impl TestHttpServer { let mut test_tls_config = None; if tls_enable { - (ca_cert_pem, ca_key_pem) = new_test_cert(true, true, true, "test-ca", None, None, None, None, None, None).unwrap(); - (server_cert_pem, server_key_pem) = new_test_cert(false, true, true, "localhost", Some("localhost"), Some("127.0.0.1"), - None, None, Some(ca_cert_pem.clone()), Some(ca_key_pem.clone())).unwrap(); + (ca_cert_pem, ca_key_pem) = + new_test_cert(true, true, true, "test-ca", None, None, None, None, None, None).unwrap(); + (server_cert_pem, server_key_pem) = new_test_cert( + false, + true, + true, + "localhost", + Some("localhost"), + Some("127.0.0.1"), + None, + None, + Some(ca_cert_pem.clone()), + Some(ca_key_pem.clone()), + ) + .unwrap(); let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); let test_certs_dir = env::temp_dir().join(format!("{}/certs/{}-{}", *TEST_DIR, name, now).as_str()); @@ -119,10 +133,7 @@ impl TestHttpServer { let mut key_file = fs::File::create(&key_path).unwrap(); assert!(key_file.write_all(server_key_pem.as_bytes()).is_ok()); - test_tls_config = Some(TestTlsConfig { - cert_path, - key_path, - }); + test_tls_config = Some(TestTlsConfig { cert_path, key_path }); scheme = "https"; } @@ -151,6 +162,88 @@ impl TestHttpServer { } } + pub fn new_with_prometheus(name: &str, tls_enable: bool) -> Self { + let barrier = Arc::new(Barrier::new(2)); + let (stop_tx, stop_rx) = oneshot::channel(); + let (root_token, core) = test_rusty_vault_init(name); + + let mut scheme = "http"; + let mut ca_cert_pem = "".into(); + let mut ca_key_pem = "".into(); + let mut server_cert_pem = "".into(); + let mut server_key_pem = "".into(); + let mut test_tls_config = None; + + if tls_enable { + (ca_cert_pem, ca_key_pem) = + new_test_cert(true, true, true, "test-ca", None, None, None, None, None, None).unwrap(); + (server_cert_pem, server_key_pem) = new_test_cert( + false, + true, + true, + "localhost", + Some("localhost"), + Some("127.0.0.1"), + None, + None, + Some(ca_cert_pem.clone()), + Some(ca_key_pem.clone()), + ) + .unwrap(); + + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + let test_certs_dir = env::temp_dir().join(format!("{}/certs/{}-{}", *TEST_DIR, name, now).as_str()); + let dir = test_certs_dir.to_string_lossy().into_owned(); + assert!(fs::create_dir_all(&test_certs_dir).is_ok()); + + let ca_path = format!("{}/ca.crt", dir); + let cert_path = format!("{}/server.crt", dir); + let key_path = format!("{}/key.pem", dir); + + let mut ca_file = fs::File::create(&ca_path).unwrap(); + assert!(ca_file.write_all(ca_cert_pem.as_bytes()).is_ok()); + + let mut cert_file = fs::File::create(&cert_path).unwrap(); + assert!(cert_file.write_all(server_cert_pem.as_bytes()).is_ok()); + + let mut key_file = fs::File::create(&key_path).unwrap(); + assert!(key_file.write_all(server_key_pem.as_bytes()).is_ok()); + + test_tls_config = Some(TestTlsConfig { cert_path, key_path }); + + scheme = "https"; + } + + let collection_interval: u64 = 15; + let metrics_manager = Arc::new(RwLock::new(MetricsManager::new(collection_interval))); + let system_metrics = Arc::clone(&metrics_manager.read().unwrap().system_metrics); + + let (server, listen_addr) = + new_test_http_server_with_prometheus(core.clone(), metrics_manager, test_tls_config).unwrap(); + let server_thread = + start_test_http_server_with_prometheus(server, Arc::clone(&barrier), stop_rx, system_metrics); + + barrier.wait(); + + let url_prefix = format!("{}://{}", scheme, listen_addr); + + Self { + name: name.to_string(), + core, + root_token, + tls_enable, + ca_cert_pem, + ca_key_pem, + server_cert_pem, + server_key_pem, + listen_addr, + url_prefix, + mount_path: "".into(), + stop_tx: Some(stop_tx), + thread: Some(server_thread), + } + } + pub fn mount(&mut self, path: &str, mtype: &str) -> Result<(u16, Value), RvError> { let data = json!({ "type": mtype, @@ -181,7 +274,12 @@ impl TestHttpServer { Ok((status, resp)) } - pub fn login(&self, path: &str, data: Option>, tls_client_auth: Option) -> Result<(u16, Value), RvError> { + pub fn login( + &self, + path: &str, + data: Option>, + tls_client_auth: Option, + ) -> Result<(u16, Value), RvError> { self.request("POST", path, data, None, tls_client_auth) } @@ -193,15 +291,32 @@ impl TestHttpServer { self.request("GET", path, None, token, None) } - pub fn write(&self, path: &str, data: Option>, token: Option<&str>) -> Result<(u16, Value), RvError> { + pub fn write( + &self, + path: &str, + data: Option>, + token: Option<&str>, + ) -> Result<(u16, Value), RvError> { self.request("POST", path, data, token, None) } - pub fn delete(&self, path: &str, data: Option>, token: Option<&str>) -> Result<(u16, Value), RvError> { + pub fn delete( + &self, + path: &str, + data: Option>, + token: Option<&str>, + ) -> Result<(u16, Value), RvError> { self.request("DELETE", path, data, token, None) } - pub fn request(&self, method: &str, path: &str, data: Option>, token: Option<&str>, tls_client_auth: Option) -> Result<(u16, Value), RvError> { + pub fn request( + &self, + method: &str, + path: &str, + data: Option>, + token: Option<&str>, + tls_client_auth: Option, + ) -> Result<(u16, Value), RvError> { let url = format!("{}/{}", self.url_prefix, path); println!("request url: {}, method: {}", url, method); let tk = token.unwrap_or(&self.root_token); @@ -222,18 +337,18 @@ impl TestHttpServer { Some((rustls_pemfile::Item::X509Certificate(cert), rest)) => { cert_pem = rest; client_certs.push(cert.into()); - }, + } None => break, _ => return Err(rv_error_response!("client cert format invalid")), } } - let client_key: PrivateKeyDer = match rustls_pemfile::read_one_from_slice(client_auth.key_pem.as_bytes())? { - Some((rustls_pemfile::Item::Pkcs1Key(key), _)) => PrivateKeyDer::Pkcs1(key), - Some((rustls_pemfile::Item::Pkcs8Key(key), _)) => PrivateKeyDer::Pkcs8(key), - _ => return Err(rv_error_response!("client key format invalid")), - }; - + let client_key: PrivateKeyDer = + match rustls_pemfile::read_one_from_slice(client_auth.key_pem.as_bytes())? { + Some((rustls_pemfile::Item::Pkcs1Key(key), _)) => PrivateKeyDer::Pkcs1(key), + Some((rustls_pemfile::Item::Pkcs8Key(key), _)) => PrivateKeyDer::Pkcs8(key), + _ => return Err(rv_error_response!("client key format invalid")), + }; tls_config = ClientConfig::builder() .with_root_certificates(ca_store) @@ -246,9 +361,7 @@ impl TestHttpServer { let mut root_store = RootCertStore::empty(); root_store.add(root_cert)?; - tls_config = ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); + tls_config = ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth(); } let agent = AgentBuilder::new() @@ -266,25 +379,114 @@ impl TestHttpServer { req = req.set("X-RustyVault-Token", tk); } - let response_result = if let Some(send_data) = data { - req.send_json(send_data) + let response_result = if let Some(send_data) = data { req.send_json(send_data) } else { req.call() }; + + match response_result { + Ok(response) => { + let status = response.status(); + if status == 204 { + return Ok((status, json!(""))); + } + let json: Value = response.into_json()?; + return Ok((status, json)); + } + Err(ureq::Error::Status(code, response)) => { + let json: Value = response.into_json()?; + return Ok((code, json)); + } + Err(e) => { + println!("Request failed: {}", e); + return Err(RvError::UreqError { source: e }); + } + } + } + + pub fn request_prometheus( + &self, + method: &str, + path: &str, + data: Option>, + token: Option<&str>, + tls_client_auth: Option, + ) -> Result<(u16, Value), RvError> { + let url = format!("{}/{}", self.url_prefix, path); + println!("request url: {}, method: {}", url, method); + let tk = token.unwrap_or(&self.root_token); + let mut req = if self.tls_enable { + // Create rustls ClientConfig + let tls_config; + if let Some(client_auth) = tls_client_auth { + let ca_pem = pem::parse(client_auth.ca_pem.as_bytes())?; + let ca_cert = CertificateDer::from_slice(ca_pem.contents()); + + let mut ca_store = RootCertStore::empty(); + ca_store.add(ca_cert)?; + + let mut client_certs = vec![]; + let mut cert_pem = client_auth.cert_pem.as_bytes(); + loop { + match rustls_pemfile::read_one_from_slice(cert_pem)? { + Some((rustls_pemfile::Item::X509Certificate(cert), rest)) => { + cert_pem = rest; + client_certs.push(cert.into()); + } + None => break, + _ => return Err(rv_error_response!("client cert format invalid")), + } + } + + let client_key: PrivateKeyDer = + match rustls_pemfile::read_one_from_slice(client_auth.key_pem.as_bytes())? { + Some((rustls_pemfile::Item::Pkcs1Key(key), _)) => PrivateKeyDer::Pkcs1(key), + Some((rustls_pemfile::Item::Pkcs8Key(key), _)) => PrivateKeyDer::Pkcs8(key), + _ => return Err(rv_error_response!("client key format invalid")), + }; + + tls_config = ClientConfig::builder() + .with_root_certificates(ca_store) + .with_client_auth_cert(client_certs, client_key)?; + } else { + let cert_pem = pem::parse(self.ca_cert_pem.as_bytes())?; + let root_cert = CertificateDer::from_slice(cert_pem.contents()); + + // Configure the root certificate + let mut root_store = RootCertStore::empty(); + root_store.add(root_cert)?; + + tls_config = ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth(); + } + + let agent = AgentBuilder::new() + .timeout_connect(Duration::from_secs(10)) + .timeout(Duration::from_secs(30)) + .tls_config(Arc::new(tls_config)) + .build(); + agent.request(&method.to_uppercase(), &url) } else { - req.call() + ureq::request(&method.to_uppercase(), &url) }; + req = req.set("Accept", "application/json"); + if !path.ends_with("/login") { + req = req.set("X-RustyVault-Token", tk); + } + + let response_result = if let Some(send_data) = data { req.send_json(send_data) } else { req.call() }; + match response_result { Ok(response) => { let status = response.status(); if status == 204 { return Ok((status, json!(""))); } - let json: Value = response.into_json()?; - return Ok((status, json)) - }, + let text = response.into_string()?; + let wrapped_json = json!({"metrics":text}); + return Ok((status, wrapped_json)); + } Err(ureq::Error::Status(code, response)) => { let json: Value = response.into_json()?; - return Ok((code, json)) - }, + return Ok((code, json)); + } Err(e) => { println!("Request failed: {}", e); return Err(RvError::UreqError { source: e }); @@ -333,7 +535,7 @@ pub fn new_test_cert( uri_sans: Option<&str>, ttl: Option<&str>, ca_cert_pem: Option, - ca_key_pem: Option + ca_key_pem: Option, ) -> Result<(String, String), RvError> { let not_before = SystemTime::now(); let not_after = not_before + parse_duration(ttl.unwrap_or("5d"))?; @@ -346,13 +548,7 @@ pub fn new_test_cert( let subject = subject_name.build(); - let mut cert = Certificate { - not_before, - not_after, - subject, - is_ca, - ..Default::default() - }; + let mut cert = Certificate { not_before, not_after, subject, is_ca, ..Default::default() }; if let Some(dns) = dns_sans { cert.dns_sans = dns.split(',').map(|s| s.trim().to_string()).collect(); @@ -373,14 +569,11 @@ pub fn new_test_cert( let ca_cert = X509::from_pem(cert_pem.as_bytes())?; let ca_key = PKey::private_key_from_pem(key_pem.as_bytes())?; cert_to_x509(&cert, client_auth, server_auth, Some(&ca_cert), Some(&ca_key), &pkey)? - }, - _ => cert_to_x509(&cert, client_auth, server_auth, None, None, &pkey)? + } + _ => cert_to_x509(&cert, client_auth, server_auth, None, None, &pkey)?, }; - Ok(( - String::from_utf8(x509.to_pem()?)?, - String::from_utf8(pkey.private_key_to_pem_pkcs8()?)?, - )) + Ok((String::from_utf8(x509.to_pem()?)?, String::from_utf8(pkey.private_key_to_pem_pkcs8()?)?)) } pub fn new_test_cert_ext( @@ -393,7 +586,7 @@ pub fn new_test_cert_ext( uri_sans: Option<&str>, ttl: Option<&str>, ca_cert_pem: Option, - ca_key_pem: Option + ca_key_pem: Option, ) -> Result<(String, String), RvError> { let not_before = SystemTime::now(); let not_after = not_before + parse_duration(ttl.unwrap_or("5d"))?; @@ -407,35 +600,34 @@ pub fn new_test_cert_ext( let subject = subject_name.build(); - let extensions = vec![X509Extension::new_from_der( - &Asn1Object::from_str("2.1.1.1").unwrap(), - false, - &Asn1OctetString::new_from_bytes(b"A UTF8String Extension").unwrap(), - ).unwrap(), + let extensions = vec![ X509Extension::new_from_der( - &Asn1Object::from_str("2.1.1.2").unwrap(), - false, - &Asn1OctetString::new_from_bytes(b"A UTF8 Extension").unwrap(), - ).unwrap(), + &Asn1Object::from_str("2.1.1.1").unwrap(), + false, + &Asn1OctetString::new_from_bytes(b"A UTF8String Extension").unwrap(), + ) + .unwrap(), X509Extension::new_from_der( - &Asn1Object::from_str("2.1.1.3").unwrap(), - false, - &Asn1OctetString::new_from_bytes(b"An IA5 Extension").unwrap(), - ).unwrap(), + &Asn1Object::from_str("2.1.1.2").unwrap(), + false, + &Asn1OctetString::new_from_bytes(b"A UTF8 Extension").unwrap(), + ) + .unwrap(), X509Extension::new_from_der( - &Asn1Object::from_str("2.1.1.4").unwrap(), - false, - &Asn1OctetString::new_from_bytes(b"A Visible Extension").unwrap(), - ).unwrap()]; - - let mut cert = Certificate { - not_before, - not_after, - subject, - is_ca, - extensions, - ..Default::default() - }; + &Asn1Object::from_str("2.1.1.3").unwrap(), + false, + &Asn1OctetString::new_from_bytes(b"An IA5 Extension").unwrap(), + ) + .unwrap(), + X509Extension::new_from_der( + &Asn1Object::from_str("2.1.1.4").unwrap(), + false, + &Asn1OctetString::new_from_bytes(b"A Visible Extension").unwrap(), + ) + .unwrap(), + ]; + + let mut cert = Certificate { not_before, not_after, subject, is_ca, extensions, ..Default::default() }; if !is_ca { cert.email_sans = vec!["valid@example.com".into()]; @@ -460,14 +652,11 @@ pub fn new_test_cert_ext( let ca_cert = X509::from_pem(cert_pem.as_bytes())?; let ca_key = PKey::private_key_from_pem(key_pem.as_bytes())?; cert_to_x509(&cert, client_auth, server_auth, Some(&ca_cert), Some(&ca_key), &pkey)? - }, - _ => cert_to_x509(&cert, client_auth, server_auth, None, None, &pkey)? + } + _ => cert_to_x509(&cert, client_auth, server_auth, None, None, &pkey)?, }; - Ok(( - String::from_utf8(x509.to_pem()?)?, - String::from_utf8(pkey.private_key_to_pem_pkcs8()?)?, - )) + Ok((String::from_utf8(x509.to_pem()?)?, String::from_utf8(pkey.private_key_to_pem_pkcs8()?)?)) } pub fn cert_to_x509( @@ -531,7 +720,7 @@ pub fn cert_to_x509( builder.append_extension(BasicConstraints::new().critical().build()?)?; builder.append_extension( KeyUsage::new().critical().non_repudiation().digital_signature().key_encipherment().build()?, - )?; + )?; let mut ext = &mut ExtendedKeyUsage::new(); if client_auth { ext = ext.client_auth(); @@ -560,11 +749,7 @@ pub fn cert_to_x509( Ok(builder.build()) } -pub unsafe fn new_test_crl( - revoked_cert_pem: &str, - ca_cert_pem: &str, - ca_key_pem: &str, -) -> Result { +pub unsafe fn new_test_crl(revoked_cert_pem: &str, ca_cert_pem: &str, ca_key_pem: &str) -> Result { let revoked_cert = X509::from_pem(revoked_cert_pem.as_bytes())?; let ca_cert = X509::from_pem(ca_cert_pem.as_bytes())?; let ca_key = PKey::private_key_from_pem(ca_key_pem.as_bytes())?; @@ -671,11 +856,64 @@ pub fn test_rusty_vault_init(name: &str) -> (String, Arc>) { (root_token, c) } -pub fn new_test_http_server(core: Arc>, tls_config: Option) -> Result<(Server, String), RvError> { +pub fn new_test_http_server( + core: Arc>, + tls_config: Option, +) -> Result<(Server, String), RvError> { + let mut http_server = HttpServer::new(move || { + App::new() + .wrap(middleware::Logger::default()) + .app_data(web::Data::new(core.clone())) + .configure(http::init_service) + .default_service(web::to(|| HttpResponse::NotFound())) + }) + .on_connect(http::request_on_connect_handler); + + if let Some(tls) = tls_config { + let cert_file: &Path = Path::new(&tls.cert_path); + let key_file: &Path = Path::new(&tls.key_path); + + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; + builder + .set_private_key_file(key_file, SslFiletype::PEM) + .map_err(|err| format_err!("unable to read proxy key {} - {}", key_file.display(), err))?; + builder + .set_certificate_chain_file(cert_file) + .map_err(|err| format_err!("unable to read proxy cert {} - {}", cert_file.display(), err))?; + builder.check_private_key()?; + + builder.set_min_proto_version(Some(SslVersion::TLS1_2))?; + builder.set_max_proto_version(Some(SslVersion::TLS1_3))?; + + builder.set_cipher_list( + "TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:HIGH:!PSK:!SRP:!3DES", + )?; + + builder.set_verify_callback(SslVerifyMode::PEER, |_, _| true); + + http_server = http_server.bind_openssl("127.0.0.1:0", builder)?; + } else { + http_server = http_server.bind("127.0.0.1:0")?; + } + + let addr_info = http_server.addrs().first().unwrap().to_string(); + + println!("HTTP Server is running at {}", addr_info); + + Ok((http_server.run(), addr_info)) +} + +pub fn new_test_http_server_with_prometheus( + core: Arc>, + metrics_manager: Arc>, + tls_config: Option, +) -> Result<(Server, String), RvError> { let mut http_server = HttpServer::new(move || { App::new() .wrap(middleware::Logger::default()) + .wrap(from_fn(metrics_midleware)) .app_data(web::Data::new(core.clone())) + .app_data(web::Data::new(Arc::clone(&metrics_manager))) .configure(http::init_service) .default_service(web::to(|| HttpResponse::NotFound())) }) @@ -697,7 +935,9 @@ pub fn new_test_http_server(core: Arc>, tls_config: Option>, tls_config: Option, stop_rx: oneshot::Receiver<()>) -> thread::JoinHandle<()> { +pub fn start_test_http_server( + server: Server, + barrier: Arc, + stop_rx: oneshot::Receiver<()>, +) -> thread::JoinHandle<()> { let server_thread = thread::spawn(move || { let sys = actix_web::rt::System::new(); @@ -743,6 +987,46 @@ pub fn start_test_http_server(server: Server, barrier: Arc, stop_rx: on server_thread } +pub fn start_test_http_server_with_prometheus( + server: Server, + barrier: Arc, + stop_rx: oneshot::Receiver<()>, + system_metrics: Arc, +) -> thread::JoinHandle<()> { + let server_thread = thread::spawn(move || { + let sys = actix_web::rt::System::new(); + + let server_future = async { + server.await.unwrap(); + }; + + let stop_future = async { + stop_rx.await.ok(); + }; + + let system_metrics_fucture = async { + system_metrics.start_collecting().await; + }; + + barrier.wait(); + + let _ = sys.block_on(async { + tokio::select! { + _ = server_future => {}, + _ = system_metrics_fucture => {}, + _ = stop_future => { + actix_rt::System::current().stop(); + } + } + }); + + let _ = sys.run().unwrap(); + println!("HTTP Server has stopped."); + }); + + server_thread +} + pub fn test_list_api(core: &Core, token: &str, path: &str, is_ok: bool) -> Result, RvError> { let mut req = Request::new(path); req.operation = Operation::List;