Skip to content

Commit

Permalink
allow "latest" nssp data and switch from printing to logging
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer committed Oct 18, 2024
1 parent d30d4bd commit e72c164
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions nssp_demo/prep_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import logging
import os
import pathlib
from datetime import datetime, timedelta
Expand All @@ -8,6 +9,9 @@
import polars as pl
import pyarrow.parquet as pq

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

disease_map = {
"COVID-19": "COVID-19/Omicron",
"Influenza": "Influenza",
Expand All @@ -25,11 +29,9 @@
)
parser.add_argument(
"--report_date",
type=lambda d: datetime.strptime(d, "%Y-%m-%d").date(),
# default=(datetime.now()).strftime("%Y-%m-%d"),
required=True,
# todo: allow this to just be "latest" and have this be the default
help="Report date in YYYY-MM-DD format (default: yesterday)",
type=str,
default="latest",
help="Report date in YYYY-MM-DD format or latest (default: latest)",
)
parser.add_argument(
"--training_day_offset",
Expand All @@ -48,15 +50,29 @@

disease = args.disease
report_date = args.report_date

if report_date == "latest":
report_date = max(
f.stem
for f in pathlib.Path("private_data/nssp_etl_gold").glob("*.parquet")
)

report_date = datetime.strptime(report_date, "%Y-%m-%d").date()

logger.info(f"Report date: {report_date}")
training_day_offset = args.training_day_offset
n_training_days = args.n_training_days

last_training_date = report_date - timedelta(days=training_day_offset + 1)
# +1 because max date in dataset is report_date - 1
first_training_date = last_training_date - timedelta(days=n_training_days - 1)

nssp_data = duckdb.read_parquet(f"private_data/{report_date}.parquet")
nnh_estimates = pl.from_arrow(pq.read_table("private_data/prod.parquet"))
nssp_data = duckdb.read_parquet(
f"private_data/nssp_etl_gold/{report_date}.parquet"
)
nnh_estimates = pl.from_arrow(
pq.read_table("private_data/prod_param_estimates/prod.parquet")
)


generation_interval_pmf = (
Expand Down Expand Up @@ -102,7 +118,7 @@
)

for state_abb in all_states:
print(f"Processing {state_abb}")
logger.info(f"Processing {state_abb}")
data_to_save = duckdb.sql(
f"""
SELECT report_date, reference_date, SUM(value) AS ED_admissions,
Expand Down Expand Up @@ -173,10 +189,12 @@
os.makedirs(model_folder, exist_ok=True)
state_folder = pathlib.Path(model_folder, state_abb)
os.makedirs(state_folder, exist_ok=True)
print(f"Saving {state_abb}")
logger.info(f"Saving {state_abb}")
data_to_save.to_csv(str(pathlib.Path(state_folder, "data.csv")))

with open(
pathlib.Path(state_folder, "data_for_model_fit.json"), "w"
) as json_file:
json.dump(data_for_model_fit, json_file)

logger.info("Data preparation complete.")

0 comments on commit e72c164

Please sign in to comment.