From 8ae632131abc2bea8f05b8ede8e8e8b99197d6de Mon Sep 17 00:00:00 2001 From: Kate Hudson <145493147+k88hudson-cfa@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:44:01 -0500 Subject: [PATCH] Add runner module for arg parsing and setup (#114) --- Cargo.toml | 6 + examples/runner/main.rs | 23 +++ src/global_properties.rs | 30 +-- src/lib.rs | 2 + src/runner.rs | 224 +++++++++++++++++++++++ tests/bin/runner_test_custom_args.rs | 18 ++ tests/data/global_properties_runner.json | 5 + 7 files changed, 295 insertions(+), 13 deletions(-) create mode 100644 examples/runner/main.rs create mode 100644 src/runner.rs create mode 100644 tests/bin/runner_test_custom_args.rs create mode 100644 tests/data/global_properties_runner.json diff --git a/Cargo.toml b/Cargo.toml index 24b3d367..b6cd2751 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,8 +21,14 @@ seq-macro = "0.3.5" paste = "1.0.15" ctor = "0.2.8" once_cell = "1.20.2" +clap = { version = "4.5.21", features = ["derive"] } +assert_cmd = "2.0.16" [dev-dependencies] rand_distr = "0.4.3" tempfile = "3.3" ordered-float = "4.3.0" + +[[bin]] +name = "runner_test_custom_args" +path = "tests/bin/runner_test_custom_args.rs" diff --git a/examples/runner/main.rs b/examples/runner/main.rs new file mode 100644 index 00000000..ce08e162 --- /dev/null +++ b/examples/runner/main.rs @@ -0,0 +1,23 @@ +use clap::Args; +use ixa::runner::run_with_custom_args; + +#[derive(Args, Debug)] +struct Extra { + #[arg(short, long)] + foo: bool, +} + +fn main() { + // Try running this with `cargo run --example runner -- --seed 42` + run_with_custom_args(|context, args, extra: Option| { + context.add_plan(1.0, |_| { + println!("Hello, world!"); + }); + println!("{}", args.random_seed); + if let Some(extra) = extra { + println!("{}", extra.foo); + } + Ok(()) + }) + .unwrap(); +} diff --git a/src/global_properties.rs b/src/global_properties.rs index b23923e7..e16c3f1e 100644 --- a/src/global_properties.rs +++ b/src/global_properties.rs @@ -49,19 +49,23 @@ where for<'de> ::Value: serde::Deserialize<'de>, { let properties = GLOBAL_PROPERTIES.lock().unwrap(); - properties.borrow_mut().insert( - name.to_string(), - Arc::new( - |context: &mut Context, name, value| -> Result<(), IxaError> { - let val: T::Value = serde_json::from_value(value)?; - if context.get_global_property_value(T::new()).is_some() { - return Err(IxaError::IxaError(format!("Duplicate property {name}"))); - } - context.set_global_property_value(T::new(), val)?; - Ok(()) - }, - ), - ); + assert!(properties + .borrow_mut() + .insert( + name.to_string(), + Arc::new( + |context: &mut Context, name, value| -> Result<(), IxaError> { + let val: T::Value = serde_json::from_value(value)?; + T::validate(&val)?; + if context.get_global_property_value(T::new()).is_some() { + return Err(IxaError::IxaError(format!("Duplicate property {name}"))); + } + context.set_global_property_value(T::new(), val)?; + Ok(()) + }, + ), + ) + .is_none()); } #[allow(clippy::missing_panics_doc)] diff --git a/src/lib.rs b/src/lib.rs index cfdaa3b2..6f9663cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,3 +49,5 @@ pub use random::{ContextRandomExt, RngId}; pub mod report; pub use report::{ConfigReportOptions, ContextReportExt, Report}; +pub mod runner; +pub use runner::{run_with_args, run_with_custom_args, BaseArgs}; diff --git a/src/runner.rs b/src/runner.rs new file mode 100644 index 00000000..57360fe3 --- /dev/null +++ b/src/runner.rs @@ -0,0 +1,224 @@ +use std::path::{Path, PathBuf}; + +use crate::context::Context; +use crate::error::IxaError; +use crate::global_properties::ContextGlobalPropertiesExt; +use crate::random::ContextRandomExt; +use crate::report::ContextReportExt; +use clap::{Args, Command, FromArgMatches as _}; + +/// Default cli arguments for ixa runner +#[derive(Args, Debug)] +pub struct BaseArgs { + /// Random seed + #[arg(short, long, default_value = "0")] + pub random_seed: u64, + + /// Optional path for a global properties config file + #[arg(short, long, default_value = "")] + pub config: String, + + /// Optional path for report output + #[arg(short, long, default_value = "")] + pub output_dir: String, +} + +#[derive(Args)] +pub struct PlaceholderCustom {} + +fn create_ixa_cli() -> Command { + let cli = Command::new("ixa"); + BaseArgs::augment_args(cli) +} + +/// Runs a simulation with custom cli arguments. +/// +/// This function allows you to define custom arguments and a setup function +/// +/// # Parameters +/// - `setup_fn`: A function that takes a mutable reference to a `Context`, a `BaseArgs` struct, +/// a Option where A is the custom cli arguments struct +/// +/// # Errors +/// Returns an error if argument parsing or the setup function fails +#[allow(clippy::missing_errors_doc)] +pub fn run_with_custom_args(setup_fn: F) -> Result<(), Box> +where + A: Args, + F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, +{ + let mut cli = create_ixa_cli(); + cli = A::augment_args(cli); + let matches = cli.get_matches(); + + let base_args_matches = BaseArgs::from_arg_matches(&matches)?; + let custom_matches = A::from_arg_matches(&matches)?; + run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn) +} + +/// Runs a simulation with default cli arguments +/// +/// This function parses command line arguments allows you to define a setup function +/// +/// # Parameters +/// - `setup_fn`: A function that takes a mutable reference to a `Context`and `BaseArgs` struct +/// +/// # Errors +/// Returns an error if argument parsing or the setup function fails +#[allow(clippy::missing_errors_doc)] +pub fn run_with_args(setup_fn: F) -> Result<(), Box> +where + F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, +{ + let cli = create_ixa_cli(); + let matches = cli.get_matches(); + + let base_args_matches = BaseArgs::from_arg_matches(&matches)?; + run_with_args_internal(base_args_matches, None, setup_fn) +} + +fn run_with_args_internal( + args: BaseArgs, + custom_args: Option, + setup_fn: F, +) -> Result<(), Box> +where + F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, +{ + // Instantiate a context + let mut context = Context::new(); + + // Optionally set global properties from a file + if !args.config.is_empty() { + println!("Loading global properties from: {}", args.config); + let config_path = Path::new(&args.config); + context.load_global_properties(config_path)?; + } + + // Optionally set output dir for reports + if !args.output_dir.is_empty() { + let output_dir = PathBuf::from(&args.output_dir); + let report_config = context.report_options(); + report_config.directory(output_dir); + } + + context.init_random(args.random_seed); + + // Run the provided Fn + setup_fn(&mut context, args, custom_args)?; + + // Execute the context + context.execute(); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{define_global_property, define_rng}; + use assert_cmd::Command; + use serde::Deserialize; + + #[derive(Args, Debug)] + struct CustomArgs { + #[arg(short, long, default_value = "0")] + field: u32, + } + + #[test] + fn test_run_with_custom_args() { + let result = run_with_custom_args(|_, _, _: Option| Ok(())); + assert!(result.is_ok()); + } + + #[test] + fn test_cli_invocation_with_custom_args() { + // Note this target is defined in the bin section of Cargo.toml + // and the entry point is in tests/bin/runner_test_custom_args + Command::cargo_bin("runner_test_custom_args") + .unwrap() + .args(["--field", "42"]) + .assert() + .success() + .stdout("42\n"); + } + + #[test] + fn test_run_with_args() { + let result = run_with_args(|_, _, _| Ok(())); + assert!(result.is_ok()); + } + + #[test] + fn test_run_with_random_seed() { + let test_args = BaseArgs { + random_seed: 42, + config: String::new(), + output_dir: String::new(), + }; + + // Use a comparison context to verify the random seed was set + let mut compare_ctx = Context::new(); + compare_ctx.init_random(42); + define_rng!(TestRng); + let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| { + assert_eq!( + ctx.sample_range(TestRng, 0..100), + compare_ctx.sample_range(TestRng, 0..100) + ); + Ok(()) + }); + assert!(result.is_ok()); + } + + #[derive(Deserialize)] + pub struct RunnerPropertyType { + field_int: u32, + } + define_global_property!(RunnerProperty, RunnerPropertyType); + + #[test] + fn test_run_with_config_path() { + let test_args = BaseArgs { + random_seed: 42, + config: "tests/data/global_properties_runner.json".to_string(), + output_dir: String::new(), + }; + let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| { + let p3 = ctx.get_global_property_value(RunnerProperty).unwrap(); + assert_eq!(p3.field_int, 0); + Ok(()) + }); + assert!(result.is_ok()); + } + + #[test] + fn test_run_with_output_dir() { + let test_args = BaseArgs { + random_seed: 42, + config: String::new(), + output_dir: "data".to_string(), + }; + let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| { + let output_dir = &ctx.report_options().directory; + assert_eq!(output_dir, &PathBuf::from("data")); + Ok(()) + }); + assert!(result.is_ok()); + } + + #[test] + fn test_run_with_custom() { + let test_args = BaseArgs { + random_seed: 42, + config: String::new(), + output_dir: String::new(), + }; + let custom = CustomArgs { field: 42 }; + let result = run_with_args_internal(test_args, Some(custom), |_, _, c| { + assert_eq!(c.unwrap().field, 42); + Ok(()) + }); + assert!(result.is_ok()); + } +} diff --git a/tests/bin/runner_test_custom_args.rs b/tests/bin/runner_test_custom_args.rs new file mode 100644 index 00000000..2870bd6c --- /dev/null +++ b/tests/bin/runner_test_custom_args.rs @@ -0,0 +1,18 @@ +use clap::Args; +use ixa::runner::run_with_custom_args; + +#[derive(Args, Debug)] +struct Extra { + #[arg(short, long)] + field: u32, +} + +fn main() { + run_with_custom_args(|_context, _args, extra: Option| { + if let Some(extra) = extra { + println!("{}", extra.field); + } + Ok(()) + }) + .unwrap(); +} diff --git a/tests/data/global_properties_runner.json b/tests/data/global_properties_runner.json new file mode 100644 index 00000000..7974088f --- /dev/null +++ b/tests/data/global_properties_runner.json @@ -0,0 +1,5 @@ +{ + "ixa.RunnerProperty": { + "field_int": 0 + } +}