Skip to content

Commit

Permalink
refactor lists
Browse files Browse the repository at this point in the history
  • Loading branch information
swo committed Dec 18, 2024
1 parent dd086ad commit 8a12f18
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 26 deletions.
38 changes: 13 additions & 25 deletions ringvax/summary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections import Counter
from typing import Container, Sequence
from typing import Sequence

import numpy as np
import polars as pl
Expand All @@ -25,21 +24,6 @@
"""


def prepare_for_df(infection: dict) -> dict:
"""
Handle vector-valued infection properties for downstream use in pl.DataFrame
"""
dfable = {}
for k, v in infection.items():
if isinstance(v, np.ndarray):
assert k == "infection_times"
dfable |= {k: [float(vv) for vv in v]}
else:
assert isinstance(v, str) or not isinstance(v, Container)
dfable |= {k: v}
return dfable


def get_all_person_properties(
sims: Sequence[Simulation], exclude_termination_if: list[str] = ["max_infections"]
) -> pl.DataFrame:
Expand All @@ -56,23 +40,27 @@ def get_all_person_properties(

return pl.concat(
[
get_person_properties(sim).with_columns(simulation=sim_idx)
_get_person_properties(sim).with_columns(simulation=sim_idx)
for sim_idx, sim in enumerate(sims)
if sim.termination["criterion"] not in exclude_termination_if
]
)


def get_person_properties(sim: Simulation) -> pl.DataFrame:
def _get_person_properties(sim: Simulation) -> pl.DataFrame:
"""Get a DataFrame of all properties of all infections in a simulation"""
sims_dict = {k: [] for k in infection_schema.keys()}
return pl.from_dicts(
[_prepare_for_df(x) for x in sim.infections.values()], schema=infection_schema
)

for infection in sim.infections.values():
prep = prepare_for_df(infection)
for k in infection_schema.keys():
sims_dict[k].append(prep[k])

return pl.DataFrame(sims_dict, schema=infection_schema)
def _prepare_for_df(infection: dict) -> dict:
"""
Convert numpy arrays in a dictionary to lists, for DataFrame compatibility
"""
return {
k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in infection.items()
}


def summarize_detections(df: pl.DataFrame) -> pl.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def test_prep_for_df():
infection = {"infection_times": np.array([0, 1, 2]), "detected": False}
assert ringvax.summary.prepare_for_df(infection) == {
assert ringvax.summary._prepare_for_df(infection) == {
"infection_times": [0, 1, 2],
"detected": False,
}
Expand Down

0 comments on commit 8a12f18

Please sign in to comment.