-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added random module #18
Changes from 3 commits
6505831
def51f2
a99cedc
4870f23
1d01a69
377efb2
23ef220
f6a9f3b
7404c87
522f24a
fc51d73
d0e2aed
f61ae10
fa7acb4
de14d50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,5 @@ license = "Apache-2.0" | |
homepage = "https://github.com/CDCgov/ixa" | ||
|
||
[dependencies] | ||
fxhash = "0.2.1" | ||
rand = "0.8.5" |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,178 @@ | ||||||
use crate::context::Context; | ||||||
use rand::SeedableRng; | ||||||
use std::any::{Any, TypeId}; | ||||||
use std::cell::{RefCell, RefMut}; | ||||||
use std::collections::HashMap; | ||||||
|
||||||
/// Use this to define a unique type which will be used as a key to retrieve | ||||||
/// an independent rng instance when calling `.get_rng`. | ||||||
#[macro_export] | ||||||
macro_rules! define_rng { | ||||||
($random_id:ident) => { | ||||||
struct $random_id {} | ||||||
|
||||||
impl $crate::random::RngId for $random_id { | ||||||
// TODO([email protected]): This is hardcoded to StdRng; we should replace this | ||||||
type RngType = rand::rngs::StdRng; | ||||||
|
||||||
fn get_name() -> &'static str { | ||||||
stringify!($random_id) | ||||||
ekr-cfa marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
} | ||||||
} | ||||||
}; | ||||||
} | ||||||
pub use define_rng; | ||||||
|
||||||
pub trait RngId: Any { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this need to implement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After some discussion, turns out problem here was about the lifetime of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you could get away with just |
||||||
type RngType: SeedableRng; | ||||||
fn get_name() -> &'static str; | ||||||
} | ||||||
|
||||||
// This is a wrapper which allows for future support for different types of | ||||||
// random number generators (anything that implements SeedableRng is valid). | ||||||
struct RngHolder { | ||||||
rng: Box<dyn Any>, | ||||||
} | ||||||
|
||||||
struct RngData { | ||||||
base_seed: u64, | ||||||
rng_holders: RefCell<HashMap<TypeId, RngHolder>>, | ||||||
} | ||||||
|
||||||
// Registers a data container which stores: | ||||||
// * base_seed: A base seed for all rngs | ||||||
// * rng_holders: A map of rngs, keyed by their RngId. Note that this is | ||||||
// stored in a RefCell to allow for mutable borrow without requiring a | ||||||
// mutable borrow of the Context itself. | ||||||
crate::context::define_data_plugin!( | ||||||
ekr-cfa marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
RngPlugin, | ||||||
RngData, | ||||||
RngData { | ||||||
base_seed: 0, | ||||||
rng_holders: RefCell::new(HashMap::new()), | ||||||
} | ||||||
); | ||||||
|
||||||
// This is a trait exension on Context | ||||||
pub trait ContextRandomExt { | ||||||
fn init_random(&mut self, base_seed: u64); | ||||||
|
||||||
fn get_rng<R: RngId>(&self) -> RefMut<R::RngType>; | ||||||
} | ||||||
|
||||||
impl ContextRandomExt for Context { | ||||||
/// Initializes the `RngPlugin` data container to store rngs as well as a base | ||||||
/// seed. Note that rngs are created lazily when `get_rng` is called. | ||||||
fn init_random(&mut self, base_seed: u64) { | ||||||
let data_container = self.get_data_container_mut::<RngPlugin>(); | ||||||
data_container.base_seed = base_seed; | ||||||
|
||||||
// Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called | ||||||
let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap(); | ||||||
rng_map.clear(); | ||||||
} | ||||||
|
||||||
/// Gets a mutable reference to the random number generator associated with the given | ||||||
/// `RngId`. If the Rng has not been used before, one will be created with the base seed | ||||||
/// you defined in `init`. Note that this will panic if `init` was not called yet. | ||||||
fn get_rng<R: RngId + 'static>(&self) -> RefMut<R::RngType> { | ||||||
let data_container = self | ||||||
.get_data_container::<RngPlugin>() | ||||||
.expect("You must initialize the random number generator with a base seed"); | ||||||
|
||||||
let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap(); | ||||||
RefMut::map(rng_holders, |holders| { | ||||||
holders | ||||||
.entry(TypeId::of::<R>()) | ||||||
// Create a new rng holder if it doesn't exist yet | ||||||
.or_insert_with(|| { | ||||||
let base_seed = data_container.base_seed; | ||||||
let seed_offset = fxhash::hash64(R::get_name()); | ||||||
RngHolder { | ||||||
rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)), | ||||||
} | ||||||
}) | ||||||
.rng | ||||||
.downcast_mut::<R::RngType>() | ||||||
.unwrap() | ||||||
}) | ||||||
} | ||||||
} | ||||||
|
||||||
#[cfg(test)] | ||||||
mod test { | ||||||
use crate::context::Context; | ||||||
use crate::random::ContextRandomExt; | ||||||
use rand::RngCore; | ||||||
|
||||||
define_rng!(FooRng); | ||||||
define_rng!(BarRng); | ||||||
|
||||||
#[test] | ||||||
fn get_rng_basic() { | ||||||
let mut context = Context::new(); | ||||||
context.init_random(42); | ||||||
|
||||||
let mut foo_rng = context.get_rng::<FooRng>(); | ||||||
|
||||||
assert_ne!(foo_rng.next_u64(), foo_rng.next_u64()); | ||||||
} | ||||||
|
||||||
#[test] | ||||||
#[should_panic(expected = "You must initialize the random number generator with a base seed")] | ||||||
fn panic_if_not_initialized() { | ||||||
let context = Context::new(); | ||||||
context.get_rng::<FooRng>(); | ||||||
} | ||||||
|
||||||
#[test] | ||||||
#[should_panic] | ||||||
fn no_multiple_references_to_rngs() { | ||||||
let mut context = Context::new(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test doesn't seem to do what it says. |
||||||
context.init_random(42); | ||||||
let mut foo_rng = context.get_rng::<FooRng>(); | ||||||
|
||||||
// This should panic because we already have a mutable reference to FooRng | ||||||
let mut foo_rng_2 = context.get_rng::<BarRng>(); | ||||||
foo_rng.next_u64(); | ||||||
foo_rng_2.next_u64(); | ||||||
} | ||||||
|
||||||
#[test] | ||||||
fn multiple_references_with_drop() { | ||||||
let mut context = Context::new(); | ||||||
context.init_random(42); | ||||||
|
||||||
let mut foo_rng = context.get_rng::<FooRng>(); | ||||||
foo_rng.next_u64(); | ||||||
// If you drop the first reference, you should be able to get a reference to a different rng | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The same or different, right? |
||||||
drop(foo_rng); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe comment this drop to explain why it's different from the previous test. |
||||||
|
||||||
let mut bar_rng = context.get_rng::<BarRng>(); | ||||||
bar_rng.next_u64(); | ||||||
} | ||||||
|
||||||
#[test] | ||||||
fn reset_seed() { | ||||||
let mut context = Context::new(); | ||||||
context.init_random(42); | ||||||
|
||||||
let mut foo_rng = context.get_rng::<FooRng>(); | ||||||
let run_0 = foo_rng.next_u64(); | ||||||
let run_1 = foo_rng.next_u64(); | ||||||
drop(foo_rng); | ||||||
|
||||||
// Reset with same seed, ensure we get the same values | ||||||
context.init_random(42); | ||||||
let mut foo_rng = context.get_rng::<FooRng>(); | ||||||
assert_eq!(run_0, foo_rng.next_u64()); | ||||||
assert_eq!(run_1, foo_rng.next_u64()); | ||||||
drop(foo_rng); | ||||||
|
||||||
// Reset with different seed, ensure we get different values | ||||||
context.init_random(88); | ||||||
let mut foo_rng = context.get_rng::<FooRng>(); | ||||||
assert_ne!(run_0, foo_rng.next_u64()); | ||||||
assert_ne!(run_1, foo_rng.next_u64()); | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is conceptually like the eosim code, but we also talked about a version where you didn't get the RNG but just told it what distribution to draw from. @jasonasher did you ever prototype that.