From 6cdd875ba565d349d56c04ae25cd89ad33959690 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Mon, 12 Aug 2024 23:40:12 -0400 Subject: [PATCH] explicit initialization --- src/random.rs | 49 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/src/random.rs b/src/random.rs index 42b36324..013c5035 100644 --- a/src/random.rs +++ b/src/random.rs @@ -51,7 +51,7 @@ crate::context::define_data_plugin!( #[allow(clippy::module_name_repetitions)] pub trait RandomContext { fn set_base_random_seed(&mut self, base_seed: u64); - + fn create_rng(&self); fn get_rng(&self) -> RefMut<'_, R::RngType>; } @@ -67,6 +67,19 @@ impl RandomContext for Context { rng_map.clear(); } + fn create_rng(&self) { + let data_container = self + .get_data_container::() + .expect("You must initialize the random number generator with a base seed"); + let mut random_holders = data_container.rng_holders.try_borrow_mut().unwrap(); + let base_seed = data_container.base_seed; + let seed_offset = fxhash::hash64(R::get_name()); + let holder = RngHolder { + rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), + }; + random_holders.insert(TypeId::of::(), holder); + } + /// Gets a mutable reference to the random number generator associated with the given /// `RngId`. If the Rng has not been used before, one will be created with the base seed /// you defined in `set_base_random_seed`. Note that this will panic if `set_base_random_seed` was not called yet. @@ -74,20 +87,16 @@ impl RandomContext for Context { let data_container = self .get_data_container::() .expect("You must initialize the random number generator with a base seed"); - let random_holders = data_container.rng_holders.try_borrow_mut().unwrap(); - - let random_holder = RefMut::map(random_holders, |random_holders| { - random_holders.entry(TypeId::of::()).or_insert_with(|| { - let base_seed = data_container.base_seed; - let seed_offset = fxhash::hash64(R::get_name()); - RngHolder { - rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), - } - }) - }); - - RefMut::map(random_holder, |random_holder| { - random_holder.rng.downcast_mut::().unwrap() + + let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap(); + + RefMut::map(rng_holders, |holders| { + holders + .get_mut(&TypeId::of::()) + .expect("You must call initialize with create_rng") + .rng + .downcast_mut::() + .unwrap() }) } } @@ -105,6 +114,7 @@ mod test { fn get_rng_basic() { let mut context = Context::new(); context.set_base_random_seed(42); + context.create_rng::(); let mut foo_rng = context.get_rng::(); assert_eq!(foo_rng.next_u64(), 5113542052170610017); @@ -116,6 +126,7 @@ mod test { #[should_panic(expected = "You must initialize the random number generator with a base seed")] fn panic_if_not_initialized() { let context = Context::new(); + context.create_rng::(); context.get_rng::(); } @@ -124,6 +135,8 @@ mod test { fn get_rng_one_ref_per_rng_id() { let mut context = Context::new(); context.set_base_random_seed(42); + + context.create_rng::(); let mut foo_rng = context.get_rng::(); // This should panic because we already have a mutable reference to FooRng @@ -137,10 +150,13 @@ mod test { let mut context = Context::new(); context.set_base_random_seed(42); + context.create_rng::(); let mut foo_rng = context.get_rng::(); + foo_rng.next_u64(); drop(foo_rng); + context.create_rng::(); let mut bar_rng = context.get_rng::(); bar_rng.next_u64(); } @@ -150,6 +166,7 @@ mod test { let mut context = Context::new(); context.set_base_random_seed(42); + context.create_rng::(); let mut foo_rng = context.get_rng::(); let run_0 = foo_rng.next_u64(); let run_1 = foo_rng.next_u64(); @@ -157,6 +174,7 @@ mod test { // Reset with same seed, ensure we get the same values context.set_base_random_seed(42); + context.create_rng::(); let mut foo_rng = context.get_rng::(); assert_eq!(run_0, foo_rng.next_u64()); assert_eq!(run_1, foo_rng.next_u64()); @@ -164,6 +182,7 @@ mod test { // Reset with different seed, ensure we get different values context.set_base_random_seed(88); + context.create_rng::(); let mut foo_rng = context.get_rng::(); assert_ne!(run_0, foo_rng.next_u64()); assert_ne!(run_1, foo_rng.next_u64());