diff --git a/Cargo.toml b/Cargo.toml index be2f210f..ca2369d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,3 +8,5 @@ license = "Apache-2.0" homepage = "https://github.com/CDCgov/ixa" [dependencies] +fxhash = "0.2.1" +rand = "0.8.5" diff --git a/src/lib.rs b/src/lib.rs index dfe09e70..de495e01 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,3 +28,4 @@ //! person trying to infect susceptible people in the population. pub mod context; pub mod plan; +pub mod random; diff --git a/src/random.rs b/src/random.rs new file mode 100644 index 00000000..6bc95209 --- /dev/null +++ b/src/random.rs @@ -0,0 +1,162 @@ +use crate::context::Context; +use rand::SeedableRng; +use std::any::{Any, TypeId}; +use std::cell::{RefCell, RefMut}; +use std::collections::HashMap; + +#[macro_export] +macro_rules! define_rng { + ($random_id:ident) => { + struct $random_id {} + + impl $crate::random::RngId for $random_id { + // TODO: This is hardcoded to StdRng; we should replace this + type RngType = rand::rngs::StdRng; + + fn get_name() -> &'static str { + stringify!($random_id) + } + } + }; +} +pub use define_rng; + +pub trait RngId: Any { + type RngType: SeedableRng; + fn get_name() -> &'static str; +} + +struct RngHolder { + rng: Box, +} + +struct RngData { + base_seed: u64, + rng_holders: RefCell>, +} + +crate::context::define_data_plugin!( + RngPlugin, + RngData, + RngData { + base_seed: 0, + rng_holders: RefCell::new(HashMap::new()), + } +); + +#[allow(clippy::module_name_repetitions)] +pub trait RandomContext { + fn init_random(&mut self, base_seed: u64); + + fn get_rng(&self) -> RefMut<'_, R::RngType>; +} + +impl RandomContext for Context { + fn init_random(&mut self, base_seed: u64) { + let data_container = self.get_data_container_mut::(); + data_container.base_seed = base_seed; + + // Clear any existing Rngs to ensure they are reseeded + let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap(); + rng_map.clear(); + } + + fn get_rng(&self) -> RefMut<'_, R::RngType> { + let data_container = self + .get_data_container::() + .expect("You must initialize the random number generator with a base seed"); + let random_holders = data_container.rng_holders.try_borrow_mut().unwrap(); + + let random_holder = RefMut::map(random_holders, |random_holders| { + random_holders.entry(TypeId::of::()).or_insert_with(|| { + let base_seed = data_container.base_seed; + let seed_offset = fxhash::hash64(R::get_name()); + RngHolder { + rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), + } + }) + }); + + RefMut::map(random_holder, |random_holder| { + random_holder.rng.downcast_mut::().unwrap() + }) + } +} + +#[cfg(test)] +mod test { + use crate::context::Context; + use crate::random::RandomContext; + use rand::RngCore; + + define_rng!(FooRng); + define_rng!(BarRng); + + #[test] + fn get_rng_basic() { + let mut context = Context::new(); + context.init_random(42); + + let mut foo_rng = context.get_rng::(); + assert_eq!(foo_rng.next_u64(), 5113542052170610017); + assert_eq!(foo_rng.next_u64(), 8640506012583485895); + assert_eq!(foo_rng.next_u64(), 16699691489468094833); + } + + #[test] + #[should_panic(expected = "You must initialize the random number generator with a base seed")] + fn panic_if_not_initialized() { + let context = Context::new(); + context.get_rng::(); + } + + #[test] + #[should_panic] + fn get_rng_one_ref_per_rng_id() { + let mut context = Context::new(); + context.init_random(42); + let mut foo_rng = context.get_rng::(); + + // This should panic because we already have a mutable reference to FooRng + let mut foo_rng_2 = context.get_rng::(); + foo_rng.next_u64(); + foo_rng_2.next_u64(); + } + + #[test] + fn get_rng_two_types() { + let mut context = Context::new(); + context.init_random(42); + + let mut foo_rng = context.get_rng::(); + foo_rng.next_u64(); + drop(foo_rng); + + let mut bar_rng = context.get_rng::(); + bar_rng.next_u64(); + } + + #[test] + fn reset_seed() { + let mut context = Context::new(); + context.init_random(42); + + let mut foo_rng = context.get_rng::(); + let run_0 = foo_rng.next_u64(); + let run_1 = foo_rng.next_u64(); + drop(foo_rng); + + // Reset with same seed, ensure we get the same values + context.init_random(42); + let mut foo_rng = context.get_rng::(); + assert_eq!(run_0, foo_rng.next_u64()); + assert_eq!(run_1, foo_rng.next_u64()); + drop(foo_rng); + + // Reset with different seed, ensure we get different values + context.init_random(88); + let mut foo_rng = context.get_rng::(); + assert_ne!(run_0, foo_rng.next_u64()); + assert_ne!(run_1, foo_rng.next_u64()); + } +}