diff --git a/examples/runner/main.rs b/examples/runner/main.rs index 60e10ba..106be5d 100644 --- a/examples/runner/main.rs +++ b/examples/runner/main.rs @@ -1,12 +1,22 @@ -use ixa::runner::run_with_args; +use clap::Args; +use ixa::runner::run_with_custom_args; + +#[derive(Args, Debug)] +struct Extra { + #[arg(short, long)] + foo: bool, +} fn main() { // Try running this with `cargo run --example runner -- --seed 42` - run_with_args(|context, args| { + run_with_custom_args(|context, args, extra: Option| { context.add_plan(1.0, |_| { println!("Hello, world!"); }); println!("{}", args.seed); + if let Some(extra) = extra { + println!("{}", extra.foo); + } Ok(()) }) .unwrap(); diff --git a/src/runner.rs b/src/runner.rs index 37cff45..6795e89 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -5,11 +5,10 @@ use crate::error::IxaError; use crate::global_properties::ContextGlobalPropertiesExt; use crate::random::ContextRandomExt; use crate::report::ContextReportExt; -use clap::Parser; +use clap::{Args, Command, FromArgMatches as _}; -#[derive(Parser, Debug)] -#[command(version, about, long_about = None)] -pub struct Args { +#[derive(Args, Debug)] +pub struct BaseArgs { /// Random seed #[arg(short, long, default_value = "0")] pub seed: u64, @@ -23,18 +22,41 @@ pub struct Args { pub output_dir: String, } +#[derive(Args)] +pub struct PlaceholderCustom {} + #[allow(clippy::missing_errors_doc)] -pub fn run_with_args(load: F) -> Result<(), Box> +pub fn run_with_custom_args(load: F) -> Result<(), Box> where - F: Fn(&mut Context, Args) -> Result<(), IxaError>, + A: Args, + F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, { - run_with_args_internal(Args::parse(), load) + let cli = Command::new("Ixa"); + let cli = BaseArgs::augment_args(cli); + let cli = A::augment_args(cli); + let matches = cli.get_matches(); + + let base_args_matches = BaseArgs::from_arg_matches(&matches)?; + let custom_matches = A::from_arg_matches(&matches)?; + + run_with_args_internal(base_args_matches, Some(custom_matches), load) } -fn run_with_args_internal(args: Args, load: F) -> Result<(), Box> +#[allow(clippy::missing_errors_doc)] +pub fn run_with_args(load: F) -> Result<(), Box> where - F: Fn(&mut Context, Args) -> Result<(), IxaError>, + F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, { + let cli = Command::new("Ixa"); + let cli = BaseArgs::augment_args(cli); + let matches = cli.get_matches(); + + let base_args_matches = BaseArgs::from_arg_matches(&matches)?; + + run_with_args_internal(base_args_matches, None, load) +} + +fn setup_context(args: &BaseArgs) -> Result { // Instantiate a context let mut context = Context::new(); @@ -53,9 +75,23 @@ where } context.init_random(args.seed); + Ok(context) +} + +fn run_with_args_internal( + args: BaseArgs, + custom_args: Option, + load: F, +) -> Result<(), Box> +where + A: Args, + F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, +{ + // Create a context + let mut context = setup_context(&args)?; // Run the provided Fn - load(&mut context, args)?; + load(&mut context, args, custom_args)?; // Execute the context context.execute(); @@ -68,20 +104,32 @@ mod tests { use crate::define_global_property; use serde::Deserialize; + #[derive(Args, Debug)] + struct CustomArgs { + #[arg(short, long, default_value = "0")] + field: u32, + } + + #[test] + fn test_run_with_custom_args() { + let result = run_with_custom_args(|_, _, _: Option| Ok(())); + assert!(result.is_ok()); + } + #[test] - fn test_run_with_args_default() { - let result = run_with_args(|_, _| Ok(())); + fn test_run_with_args() { + let result = run_with_args(|_, _, _| Ok(())); assert!(result.is_ok()); } #[test] fn test_run_with_random_seed() { - let test_args = Args { + let test_args = BaseArgs { seed: 42, config: String::new(), output_dir: String::new(), }; - let result = run_with_args_internal(test_args, |ctx, _| { + let result = run_with_args_internal(test_args, None, |ctx, _, _: Option| { assert_eq!(ctx.get_base_seed(), 42); Ok(()) }); @@ -96,12 +144,12 @@ mod tests { #[test] fn test_run_with_config_path() { - let test_args = Args { + let test_args = BaseArgs { seed: 42, config: "tests/data/global_properties_runner.json".to_string(), output_dir: String::new(), }; - let result = run_with_args_internal(test_args, |ctx, _| { + let result = run_with_args_internal(test_args, None, |ctx, _, _: Option| { let p3 = ctx.get_global_property_value(RunnerProperty).unwrap(); assert_eq!(p3.field_int, 0); Ok(()) @@ -111,12 +159,12 @@ mod tests { #[test] fn test_run_with_output_dir() { - let test_args = Args { + let test_args = BaseArgs { seed: 42, config: String::new(), output_dir: "data".to_string(), }; - let result = run_with_args_internal(test_args, |ctx, _| { + let result = run_with_args_internal(test_args, None, |ctx, _, _: Option| { let output_dir = &ctx.report_options().directory; assert_eq!(output_dir, &PathBuf::from("data")); Ok(())