From e8059c7b516a94d3f9c175c1afd8276df803dbaf Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Sun, 1 Dec 2024 19:20:25 -0800 Subject: [PATCH] Add run_with_args --- Cargo.toml | 1 + examples/runner/main.rs | 13 +++ src/global_properties.rs | 32 +++--- src/lib.rs | 1 + src/random.rs | 8 ++ src/runner.rs | 126 +++++++++++++++++++++++ tests/data/global_properties_runner.json | 5 + 7 files changed, 172 insertions(+), 14 deletions(-) create mode 100644 examples/runner/main.rs create mode 100644 src/runner.rs create mode 100644 tests/data/global_properties_runner.json diff --git a/Cargo.toml b/Cargo.toml index 2769f8ae..f065fcb6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ seq-macro = "0.3.5" paste = "1.0.15" ctor = "0.2.8" once_cell = "1.20.2" +clap = { version = "4.5.21", features = ["derive"] } [dev-dependencies] rand_distr = "0.4.3" diff --git a/examples/runner/main.rs b/examples/runner/main.rs new file mode 100644 index 00000000..60e10ba8 --- /dev/null +++ b/examples/runner/main.rs @@ -0,0 +1,13 @@ +use ixa::runner::run_with_args; + +fn main() { + // Try running this with `cargo run --example runner -- --seed 42` + run_with_args(|context, args| { + context.add_plan(1.0, |_| { + println!("Hello, world!"); + }); + println!("{}", args.seed); + Ok(()) + }) + .unwrap(); +} diff --git a/src/global_properties.rs b/src/global_properties.rs index 6148d02a..4ba259ef 100644 --- a/src/global_properties.rs +++ b/src/global_properties.rs @@ -49,20 +49,24 @@ where for<'de> ::Value: serde::Deserialize<'de>, { let properties = GLOBAL_PROPERTIES.lock().unwrap(); - properties.borrow_mut().insert( - name.to_string(), - Arc::new( - |context: &mut Context, name, value| -> Result<(), IxaError> { - let val: T::Value = serde_json::from_value(value)?; - T::validate(&val)?; - if context.get_global_property_value(T::new()).is_some() { - return Err(IxaError::IxaError(format!("Duplicate property {name}"))); - } - context.set_global_property_value(T::new(), val); - Ok(()) - }, - ), - ); + // We should not define duplicate properties + assert!(properties + .borrow_mut() + .insert( + name.to_string(), + Arc::new( + |context: &mut Context, name, value| -> Result<(), IxaError> { + let val: T::Value = serde_json::from_value(value)?; + T::validate(&val)?; + if context.get_global_property_value(T::new()).is_some() { + return Err(IxaError::IxaError(format!("Duplicate property {name}"))); + } + context.set_global_property_value(T::new(), val); + Ok(()) + }, + ), + ) + .is_none()); } #[allow(clippy::missing_panics_doc)] diff --git a/src/lib.rs b/src/lib.rs index 79cd64d0..3da9a5e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,3 +34,4 @@ pub mod people; pub mod plan; pub mod random; pub mod report; +pub mod runner; diff --git a/src/random.rs b/src/random.rs index 0f600554..61e9e0ef 100644 --- a/src/random.rs +++ b/src/random.rs @@ -94,6 +94,8 @@ fn get_rng(context: &Context) -> RefMut { pub trait ContextRandomExt { fn init_random(&mut self, base_seed: u64); + fn get_base_seed(&self) -> u64; + /// Gets a random sample from the random number generator associated with the given /// `RngId` by applying the specified sampler function. If the Rng has not been used /// before, one will be created with the base seed you defined in `set_base_random_seed`. @@ -154,6 +156,12 @@ impl ContextRandomExt for Context { rng_map.clear(); } + fn get_base_seed(&self) -> u64 { + self.get_data_container(RngPlugin) + .expect("You must initialize the random number generator with a base seed") + .base_seed + } + fn sample( &self, _rng_id: R, diff --git a/src/runner.rs b/src/runner.rs new file mode 100644 index 00000000..37cff45d --- /dev/null +++ b/src/runner.rs @@ -0,0 +1,126 @@ +use std::path::{Path, PathBuf}; + +use crate::context::Context; +use crate::error::IxaError; +use crate::global_properties::ContextGlobalPropertiesExt; +use crate::random::ContextRandomExt; +use crate::report::ContextReportExt; +use clap::Parser; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +pub struct Args { + /// Random seed + #[arg(short, long, default_value = "0")] + pub seed: u64, + + /// Optional path for a global properties config file + #[arg(short, long, default_value = "")] + pub config: String, + + /// Optional path for report output + #[arg(short, long, default_value = "")] + pub output_dir: String, +} + +#[allow(clippy::missing_errors_doc)] +pub fn run_with_args(load: F) -> Result<(), Box> +where + F: Fn(&mut Context, Args) -> Result<(), IxaError>, +{ + run_with_args_internal(Args::parse(), load) +} + +fn run_with_args_internal(args: Args, load: F) -> Result<(), Box> +where + F: Fn(&mut Context, Args) -> Result<(), IxaError>, +{ + // Instantiate a context + let mut context = Context::new(); + + // Optionally set global properties from a file + if !args.config.is_empty() { + println!("Loading global properties from: {}", args.config); + let config_path = Path::new(&args.config); + context.load_global_properties(config_path)?; + } + + // Optionally set output dir for reports + if !args.output_dir.is_empty() { + let output_dir = PathBuf::from(&args.output_dir); + let report_config = context.report_options(); + report_config.directory(output_dir); + } + + context.init_random(args.seed); + + // Run the provided Fn + load(&mut context, args)?; + + // Execute the context + context.execute(); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::define_global_property; + use serde::Deserialize; + + #[test] + fn test_run_with_args_default() { + let result = run_with_args(|_, _| Ok(())); + assert!(result.is_ok()); + } + + #[test] + fn test_run_with_random_seed() { + let test_args = Args { + seed: 42, + config: String::new(), + output_dir: String::new(), + }; + let result = run_with_args_internal(test_args, |ctx, _| { + assert_eq!(ctx.get_base_seed(), 42); + Ok(()) + }); + assert!(result.is_ok()); + } + + #[derive(Deserialize)] + pub struct RunnerPropertyType { + field_int: u32, + } + define_global_property!(RunnerProperty, RunnerPropertyType); + + #[test] + fn test_run_with_config_path() { + let test_args = Args { + seed: 42, + config: "tests/data/global_properties_runner.json".to_string(), + output_dir: String::new(), + }; + let result = run_with_args_internal(test_args, |ctx, _| { + let p3 = ctx.get_global_property_value(RunnerProperty).unwrap(); + assert_eq!(p3.field_int, 0); + Ok(()) + }); + assert!(result.is_ok()); + } + + #[test] + fn test_run_with_output_dir() { + let test_args = Args { + seed: 42, + config: String::new(), + output_dir: "data".to_string(), + }; + let result = run_with_args_internal(test_args, |ctx, _| { + let output_dir = &ctx.report_options().directory; + assert_eq!(output_dir, &PathBuf::from("data")); + Ok(()) + }); + assert!(result.is_ok()); + } +} diff --git a/tests/data/global_properties_runner.json b/tests/data/global_properties_runner.json new file mode 100644 index 00000000..7974088f --- /dev/null +++ b/tests/data/global_properties_runner.json @@ -0,0 +1,5 @@ +{ + "ixa.RunnerProperty": { + "field_int": 0 + } +}