generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f6a9f3b
commit 5e16381
Showing
1 changed file
with
31 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ use rand::prelude::Distribution; | |
use rand::{Rng, SeedableRng}; | ||
use std::any::{Any, TypeId}; | ||
use std::cell::{RefCell, RefMut}; | ||
use std::collections::HashMap; | ||
use std::collections::{HashMap, HashSet}; | ||
|
||
/// Use this to define a unique type which will be used as a key to retrieve | ||
/// an independent rng instance when calling `.get_rng`. | ||
|
@@ -45,6 +45,7 @@ struct RngHolder { | |
struct RngData { | ||
base_seed: u64, | ||
rng_holders: RefCell<HashMap<TypeId, RngHolder>>, | ||
rng_names: RefCell<HashSet<String>>, | ||
} | ||
|
||
// Registers a data container which stores: | ||
|
@@ -58,6 +59,7 @@ crate::context::define_data_plugin!( | |
RngData { | ||
base_seed: 0, | ||
rng_holders: RefCell::new(HashMap::new()), | ||
rng_names: RefCell::new(HashSet::new()) | ||
} | ||
); | ||
|
||
|
@@ -75,6 +77,11 @@ fn get_rng<R: RngId + 'static>(context: &Context) -> RefMut<R::RngType> { | |
.entry(TypeId::of::<R>()) | ||
// Create a new rng holder if it doesn't exist yet | ||
.or_insert_with(|| { | ||
let mut rng_names = data_container.rng_names.try_borrow_mut().unwrap(); | ||
assert!( | ||
rng_names.insert(R::get_name().to_owned()), | ||
"Rng name already exists" | ||
); | ||
let base_seed = data_container.base_seed; | ||
let seed_offset = fxhash::hash64(R::get_name()); | ||
RngHolder { | ||
|
@@ -106,8 +113,8 @@ impl ContextRandomExt for Context { | |
data_container.base_seed = base_seed; | ||
|
||
// Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called | ||
let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap(); | ||
rng_map.clear(); | ||
data_container.rng_holders.try_borrow_mut().unwrap().clear(); | ||
data_container.rng_names.try_borrow_mut().unwrap().clear(); | ||
} | ||
|
||
/// Gets a random sample from the random number generator associated with the given | ||
|
@@ -136,7 +143,7 @@ impl ContextRandomExt for Context { | |
mod test { | ||
use crate::context::Context; | ||
use crate::define_data_plugin; | ||
use crate::random::ContextRandomExt; | ||
use crate::random::{ContextRandomExt, RngId}; | ||
use rand::RngCore; | ||
use rand::{distributions::WeightedIndex, prelude::Distribution}; | ||
|
||
|
@@ -237,4 +244,24 @@ mod test { | |
} | ||
assert!((zero_counter - 1000 as i32).abs() < 30); | ||
} | ||
|
||
#[test] | ||
#[should_panic(expected = "Rng name already exists")] | ||
fn name_collision() { | ||
struct FooRng; | ||
|
||
impl RngId for FooRng { | ||
// TODO([email protected]): This is hardcoded to StdRng; we should replace this | ||
type RngType = rand::rngs::StdRng; | ||
|
||
fn get_name() -> &'static str { | ||
"FooRng" | ||
} | ||
} | ||
|
||
let mut context = Context::new(); | ||
context.init_random(42); | ||
context.sample::<FooRng, ()>(|_| ()); | ||
context.sample::<crate::random::test::FooRng, ()>(|_| ()); | ||
} | ||
} |