From 622c139161d6510665bd0ef190bd1eadfd467a47 Mon Sep 17 00:00:00 2001 From: Scott Olesen Date: Wed, 11 Dec 2024 16:51:50 -0500 Subject: [PATCH] check infection results --- ringvax/__init__.py | 49 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/ringvax/__init__.py b/ringvax/__init__.py index 1c2a0d8..9059b70 100644 --- a/ringvax/__init__.py +++ b/ringvax/__init__.py @@ -1,10 +1,10 @@ -from typing import Any +from typing import Any, List, Optional import numpy.random class Simulation: - def __init__(self, params: dict[str, Any], seed: int = None): + def __init__(self, params: dict[str, Any], seed: Optional[int] = None): self.params = params self.seed = seed self.rng = numpy.random.default_rng(self.seed) @@ -16,7 +16,7 @@ def run(self): def create_person(self) -> str: """Add a new person to the data""" - id = len(self.infections) + id = str(len(self.infections)) self.infections[id] = {} return id @@ -24,20 +24,25 @@ def get_person_property(self, id: str, property: str) -> Any: """Get a property of a person""" return self.infections[id][property] - def query_people(self, query: dict[str, Any]) -> [str]: + def query_people( + self, query: Optional[dict[str, Any]] = None + ) -> List[str]: """Get IDs of people with a given set of properties""" - return [ - id - for id, person in self.infections.items() - if all(person[k] == v for k, v in query.items()) - ] + if query is None: + return list(self.infections.keys()) + else: + return [ + id + for id, person in self.infections.items() + if all(person[k] == v for k, v in query.items()) + ] def update_person(self, id: str, content: dict[str, Any]) -> None: self.infections[id] |= content def generate_person_properties( - self, t_exposed: float, infector: str - ) -> str: + self, t_exposed: float, infector: Optional[str] + ) -> dict[str, Any]: """Generate properties of a single infected person""" # disease state history in this individual, and when they infect others infection_history = self.generate_infection_history( @@ -83,7 +88,7 @@ def run_infections(self) -> None: this_generation = [index_id] - for generation in range(self.params["n_generations"]): + for _ in range(self.params["n_generations"]): next_generation = [] # instantiate the next-gen infections caused by each infection in this generation @@ -103,6 +108,26 @@ def run_infections(self) -> None: this_generation = next_generation + # validate that we did everything right + # we should have no more than N generations in the data + assert ( + max( + self.get_person_property(id, "generation") + for id in self.query_people() + ) + <= self.params["n_generations"] + ) + + # the number of infections generated by each generation, should equal the number + # in the following generation, except for the final generation + for g in range(self.params["n_generations"]): + n_infections = sum( + len(self.get_person_property(id, "t_infections")) + for id in self.query_people({"generation": g}) + ) + n_infectees = len(self.query_people({"generation": g + 1})) + assert n_infections == n_infectees + def intervene(self) -> None: """Draw intervention outcomes and update chains of infection""" for generation in range(self.params["n_generations"]):