Skip to content

Commit

Permalink
Demonstrate runtime collision guard
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonasher committed Aug 21, 2024
1 parent f6a9f3b commit 5e16381
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions src/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use rand::prelude::Distribution;
use rand::{Rng, SeedableRng};
use std::any::{Any, TypeId};
use std::cell::{RefCell, RefMut};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

/// Use this to define a unique type which will be used as a key to retrieve
/// an independent rng instance when calling `.get_rng`.
Expand Down Expand Up @@ -45,6 +45,7 @@ struct RngHolder {
struct RngData {
base_seed: u64,
rng_holders: RefCell<HashMap<TypeId, RngHolder>>,
rng_names: RefCell<HashSet<String>>,
}

// Registers a data container which stores:
Expand All @@ -58,6 +59,7 @@ crate::context::define_data_plugin!(
RngData {
base_seed: 0,
rng_holders: RefCell::new(HashMap::new()),
rng_names: RefCell::new(HashSet::new())
}
);

Expand All @@ -75,6 +77,11 @@ fn get_rng<R: RngId + 'static>(context: &Context) -> RefMut<R::RngType> {
.entry(TypeId::of::<R>())
// Create a new rng holder if it doesn't exist yet
.or_insert_with(|| {
let mut rng_names = data_container.rng_names.try_borrow_mut().unwrap();
assert!(
rng_names.insert(R::get_name().to_owned()),
"Rng name already exists"
);
let base_seed = data_container.base_seed;
let seed_offset = fxhash::hash64(R::get_name());
RngHolder {
Expand Down Expand Up @@ -106,8 +113,8 @@ impl ContextRandomExt for Context {
data_container.base_seed = base_seed;

// Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called
let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap();
rng_map.clear();
data_container.rng_holders.try_borrow_mut().unwrap().clear();
data_container.rng_names.try_borrow_mut().unwrap().clear();
}

/// Gets a random sample from the random number generator associated with the given
Expand Down Expand Up @@ -136,7 +143,7 @@ impl ContextRandomExt for Context {
mod test {
use crate::context::Context;
use crate::define_data_plugin;
use crate::random::ContextRandomExt;
use crate::random::{ContextRandomExt, RngId};
use rand::RngCore;
use rand::{distributions::WeightedIndex, prelude::Distribution};

Expand Down Expand Up @@ -237,4 +244,24 @@ mod test {
}
assert!((zero_counter - 1000 as i32).abs() < 30);
}

#[test]
#[should_panic(expected = "Rng name already exists")]
fn name_collision() {
struct FooRng;

impl RngId for FooRng {
// TODO([email protected]): This is hardcoded to StdRng; we should replace this
type RngType = rand::rngs::StdRng;

fn get_name() -> &'static str {
"FooRng"
}
}

let mut context = Context::new();
context.init_random(42);
context.sample::<FooRng, ()>(|_| ());
context.sample::<crate::random::test::FooRng, ()>(|_| ());
}
}

0 comments on commit 5e16381

Please sign in to comment.