From 5e163812a11a62e18ece7f5a52caf26beb201122 Mon Sep 17 00:00:00 2001 From: Jason Asher Date: Tue, 20 Aug 2024 22:48:58 -0400 Subject: [PATCH] Demonstrate runtime collision guard --- src/random.rs | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/random.rs b/src/random.rs index 341fd994..80a76a68 100644 --- a/src/random.rs +++ b/src/random.rs @@ -3,7 +3,7 @@ use rand::prelude::Distribution; use rand::{Rng, SeedableRng}; use std::any::{Any, TypeId}; use std::cell::{RefCell, RefMut}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; /// Use this to define a unique type which will be used as a key to retrieve /// an independent rng instance when calling `.get_rng`. @@ -45,6 +45,7 @@ struct RngHolder { struct RngData { base_seed: u64, rng_holders: RefCell>, + rng_names: RefCell>, } // Registers a data container which stores: @@ -58,6 +59,7 @@ crate::context::define_data_plugin!( RngData { base_seed: 0, rng_holders: RefCell::new(HashMap::new()), + rng_names: RefCell::new(HashSet::new()) } ); @@ -75,6 +77,11 @@ fn get_rng(context: &Context) -> RefMut { .entry(TypeId::of::()) // Create a new rng holder if it doesn't exist yet .or_insert_with(|| { + let mut rng_names = data_container.rng_names.try_borrow_mut().unwrap(); + assert!( + rng_names.insert(R::get_name().to_owned()), + "Rng name already exists" + ); let base_seed = data_container.base_seed; let seed_offset = fxhash::hash64(R::get_name()); RngHolder { @@ -106,8 +113,8 @@ impl ContextRandomExt for Context { data_container.base_seed = base_seed; // Clear any existing Rngs to ensure they get re-seeded when `get_rng` is called - let mut rng_map = data_container.rng_holders.try_borrow_mut().unwrap(); - rng_map.clear(); + data_container.rng_holders.try_borrow_mut().unwrap().clear(); + data_container.rng_names.try_borrow_mut().unwrap().clear(); } /// Gets a random sample from the random number generator associated with the given @@ -136,7 +143,7 @@ impl ContextRandomExt for Context { mod test { use crate::context::Context; use crate::define_data_plugin; - use crate::random::ContextRandomExt; + use crate::random::{ContextRandomExt, RngId}; use rand::RngCore; use rand::{distributions::WeightedIndex, prelude::Distribution}; @@ -237,4 +244,24 @@ mod test { } assert!((zero_counter - 1000 as i32).abs() < 30); } + + #[test] + #[should_panic(expected = "Rng name already exists")] + fn name_collision() { + struct FooRng; + + impl RngId for FooRng { + // TODO(ryl8@cdc.gov): This is hardcoded to StdRng; we should replace this + type RngType = rand::rngs::StdRng; + + fn get_name() -> &'static str { + "FooRng" + } + } + + let mut context = Context::new(); + context.init_random(42); + context.sample::(|_| ()); + context.sample::(|_| ()); + } }