Skip to content

Commit

Permalink
Use rng, not seed
Browse files Browse the repository at this point in the history
  • Loading branch information
swo committed Dec 19, 2024
1 parent 239a728 commit deed8e2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
7 changes: 4 additions & 3 deletions ringvax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ class Simulation:
"infection_times",
}

def __init__(self, params: dict[str, Any], seed: Optional[int] = None):
def __init__(
self, params: dict[str, Any], rng: Optional[numpy.random.Generator] = None
):
self.params = params
self.seed = seed
self.rng = numpy.random.default_rng(self.seed)
self.rng = rng if rng is not None else numpy.random.default_rng()
self.infections = {}

def create_person(self) -> str:
Expand Down
7 changes: 6 additions & 1 deletion ringvax/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import altair as alt
import graphviz
import numpy.random
import polars as pl
import streamlit as st

Expand Down Expand Up @@ -171,8 +172,12 @@ def app():
sims = []
with st.spinner("Running simulation..."):
tic = time.perf_counter()

# initialize rngs
rngs = numpy.random.default_rng(seed).spawn(nsim)

for i in range(nsim):
sims.append(Simulation(params=params, seed=seed + i))
sims.append(Simulation(params=params, rng=rngs[i]))
sims[-1].run()
toc = time.perf_counter()

Expand Down
14 changes: 7 additions & 7 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_generate_disease_history(rng):
"infectious_duration": 1.0,
"infection_rate": 2.0,
}
s = ringvax.Simulation(params=params, seed=rng)
s = ringvax.Simulation(params=params, rng=rng)
history = s.generate_disease_history(t_exposed=0.0)
# for ease of testing, make this a list of rounded numbers

Expand All @@ -74,7 +74,7 @@ def test_generate_disease_history_nonzero(rng):
"infectious_duration": 1.0,
"infection_rate": 2.0,
}
s = ringvax.Simulation(params=params, seed=rng)
s = ringvax.Simulation(params=params, rng=rng)
history = s.generate_disease_history(t_exposed=10.0)
assert history == {
"t_exposed": 10.0,
Expand All @@ -100,36 +100,36 @@ def base_params():


def test_simulate(rng, base_params):
s = ringvax.Simulation(params=base_params, seed=rng)
s = ringvax.Simulation(params=base_params, rng=rng)
s.run()
assert len(s.infections) == 19


def test_simulate_max_infections(rng, base_params):
params = base_params
params["max_infections"] = 10
s = ringvax.Simulation(params=params, seed=rng)
s = ringvax.Simulation(params=params, rng=rng)
s.run()
assert len(s.infections) == 10


def test_simulate_set_field(rng, base_params):
s = ringvax.Simulation(params=base_params, seed=rng)
s = ringvax.Simulation(params=base_params, rng=rng)
id = s.create_person()
s.update_person(id, {"generation": 0})
assert s.get_person_property(id, "generation") == 0


def test_simulate_error_on_bad_get_property(rng, base_params):
s = ringvax.Simulation(params=base_params, seed=rng)
s = ringvax.Simulation(params=base_params, rng=rng)
id = s.create_person()

with pytest.raises(RuntimeError, match="foo"):
s.get_person_property(id, "foo")


def test_simulate_error_on_bad_update_property(rng, base_params):
s = ringvax.Simulation(params=base_params, seed=rng)
s = ringvax.Simulation(params=base_params, rng=rng)
id = s.create_person()

with pytest.raises(RuntimeError, match="foo"):
Expand Down

0 comments on commit deed8e2

Please sign in to comment.