Skip to content

Commit

Permalink
Added random module
Browse files Browse the repository at this point in the history
  • Loading branch information
k88hudson-cfa committed Aug 8, 2024
1 parent a8c4e9d commit 74bc1f2
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ license = "Apache-2.0"
homepage = "https://github.com/CDCgov/ixa"

[dependencies]
fxhash = "0.2.1"
rand = "0.8.5"
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
//! person trying to infect susceptible people in the population.
pub mod context;
pub mod plan;
pub mod random;
162 changes: 162 additions & 0 deletions src/random.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
use crate::context::Context;
use rand::SeedableRng;
use std::any::{Any, TypeId};
use std::cell::{RefCell, RefMut};
use std::collections::HashMap;

#[macro_export]
macro_rules! define_rng {
($random_id:ident) => {
struct $random_id {}

impl $crate::random::RngId for $random_id {
// TODO: This is hardcoded to StdRng; we should replace this
type RngType = rand::rngs::StdRng;

fn get_name() -> &'static str {
stringify!($random_id)
}
}
};
}
pub use define_rng;

pub trait RngId: Any {
type RngType: SeedableRng;
fn get_name() -> &'static str;
}

struct RngHolder {
rng: Box<dyn Any>,
}

struct RngData {
base_seed: u64,
rng_holders: RefCell<HashMap<TypeId, RngHolder>>,
}

crate::context::define_data_plugin!(
RngPlugin,
RngData,
RngData {
base_seed: 0,
rng_holders: RefCell::new(HashMap::new()),
}
);

#[allow(clippy::module_name_repetitions)]
pub trait RandomContext {
fn init_random(&mut self, base_seed: u64);

fn get_rng<R: RngId>(&self) -> RefMut<'_, R::RngType>;
}

impl RandomContext for Context {
fn init_random(&mut self, base_seed: u64) {
let data_container = self.get_data_container_mut::<RngPlugin>();
data_container.base_seed = base_seed;

// Clear any existing Rngs to ensure they are reseeded
let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
rng_map.clear();
}

fn get_rng<R: RngId>(&self) -> RefMut<'_, R::RngType> {
let data_container = self
.get_data_container::<RngPlugin>()
.expect("You must initialize the random number generator with a base seed");
let random_holders = data_container.rng_holders.try_borrow_mut().unwrap();

let random_holder = RefMut::map(random_holders, |random_holders| {
random_holders.entry(TypeId::of::<R>()).or_insert_with(|| {
let base_seed = data_container.base_seed;
let seed_offset = fxhash::hash64(R::get_name());
RngHolder {
rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)),
}
})
});

RefMut::map(random_holder, |random_holder| {
random_holder.rng.downcast_mut::<R::RngType>().unwrap()
})
}
}

#[cfg(test)]
mod test {
use crate::context::Context;
use crate::random::RandomContext;
use rand::RngCore;

define_rng!(FooRng);
define_rng!(BarRng);

#[test]
fn get_rng_basic() {
let mut context = Context::new();
context.init_random(42);

let mut foo_rng = context.get_rng::<FooRng>();
assert_eq!(foo_rng.next_u64(), 5113542052170610017);
assert_eq!(foo_rng.next_u64(), 8640506012583485895);
assert_eq!(foo_rng.next_u64(), 16699691489468094833);
}

#[test]
#[should_panic(expected = "You must initialize the random number generator with a base seed")]
fn panic_if_not_initialized() {
let context = Context::new();
context.get_rng::<FooRng>();
}

#[test]
#[should_panic]
fn get_rng_one_ref_per_rng_id() {
let mut context = Context::new();
context.init_random(42);
let mut foo_rng = context.get_rng::<FooRng>();

// This should panic because we already have a mutable reference to FooRng
let mut foo_rng_2 = context.get_rng::<BarRng>();
foo_rng.next_u64();
foo_rng_2.next_u64();
}

#[test]
fn get_rng_two_types() {
let mut context = Context::new();
context.init_random(42);

let mut foo_rng = context.get_rng::<FooRng>();
foo_rng.next_u64();
drop(foo_rng);

let mut bar_rng = context.get_rng::<BarRng>();
bar_rng.next_u64();
}

#[test]
fn reset_seed() {
let mut context = Context::new();
context.init_random(42);

let mut foo_rng = context.get_rng::<FooRng>();
let run_0 = foo_rng.next_u64();
let run_1 = foo_rng.next_u64();
drop(foo_rng);

// Reset with same seed, ensure we get the same values
context.init_random(42);
let mut foo_rng = context.get_rng::<FooRng>();
assert_eq!(run_0, foo_rng.next_u64());
assert_eq!(run_1, foo_rng.next_u64());
drop(foo_rng);

// Reset with different seed, ensure we get different values
context.init_random(88);
let mut foo_rng = context.get_rng::<FooRng>();
assert_ne!(run_0, foo_rng.next_u64());
assert_ne!(run_1, foo_rng.next_u64());
}
}

0 comments on commit 74bc1f2

Please sign in to comment.