Skip to content

Commit

Permalink
Add run_with_args
Browse files Browse the repository at this point in the history
  • Loading branch information
k88hudson-cfa committed Dec 2, 2024
1 parent 1d9f8e3 commit e8059c7
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 14 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ seq-macro = "0.3.5"
paste = "1.0.15"
ctor = "0.2.8"
once_cell = "1.20.2"
clap = { version = "4.5.21", features = ["derive"] }

[dev-dependencies]
rand_distr = "0.4.3"
Expand Down
13 changes: 13 additions & 0 deletions examples/runner/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use ixa::runner::run_with_args;

fn main() {
// Try running this with `cargo run --example runner -- --seed 42`
run_with_args(|context, args| {
context.add_plan(1.0, |_| {
println!("Hello, world!");
});
println!("{}", args.seed);
Ok(())
})
.unwrap();
}
32 changes: 18 additions & 14 deletions src/global_properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,24 @@ where
for<'de> <T as GlobalProperty>::Value: serde::Deserialize<'de>,
{
let properties = GLOBAL_PROPERTIES.lock().unwrap();
properties.borrow_mut().insert(
name.to_string(),
Arc::new(
|context: &mut Context, name, value| -> Result<(), IxaError> {
let val: T::Value = serde_json::from_value(value)?;
T::validate(&val)?;
if context.get_global_property_value(T::new()).is_some() {
return Err(IxaError::IxaError(format!("Duplicate property {name}")));
}
context.set_global_property_value(T::new(), val);
Ok(())
},
),
);
// We should not define duplicate properties
assert!(properties
.borrow_mut()
.insert(
name.to_string(),
Arc::new(
|context: &mut Context, name, value| -> Result<(), IxaError> {
let val: T::Value = serde_json::from_value(value)?;
T::validate(&val)?;
if context.get_global_property_value(T::new()).is_some() {
return Err(IxaError::IxaError(format!("Duplicate property {name}")));
}
context.set_global_property_value(T::new(), val);
Ok(())
},
),
)
.is_none());
}

#[allow(clippy::missing_panics_doc)]
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ pub mod people;
pub mod plan;
pub mod random;
pub mod report;
pub mod runner;
8 changes: 8 additions & 0 deletions src/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ fn get_rng<R: RngId + 'static>(context: &Context) -> RefMut<R::RngType> {
pub trait ContextRandomExt {
fn init_random(&mut self, base_seed: u64);

fn get_base_seed(&self) -> u64;

/// Gets a random sample from the random number generator associated with the given
/// `RngId` by applying the specified sampler function. If the Rng has not been used
/// before, one will be created with the base seed you defined in `set_base_random_seed`.
Expand Down Expand Up @@ -154,6 +156,12 @@ impl ContextRandomExt for Context {
rng_map.clear();
}

fn get_base_seed(&self) -> u64 {
self.get_data_container(RngPlugin)
.expect("You must initialize the random number generator with a base seed")
.base_seed
}

fn sample<R: RngId + 'static, T>(
&self,
_rng_id: R,
Expand Down
126 changes: 126 additions & 0 deletions src/runner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use std::path::{Path, PathBuf};

use crate::context::Context;
use crate::error::IxaError;
use crate::global_properties::ContextGlobalPropertiesExt;
use crate::random::ContextRandomExt;
use crate::report::ContextReportExt;
use clap::Parser;

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
pub struct Args {
/// Random seed
#[arg(short, long, default_value = "0")]
pub seed: u64,

/// Optional path for a global properties config file
#[arg(short, long, default_value = "")]
pub config: String,

/// Optional path for report output
#[arg(short, long, default_value = "")]
pub output_dir: String,
}

#[allow(clippy::missing_errors_doc)]
pub fn run_with_args<F>(load: F) -> Result<(), Box<dyn std::error::Error>>
where
F: Fn(&mut Context, Args) -> Result<(), IxaError>,
{
run_with_args_internal(Args::parse(), load)
}

fn run_with_args_internal<F>(args: Args, load: F) -> Result<(), Box<dyn std::error::Error>>
where
F: Fn(&mut Context, Args) -> Result<(), IxaError>,
{
// Instantiate a context
let mut context = Context::new();

// Optionally set global properties from a file
if !args.config.is_empty() {
println!("Loading global properties from: {}", args.config);
let config_path = Path::new(&args.config);
context.load_global_properties(config_path)?;
}

// Optionally set output dir for reports
if !args.output_dir.is_empty() {
let output_dir = PathBuf::from(&args.output_dir);
let report_config = context.report_options();
report_config.directory(output_dir);
}

context.init_random(args.seed);

// Run the provided Fn
load(&mut context, args)?;

// Execute the context
context.execute();
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::define_global_property;
use serde::Deserialize;

#[test]
fn test_run_with_args_default() {
let result = run_with_args(|_, _| Ok(()));
assert!(result.is_ok());
}

#[test]
fn test_run_with_random_seed() {
let test_args = Args {
seed: 42,
config: String::new(),
output_dir: String::new(),
};
let result = run_with_args_internal(test_args, |ctx, _| {
assert_eq!(ctx.get_base_seed(), 42);
Ok(())
});
assert!(result.is_ok());
}

#[derive(Deserialize)]
pub struct RunnerPropertyType {
field_int: u32,
}
define_global_property!(RunnerProperty, RunnerPropertyType);

#[test]
fn test_run_with_config_path() {
let test_args = Args {
seed: 42,
config: "tests/data/global_properties_runner.json".to_string(),
output_dir: String::new(),
};
let result = run_with_args_internal(test_args, |ctx, _| {
let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
assert_eq!(p3.field_int, 0);
Ok(())
});
assert!(result.is_ok());
}

#[test]
fn test_run_with_output_dir() {
let test_args = Args {
seed: 42,
config: String::new(),
output_dir: "data".to_string(),
};
let result = run_with_args_internal(test_args, |ctx, _| {
let output_dir = &ctx.report_options().directory;
assert_eq!(output_dir, &PathBuf::from("data"));
Ok(())
});
assert!(result.is_ok());
}
}
5 changes: 5 additions & 0 deletions tests/data/global_properties_runner.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"ixa.RunnerProperty": {
"field_int": 0
}
}

0 comments on commit e8059c7

Please sign in to comment.