Skip to content

Commit

Permalink
check infection results
Browse files Browse the repository at this point in the history
  • Loading branch information
swo committed Dec 11, 2024
1 parent 4a7d8c0 commit 622c139
Showing 1 changed file with 37 additions and 12 deletions.
49 changes: 37 additions & 12 deletions ringvax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any
from typing import Any, List, Optional

import numpy.random


class Simulation:
def __init__(self, params: dict[str, Any], seed: int = None):
def __init__(self, params: dict[str, Any], seed: Optional[int] = None):
self.params = params
self.seed = seed
self.rng = numpy.random.default_rng(self.seed)
Expand All @@ -16,28 +16,33 @@ def run(self):

def create_person(self) -> str:
"""Add a new person to the data"""
id = len(self.infections)
id = str(len(self.infections))
self.infections[id] = {}
return id

def get_person_property(self, id: str, property: str) -> Any:
"""Get a property of a person"""
return self.infections[id][property]

def query_people(self, query: dict[str, Any]) -> [str]:
def query_people(
self, query: Optional[dict[str, Any]] = None
) -> List[str]:
"""Get IDs of people with a given set of properties"""
return [
id
for id, person in self.infections.items()
if all(person[k] == v for k, v in query.items())
]
if query is None:
return list(self.infections.keys())
else:
return [
id
for id, person in self.infections.items()
if all(person[k] == v for k, v in query.items())
]

def update_person(self, id: str, content: dict[str, Any]) -> None:
self.infections[id] |= content

def generate_person_properties(
self, t_exposed: float, infector: str
) -> str:
self, t_exposed: float, infector: Optional[str]
) -> dict[str, Any]:
"""Generate properties of a single infected person"""
# disease state history in this individual, and when they infect others
infection_history = self.generate_infection_history(
Expand Down Expand Up @@ -83,7 +88,7 @@ def run_infections(self) -> None:

this_generation = [index_id]

for generation in range(self.params["n_generations"]):
for _ in range(self.params["n_generations"]):
next_generation = []

# instantiate the next-gen infections caused by each infection in this generation
Expand All @@ -103,6 +108,26 @@ def run_infections(self) -> None:

this_generation = next_generation

# validate that we did everything right
# we should have no more than N generations in the data
assert (
max(
self.get_person_property(id, "generation")
for id in self.query_people()
)
<= self.params["n_generations"]
)

# the number of infections generated by each generation, should equal the number
# in the following generation, except for the final generation
for g in range(self.params["n_generations"]):
n_infections = sum(
len(self.get_person_property(id, "t_infections"))
for id in self.query_people({"generation": g})
)
n_infectees = len(self.query_people({"generation": g + 1}))
assert n_infections == n_infectees

def intervene(self) -> None:
"""Draw intervention outcomes and update chains of infection"""
for generation in range(self.params["n_generations"]):
Expand Down

0 comments on commit 622c139

Please sign in to comment.