diff --git a/examples/runner/main.rs b/examples/runner/main.rs index 106be5d..ce08e16 100644 --- a/examples/runner/main.rs +++ b/examples/runner/main.rs @@ -13,7 +13,7 @@ fn main() { context.add_plan(1.0, |_| { println!("Hello, world!"); }); - println!("{}", args.seed); + println!("{}", args.random_seed); if let Some(extra) = extra { println!("{}", extra.foo); } diff --git a/src/random.rs b/src/random.rs index 61e9e0e..0f60055 100644 --- a/src/random.rs +++ b/src/random.rs @@ -94,8 +94,6 @@ fn get_rng(context: &Context) -> RefMut { pub trait ContextRandomExt { fn init_random(&mut self, base_seed: u64); - fn get_base_seed(&self) -> u64; - /// Gets a random sample from the random number generator associated with the given /// `RngId` by applying the specified sampler function. If the Rng has not been used /// before, one will be created with the base seed you defined in `set_base_random_seed`. @@ -156,12 +154,6 @@ impl ContextRandomExt for Context { rng_map.clear(); } - fn get_base_seed(&self) -> u64 { - self.get_data_container(RngPlugin) - .expect("You must initialize the random number generator with a base seed") - .base_seed - } - fn sample( &self, _rng_id: R, diff --git a/src/runner.rs b/src/runner.rs index 6691b36..66d6ef9 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -11,7 +11,7 @@ use clap::{Args, Command, FromArgMatches as _}; pub struct BaseArgs { /// Random seed #[arg(short, long, default_value = "0")] - pub seed: u64, + pub random_seed: u64, /// Optional path for a global properties config file #[arg(short, long, default_value = "")] @@ -26,12 +26,12 @@ pub struct BaseArgs { pub struct PlaceholderCustom {} fn create_ixa_cli() -> Command { - let cli = Command::new("Ixa"); + let cli = Command::new("ixa"); BaseArgs::augment_args(cli) } #[allow(clippy::missing_errors_doc)] -pub fn run_with_custom_args(load: F) -> Result<(), Box> +pub fn run_with_custom_args(setup_fn: F) -> Result<(), Box> where A: Args, F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, @@ -42,11 +42,11 @@ where let base_args_matches = BaseArgs::from_arg_matches(&matches)?; let custom_matches = A::from_arg_matches(&matches)?; - run_with_args_internal(base_args_matches, Some(custom_matches), load) + run_with_args_internal(base_args_matches, Some(custom_matches), setup_fn) } #[allow(clippy::missing_errors_doc)] -pub fn run_with_args(load: F) -> Result<(), Box> +pub fn run_with_args(setup_fn: F) -> Result<(), Box> where F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, { @@ -54,10 +54,17 @@ where let matches = cli.get_matches(); let base_args_matches = BaseArgs::from_arg_matches(&matches)?; - run_with_args_internal(base_args_matches, None, load) + run_with_args_internal(base_args_matches, None, setup_fn) } -fn setup_context(args: &BaseArgs) -> Result { +fn run_with_args_internal( + args: BaseArgs, + custom_args: Option, + setup_fn: F, +) -> Result<(), Box> +where + F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, +{ // Instantiate a context let mut context = Context::new(); @@ -75,24 +82,10 @@ fn setup_context(args: &BaseArgs) -> Result { report_config.directory(output_dir); } - context.init_random(args.seed); - Ok(context) -} - -fn run_with_args_internal( - args: BaseArgs, - custom_args: Option, - load: F, -) -> Result<(), Box> -where - A: Args, - F: Fn(&mut Context, BaseArgs, Option) -> Result<(), IxaError>, -{ - // Create a context - let mut context = setup_context(&args)?; + context.init_random(args.random_seed); // Run the provided Fn - load(&mut context, args, custom_args)?; + setup_fn(&mut context, args, custom_args)?; // Execute the context context.execute(); @@ -102,7 +95,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::define_global_property; + use crate::{define_global_property, define_rng}; use serde::Deserialize; #[derive(Args, Debug)] @@ -126,12 +119,20 @@ mod tests { #[test] fn test_run_with_random_seed() { let test_args = BaseArgs { - seed: 42, + random_seed: 42, config: String::new(), output_dir: String::new(), }; - let result = run_with_args_internal(test_args, None, |ctx, _, _: Option| { - assert_eq!(ctx.get_base_seed(), 42); + + // Use a comparison context to verify the random seed was set + let mut compare_ctx = Context::new(); + compare_ctx.init_random(42); + define_rng!(TestRng); + let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| { + assert_eq!( + ctx.sample_range(TestRng, 0..100), + compare_ctx.sample_range(TestRng, 0..100) + ); Ok(()) }); assert!(result.is_ok()); @@ -146,11 +147,11 @@ mod tests { #[test] fn test_run_with_config_path() { let test_args = BaseArgs { - seed: 42, + random_seed: 42, config: "tests/data/global_properties_runner.json".to_string(), output_dir: String::new(), }; - let result = run_with_args_internal(test_args, None, |ctx, _, _: Option| { + let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| { let p3 = ctx.get_global_property_value(RunnerProperty).unwrap(); assert_eq!(p3.field_int, 0); Ok(()) @@ -161,11 +162,11 @@ mod tests { #[test] fn test_run_with_output_dir() { let test_args = BaseArgs { - seed: 42, + random_seed: 42, config: String::new(), output_dir: "data".to_string(), }; - let result = run_with_args_internal(test_args, None, |ctx, _, _: Option| { + let result = run_with_args_internal(test_args, None, |ctx, _, _: Option<()>| { let output_dir = &ctx.report_options().directory; assert_eq!(output_dir, &PathBuf::from("data")); Ok(()) @@ -176,7 +177,7 @@ mod tests { #[test] fn test_run_with_custom() { let test_args = BaseArgs { - seed: 42, + random_seed: 42, config: String::new(), output_dir: String::new(), };