Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added random module #18

Merged
merged 15 commits into from
Aug 27, 2024
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ license = "Apache-2.0"
homepage = "https://github.com/CDCgov/ixa"

[dependencies]
fxhash = "0.2.1"
rand = "0.8.5"
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
//! person trying to infect susceptible people in the population.
pub mod context;
pub mod plan;
pub mod random;
178 changes: 178 additions & 0 deletions src/random.rs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is conceptually like the eosim code, but we also talked about a version where you didn't get the RNG but just told it what distribution to draw from. @jasonasher did you ever prototype that.

Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use crate::context::Context;
use rand::SeedableRng;
use std::any::{Any, TypeId};
use std::cell::{RefCell, RefMut};
use std::collections::HashMap;

/// Use this to define a unique type which will be used as a key to retrieve
/// an independent rng instance when calling `.get_rng`.
#[macro_export]
macro_rules! define_rng {
($random_id:ident) => {
struct $random_id {}

impl $crate::random::RngId for $random_id {
// TODO([email protected]): This is hardcoded to StdRng; we should replace this
type RngType = rand::rngs::StdRng;

fn get_name() -> &'static str {
stringify!($random_id)
ekr-cfa marked this conversation as resolved.
Show resolved Hide resolved
}
}
};
}
pub use define_rng;

pub trait RngId: Any {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to implement Any

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some discussion, turns out problem here was about the lifetime of R (the type id) in get_rng<R: RngId> , which apparently is scoped to the function whereas the closure where it is used has a static lifetime. The right fix here was to change the signature to fn get_rng<R: RngId + 'static>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could get away with just RngId: 'static so that it has a TypeId, or move the static lifetime requirement into the get_rng / sample method signatures.

type RngType: SeedableRng;
fn get_name() -> &'static str;
}

// This is a wrapper which allows for future support for different types of
// random number generators (anything that implements SeedableRng is valid).
struct RngHolder {
rng: Box<dyn Any>,
}

struct RngData {
base_seed: u64,
rng_holders: RefCell<HashMap<TypeId, RngHolder>>,
}

// Registers a data container which stores:
// * base_seed: A base seed for all rngs
// * rng_holders: A map of rngs, keyed by their RngId. Note that this is
// stored in a RefCell to allow for mutable borrow without requiring a
// mutable borrow of the Context itself.
crate::context::define_data_plugin!(
ekr-cfa marked this conversation as resolved.
Show resolved Hide resolved
RngPlugin,
RngData,
RngData {
base_seed: 0,
rng_holders: RefCell::new(HashMap::new()),
}
);

// This is a trait exension on Context
pub trait ContextRandomExt {
fn init_random(&mut self, base_seed: u64);

fn get_rng<R: RngId>(&self) -> RefMut<R::RngType>;
}

impl ContextRandomExt for Context {
/// Initializes the `RngPlugin` data container to store rngs as well as a base
/// seed. Note that rngs are created lazily when `get_rng` is called.
fn init_random(&mut self, base_seed: u64) {
let data_container = self.get_data_container_mut::<RngPlugin>();
data_container.base_seed = base_seed;

// Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called
let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
rng_map.clear();
}

/// Gets a mutable reference to the random number generator associated with the given
/// `RngId`. If the Rng has not been used before, one will be created with the base seed
/// you defined in `init`. Note that this will panic if `init` was not called yet.
fn get_rng<R: RngId + 'static>(&self) -> RefMut<R::RngType> {
let data_container = self
.get_data_container::<RngPlugin>()
.expect("You must initialize the random number generator with a base seed");

let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
RefMut::map(rng_holders, |holders| {
holders
.entry(TypeId::of::<R>())
// Create a new rng holder if it doesn't exist yet
.or_insert_with(|| {
let base_seed = data_container.base_seed;
let seed_offset = fxhash::hash64(R::get_name());
RngHolder {
rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)),
}
})
.rng
.downcast_mut::<R::RngType>()
.unwrap()
})
}
}

#[cfg(test)]
mod test {
use crate::context::Context;
use crate::random::ContextRandomExt;
use rand::RngCore;

define_rng!(FooRng);
define_rng!(BarRng);

#[test]
fn get_rng_basic() {
let mut context = Context::new();
context.init_random(42);

let mut foo_rng = context.get_rng::<FooRng>();

assert_ne!(foo_rng.next_u64(), foo_rng.next_u64());
}

#[test]
#[should_panic(expected = "You must initialize the random number generator with a base seed")]
fn panic_if_not_initialized() {
let context = Context::new();
context.get_rng::<FooRng>();
}

#[test]
#[should_panic]
fn no_multiple_references_to_rngs() {
let mut context = Context::new();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't seem to do what it says.

context.init_random(42);
let mut foo_rng = context.get_rng::<FooRng>();

// This should panic because we already have a mutable reference to FooRng
let mut foo_rng_2 = context.get_rng::<BarRng>();
foo_rng.next_u64();
foo_rng_2.next_u64();
}

#[test]
fn multiple_references_with_drop() {
let mut context = Context::new();
context.init_random(42);

let mut foo_rng = context.get_rng::<FooRng>();
foo_rng.next_u64();
// If you drop the first reference, you should be able to get a reference to a different rng
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// If you drop the first reference, you should be able to get a reference to a different rng
// If you drop the first reference, you should be able to get another reference to an rng

The same or different, right?

drop(foo_rng);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe comment this drop to explain why it's different from the previous test.


let mut bar_rng = context.get_rng::<BarRng>();
bar_rng.next_u64();
}

#[test]
fn reset_seed() {
let mut context = Context::new();
context.init_random(42);

let mut foo_rng = context.get_rng::<FooRng>();
let run_0 = foo_rng.next_u64();
let run_1 = foo_rng.next_u64();
drop(foo_rng);

// Reset with same seed, ensure we get the same values
context.init_random(42);
let mut foo_rng = context.get_rng::<FooRng>();
assert_eq!(run_0, foo_rng.next_u64());
assert_eq!(run_1, foo_rng.next_u64());
drop(foo_rng);

// Reset with different seed, ensure we get different values
context.init_random(88);
let mut foo_rng = context.get_rng::<FooRng>();
assert_ne!(run_0, foo_rng.next_u64());
assert_ne!(run_1, foo_rng.next_u64());
}
}
Loading