Skip to content

Commit

Permalink
remove the get_random_seed helper and write a new test
Browse files Browse the repository at this point in the history
  • Loading branch information
k88hudson-cfa committed Dec 2, 2024
1 parent e03d876 commit 36315c5
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 41 deletions.
2 changes: 1 addition & 1 deletion examples/runner/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() {
context.add_plan(1.0, |_| {
println!("Hello, world!");
});
println!("{}", args.seed);
println!("{}", args.random_seed);
if let Some(extra) = extra {
println!("{}", extra.foo);
}
Expand Down
8 changes: 0 additions & 8 deletions src/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ fn get_rng<R: RngId + 'static>(context: &Context) -> RefMut<R::RngType> {
pub trait ContextRandomExt {
fn init_random(&mut self, base_seed: u64);

fn get_base_seed(&self) -> u64;

/// Gets a random sample from the random number generator associated with the given
/// `RngId` by applying the specified sampler function. If the Rng has not been used
/// before, one will be created with the base seed you defined in `set_base_random_seed`.
Expand Down Expand Up @@ -156,12 +154,6 @@ impl ContextRandomExt for Context {
rng_map.clear();
}

fn get_base_seed(&self) -> u64 {
self.get_data_container(RngPlugin)
.expect("You must initialize the random number generator with a base seed")
.base_seed
}

fn sample<R: RngId + 'static, T>(
&self,
_rng_id: R,
Expand Down
65 changes: 33 additions & 32 deletions src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use clap::{Args, Command, FromArgMatches as _};
pub struct BaseArgs {
/// Random seed
#[arg(short, long, default_value = "0")]
pub seed: u64,
pub random_seed: u64,

/// Optional path for a global properties config file
#[arg(short, long, default_value = "")]
Expand All @@ -26,12 +26,12 @@ pub struct BaseArgs {
pub struct PlaceholderCustom {}

fn create_ixa_cli() -> Command {
let cli = Command::new("Ixa");
let cli = Command::new("ixa");
BaseArgs::augment_args(cli)
}

#[allow(clippy::missing_errors_doc)]
pub fn run_with_custom_args<A, F>(load: F) -> Result<(), Box<dyn std::error::Error>>
pub fn run_with_custom_args<A, F>(setup_fn: F) -> Result<(), Box<dyn std::error::Error>>
where
A: Args,
F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
Expand All @@ -42,22 +42,29 @@ where

let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
let custom_matches = A::from_arg_matches(&matches)?;
run_with_args_internal(base_args_matches, Some(custom_matches), load)
run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn)
}

#[allow(clippy::missing_errors_doc)]
pub fn run_with_args<F>(load: F) -> Result<(), Box<dyn std::error::Error>>
pub fn run_with_args<F>(setup_fn: F) -> Result<(), Box<dyn std::error::Error>>
where
F: Fn(&mut Context, BaseArgs, Option<PlaceholderCustom>) -> Result<(), IxaError>,
{
let cli = create_ixa_cli();
let matches = cli.get_matches();

let base_args_matches = BaseArgs::from_arg_matches(&matches)?;
run_with_args_internal(base_args_matches, None, load)
run_with_args_internal(base_args_matches, None, setup_fn)
}

fn setup_context(args: &BaseArgs) -> Result<Context, IxaError> {
fn run_with_args_internal<A, F>(
args: BaseArgs,
custom_args: Option<A>,
setup_fn: F,
) -> Result<(), Box<dyn std::error::Error>>
where
F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
{
// Instantiate a context
let mut context = Context::new();

Expand All @@ -75,24 +82,10 @@ fn setup_context(args: &BaseArgs) -> Result<Context, IxaError> {
report_config.directory(output_dir);
}

context.init_random(args.seed);
Ok(context)
}

fn run_with_args_internal<A, F>(
args: BaseArgs,
custom_args: Option<A>,
load: F,
) -> Result<(), Box<dyn std::error::Error>>
where
A: Args,
F: Fn(&mut Context, BaseArgs, Option<A>) -> Result<(), IxaError>,
{
// Create a context
let mut context = setup_context(&args)?;
context.init_random(args.random_seed);

// Run the provided Fn
load(&mut context, args, custom_args)?;
setup_fn(&mut context, args, custom_args)?;

// Execute the context
context.execute();
Expand All @@ -102,7 +95,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::define_global_property;
use crate::{define_global_property, define_rng};
use serde::Deserialize;

#[derive(Args, Debug)]
Expand All @@ -126,12 +119,20 @@ mod tests {
#[test]
fn test_run_with_random_seed() {
let test_args = BaseArgs {
seed: 42,
random_seed: 42,
config: String::new(),
output_dir: String::new(),
};
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<CustomArgs>| {
assert_eq!(ctx.get_base_seed(), 42);

// Use a comparison context to verify the random seed was set
let mut compare_ctx = Context::new();
compare_ctx.init_random(42);
define_rng!(TestRng);
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
assert_eq!(
ctx.sample_range(TestRng, 0..100),
compare_ctx.sample_range(TestRng, 0..100)
);
Ok(())
});
assert!(result.is_ok());
Expand All @@ -146,11 +147,11 @@ mod tests {
#[test]
fn test_run_with_config_path() {
let test_args = BaseArgs {
seed: 42,
random_seed: 42,
config: "tests/data/global_properties_runner.json".to_string(),
output_dir: String::new(),
};
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<CustomArgs>| {
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
let p3 = ctx.get_global_property_value(RunnerProperty).unwrap();
assert_eq!(p3.field_int, 0);
Ok(())
Expand All @@ -161,11 +162,11 @@ mod tests {
#[test]
fn test_run_with_output_dir() {
let test_args = BaseArgs {
seed: 42,
random_seed: 42,
config: String::new(),
output_dir: "data".to_string(),
};
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<CustomArgs>| {
let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| {
let output_dir = &ctx.report_options().directory;
assert_eq!(output_dir, &PathBuf::from("data"));
Ok(())
Expand All @@ -176,7 +177,7 @@ mod tests {
#[test]
fn test_run_with_custom() {
let test_args = BaseArgs {
seed: 42,
random_seed: 42,
config: String::new(),
output_dir: String::new(),
};
Expand Down

0 comments on commit 36315c5

Please sign in to comment.