Skip to content

Commit

Permalink
custom args
Browse files Browse the repository at this point in the history
  • Loading branch information
k88hudson-cfa committed Dec 2, 2024
1 parent e8059c7 commit 043d77b
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 20 deletions.
14 changes: 12 additions & 2 deletions examples/runner/main.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
use ixa::runner::run_with_args;
use clap::Args;
use ixa::runner::run_with_custom_args;

#[derive(Args, Debug)]
struct Extra {
#[arg(short, long)]
foo: bool,
}

fn main() {
// Try running this with `cargo run --example runner -- --seed 42`
run_with_args(|context, args| {
run_with_custom_args(|context, args, extra: Option<Extra>| {
context.add_plan(1.0, |_| {
println!("Hello, world!");
});
println!("{}", args.seed);
if let Some(extra) = extra {
println!("{}", extra.foo);
}
Ok(())
})
.unwrap();
Expand Down
84 changes: 66 additions & 18 deletions src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ use crate::error::IxaError;
use crate::global_properties::ContextGlobalPropertiesExt;
use crate::random::ContextRandomExt;
use crate::report::ContextReportExt;
use clap::Parser;
use clap::{Args, Command, FromArgMatches as _};

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
pub struct Args {
#[derive(Args, Debug)]
pub struct BaseArgs {
/// Random seed
#[arg(short, long, default_value = "0")]
pub seed: u64,
Expand All @@ -23,18 +22,41 @@ pub struct Args {
pub output_dir: String,
}

#[derive(Args)]
pub struct PlaceholderCustom {}

#[allow(clippy::missing_errors_doc)]
pub fn run_with_args<F>(load: F) -> Result<(), Box<dyn std::error::Error>>
pub fn run_with_custom_args<A, F>(load: F) -> Result<(), Box<dyn std::error::Error>>
where
F: Fn(&mut Context, Args) -> Result<(), IxaError>,
A: Args,
F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
{
run_with_args_internal(Args::parse(), load)
let cli = Command::new("Ixa");
let cli = BaseArgs::augment_args(cli);
let cli = A::augment_args(cli);
let matches = cli.get_matches();

let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
let custom_matches = A::from_arg_matches(&matches)?;

run_with_args_internal(base_args_matches, Some(custom_matches), load)
}

fn run_with_args_internal<F>(args: Args, load: F) -> Result<(), Box<dyn std::error::Error>>
#[allow(clippy::missing_errors_doc)]
pub fn run_with_args<F>(load: F) -> Result<(), Box<dyn std::error::Error>>
where
F: Fn(&mut Context, Args) -> Result<(), IxaError>,
F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
{
let cli = Command::new("Ixa");
let cli = BaseArgs::augment_args(cli);
let matches = cli.get_matches();

let base_args_matches = BaseArgs::from_arg_matches(&matches)?;

run_with_args_internal(base_args_matches, None, load)
}

fn setup_context(args: &BaseArgs) -> Result<Context, IxaError> {
// Instantiate a context
let mut context = Context::new();

Expand All @@ -53,9 +75,23 @@ where
}

context.init_random(args.seed);
Ok(context)
}

fn run_with_args_internal<A, F>(
args: BaseArgs,
custom_args: Option<A>,
load: F,
) -> Result<(), Box<dyn std::error::Error>>
where
A: Args,
F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
{
// Create a context
let mut context = setup_context(&args)?;

// Run the provided Fn
load(&mut context, args)?;
load(&mut context, args, custom_args)?;

// Execute the context
context.execute();
Expand All @@ -68,20 +104,32 @@ mod tests {
use crate::define_global_property;
use serde::Deserialize;

#[derive(Args, Debug)]
struct CustomArgs {
#[arg(short, long, default_value = "0")]
field: u32,
}

#[test]
fn test_run_with_custom_args() {
let result = run_with_custom_args(|_, _, _: Option<CustomArgs>| Ok(()));
assert!(result.is_ok());
}

#[test]
fn test_run_with_args_default() {
let result = run_with_args(|_, _| Ok(()));
fn test_run_with_args() {
let result = run_with_args(|_, _, _| Ok(()));
assert!(result.is_ok());
}

#[test]
fn test_run_with_random_seed() {
let test_args = Args {
let test_args = BaseArgs {
seed: 42,
config: String::new(),
output_dir: String::new(),
};
let result = run_with_args_internal(test_args, |ctx, _| {
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<CustomArgs>| {
assert_eq!(ctx.get_base_seed(), 42);
Ok(())
});
Expand All @@ -96,12 +144,12 @@ mod tests {

#[test]
fn test_run_with_config_path() {
let test_args = Args {
let test_args = BaseArgs {
seed: 42,
config: "tests/data/global_properties_runner.json".to_string(),
output_dir: String::new(),
};
let result = run_with_args_internal(test_args, |ctx, _| {
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<CustomArgs>| {
let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
assert_eq!(p3.field_int, 0);
Ok(())
Expand All @@ -111,12 +159,12 @@ mod tests {

#[test]
fn test_run_with_output_dir() {
let test_args = Args {
let test_args = BaseArgs {
seed: 42,
config: String::new(),
output_dir: "data".to_string(),
};
let result = run_with_args_internal(test_args, |ctx, _| {
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<CustomArgs>| {
let output_dir = &ctx.report_options().directory;
assert_eq!(output_dir, &PathBuf::from("data"));
Ok(())
Expand Down

0 comments on commit 043d77b

Please sign in to comment.