generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added random module #18
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
6505831
Added random module
k88hudson-cfa def51f2
Refactor get_rng a bit
k88hudson-cfa a99cedc
Review fixes
k88hudson-cfa 4870f23
Added a test as an example of how to use with a distribution
k88hudson-cfa 1d01a69
Remove Any
k88hudson-cfa 377efb2
Ensure uniqueness of
k88hudson-cfa 23ef220
Switch api to sample / sample_distr
k88hudson-cfa f6a9f3b
namespace type collision guard
k88hudson-cfa 7404c87
add sample_range
k88hudson-cfa 522f24a
Use struct literals and things are better
k88hudson-cfa fc51d73
Goodbye turbofish
k88hudson-cfa d0e2aed
add sample bool
k88hudson-cfa f61ae10
typo
k88hudson-cfa fa7acb4
multiple rng types types name
k88hudson-cfa de14d50
Better test for sample_bool
k88hudson-cfa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,299 @@ | ||
use crate::context::Context; | ||
use rand::distributions::uniform::SampleRange; | ||
use rand::distributions::uniform::SampleUniform; | ||
use rand::prelude::Distribution; | ||
use rand::{Rng, SeedableRng}; | ||
use std::any::{Any, TypeId}; | ||
use std::cell::{RefCell, RefMut}; | ||
use std::collections::HashMap; | ||
|
||
/// Use this to define a unique type which will be used as a key to retrieve | ||
/// an independent rng instance when calling `.get_rng`. | ||
#[macro_export] | ||
macro_rules! define_rng { | ||
($random_id:ident) => { | ||
struct $random_id; | ||
|
||
impl $crate::random::RngId for $random_id { | ||
// TODO([email protected]): This is hardcoded to StdRng; we should replace this | ||
type RngType = rand::rngs::StdRng; | ||
|
||
fn get_name() -> &'static str { | ||
stringify!($random_id) | ||
ekr-cfa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
// This ensures that you can't define two RngIds with the same name | ||
paste::paste! { | ||
#[doc(hidden)] | ||
#[no_mangle] | ||
pub static [<rng_name_duplication_guard_ $random_id>]: () = (); | ||
} | ||
}; | ||
} | ||
pub use define_rng; | ||
|
||
pub trait RngId { | ||
type RngType: SeedableRng; | ||
fn get_name() -> &'static str; | ||
} | ||
|
||
// This is a wrapper which allows for future support for different types of | ||
// random number generators (anything that implements SeedableRng is valid). | ||
struct RngHolder { | ||
rng: Box<dyn Any>, | ||
} | ||
|
||
struct RngData { | ||
base_seed: u64, | ||
rng_holders: RefCell<HashMap<TypeId, RngHolder>>, | ||
} | ||
|
||
// Registers a data container which stores: | ||
// * base_seed: A base seed for all rngs | ||
// * rng_holders: A map of rngs, keyed by their RngId. Note that this is | ||
// stored in a RefCell to allow for mutable borrow without requiring a | ||
// mutable borrow of the Context itself. | ||
crate::context::define_data_plugin!( | ||
ekr-cfa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
RngPlugin, | ||
RngData, | ||
RngData { | ||
base_seed: 0, | ||
rng_holders: RefCell::new(HashMap::new()), | ||
} | ||
); | ||
|
||
/// Gets a mutable reference to the random number generator associated with the given | ||
/// `RngId`. If the Rng has not been used before, one will be created with the base seed | ||
/// you defined in `init`. Note that this will panic if `init` was not called yet. | ||
fn get_rng<R: RngId + 'static>(context: &Context) -> RefMut<R::RngType> { | ||
let data_container = context | ||
.get_data_container::<RngPlugin>() | ||
.expect("You must initialize the random number generator with a base seed"); | ||
|
||
let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap(); | ||
RefMut::map(rng_holders, |holders| { | ||
holders | ||
.entry(TypeId::of::<R>()) | ||
// Create a new rng holder if it doesn't exist yet | ||
.or_insert_with(|| { | ||
let base_seed = data_container.base_seed; | ||
let seed_offset = fxhash::hash64(R::get_name()); | ||
RngHolder { | ||
rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), | ||
} | ||
}) | ||
.rng | ||
.downcast_mut::<R::RngType>() | ||
.unwrap() | ||
}) | ||
} | ||
|
||
// This is a trait exension on Context | ||
pub trait ContextRandomExt { | ||
fn init_random(&mut self, base_seed: u64); | ||
|
||
/// Gets a random sample from the random number generator associated with the given | ||
/// `RngId` by applying the specified sampler function. If the Rng has not been used | ||
/// before, one will be created with the base seed you defined in `set_base_random_seed`. | ||
/// Note that this will panic if `set_base_random_seed` was not called yet. | ||
fn sample<R: RngId + 'static, T>( | ||
&self, | ||
_rng_type: R, | ||
sampler: impl FnOnce(&mut R::RngType) -> T, | ||
) -> T; | ||
|
||
/// Gets a random sample from the specified distribution using a random number generator | ||
/// associated with the given `RngId`. If the Rng has not been used before, one will be | ||
/// created with the base seed you defined in `set_base_random_seed`. | ||
/// Note that this will panic if `set_base_random_seed` was not called yet. | ||
fn sample_distr<R: RngId + 'static, T>( | ||
&self, | ||
_rng_type: R, | ||
distribution: impl Distribution<T>, | ||
) -> T | ||
where | ||
R::RngType: Rng; | ||
|
||
fn sample_range<R: RngId + 'static, S, T>(&self, rng_type: R, range: S) -> T | ||
where | ||
R::RngType: Rng, | ||
S: SampleRange<T>, | ||
T: SampleUniform; | ||
|
||
fn sample_bool<R: RngId + 'static>(&self, rng_id: R, p: f64) -> bool | ||
where | ||
R::RngType: Rng; | ||
} | ||
|
||
impl ContextRandomExt for Context { | ||
/// Initializes the `RngPlugin` data container to store rngs as well as a base | ||
/// seed. Note that rngs are created lazily when `get_rng` is called. | ||
fn init_random(&mut self, base_seed: u64) { | ||
let data_container = self.get_data_container_mut::<RngPlugin>(); | ||
data_container.base_seed = base_seed; | ||
|
||
// Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called | ||
let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap(); | ||
rng_map.clear(); | ||
} | ||
|
||
fn sample<R: RngId + 'static, T>( | ||
&self, | ||
_rng_id: R, | ||
sampler: impl FnOnce(&mut R::RngType) -> T, | ||
) -> T { | ||
let mut rng = get_rng::<R>(self); | ||
sampler(&mut rng) | ||
} | ||
|
||
fn sample_distr<R: RngId + 'static, T>( | ||
&self, | ||
_rng_id: R, | ||
distribution: impl Distribution<T>, | ||
) -> T | ||
where | ||
R::RngType: Rng, | ||
{ | ||
let mut rng = get_rng::<R>(self); | ||
distribution.sample::<R::RngType>(&mut rng) | ||
} | ||
|
||
fn sample_range<R: RngId + 'static, S, T>(&self, rng_id: R, range: S) -> T | ||
where | ||
R::RngType: Rng, | ||
S: SampleRange<T>, | ||
T: SampleUniform, | ||
{ | ||
self.sample(rng_id, |rng| rng.gen_range(range)) | ||
} | ||
|
||
fn sample_bool<R: RngId + 'static>(&self, rng_id: R, p: f64) -> bool | ||
where | ||
R::RngType: Rng, | ||
{ | ||
self.sample(rng_id, |rng| rng.gen_bool(p)) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use crate::context::Context; | ||
use crate::define_data_plugin; | ||
use crate::random::ContextRandomExt; | ||
use rand::RngCore; | ||
use rand::{distributions::WeightedIndex, prelude::Distribution}; | ||
|
||
define_rng!(FooRng); | ||
define_rng!(BarRng); | ||
|
||
#[test] | ||
fn get_rng_basic() { | ||
let mut context = Context::new(); | ||
context.init_random(42); | ||
|
||
assert_ne!( | ||
context.sample(FooRng, |rng| rng.next_u64()), | ||
context.sample(FooRng, |rng| rng.next_u64()) | ||
); | ||
} | ||
|
||
#[test] | ||
#[should_panic(expected = "You must initialize the random number generator with a base seed")] | ||
fn panic_if_not_initialized() { | ||
let context = Context::new(); | ||
context.sample(FooRng, |rng| rng.next_u64()); | ||
} | ||
|
||
#[test] | ||
fn multiple_rng_types() { | ||
let mut context = Context::new(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test doesn't seem to do what it says. |
||
context.init_random(42); | ||
|
||
assert_ne!( | ||
context.sample(FooRng, |rng| rng.next_u64()), | ||
context.sample(BarRng, |rng| rng.next_u64()) | ||
); | ||
} | ||
|
||
#[test] | ||
fn reset_seed() { | ||
let mut context = Context::new(); | ||
context.init_random(42); | ||
|
||
let run_0 = context.sample(FooRng, |rng| rng.next_u64()); | ||
let run_1 = context.sample(FooRng, |rng| rng.next_u64()); | ||
|
||
// Reset with same seed, ensure we get the same values | ||
context.init_random(42); | ||
assert_eq!(run_0, context.sample(FooRng, |rng| rng.next_u64())); | ||
assert_eq!(run_1, context.sample(FooRng, |rng| rng.next_u64())); | ||
|
||
// Reset with different seed, ensure we get different values | ||
context.init_random(88); | ||
assert_ne!(run_0, context.sample(FooRng, |rng| rng.next_u64())); | ||
assert_ne!(run_1, context.sample(FooRng, |rng| rng.next_u64())); | ||
} | ||
|
||
define_data_plugin!( | ||
SamplerData, | ||
WeightedIndex<f64>, | ||
WeightedIndex::new(vec![1.0]).unwrap() | ||
); | ||
|
||
#[test] | ||
fn sampler_function_closure_capture() { | ||
let mut context = Context::new(); | ||
context.init_random(42); | ||
// Initialize weighted sampler | ||
*context.get_data_container_mut::<SamplerData>() = | ||
WeightedIndex::new(vec![1.0, 2.0]).unwrap(); | ||
|
||
let parameters = context.get_data_container::<SamplerData>().unwrap(); | ||
let n_samples = 3000; | ||
let mut zero_counter = 0; | ||
for _ in 0..n_samples { | ||
let sample = context.sample(FooRng, |rng| parameters.sample(rng)); | ||
if sample == 0 { | ||
zero_counter += 1; | ||
} | ||
} | ||
assert!((zero_counter - 1000 as i32).abs() < 30); | ||
} | ||
|
||
#[test] | ||
fn sample_distribution() { | ||
let mut context = Context::new(); | ||
context.init_random(42); | ||
|
||
// Initialize weighted sampler | ||
*context.get_data_container_mut::<SamplerData>() = | ||
WeightedIndex::new(vec![1.0, 2.0]).unwrap(); | ||
|
||
let parameters = context.get_data_container::<SamplerData>().unwrap(); | ||
let n_samples = 3000; | ||
let mut zero_counter = 0; | ||
for _ in 0..n_samples { | ||
let sample = context.sample_distr(FooRng, parameters); | ||
if sample == 0 { | ||
zero_counter += 1; | ||
} | ||
} | ||
assert!((zero_counter - 1000 as i32).abs() < 30); | ||
} | ||
|
||
#[test] | ||
fn sample_range() { | ||
let mut context = Context::new(); | ||
context.init_random(42); | ||
let result = context.sample_range(FooRng, 0..10); | ||
assert!(result >= 0 && result < 10); | ||
} | ||
|
||
#[test] | ||
fn sample_bool() { | ||
let mut context = Context::new(); | ||
context.init_random(42); | ||
let _r: bool = context.sample_bool(FooRng, 0.5); | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is conceptually like the eosim code, but we also talked about a version where you didn't get the RNG but just told it what distribution to draw from. @jasonasher did you ever prototype that.