diff --git a/ringvax/__init__.py b/ringvax/__init__.py index 31cb4bf..205a4a0 100644 --- a/ringvax/__init__.py +++ b/ringvax/__init__.py @@ -19,10 +19,11 @@ class Simulation: "infection_times", } - def __init__(self, params: dict[str, Any], seed: Optional[int] = None): + def __init__( + self, params: dict[str, Any], rng: Optional[numpy.random.Generator] = None + ): self.params = params - self.seed = seed - self.rng = numpy.random.default_rng(self.seed) + self.rng = rng if rng is not None else numpy.random.default_rng() self.infections = {} def create_person(self) -> str: diff --git a/ringvax/app.py b/ringvax/app.py index cf9282e..c824d24 100644 --- a/ringvax/app.py +++ b/ringvax/app.py @@ -3,6 +3,7 @@ import altair as alt import graphviz +import numpy.random import polars as pl import streamlit as st @@ -171,8 +172,12 @@ def app(): sims = [] with st.spinner("Running simulation..."): tic = time.perf_counter() + + # initialize rngs + rngs = numpy.random.default_rng(seed).spawn(nsim) + for i in range(nsim): - sims.append(Simulation(params=params, seed=seed + i)) + sims.append(Simulation(params=params, rng=rngs[i])) sims[-1].run() toc = time.perf_counter() diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 19ea7d5..2e89425 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -54,7 +54,7 @@ def test_generate_disease_history(rng): "infectious_duration": 1.0, "infection_rate": 2.0, } - s = ringvax.Simulation(params=params, seed=rng) + s = ringvax.Simulation(params=params, rng=rng) history = s.generate_disease_history(t_exposed=0.0) # for ease of testing, make this a list of rounded numbers @@ -74,7 +74,7 @@ def test_generate_disease_history_nonzero(rng): "infectious_duration": 1.0, "infection_rate": 2.0, } - s = ringvax.Simulation(params=params, seed=rng) + s = ringvax.Simulation(params=params, rng=rng) history = s.generate_disease_history(t_exposed=10.0) assert history == { "t_exposed": 10.0, @@ -100,7 +100,7 @@ def base_params(): def test_simulate(rng, base_params): - s = ringvax.Simulation(params=base_params, seed=rng) + s = ringvax.Simulation(params=base_params, rng=rng) s.run() assert len(s.infections) == 19 @@ -108,20 +108,20 @@ def test_simulate(rng, base_params): def test_simulate_max_infections(rng, base_params): params = base_params params["max_infections"] = 10 - s = ringvax.Simulation(params=params, seed=rng) + s = ringvax.Simulation(params=params, rng=rng) s.run() assert len(s.infections) == 10 def test_simulate_set_field(rng, base_params): - s = ringvax.Simulation(params=base_params, seed=rng) + s = ringvax.Simulation(params=base_params, rng=rng) id = s.create_person() s.update_person(id, {"generation": 0}) assert s.get_person_property(id, "generation") == 0 def test_simulate_error_on_bad_get_property(rng, base_params): - s = ringvax.Simulation(params=base_params, seed=rng) + s = ringvax.Simulation(params=base_params, rng=rng) id = s.create_person() with pytest.raises(RuntimeError, match="foo"): @@ -129,7 +129,7 @@ def test_simulate_error_on_bad_get_property(rng, base_params): def test_simulate_error_on_bad_update_property(rng, base_params): - s = ringvax.Simulation(params=base_params, seed=rng) + s = ringvax.Simulation(params=base_params, rng=rng) id = s.create_person() with pytest.raises(RuntimeError, match="foo"):