From 65058313fd50c38190a3e0a623686ed166403ea3 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Thu, 8 Aug 2024 11:45:46 -0400 Subject: [PATCH 01/15] Added random module --- Cargo.toml | 2 + src/lib.rs | 1 + src/random.rs | 171 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+) create mode 100644 src/random.rs diff --git a/Cargo.toml b/Cargo.toml index be2f210f..ca2369d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,3 +8,5 @@ license = "Apache-2.0" homepage = "https://github.com/CDCgov/ixa" [dependencies] +fxhash = "0.2.1" +rand = "0.8.5" diff --git a/src/lib.rs b/src/lib.rs index dfe09e70..de495e01 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,3 +28,4 @@ //! person trying to infect susceptible people in the population. pub mod context; pub mod plan; +pub mod random; diff --git a/src/random.rs b/src/random.rs new file mode 100644 index 00000000..42b36324 --- /dev/null +++ b/src/random.rs @@ -0,0 +1,171 @@ +use crate::context::Context; +use rand::SeedableRng; +use std::any::{Any, TypeId}; +use std::cell::{RefCell, RefMut}; +use std::collections::HashMap; + +/// Use this to define a unique type which will be used as a key to retrieve +/// an independent rng instance when calling `.get_rng`. +#[macro_export] +macro_rules! define_rng { + ($random_id:ident) => { + struct $random_id {} + + impl $crate::random::RngId for $random_id { + // TODO: This is hardcoded to StdRng; we should replace this + type RngType = rand::rngs::StdRng; + + fn get_name() -> &'static str { + stringify!($random_id) + } + } + }; +} +pub use define_rng; + +pub trait RngId: Any { + type RngType: SeedableRng; + fn get_name() -> &'static str; +} + +// This is a wrapper which allows for future support for different types of +// random number generators (anything that implements SeedableRng is valid). +struct RngHolder { + rng: Box, +} + +struct RngData { + base_seed: u64, + rng_holders: RefCell>, +} + +crate::context::define_data_plugin!( + RngPlugin, + RngData, + RngData { + base_seed: 0, + rng_holders: RefCell::new(HashMap::new()), + } +); + +#[allow(clippy::module_name_repetitions)] +pub trait RandomContext { + fn set_base_random_seed(&mut self, base_seed: u64); + + fn get_rng(&self) -> RefMut<'_, R::RngType>; +} + +impl RandomContext for Context { + /// Initializes the `RngPlugin` data container to store rngs as well as a base + /// seed. Note that rngs are created lazily when `get_rng` is called. + fn set_base_random_seed(&mut self, base_seed: u64) { + let data_container = self.get_data_container_mut::(); + data_container.base_seed = base_seed; + + // Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called + let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap(); + rng_map.clear(); + } + + /// Gets a mutable reference to the random number generator associated with the given + /// `RngId`. If the Rng has not been used before, one will be created with the base seed + /// you defined in `set_base_random_seed`. Note that this will panic if `set_base_random_seed` was not called yet. + fn get_rng(&self) -> RefMut<'_, R::RngType> { + let data_container = self + .get_data_container::() + .expect("You must initialize the random number generator with a base seed"); + let random_holders = data_container.rng_holders.try_borrow_mut().unwrap(); + + let random_holder = RefMut::map(random_holders, |random_holders| { + random_holders.entry(TypeId::of::()).or_insert_with(|| { + let base_seed = data_container.base_seed; + let seed_offset = fxhash::hash64(R::get_name()); + RngHolder { + rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), + } + }) + }); + + RefMut::map(random_holder, |random_holder| { + random_holder.rng.downcast_mut::().unwrap() + }) + } +} + +#[cfg(test)] +mod test { + use crate::context::Context; + use crate::random::RandomContext; + use rand::RngCore; + + define_rng!(FooRng); + define_rng!(BarRng); + + #[test] + fn get_rng_basic() { + let mut context = Context::new(); + context.set_base_random_seed(42); + + let mut foo_rng = context.get_rng::(); + assert_eq!(foo_rng.next_u64(), 5113542052170610017); + assert_eq!(foo_rng.next_u64(), 8640506012583485895); + assert_eq!(foo_rng.next_u64(), 16699691489468094833); + } + + #[test] + #[should_panic(expected = "You must initialize the random number generator with a base seed")] + fn panic_if_not_initialized() { + let context = Context::new(); + context.get_rng::(); + } + + #[test] + #[should_panic] + fn get_rng_one_ref_per_rng_id() { + let mut context = Context::new(); + context.set_base_random_seed(42); + let mut foo_rng = context.get_rng::(); + + // This should panic because we already have a mutable reference to FooRng + let mut foo_rng_2 = context.get_rng::(); + foo_rng.next_u64(); + foo_rng_2.next_u64(); + } + + #[test] + fn get_rng_two_types() { + let mut context = Context::new(); + context.set_base_random_seed(42); + + let mut foo_rng = context.get_rng::(); + foo_rng.next_u64(); + drop(foo_rng); + + let mut bar_rng = context.get_rng::(); + bar_rng.next_u64(); + } + + #[test] + fn reset_seed() { + let mut context = Context::new(); + context.set_base_random_seed(42); + + let mut foo_rng = context.get_rng::(); + let run_0 = foo_rng.next_u64(); + let run_1 = foo_rng.next_u64(); + drop(foo_rng); + + // Reset with same seed, ensure we get the same values + context.set_base_random_seed(42); + let mut foo_rng = context.get_rng::(); + assert_eq!(run_0, foo_rng.next_u64()); + assert_eq!(run_1, foo_rng.next_u64()); + drop(foo_rng); + + // Reset with different seed, ensure we get different values + context.set_base_random_seed(88); + let mut foo_rng = context.get_rng::(); + assert_ne!(run_0, foo_rng.next_u64()); + assert_ne!(run_1, foo_rng.next_u64()); + } +} From def51f29700732316bd4082c2b0a263b4810d1a1 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Mon, 12 Aug 2024 23:46:22 -0400 Subject: [PATCH 02/15] Refactor get_rng a bit --- src/random.rs | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/random.rs b/src/random.rs index 42b36324..d274c5e3 100644 --- a/src/random.rs +++ b/src/random.rs @@ -74,20 +74,23 @@ impl RandomContext for Context { let data_container = self .get_data_container::() .expect("You must initialize the random number generator with a base seed"); - let random_holders = data_container.rng_holders.try_borrow_mut().unwrap(); - - let random_holder = RefMut::map(random_holders, |random_holders| { - random_holders.entry(TypeId::of::()).or_insert_with(|| { - let base_seed = data_container.base_seed; - let seed_offset = fxhash::hash64(R::get_name()); - RngHolder { - rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), - } - }) - }); - - RefMut::map(random_holder, |random_holder| { - random_holder.rng.downcast_mut::().unwrap() + + let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap(); + + RefMut::map(rng_holders, |holders| { + holders + .entry(TypeId::of::()) + // Create a new rng holder if it doesn't exist yet + .or_insert_with(|| { + let base_seed = data_container.base_seed; + let seed_offset = fxhash::hash64(R::get_name()); + RngHolder { + rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), + } + }) + .rng + .downcast_mut::() + .unwrap() }) } } From a99cedcc22efb1adee6c8621f2fe47a927e6ed92 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Sat, 17 Aug 2024 13:10:54 -0400 Subject: [PATCH 03/15] Review fixes --- src/random.rs | 48 ++++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/src/random.rs b/src/random.rs index d274c5e3..f4b998d4 100644 --- a/src/random.rs +++ b/src/random.rs @@ -12,7 +12,7 @@ macro_rules! define_rng { struct $random_id {} impl $crate::random::RngId for $random_id { - // TODO: This is hardcoded to StdRng; we should replace this + // TODO(ryl8@cdc.gov): This is hardcoded to StdRng; we should replace this type RngType = rand::rngs::StdRng; fn get_name() -> &'static str { @@ -39,6 +39,11 @@ struct RngData { rng_holders: RefCell>, } +// Registers a data container which stores: +// * base_seed: A base seed for all rngs +// * rng_holders: A map of rngs, keyed by their RngId. Note that this is +// stored in a RefCell to allow for mutable borrow without requiring a +// mutable borrow of the Context itself. crate::context::define_data_plugin!( RngPlugin, RngData, @@ -48,17 +53,17 @@ crate::context::define_data_plugin!( } ); -#[allow(clippy::module_name_repetitions)] -pub trait RandomContext { - fn set_base_random_seed(&mut self, base_seed: u64); +// This is a trait exension on Context +pub trait ContextRandomExt { + fn init_random(&mut self, base_seed: u64); - fn get_rng(&self) -> RefMut<'_, R::RngType>; + fn get_rng(&self) -> RefMut; } -impl RandomContext for Context { +impl ContextRandomExt for Context { /// Initializes the `RngPlugin` data container to store rngs as well as a base /// seed. Note that rngs are created lazily when `get_rng` is called. - fn set_base_random_seed(&mut self, base_seed: u64) { + fn init_random(&mut self, base_seed: u64) { let data_container = self.get_data_container_mut::(); data_container.base_seed = base_seed; @@ -69,14 +74,13 @@ impl RandomContext for Context { /// Gets a mutable reference to the random number generator associated with the given /// `RngId`. If the Rng has not been used before, one will be created with the base seed - /// you defined in `set_base_random_seed`. Note that this will panic if `set_base_random_seed` was not called yet. - fn get_rng(&self) -> RefMut<'_, R::RngType> { + /// you defined in `init`. Note that this will panic if `init` was not called yet. + fn get_rng(&self) -> RefMut { let data_container = self .get_data_container::() .expect("You must initialize the random number generator with a base seed"); let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap(); - RefMut::map(rng_holders, |holders| { holders .entry(TypeId::of::()) @@ -98,7 +102,7 @@ impl RandomContext for Context { #[cfg(test)] mod test { use crate::context::Context; - use crate::random::RandomContext; + use crate::random::ContextRandomExt; use rand::RngCore; define_rng!(FooRng); @@ -107,12 +111,11 @@ mod test { #[test] fn get_rng_basic() { let mut context = Context::new(); - context.set_base_random_seed(42); + context.init_random(42); let mut foo_rng = context.get_rng::(); - assert_eq!(foo_rng.next_u64(), 5113542052170610017); - assert_eq!(foo_rng.next_u64(), 8640506012583485895); - assert_eq!(foo_rng.next_u64(), 16699691489468094833); + + assert_ne!(foo_rng.next_u64(), foo_rng.next_u64()); } #[test] @@ -124,9 +127,9 @@ mod test { #[test] #[should_panic] - fn get_rng_one_ref_per_rng_id() { + fn no_multiple_references_to_rngs() { let mut context = Context::new(); - context.set_base_random_seed(42); + context.init_random(42); let mut foo_rng = context.get_rng::(); // This should panic because we already have a mutable reference to FooRng @@ -136,12 +139,13 @@ mod test { } #[test] - fn get_rng_two_types() { + fn multiple_references_with_drop() { let mut context = Context::new(); - context.set_base_random_seed(42); + context.init_random(42); let mut foo_rng = context.get_rng::(); foo_rng.next_u64(); + // If you drop the first reference, you should be able to get a reference to a different rng drop(foo_rng); let mut bar_rng = context.get_rng::(); @@ -151,7 +155,7 @@ mod test { #[test] fn reset_seed() { let mut context = Context::new(); - context.set_base_random_seed(42); + context.init_random(42); let mut foo_rng = context.get_rng::(); let run_0 = foo_rng.next_u64(); @@ -159,14 +163,14 @@ mod test { drop(foo_rng); // Reset with same seed, ensure we get the same values - context.set_base_random_seed(42); + context.init_random(42); let mut foo_rng = context.get_rng::(); assert_eq!(run_0, foo_rng.next_u64()); assert_eq!(run_1, foo_rng.next_u64()); drop(foo_rng); // Reset with different seed, ensure we get different values - context.set_base_random_seed(88); + context.init_random(88); let mut foo_rng = context.get_rng::(); assert_ne!(run_0, foo_rng.next_u64()); assert_ne!(run_1, foo_rng.next_u64()); From 4870f23e16bd4f6be7eacfa8a324e2fb82c6ad18 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Sat, 17 Aug 2024 17:37:29 -0400 Subject: [PATCH 04/15] Added a test as an example of how to use with a distribution --- Cargo.toml | 3 +++ src/random.rs | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index ca2369d4..defc117b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,3 +10,6 @@ homepage = "https://github.com/CDCgov/ixa" [dependencies] fxhash = "0.2.1" rand = "0.8.5" + +[dev-dependencies] +rand_distr = "0.4.3" diff --git a/src/random.rs b/src/random.rs index f4b998d4..a584476a 100644 --- a/src/random.rs +++ b/src/random.rs @@ -104,6 +104,7 @@ mod test { use crate::context::Context; use crate::random::ContextRandomExt; use rand::RngCore; + use rand_distr::{Distribution, Exp}; define_rng!(FooRng); define_rng!(BarRng); @@ -152,6 +153,15 @@ mod test { bar_rng.next_u64(); } + #[test] + fn usage_with_distribution() { + let mut context = Context::new(); + context.init_random(42); + let mut rng = context.get_rng::(); + let dist = Exp::new(1.0).unwrap(); + assert_ne!(dist.sample(&mut *rng), dist.sample(&mut *rng)); + } + #[test] fn reset_seed() { let mut context = Context::new(); From 1d01a693f94a178edafc99a0708f471f3a749e83 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Mon, 19 Aug 2024 19:05:35 -0400 Subject: [PATCH 05/15] Remove Any --- src/random.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/random.rs b/src/random.rs index a584476a..9e8a3c20 100644 --- a/src/random.rs +++ b/src/random.rs @@ -23,7 +23,7 @@ macro_rules! define_rng { } pub use define_rng; -pub trait RngId: Any { +pub trait RngId { type RngType: SeedableRng; fn get_name() -> &'static str; } @@ -57,7 +57,7 @@ crate::context::define_data_plugin!( pub trait ContextRandomExt { fn init_random(&mut self, base_seed: u64); - fn get_rng(&self) -> RefMut; + fn get_rng(&self) -> RefMut; } impl ContextRandomExt for Context { From 377efb24d0ece813d7705a3d6f5c392d57ff2bc8 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Mon, 19 Aug 2024 20:06:13 -0400 Subject: [PATCH 06/15] Ensure uniqueness of --- src/random.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/random.rs b/src/random.rs index 9e8a3c20..854de548 100644 --- a/src/random.rs +++ b/src/random.rs @@ -19,6 +19,11 @@ macro_rules! define_rng { stringify!($random_id) } } + + // This ensures that you can't define two RngIds with the same name + #[doc(hidden)] + #[no_mangle] + pub static $random_id: () = (); }; } pub use define_rng; From 23ef2206b4a6225cd4efcbc80fc9868104466ceb Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Mon, 19 Aug 2024 20:39:26 -0400 Subject: [PATCH 07/15] Switch api to sample / sample_distr --- src/random.rs | 179 +++++++++++++++++++++++++++++++------------------- 1 file changed, 112 insertions(+), 67 deletions(-) diff --git a/src/random.rs b/src/random.rs index 854de548..01939b97 100644 --- a/src/random.rs +++ b/src/random.rs @@ -1,5 +1,6 @@ use crate::context::Context; -use rand::SeedableRng; +use rand::prelude::Distribution; +use rand::{Rng, SeedableRng}; use std::any::{Any, TypeId}; use std::cell::{RefCell, RefMut}; use std::collections::HashMap; @@ -58,11 +59,41 @@ crate::context::define_data_plugin!( } ); +/// Gets a mutable reference to the random number generator associated with the given +/// `RngId`. If the Rng has not been used before, one will be created with the base seed +/// you defined in `init`. Note that this will panic if `init` was not called yet. +fn get_rng(context: &Context) -> RefMut { + let data_container = context + .get_data_container::() + .expect("You must initialize the random number generator with a base seed"); + + let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap(); + RefMut::map(rng_holders, |holders| { + holders + .entry(TypeId::of::()) + // Create a new rng holder if it doesn't exist yet + .or_insert_with(|| { + let base_seed = data_container.base_seed; + let seed_offset = fxhash::hash64(R::get_name()); + RngHolder { + rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), + } + }) + .rng + .downcast_mut::() + .unwrap() + }) +} + // This is a trait exension on Context pub trait ContextRandomExt { fn init_random(&mut self, base_seed: u64); - fn get_rng(&self) -> RefMut; + fn sample(&self, sampler: impl FnOnce(&mut R::RngType) -> T) -> T; + + fn sample_distr(&self, distribution: impl Distribution) -> T + where + R::RngType: Rng; } impl ContextRandomExt for Context { @@ -77,39 +108,35 @@ impl ContextRandomExt for Context { rng_map.clear(); } - /// Gets a mutable reference to the random number generator associated with the given - /// `RngId`. If the Rng has not been used before, one will be created with the base seed - /// you defined in `init`. Note that this will panic if `init` was not called yet. - fn get_rng(&self) -> RefMut { - let data_container = self - .get_data_container::() - .expect("You must initialize the random number generator with a base seed"); - - let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap(); - RefMut::map(rng_holders, |holders| { - holders - .entry(TypeId::of::()) - // Create a new rng holder if it doesn't exist yet - .or_insert_with(|| { - let base_seed = data_container.base_seed; - let seed_offset = fxhash::hash64(R::get_name()); - RngHolder { - rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), - } - }) - .rng - .downcast_mut::() - .unwrap() - }) + /// Gets a random sample from the random number generator associated with the given + /// `RngId` by applying the specified sampler function. If the Rng has not been used + /// before, one will be created with the base seed you defined in `set_base_random_seed`. + /// Note that this will panic if `set_base_random_seed` was not called yet. + fn sample(&self, sampler: impl FnOnce(&mut R::RngType) -> T) -> T { + let mut rng = get_rng::(self); + sampler(&mut rng) + } + + /// Gets a random sample from the specified distribution using a random number generator + /// associated with the given `RngId`. If the Rng has not been used before, one will be + /// created with the base seed you defined in `set_base_random_seed`. + /// Note that this will panic if `set_base_random_seed` was not called yet. + fn sample_distr(&self, distribution: impl Distribution) -> T + where + R::RngType: Rng, + { + let mut rng = get_rng::(self); + distribution.sample::(&mut rng) } } #[cfg(test)] mod test { use crate::context::Context; + use crate::define_data_plugin; use crate::random::ContextRandomExt; use rand::RngCore; - use rand_distr::{Distribution, Exp}; + use rand::{distributions::WeightedIndex, prelude::Distribution}; define_rng!(FooRng); define_rng!(BarRng); @@ -119,75 +146,93 @@ mod test { let mut context = Context::new(); context.init_random(42); - let mut foo_rng = context.get_rng::(); - - assert_ne!(foo_rng.next_u64(), foo_rng.next_u64()); + assert_ne!( + context.sample::(|rng| rng.next_u64()), + context.sample::(|rng| rng.next_u64()) + ); } #[test] #[should_panic(expected = "You must initialize the random number generator with a base seed")] fn panic_if_not_initialized() { let context = Context::new(); - context.get_rng::(); + context.sample::(|rng| rng.next_u64()); } #[test] - #[should_panic] - fn no_multiple_references_to_rngs() { + fn multiple_references_with_drop() { let mut context = Context::new(); context.init_random(42); - let mut foo_rng = context.get_rng::(); - // This should panic because we already have a mutable reference to FooRng - let mut foo_rng_2 = context.get_rng::(); - foo_rng.next_u64(); - foo_rng_2.next_u64(); + assert_ne!( + context.sample::(|rng| rng.next_u64()), + context.sample::(|rng| rng.next_u64()) + ); } #[test] - fn multiple_references_with_drop() { + fn reset_seed() { let mut context = Context::new(); context.init_random(42); - let mut foo_rng = context.get_rng::(); - foo_rng.next_u64(); - // If you drop the first reference, you should be able to get a reference to a different rng - drop(foo_rng); + let run_0 = context.sample::(|rng| rng.next_u64()); + let run_1 = context.sample::(|rng| rng.next_u64()); + + // Reset with same seed, ensure we get the same values + context.init_random(42); + assert_eq!(run_0, context.sample::(|rng| rng.next_u64())); + assert_eq!(run_1, context.sample::(|rng| rng.next_u64())); - let mut bar_rng = context.get_rng::(); - bar_rng.next_u64(); + // Reset with different seed, ensure we get different values + context.init_random(88); + assert_ne!(run_0, context.sample::(|rng| rng.next_u64())); + assert_ne!(run_1, context.sample::(|rng| rng.next_u64())); } + define_data_plugin!( + SamplerData, + WeightedIndex, + WeightedIndex::new(vec![1.0]).unwrap() + ); + #[test] - fn usage_with_distribution() { + fn sampler_function_closure_capture() { let mut context = Context::new(); context.init_random(42); - let mut rng = context.get_rng::(); - let dist = Exp::new(1.0).unwrap(); - assert_ne!(dist.sample(&mut *rng), dist.sample(&mut *rng)); + // Initialize weighted sampler + *context.get_data_container_mut::() = + WeightedIndex::new(vec![1.0, 2.0]).unwrap(); + + let parameters = context.get_data_container::().unwrap(); + let n_samples = 3000; + let mut zero_counter = 0; + for _ in 0..n_samples { + let sample = context.sample::(|rng| parameters.sample(rng)); + if sample == 0 { + zero_counter += 1; + } + } + assert!((zero_counter - 1000 as i32).abs() < 30); } #[test] - fn reset_seed() { + fn sample_distribution() { let mut context = Context::new(); context.init_random(42); - let mut foo_rng = context.get_rng::(); - let run_0 = foo_rng.next_u64(); - let run_1 = foo_rng.next_u64(); - drop(foo_rng); - - // Reset with same seed, ensure we get the same values - context.init_random(42); - let mut foo_rng = context.get_rng::(); - assert_eq!(run_0, foo_rng.next_u64()); - assert_eq!(run_1, foo_rng.next_u64()); - drop(foo_rng); - - // Reset with different seed, ensure we get different values - context.init_random(88); - let mut foo_rng = context.get_rng::(); - assert_ne!(run_0, foo_rng.next_u64()); - assert_ne!(run_1, foo_rng.next_u64()); + // Initialize weighted sampler + *context.get_data_container_mut::() = + WeightedIndex::new(vec![1.0, 2.0]).unwrap(); + + let parameters = context.get_data_container::().unwrap(); + let n_samples = 3000; + let mut zero_counter = 0; + for _ in 0..n_samples { + let sample = context.sample_distr::(parameters); + if sample == 0 { + zero_counter += 1; + } + } + assert!((zero_counter - 1000 as i32).abs() < 30); } } From f6a9f3b173d58d51d0c20b28c1617db7cda704ca Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Tue, 20 Aug 2024 17:04:58 -0400 Subject: [PATCH 08/15] namespace type collision guard --- Cargo.toml | 1 + src/random.rs | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index defc117b..a704bfc3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ homepage = "https://github.com/CDCgov/ixa" [dependencies] fxhash = "0.2.1" +paste = "1.0.15" rand = "0.8.5" [dev-dependencies] diff --git a/src/random.rs b/src/random.rs index 01939b97..341fd994 100644 --- a/src/random.rs +++ b/src/random.rs @@ -22,9 +22,11 @@ macro_rules! define_rng { } // This ensures that you can't define two RngIds with the same name - #[doc(hidden)] - #[no_mangle] - pub static $random_id: () = (); + paste::paste! { + #[doc(hidden)] + #[no_mangle] + pub static []: () = (); + } }; } pub use define_rng; From 7404c87cb7653e4980de389f8790d6dd5cdaf33e Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Tue, 27 Aug 2024 14:48:06 -0400 Subject: [PATCH 09/15] add sample_range --- src/random.rs | 41 +++++++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/src/random.rs b/src/random.rs index 341fd994..ca3dc98d 100644 --- a/src/random.rs +++ b/src/random.rs @@ -1,4 +1,6 @@ use crate::context::Context; +use rand::distributions::uniform::SampleRange; +use rand::distributions::uniform::SampleUniform; use rand::prelude::Distribution; use rand::{Rng, SeedableRng}; use std::any::{Any, TypeId}; @@ -91,11 +93,25 @@ fn get_rng(context: &Context) -> RefMut { pub trait ContextRandomExt { fn init_random(&mut self, base_seed: u64); + /// Gets a random sample from the random number generator associated with the given + /// `RngId` by applying the specified sampler function. If the Rng has not been used + /// before, one will be created with the base seed you defined in `set_base_random_seed`. + /// Note that this will panic if `set_base_random_seed` was not called yet. fn sample(&self, sampler: impl FnOnce(&mut R::RngType) -> T) -> T; + /// Gets a random sample from the specified distribution using a random number generator + /// associated with the given `RngId`. If the Rng has not been used before, one will be + /// created with the base seed you defined in `set_base_random_seed`. + /// Note that this will panic if `set_base_random_seed` was not called yet. fn sample_distr(&self, distribution: impl Distribution) -> T where R::RngType: Rng; + + fn sample_range(&self, range: S) -> T + where + R::RngType: Rng, + S: SampleRange, + T: SampleUniform; } impl ContextRandomExt for Context { @@ -110,19 +126,11 @@ impl ContextRandomExt for Context { rng_map.clear(); } - /// Gets a random sample from the random number generator associated with the given - /// `RngId` by applying the specified sampler function. If the Rng has not been used - /// before, one will be created with the base seed you defined in `set_base_random_seed`. - /// Note that this will panic if `set_base_random_seed` was not called yet. fn sample(&self, sampler: impl FnOnce(&mut R::RngType) -> T) -> T { let mut rng = get_rng::(self); sampler(&mut rng) } - /// Gets a random sample from the specified distribution using a random number generator - /// associated with the given `RngId`. If the Rng has not been used before, one will be - /// created with the base seed you defined in `set_base_random_seed`. - /// Note that this will panic if `set_base_random_seed` was not called yet. fn sample_distr(&self, distribution: impl Distribution) -> T where R::RngType: Rng, @@ -130,6 +138,15 @@ impl ContextRandomExt for Context { let mut rng = get_rng::(self); distribution.sample::(&mut rng) } + + fn sample_range(&self, range: S) -> T + where + R::RngType: Rng, + S: SampleRange, + T: SampleUniform, + { + self.sample::(|rng| rng.gen_range(range)) + } } #[cfg(test)] @@ -237,4 +254,12 @@ mod test { } assert!((zero_counter - 1000 as i32).abs() < 30); } + + #[test] + fn sample_range() { + let mut context = Context::new(); + context.init_random(42); + + context.sample_range::, usize>(0..10); + } } From 522f24a1110609ac3662f33f7f857bce21e468c3 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Tue, 27 Aug 2024 15:20:21 -0400 Subject: [PATCH 10/15] Use struct literals and things are better --- src/random.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/random.rs b/src/random.rs index ca3dc98d..d5eaaab9 100644 --- a/src/random.rs +++ b/src/random.rs @@ -12,7 +12,7 @@ use std::collections::HashMap; #[macro_export] macro_rules! define_rng { ($random_id:ident) => { - struct $random_id {} + struct $random_id; impl $crate::random::RngId for $random_id { // TODO(ryl8@cdc.gov): This is hardcoded to StdRng; we should replace this @@ -107,7 +107,7 @@ pub trait ContextRandomExt { where R::RngType: Rng; - fn sample_range(&self, range: S) -> T + fn sample_range(&self, rng_type: R, range: S) -> T where R::RngType: Rng, S: SampleRange, @@ -139,7 +139,7 @@ impl ContextRandomExt for Context { distribution.sample::(&mut rng) } - fn sample_range(&self, range: S) -> T + fn sample_range(&self, _rng_id: R, range: S) -> T where R::RngType: Rng, S: SampleRange, @@ -259,7 +259,7 @@ mod test { fn sample_range() { let mut context = Context::new(); context.init_random(42); - - context.sample_range::, usize>(0..10); + let result = context.sample_range(FooRng, 0..10); + assert!(result >= 0 && result < 10); } } From fc51d73b0821fa229c29d4da9623f2bef18c7871 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Tue, 27 Aug 2024 16:06:16 -0400 Subject: [PATCH 11/15] Goodbye turbofish --- src/random.rs | 54 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/src/random.rs b/src/random.rs index d5eaaab9..1f553550 100644 --- a/src/random.rs +++ b/src/random.rs @@ -97,13 +97,21 @@ pub trait ContextRandomExt { /// `RngId` by applying the specified sampler function. If the Rng has not been used /// before, one will be created with the base seed you defined in `set_base_random_seed`. /// Note that this will panic if `set_base_random_seed` was not called yet. - fn sample(&self, sampler: impl FnOnce(&mut R::RngType) -> T) -> T; + fn sample( + &self, + _rang_type: R, + sampler: impl FnOnce(&mut R::RngType) -> T, + ) -> T; /// Gets a random sample from the specified distribution using a random number generator /// associated with the given `RngId`. If the Rng has not been used before, one will be /// created with the base seed you defined in `set_base_random_seed`. /// Note that this will panic if `set_base_random_seed` was not called yet. - fn sample_distr(&self, distribution: impl Distribution) -> T + fn sample_distr( + &self, + _rng_type: R, + distribution: impl Distribution, + ) -> T where R::RngType: Rng; @@ -126,12 +134,20 @@ impl ContextRandomExt for Context { rng_map.clear(); } - fn sample(&self, sampler: impl FnOnce(&mut R::RngType) -> T) -> T { + fn sample( + &self, + _rng_id: R, + sampler: impl FnOnce(&mut R::RngType) -> T, + ) -> T { let mut rng = get_rng::(self); sampler(&mut rng) } - fn sample_distr(&self, distribution: impl Distribution) -> T + fn sample_distr( + &self, + _rng_id: R, + distribution: impl Distribution, + ) -> T where R::RngType: Rng, { @@ -139,13 +155,13 @@ impl ContextRandomExt for Context { distribution.sample::(&mut rng) } - fn sample_range(&self, _rng_id: R, range: S) -> T + fn sample_range(&self, rng_id: R, range: S) -> T where R::RngType: Rng, S: SampleRange, T: SampleUniform, { - self.sample::(|rng| rng.gen_range(range)) + self.sample(rng_id, |rng| rng.gen_range(range)) } } @@ -166,8 +182,8 @@ mod test { context.init_random(42); assert_ne!( - context.sample::(|rng| rng.next_u64()), - context.sample::(|rng| rng.next_u64()) + context.sample(FooRng, |rng| rng.next_u64()), + context.sample(FooRng, |rng| rng.next_u64()) ); } @@ -175,7 +191,7 @@ mod test { #[should_panic(expected = "You must initialize the random number generator with a base seed")] fn panic_if_not_initialized() { let context = Context::new(); - context.sample::(|rng| rng.next_u64()); + context.sample(FooRng, |rng| rng.next_u64()); } #[test] @@ -184,8 +200,8 @@ mod test { context.init_random(42); assert_ne!( - context.sample::(|rng| rng.next_u64()), - context.sample::(|rng| rng.next_u64()) + context.sample(FooRng, |rng| rng.next_u64()), + context.sample(BarRng, |rng| rng.next_u64()) ); } @@ -194,18 +210,18 @@ mod test { let mut context = Context::new(); context.init_random(42); - let run_0 = context.sample::(|rng| rng.next_u64()); - let run_1 = context.sample::(|rng| rng.next_u64()); + let run_0 = context.sample(FooRng, |rng| rng.next_u64()); + let run_1 = context.sample(FooRng, |rng| rng.next_u64()); // Reset with same seed, ensure we get the same values context.init_random(42); - assert_eq!(run_0, context.sample::(|rng| rng.next_u64())); - assert_eq!(run_1, context.sample::(|rng| rng.next_u64())); + assert_eq!(run_0, context.sample(FooRng, |rng| rng.next_u64())); + assert_eq!(run_1, context.sample(FooRng, |rng| rng.next_u64())); // Reset with different seed, ensure we get different values context.init_random(88); - assert_ne!(run_0, context.sample::(|rng| rng.next_u64())); - assert_ne!(run_1, context.sample::(|rng| rng.next_u64())); + assert_ne!(run_0, context.sample(FooRng, |rng| rng.next_u64())); + assert_ne!(run_1, context.sample(FooRng, |rng| rng.next_u64())); } define_data_plugin!( @@ -226,7 +242,7 @@ mod test { let n_samples = 3000; let mut zero_counter = 0; for _ in 0..n_samples { - let sample = context.sample::(|rng| parameters.sample(rng)); + let sample = context.sample(FooRng, |rng| parameters.sample(rng)); if sample == 0 { zero_counter += 1; } @@ -247,7 +263,7 @@ mod test { let n_samples = 3000; let mut zero_counter = 0; for _ in 0..n_samples { - let sample = context.sample_distr::(parameters); + let sample = context.sample_distr(FooRng, parameters); if sample == 0 { zero_counter += 1; } From d0e2aed5feca0398a0bfe09262e724cbfefb65ef Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Tue, 27 Aug 2024 16:29:17 -0400 Subject: [PATCH 12/15] add sample bool --- src/random.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/random.rs b/src/random.rs index 1f553550..12f2cb58 100644 --- a/src/random.rs +++ b/src/random.rs @@ -120,6 +120,10 @@ pub trait ContextRandomExt { R::RngType: Rng, S: SampleRange, T: SampleUniform; + + fn sample_bool(&self, rng_id: R, p: f64) -> bool + where + R::RngType: Rng; } impl ContextRandomExt for Context { @@ -163,6 +167,13 @@ impl ContextRandomExt for Context { { self.sample(rng_id, |rng| rng.gen_range(range)) } + + fn sample_bool(&self, rng_id: R, p: f64) -> bool + where + R::RngType: Rng, + { + self.sample(rng_id, |rng| rng.gen_bool(p)) + } } #[cfg(test)] @@ -278,4 +289,12 @@ mod test { let result = context.sample_range(FooRng, 0..10); assert!(result >= 0 && result < 10); } + + #[test] + fn sample_bool() { + let mut context = Context::new(); + context.init_random(42); + let result = context.sample_bool(FooRng, 0.5); + assert!(result == true || result == false); + } } From f61ae10f5a9ea8ab0b37c6bbe3f046009970a54f Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Tue, 27 Aug 2024 16:43:12 -0400 Subject: [PATCH 13/15] typo --- src/random.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/random.rs b/src/random.rs index 12f2cb58..519cd96c 100644 --- a/src/random.rs +++ b/src/random.rs @@ -99,7 +99,7 @@ pub trait ContextRandomExt { /// Note that this will panic if `set_base_random_seed` was not called yet. fn sample( &self, - _rang_type: R, + _rng_type: R, sampler: impl FnOnce(&mut R::RngType) -> T, ) -> T; From fa7acb4ed6745d1234a1ecf233e6c14fdaf75f81 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Tue, 27 Aug 2024 17:01:00 -0400 Subject: [PATCH 14/15] multiple rng types types name --- src/random.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/random.rs b/src/random.rs index 519cd96c..7ed9b0e1 100644 --- a/src/random.rs +++ b/src/random.rs @@ -206,7 +206,7 @@ mod test { } #[test] - fn multiple_references_with_drop() { + fn multiple_rng_types() { let mut context = Context::new(); context.init_random(42); From de14d50ba9262ed3c39f3480fa9155933c8b343b Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Tue, 27 Aug 2024 17:14:12 -0400 Subject: [PATCH 15/15] Better test for sample_bool --- src/random.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/random.rs b/src/random.rs index 7ed9b0e1..4498a18a 100644 --- a/src/random.rs +++ b/src/random.rs @@ -294,7 +294,6 @@ mod test { fn sample_bool() { let mut context = Context::new(); context.init_random(42); - let result = context.sample_bool(FooRng, 0.5); - assert!(result == true || result == false); + let _r: bool = context.sample_bool(FooRng, 0.5); } }