Skip to content

Commit

Permalink
explicit initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
k88hudson-cfa committed Aug 13, 2024
1 parent 6505831 commit 6cdd875
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions src/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ crate::context::define_data_plugin!(
#[allow(clippy::module_name_repetitions)]
pub trait RandomContext {
fn set_base_random_seed(&mut self, base_seed: u64);

fn create_rng<R: RngId>(&self);
fn get_rng<R: RngId>(&self) -> RefMut<'_, R::RngType>;
}

Expand All @@ -67,27 +67,36 @@ impl RandomContext for Context {
rng_map.clear();
}

fn create_rng<R: RngId>(&self) {
let data_container = self
.get_data_container::<RngPlugin>()
.expect("You must initialize the random number generator with a base seed");
let mut random_holders = data_container.rng_holders.try_borrow_mut().unwrap();
let base_seed = data_container.base_seed;
let seed_offset = fxhash::hash64(R::get_name());
let holder = RngHolder {
rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)),
};
random_holders.insert(TypeId::of::<R>(), holder);
}

/// Gets a mutable reference to the random number generator associated with the given
/// `RngId`. If the Rng has not been used before, one will be created with the base seed
/// you defined in `set_base_random_seed`. Note that this will panic if `set_base_random_seed` was not called yet.
fn get_rng<R: RngId>(&self) -> RefMut<'_, R::RngType> {
let data_container = self
.get_data_container::<RngPlugin>()
.expect("You must initialize the random number generator with a base seed");
let random_holders = data_container.rng_holders.try_borrow_mut().unwrap();

let random_holder = RefMut::map(random_holders, |random_holders| {
random_holders.entry(TypeId::of::<R>()).or_insert_with(|| {
let base_seed = data_container.base_seed;
let seed_offset = fxhash::hash64(R::get_name());
RngHolder {
rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)),
}
})
});

RefMut::map(random_holder, |random_holder| {
random_holder.rng.downcast_mut::<R::RngType>().unwrap()

let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();

RefMut::map(rng_holders, |holders| {
holders
.get_mut(&TypeId::of::<R>())
.expect("You must call initialize with create_rng")
.rng
.downcast_mut::<R::RngType>()
.unwrap()
})
}
}
Expand All @@ -105,6 +114,7 @@ mod test {
fn get_rng_basic() {
let mut context = Context::new();
context.set_base_random_seed(42);
context.create_rng::<FooRng>();

let mut foo_rng = context.get_rng::<FooRng>();
assert_eq!(foo_rng.next_u64(), 5113542052170610017);
Expand All @@ -116,6 +126,7 @@ mod test {
#[should_panic(expected = "You must initialize the random number generator with a base seed")]
fn panic_if_not_initialized() {
let context = Context::new();
context.create_rng::<FooRng>();
context.get_rng::<FooRng>();
}

Expand All @@ -124,6 +135,8 @@ mod test {
fn get_rng_one_ref_per_rng_id() {
let mut context = Context::new();
context.set_base_random_seed(42);

context.create_rng::<FooRng>();
let mut foo_rng = context.get_rng::<FooRng>();

// This should panic because we already have a mutable reference to FooRng
Expand All @@ -137,10 +150,13 @@ mod test {
let mut context = Context::new();
context.set_base_random_seed(42);

context.create_rng::<FooRng>();
let mut foo_rng = context.get_rng::<FooRng>();

foo_rng.next_u64();
drop(foo_rng);

context.create_rng::<BarRng>();
let mut bar_rng = context.get_rng::<BarRng>();
bar_rng.next_u64();
}
Expand All @@ -150,20 +166,23 @@ mod test {
let mut context = Context::new();
context.set_base_random_seed(42);

context.create_rng::<FooRng>();
let mut foo_rng = context.get_rng::<FooRng>();
let run_0 = foo_rng.next_u64();
let run_1 = foo_rng.next_u64();
drop(foo_rng);

// Reset with same seed, ensure we get the same values
context.set_base_random_seed(42);
context.create_rng::<FooRng>();
let mut foo_rng = context.get_rng::<FooRng>();
assert_eq!(run_0, foo_rng.next_u64());
assert_eq!(run_1, foo_rng.next_u64());
drop(foo_rng);

// Reset with different seed, ensure we get different values
context.set_base_random_seed(88);
context.create_rng::<FooRng>();
let mut foo_rng = context.get_rng::<FooRng>();
assert_ne!(run_0, foo_rng.next_u64());
assert_ne!(run_1, foo_rng.next_u64());
Expand Down

0 comments on commit 6cdd875

Please sign in to comment.