Skip to content

Commit

Permalink
Switch api to sample / sample_distr
Browse files Browse the repository at this point in the history
  • Loading branch information
k88hudson-cfa committed Aug 20, 2024
1 parent 377efb2 commit 23ef220
Showing 1 changed file with 112 additions and 67 deletions.
179 changes: 112 additions & 67 deletions src/random.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::context::Context;
use rand::SeedableRng;
use rand::prelude::Distribution;
use rand::{Rng, SeedableRng};
use std::any::{Any, TypeId};
use std::cell::{RefCell, RefMut};
use std::collections::HashMap;
Expand Down Expand Up @@ -58,11 +59,41 @@ crate::context::define_data_plugin!(
}
);

/// Gets a mutable reference to the random number generator associated with the given
/// `RngId`. If the Rng has not been used before, one will be created with the base seed
/// you defined in `init`. Note that this will panic if `init` was not called yet.
fn get_rng<R: RngId + 'static>(context: &Context) -> RefMut<R::RngType> {
let data_container = context
.get_data_container::<RngPlugin>()
.expect("You must initialize the random number generator with a base seed");

let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
RefMut::map(rng_holders, |holders| {
holders
.entry(TypeId::of::<R>())
// Create a new rng holder if it doesn't exist yet
.or_insert_with(|| {
let base_seed = data_container.base_seed;
let seed_offset = fxhash::hash64(R::get_name());
RngHolder {
rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)),
}
})
.rng
.downcast_mut::<R::RngType>()
.unwrap()
})
}

// This is a trait exension on Context
pub trait ContextRandomExt {
fn init_random(&mut self, base_seed: u64);

fn get_rng<R: RngId + 'static>(&self) -> RefMut<R::RngType>;
fn sample<R: RngId + 'static, T>(&self, sampler: impl FnOnce(&mut R::RngType) -> T) -> T;

fn sample_distr<R: RngId + 'static, T>(&self, distribution: impl Distribution<T>) -> T
where
R::RngType: Rng;
}

impl ContextRandomExt for Context {
Expand All @@ -77,39 +108,35 @@ impl ContextRandomExt for Context {
rng_map.clear();
}

/// Gets a mutable reference to the random number generator associated with the given
/// `RngId`. If the Rng has not been used before, one will be created with the base seed
/// you defined in `init`. Note that this will panic if `init` was not called yet.
fn get_rng<R: RngId + 'static>(&self) -> RefMut<R::RngType> {
let data_container = self
.get_data_container::<RngPlugin>()
.expect("You must initialize the random number generator with a base seed");

let rng_holders = data_container.rng_holders.try_borrow_mut().unwrap();
RefMut::map(rng_holders, |holders| {
holders
.entry(TypeId::of::<R>())
// Create a new rng holder if it doesn't exist yet
.or_insert_with(|| {
let base_seed = data_container.base_seed;
let seed_offset = fxhash::hash64(R::get_name());
RngHolder {
rng: Box::new(R::RngType::seed_from_u64(base_seed + seed_offset)),
}
})
.rng
.downcast_mut::<R::RngType>()
.unwrap()
})
/// Gets a random sample from the random number generator associated with the given
/// `RngId` by applying the specified sampler function. If the Rng has not been used
/// before, one will be created with the base seed you defined in `set_base_random_seed`.
/// Note that this will panic if `set_base_random_seed` was not called yet.
fn sample<R: RngId + 'static, T>(&self, sampler: impl FnOnce(&mut R::RngType) -> T) -> T {
let mut rng = get_rng::<R>(self);
sampler(&mut rng)
}

/// Gets a random sample from the specified distribution using a random number generator
/// associated with the given `RngId`. If the Rng has not been used before, one will be
/// created with the base seed you defined in `set_base_random_seed`.
/// Note that this will panic if `set_base_random_seed` was not called yet.
fn sample_distr<R: RngId + 'static, T>(&self, distribution: impl Distribution<T>) -> T
where
R::RngType: Rng,
{
let mut rng = get_rng::<R>(self);
distribution.sample::<R::RngType>(&mut rng)
}
}

#[cfg(test)]
mod test {
use crate::context::Context;
use crate::define_data_plugin;
use crate::random::ContextRandomExt;
use rand::RngCore;
use rand_distr::{Distribution, Exp};
use rand::{distributions::WeightedIndex, prelude::Distribution};

define_rng!(FooRng);
define_rng!(BarRng);
Expand All @@ -119,75 +146,93 @@ mod test {
let mut context = Context::new();
context.init_random(42);

let mut foo_rng = context.get_rng::<FooRng>();

assert_ne!(foo_rng.next_u64(), foo_rng.next_u64());
assert_ne!(
context.sample::<FooRng, u64>(|rng| rng.next_u64()),
context.sample::<FooRng, u64>(|rng| rng.next_u64())
);
}

#[test]
#[should_panic(expected = "You must initialize the random number generator with a base seed")]
fn panic_if_not_initialized() {
let context = Context::new();
context.get_rng::<FooRng>();
context.sample::<FooRng, u64>(|rng| rng.next_u64());
}

#[test]
#[should_panic]
fn no_multiple_references_to_rngs() {
fn multiple_references_with_drop() {
let mut context = Context::new();
context.init_random(42);
let mut foo_rng = context.get_rng::<FooRng>();

// This should panic because we already have a mutable reference to FooRng
let mut foo_rng_2 = context.get_rng::<BarRng>();
foo_rng.next_u64();
foo_rng_2.next_u64();
assert_ne!(
context.sample::<FooRng, u64>(|rng| rng.next_u64()),
context.sample::<BarRng, u64>(|rng| rng.next_u64())
);
}

#[test]
fn multiple_references_with_drop() {
fn reset_seed() {
let mut context = Context::new();
context.init_random(42);

let mut foo_rng = context.get_rng::<FooRng>();
foo_rng.next_u64();
// If you drop the first reference, you should be able to get a reference to a different rng
drop(foo_rng);
let run_0 = context.sample::<FooRng, u64>(|rng| rng.next_u64());
let run_1 = context.sample::<FooRng, u64>(|rng| rng.next_u64());

// Reset with same seed, ensure we get the same values
context.init_random(42);
assert_eq!(run_0, context.sample::<FooRng, u64>(|rng| rng.next_u64()));
assert_eq!(run_1, context.sample::<FooRng, u64>(|rng| rng.next_u64()));

let mut bar_rng = context.get_rng::<BarRng>();
bar_rng.next_u64();
// Reset with different seed, ensure we get different values
context.init_random(88);
assert_ne!(run_0, context.sample::<FooRng, u64>(|rng| rng.next_u64()));
assert_ne!(run_1, context.sample::<FooRng, u64>(|rng| rng.next_u64()));
}

define_data_plugin!(
SamplerData,
WeightedIndex<f64>,
WeightedIndex::new(vec![1.0]).unwrap()
);

#[test]
fn usage_with_distribution() {
fn sampler_function_closure_capture() {
let mut context = Context::new();
context.init_random(42);
let mut rng = context.get_rng::<FooRng>();
let dist = Exp::new(1.0).unwrap();
assert_ne!(dist.sample(&mut *rng), dist.sample(&mut *rng));
// Initialize weighted sampler
*context.get_data_container_mut::<SamplerData>() =
WeightedIndex::new(vec![1.0, 2.0]).unwrap();

let parameters = context.get_data_container::<SamplerData>().unwrap();
let n_samples = 3000;
let mut zero_counter = 0;
for _ in 0..n_samples {
let sample = context.sample::<FooRng, usize>(|rng| parameters.sample(rng));
if sample == 0 {
zero_counter += 1;
}
}
assert!((zero_counter - 1000 as i32).abs() < 30);
}

#[test]
fn reset_seed() {
fn sample_distribution() {
let mut context = Context::new();
context.init_random(42);

let mut foo_rng = context.get_rng::<FooRng>();
let run_0 = foo_rng.next_u64();
let run_1 = foo_rng.next_u64();
drop(foo_rng);

// Reset with same seed, ensure we get the same values
context.init_random(42);
let mut foo_rng = context.get_rng::<FooRng>();
assert_eq!(run_0, foo_rng.next_u64());
assert_eq!(run_1, foo_rng.next_u64());
drop(foo_rng);

// Reset with different seed, ensure we get different values
context.init_random(88);
let mut foo_rng = context.get_rng::<FooRng>();
assert_ne!(run_0, foo_rng.next_u64());
assert_ne!(run_1, foo_rng.next_u64());
// Initialize weighted sampler
*context.get_data_container_mut::<SamplerData>() =
WeightedIndex::new(vec![1.0, 2.0]).unwrap();

let parameters = context.get_data_container::<SamplerData>().unwrap();
let n_samples = 3000;
let mut zero_counter = 0;
for _ in 0..n_samples {
let sample = context.sample_distr::<FooRng, usize>(parameters);
if sample == 0 {
zero_counter += 1;
}
}
assert!((zero_counter - 1000 as i32).abs() < 30);
}
}

0 comments on commit 23ef220

Please sign in to comment.