From 8946d155ebab153c8afffaa112029a66f1fc1d5a Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sun, 20 Oct 2024 15:58:25 -0400 Subject: [PATCH 01/15] Split prep_data.py into functions and make it a bit more configurable --- nssp_demo/prep_data.py | 306 ++++++++++++++++++++++++----------------- 1 file changed, 180 insertions(+), 126 deletions(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index a1e06a16..0032a414 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -9,122 +9,28 @@ import polars as pl import pyarrow.parquet as pq -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) -disease_map = { - "COVID-19": "COVID-19/Omicron", - "Influenza": "Influenza", - "RSV": "RSV", -} - -parser = argparse.ArgumentParser( - description="Create fit data for disease modeling." -) -parser.add_argument( - "--disease", - type=str, - required=True, - help="Disease to model (e.g., COVID-19, Influenza, RSV)", -) -parser.add_argument( - "--report_date", - type=str, - default="latest", - help="Report date in YYYY-MM-DD format or latest (default: latest)", -) -parser.add_argument( - "--training_day_offset", - type=int, - default=7, - help="Number of days before the reference day to use as test data (default: 7)", -) -parser.add_argument( - "--n_training_days", - type=int, - default=90, - help="Number of training days (default: 90)", -) - -args = parser.parse_args() - -disease = args.disease -report_date = args.report_date - -if report_date == "latest": - report_date = max( - f.stem - for f in pathlib.Path("private_data/nssp_etl_gold").glob("*.parquet") - ) - -report_date = datetime.strptime(report_date, "%Y-%m-%d").date() - -logger.info(f"Report date: {report_date}") -training_day_offset = args.training_day_offset -n_training_days = args.n_training_days - -last_training_date = report_date - timedelta(days=training_day_offset + 1) -# +1 because max date in dataset is report_date - 1 -first_training_date = last_training_date - timedelta(days=n_training_days - 1) - -nssp_data = duckdb.read_parquet( - f"private_data/nssp_etl_gold/{report_date}.parquet" -) -nnh_estimates = pl.from_arrow( - pq.read_table("private_data/prod_param_estimates/prod.parquet") -) - - -generation_interval_pmf = ( - nnh_estimates.filter( - (pl.col("geo_value").is_null()) - & (pl.col("disease") == disease) - & (pl.col("parameter") == "generation_interval") - & (pl.col("end_date").is_null()) # most recent estimate - ) - .get_column("value") - .to_list()[0] -) - -delay_pmf = ( - nnh_estimates.filter( - (pl.col("geo_value").is_null()) - & (pl.col("disease") == disease) - & (pl.col("parameter") == "delay") - & (pl.col("end_date").is_null()) # most recent estimate - ) - .get_column("value") - .to_list()[0] -) - -excluded_states = ["GU", "MO", "WY"] -all_states = ( - nssp_data.unique("geo_value") - .filter(f"geo_value NOT IN {excluded_states}") - .order("geo_value") - .pl()["geo_value"] - .to_list() -) - -facts = pl.read_csv( - "https://raw.githubusercontent.com/k5cents/usa/refs/heads/master/data-raw/facts.csv" -) -states = pl.read_csv( - "https://raw.githubusercontent.com/k5cents/usa/refs/heads/master/data-raw/states.csv" -) - -state_pop_df = facts.join(states, on="name").select( - ["abb", "name", "population"] -) +def process_and_save_state(state_abb, + disease, + report_date, + first_training_date, + last_training_date, + state_pop_df, + param_estimates, + output_data_dir, + logger=None): + disease_map = { + "COVID-19": "COVID-19/Omicron", + "Influenza": "Influenza", + "RSV": "RSV", + } -for state_abb in all_states: - logger.info(f"Processing {state_abb}") data_to_save = duckdb.sql( f""" SELECT report_date, reference_date, SUM(value) AS ED_admissions, CASE WHEN reference_date <= '{last_training_date}' - THEN 'train' - ELSE 'test' END AS data_type + THEN 'train' + ELSE 'test' END AS data_type FROM nssp_data WHERE disease = '{disease_map[disease]}' AND metric = 'count_ed_visits' AND geo_value = '{state_abb}' @@ -137,17 +43,37 @@ data_to_save_pl = data_to_save.pl() - actual_first_date = data_to_save_pl["reference_date"].min() - actual_last_date = data_to_save_pl["reference_date"].max() - state_pop = ( state_pop_df.filter(pl.col("abb") == state_abb) .get_column("population") .to_list()[0] ) + generation_interval_pmf = ( + param_estimates.filter( + (pl.col("geo_value").is_null()) + & (pl.col("disease") == disease) + & (pl.col("parameter") == "generation_interval") + & (pl.col("end_date").is_null()) # most recent estimate + ) + .get_column("value") + .to_list()[0] + ) + + delay_pmf = ( + param_estimates.filter( + (pl.col("geo_value").is_null()) + & (pl.col("disease") == disease) + & (pl.col("parameter") == "delay") + & (pl.col("end_date").is_null()) # most recent estimate + ) + .get_column("value") + .to_list()[0] + ) + + right_truncation_pmf = ( - nnh_estimates.filter( + param_estimates.filter( (pl.col("geo_value") == state_abb) & (pl.col("disease") == disease) & (pl.col("parameter") == "right_truncation") @@ -183,18 +109,146 @@ "state_pop": state_pop, } - model_folder_name = f"{disease.lower()}_r_{report_date}_f_{actual_first_date}_l_{actual_last_date}_t_{last_training_date}" + state_dir = os.path.join(output_data_dir, state_abb) + os.makedirs(state_dir, exist_ok=True) + if logger is not None: + logger.info(f"Saving {state_abb} to {state_dir}") + data_to_save.to_csv(str(pathlib.Path(state_dir, "data.csv"))) - model_folder = pathlib.Path("private_data", model_folder_name) - os.makedirs(model_folder, exist_ok=True) - state_folder = pathlib.Path(model_folder, state_abb) - os.makedirs(state_folder, exist_ok=True) - logger.info(f"Saving {state_abb}") - data_to_save.to_csv(str(pathlib.Path(state_folder, "data.csv"))) - - with open( - pathlib.Path(state_folder, "data_for_model_fit.json"), "w" - ) as json_file: + with open(os.path.join(state_dir, "data_for_model_fit.json"), "w" + ) as json_file: json.dump(data_for_model_fit, json_file) -logger.info("Data preparation complete.") + +def main(disease, + report_date, + nssp_data_dir, + param_estimate_dir, + output_data_dir, + training_day_offset, + n_training_days): + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + if report_date == "latest": + report_date = max( + f.stem + for f in pathlib.Path(nssp_data_dir).glob("*.parquet") + ) + + report_date = datetime.strptime(report_date, "%Y-%m-%d").date() + + logger.info(f"Report date: {report_date}") + + last_training_date = (report_date - + timedelta(days=training_day_offset + 1)) + # +1 because max date in dataset is report_date - 1 + first_training_date = (last_training_date - + timedelta(days=n_training_days - 1)) + + datafile = f"{report_date}.parquet" + nssp_data = duckdb.read_parquet( + os.path.join(nssp_data_dir, datafile)) + param_estimates = pl.from_arrow( + pq.read_table(os.path.join( + param_estimate_dir, + "prod.parquet"))) + + excluded_states = ["GU", "MO", "WY"] + all_states = ( + nssp_data.unique("geo_value") + .filter(f"geo_value NOT IN {excluded_states}") + .order("geo_value") + .pl()["geo_value"] + .to_list() + ) + + facts = pl.read_csv( + "https://raw.githubusercontent.com/k5cents/usa/" + "refs/heads/master/data-raw/facts.csv" + ) + states = pl.read_csv( + "https://raw.githubusercontent.com/k5cents/usa/" + "refs/heads/master/data-raw/states.csv" + ) + + state_pop_df = facts.join(states, on="name").select( + ["abb", "name", "population"] + ) + + model_folder_name = ( + f"{disease.lower()}_r_{report_date}_f_" + f"{first_training_date}_t_{last_training_date}") + + model_folder = os.path.join(output_data_dir, model_folder_name) + os.makedirs(model_folder, exist_ok=True) + + for state_abb in all_states: + logger.info(f"Processing {state_abb}") + process_and_save_state( + state_abb=state_abb, + disease=disease, + report_date=report_date, + first_training_date=first_training_date, + last_training_date=last_training_date, + state_pop_df=state_pop_df, + param_estimates=param_estimates, + output_data_dir=output_data_dir, + logger=logger) + logger.info("Data preparation complete.") + + +parser = argparse.ArgumentParser( + description="Create fit data for disease modeling." +) +parser.add_argument( + "--disease", + type=str, + required=True, + help="Disease to model (e.g., COVID-19, Influenza, RSV)", +) +parser.add_argument( + "--report-date", + type=str, + default="latest", + help="Report date in YYYY-MM-DD format or latest (default: latest)", +) + +parser.add_argument( + "--nssp-data-dir", + type=str, + default=os.path.join("private_data", "nssp_etl_gold"), + help="Directory in which to look for NSSP input data.") + +parser.add_argument( + "--param-data-dir", + type=str, + default=os.path.join("private_data", "prod_param_estimates"), + help=( + "Directory in which to look for parameter estimates" + "such as delay PMFs.")) + +parser.add_argument( + "--output-data-dir", + type=str, + default=os.path.join("private_data"), + help="Directory in which to save output data.") + +parser.add_argument( + "--training-day-offset", + type=int, + default=7, + help="Number of days before the reference day to use as test data (default: 7)", +) + +parser.add_argument( + "--n-training-days", + type=int, + default=90, + help="Number of training days (default: 90)", +) + +args = parser.parse_args() + +main(**args) From 8a18e4bc2510f8a9057010d41a2b35ee26d858ee Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sun, 20 Oct 2024 16:03:16 -0400 Subject: [PATCH 02/15] use parser output as dictionary --- nssp_demo/prep_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index 0032a414..8c50f952 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -251,4 +251,4 @@ def main(disease, args = parser.parse_args() -main(**args) +main(**vars(args)) From 46e2cded273d54ce6df96ba54bfc8ef3b4a0647c Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sun, 20 Oct 2024 16:04:05 -0400 Subject: [PATCH 03/15] Fix misnamed variable --- nssp_demo/prep_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index 8c50f952..79abb1df 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -123,7 +123,7 @@ def process_and_save_state(state_abb, def main(disease, report_date, nssp_data_dir, - param_estimate_dir, + param_data_dir, output_data_dir, training_day_offset, n_training_days): From 1eaf0334a5fdd7b04057c8cc1dd05481b9ed3f0b Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sun, 20 Oct 2024 22:14:05 +0000 Subject: [PATCH 04/15] Working prep_data script --- nssp_demo/prep_data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index 79abb1df..9721de39 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -12,6 +12,7 @@ def process_and_save_state(state_abb, disease, + nssp_data, report_date, first_training_date, last_training_date, @@ -71,7 +72,6 @@ def process_and_save_state(state_abb, .to_list()[0] ) - right_truncation_pmf = ( param_estimates.filter( (pl.col("geo_value") == state_abb) @@ -152,7 +152,7 @@ def main(disease, os.path.join(nssp_data_dir, datafile)) param_estimates = pl.from_arrow( pq.read_table(os.path.join( - param_estimate_dir, + param_data_dir, "prod.parquet"))) excluded_states = ["GU", "MO", "WY"] @@ -189,6 +189,7 @@ def main(disease, process_and_save_state( state_abb=state_abb, disease=disease, + nssp_data=nssp_data, report_date=report_date, first_training_date=first_training_date, last_training_date=last_training_date, From fef178b35b7d963bd4d45d0406ac31333f1cf4e1 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sun, 20 Oct 2024 22:19:11 +0000 Subject: [PATCH 05/15] Fix output path bug --- nssp_demo/prep_data.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index 9721de39..a13137fe 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -18,7 +18,7 @@ def process_and_save_state(state_abb, last_training_date, state_pop_df, param_estimates, - output_data_dir, + model_data_dir, logger=None): disease_map = { "COVID-19": "COVID-19/Omicron", @@ -109,7 +109,7 @@ def process_and_save_state(state_abb, "state_pop": state_pop, } - state_dir = os.path.join(output_data_dir, state_abb) + state_dir = os.path.join(model_data_dir, state_abb) os.makedirs(state_dir, exist_ok=True) if logger is not None: logger.info(f"Saving {state_abb} to {state_dir}") @@ -177,12 +177,12 @@ def main(disease, ["abb", "name", "population"] ) - model_folder_name = ( + model_dir_name = ( f"{disease.lower()}_r_{report_date}_f_" f"{first_training_date}_t_{last_training_date}") - model_folder = os.path.join(output_data_dir, model_folder_name) - os.makedirs(model_folder, exist_ok=True) + model_data_dir = os.path.join(output_data_dir, model_dir_name) + os.makedirs(model_data_dir, exist_ok=True) for state_abb in all_states: logger.info(f"Processing {state_abb}") @@ -195,7 +195,7 @@ def main(disease, last_training_date=last_training_date, state_pop_df=state_pop_df, param_estimates=param_estimates, - output_data_dir=output_data_dir, + model_data_dir=model_data_dir, logger=logger) logger.info("Data preparation complete.") From cbb20661d5ad72b2784d2c5ffd19f866f8676c9d Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Sun, 20 Oct 2024 22:32:46 +0000 Subject: [PATCH 06/15] Add right_truncation_offset calculation to python prep_data.py --- nssp_demo/prep_data.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index a13137fe..70978630 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -89,11 +89,20 @@ def process_and_save_state(state_abb, .to_list()[0] ) + last_actual_training_date = ( + data_to_save_pl.filter( + pl.col("data_type") == "train" + ).get_column("reference_date").max()) + + right_truncation_offset = ( + report_date - last_actual_training_date).days + train_ed_admissions = ( data_to_save_pl.filter(pl.col("data_type") == "train") .get_column("ED_admissions") .to_list() ) + test_ed_admissions = ( data_to_save_pl.filter(pl.col("data_type") == "test") .get_column("ED_admissions") @@ -107,6 +116,7 @@ def process_and_save_state(state_abb, "data_observed_hospital_admissions": train_ed_admissions, "test_ed_admissions": test_ed_admissions, "state_pop": state_pop, + "right_truncation_offset": right_truncation_offset } state_dir = os.path.join(model_data_dir, state_abb) From c122b2c84ba8c3d161e4741e9b93df9634bd9216 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 21 Oct 2024 00:35:24 +0000 Subject: [PATCH 07/15] Make post_process.R more generic --- nssp_demo/post_process.R | 86 +++++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/nssp_demo/post_process.R b/nssp_demo/post_process.R index 489902de..2f44a1f4 100644 --- a/nssp_demo/post_process.R +++ b/nssp_demo/post_process.R @@ -10,7 +10,8 @@ theme_set(theme_minimal_grid()) disease_name_formatter <- c("covid-19" = "COVID-19", "influenza" = "Flu") -make_forecast_fig <- function(model_dir) { +make_forecast_fig <- function(model_dir, + base_dir) { disease_name_raw <- base_dir %>% path_file() %>% str_extract("^.+(?=_r_)") @@ -28,9 +29,10 @@ make_forecast_fig <- function(model_dir) { dat <- read_csv(data_path) %>% + rename(date = reference_date) %>% arrange(date) %>% mutate(time = row_number() - 1) %>% - rename(.value = COVID_ED_admissions) + rename(.value = ED_admissions) last_training_date <- dat %>% filter(data_type == "train") %>% @@ -119,39 +121,49 @@ make_forecast_fig <- function(model_dir) { } -base_dir <- path(here( - "nssp_demo", - "private_data", - "covid-19_r_2024-10-10_f_2024-04-12_l_2024-10-09_t_2024-10-05" -)) - - -forecast_fig_tbl <- - tibble(base_model_dir = dir_ls(base_dir)) %>% - filter( - path(base_model_dir, "inference_data", ext = "csv") %>% - file_exists() - ) %>% - mutate(forecast_fig = map(base_model_dir, make_forecast_fig)) %>% - mutate(figure_path = path(base_model_dir, "forecast_plot", ext = "pdf")) - -pwalk( - forecast_fig_tbl %>% select(forecast_fig, figure_path), - function(forecast_fig, figure_path) { - save_plot( - filename = figure_path, - plot = forecast_fig, - device = cairo_pdf, base_height = 6 +postprocess <- function(base_dir){ + forecast_fig_tbl <- + tibble(base_model_dir = dir_ls(base_dir)) %>% + filter( + path(base_model_dir, "inference_data", ext = "csv") %>% + file_exists() + ) %>% + mutate(forecast_fig = map( + base_model_dir, + \(x) make_forecast_fig(x, base_dir=base_dir)), + figure_path = path(base_model_dir, + "forecast_plot", ext = "pdf")) + + pwalk( + forecast_fig_tbl %>% select(forecast_fig, figure_path), + function(forecast_fig, figure_path) { + save_plot( + filename = figure_path, + plot = forecast_fig, + device = cairo_pdf, base_height = 6 + ) + } ) - } -) - -str_c(forecast_fig_tbl$figure_path, collapse = " ") %>% - str_c( - path(base_dir, - glue("{path_file(base_dir)}_all_forecasts"), - ext = "pdf" - ), - sep = " " - ) %>% - system2("pdfunite", args = .) + + str_c(forecast_fig_tbl$figure_path, collapse = " ") %>% + str_c( + path(base_dir, + glue("{path_file(base_dir)}_all_forecasts"), + ext = "pdf" + ), + sep = " " + ) %>% + system2("pdfunite", args = .) +} + +argv_parser <- argparser::arg_parser(paste0( + "Postprocess Pyrenew HEW model fits" + )) |> + argparser::add_argument( + "model_dir", + help = "Directory of forecasts to postprocess" + ) + +argv <- argparser::parse_args(argv_parser) + +postprocess(argv$model_dir) From f15eec6758bb9bd6026b007ba06ad1e478b722ea Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 21 Oct 2024 00:47:49 +0000 Subject: [PATCH 08/15] Quiet read_csv, progress messages --- nssp_demo/post_process.R | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nssp_demo/post_process.R b/nssp_demo/post_process.R index 2f44a1f4..26e95c9d 100644 --- a/nssp_demo/post_process.R +++ b/nssp_demo/post_process.R @@ -21,6 +21,9 @@ make_forecast_fig <- function(model_dir, pluck(1) %>% tail(1) + message( + glue::glue("Making forecast fig for {state_abb}/{disease_name_raw}" + )) data_path <- path(model_dir, "data", ext = "csv") inference_data_path <- path(model_dir, "inference_data", @@ -28,7 +31,7 @@ make_forecast_fig <- function(model_dir, ) - dat <- read_csv(data_path) %>% + dat <- read_csv(data_path, show_col_types=FALSE) %>% rename(date = reference_date) %>% arrange(date) %>% mutate(time = row_number() - 1) %>% @@ -50,7 +53,7 @@ make_forecast_fig <- function(model_dir, } pyrenew_samples <- - read_csv(inference_data_path) %>% + read_csv(inference_data_path, show_col_types=FALSE) %>% rename_with(\(varname) str_remove_all(varname, "\\(|\\)|\\'|(, \\d+)")) |> rename( .chain = chain, From 2d6e2b348c8306b14de877729afdb7376e2e5836 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 21 Oct 2024 13:33:04 -0400 Subject: [PATCH 09/15] Format files --- nssp_demo/prep_data.py | 91 +++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 42 deletions(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index 70978630..6b6ee863 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -10,16 +10,18 @@ import pyarrow.parquet as pq -def process_and_save_state(state_abb, - disease, - nssp_data, - report_date, - first_training_date, - last_training_date, - state_pop_df, - param_estimates, - model_data_dir, - logger=None): +def process_and_save_state( + state_abb, + disease, + nssp_data, + report_date, + first_training_date, + last_training_date, + state_pop_df, + param_estimates, + model_data_dir, + logger=None, +): disease_map = { "COVID-19": "COVID-19/Omicron", "Influenza": "Influenza", @@ -90,12 +92,12 @@ def process_and_save_state(state_abb, ) last_actual_training_date = ( - data_to_save_pl.filter( - pl.col("data_type") == "train" - ).get_column("reference_date").max()) + data_to_save_pl.filter(pl.col("data_type") == "train") + .get_column("reference_date") + .max() + ) - right_truncation_offset = ( - report_date - last_actual_training_date).days + right_truncation_offset = (report_date - last_actual_training_date).days train_ed_admissions = ( data_to_save_pl.filter(pl.col("data_type") == "train") @@ -116,7 +118,7 @@ def process_and_save_state(state_abb, "data_observed_hospital_admissions": train_ed_admissions, "test_ed_admissions": test_ed_admissions, "state_pop": state_pop, - "right_truncation_offset": right_truncation_offset + "right_truncation_offset": right_truncation_offset, } state_dir = os.path.join(model_data_dir, state_abb) @@ -125,45 +127,44 @@ def process_and_save_state(state_abb, logger.info(f"Saving {state_abb} to {state_dir}") data_to_save.to_csv(str(pathlib.Path(state_dir, "data.csv"))) - with open(os.path.join(state_dir, "data_for_model_fit.json"), "w" - ) as json_file: + with open( + os.path.join(state_dir, "data_for_model_fit.json"), "w" + ) as json_file: json.dump(data_for_model_fit, json_file) -def main(disease, - report_date, - nssp_data_dir, - param_data_dir, - output_data_dir, - training_day_offset, - n_training_days): - +def main( + disease, + report_date, + nssp_data_dir, + param_data_dir, + output_data_dir, + training_day_offset, + n_training_days, +): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) if report_date == "latest": report_date = max( - f.stem - for f in pathlib.Path(nssp_data_dir).glob("*.parquet") + f.stem for f in pathlib.Path(nssp_data_dir).glob("*.parquet") ) report_date = datetime.strptime(report_date, "%Y-%m-%d").date() logger.info(f"Report date: {report_date}") - last_training_date = (report_date - - timedelta(days=training_day_offset + 1)) + last_training_date = report_date - timedelta(days=training_day_offset + 1) # +1 because max date in dataset is report_date - 1 - first_training_date = (last_training_date - - timedelta(days=n_training_days - 1)) + first_training_date = last_training_date - timedelta( + days=n_training_days - 1 + ) datafile = f"{report_date}.parquet" - nssp_data = duckdb.read_parquet( - os.path.join(nssp_data_dir, datafile)) + nssp_data = duckdb.read_parquet(os.path.join(nssp_data_dir, datafile)) param_estimates = pl.from_arrow( - pq.read_table(os.path.join( - param_data_dir, - "prod.parquet"))) + pq.read_table(os.path.join(param_data_dir, "prod.parquet")) + ) excluded_states = ["GU", "MO", "WY"] all_states = ( @@ -189,7 +190,8 @@ def main(disease, model_dir_name = ( f"{disease.lower()}_r_{report_date}_f_" - f"{first_training_date}_t_{last_training_date}") + f"{first_training_date}_t_{last_training_date}" + ) model_data_dir = os.path.join(output_data_dir, model_dir_name) os.makedirs(model_data_dir, exist_ok=True) @@ -206,7 +208,8 @@ def main(disease, state_pop_df=state_pop_df, param_estimates=param_estimates, model_data_dir=model_data_dir, - logger=logger) + logger=logger, + ) logger.info("Data preparation complete.") @@ -230,7 +233,8 @@ def main(disease, "--nssp-data-dir", type=str, default=os.path.join("private_data", "nssp_etl_gold"), - help="Directory in which to look for NSSP input data.") + help="Directory in which to look for NSSP input data.", +) parser.add_argument( "--param-data-dir", @@ -238,13 +242,16 @@ def main(disease, default=os.path.join("private_data", "prod_param_estimates"), help=( "Directory in which to look for parameter estimates" - "such as delay PMFs.")) + "such as delay PMFs." + ), +) parser.add_argument( "--output-data-dir", type=str, default=os.path.join("private_data"), - help="Directory in which to save output data.") + help="Directory in which to save output data.", +) parser.add_argument( "--training-day-offset", From cb327888601d4efcc0571843ff7db007273ab2cf Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Mon, 21 Oct 2024 15:49:23 +0000 Subject: [PATCH 10/15] only run main if we're running in main --- nssp_demo/prep_data.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index 6b6ee863..d7d634c6 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -42,7 +42,6 @@ def process_and_save_state( ORDER BY report_date, reference_date """ ) - # why not count_admitted_ed_visits ? data_to_save_pl = data_to_save.pl() @@ -267,6 +266,6 @@ def main( help="Number of training days (default: 90)", ) -args = parser.parse_args() - -main(**vars(args)) +if __name__ == "__main__": + args = parser.parse_args() + main(**vars(args)) From 28543286edb543c53b1ae7837cf285678b76c164 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Mon, 21 Oct 2024 15:49:46 +0000 Subject: [PATCH 11/15] clear up temporary code info for post_process --- nssp_demo/post_process.R | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nssp_demo/post_process.R b/nssp_demo/post_process.R index 8ef9bea0..f768a39d 100644 --- a/nssp_demo/post_process.R +++ b/nssp_demo/post_process.R @@ -7,6 +7,10 @@ library(scales) library(here) library(argparser) +theme_set(theme_minimal_grid()) + +disease_name_formatter <- c("covid-19" = "COVID-19", "influenza" = "Flu") + # Create a parser p <- arg_parser("Generate forecast figures") %>% add_argument(p, "--model_dir", @@ -29,9 +33,7 @@ good_chain_tol <- argv$good_chain_tol base_dir <- path_dir(model_dir) -theme_set(theme_minimal_grid()) -disease_name_formatter <- c("covid-19" = "COVID-19", "influenza" = "Flu") read_pyrenew_samples <- function(inference_data_path, filter_bad_chains = TRUE, @@ -178,8 +180,9 @@ save_plot( device = cairo_pdf, base_height = 6 ) - -# Temp code while command line version doesn't work +# File will end here once command line version is working +# Temp code to run for all states while command line version doesn't work +# Command line version is dependent on https://github.com/rstudio/renv/pull/2018 base_dir <- path( "nssp_demo", "private_data", From e3f80ab6db0593f24e4571cc77465f0a976b6cbf Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Mon, 21 Oct 2024 18:41:19 +0000 Subject: [PATCH 12/15] remove duckdb --- nssp_demo/prep_data.py | 72 ++++++++++++++++++++---------------------- pyproject.toml | 1 - 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index d7d634c6..60b2be27 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -2,8 +2,8 @@ import json import logging import os -import pathlib from datetime import datetime, timedelta +from pathlib import Path import duckdb import polars as pl @@ -28,23 +28,24 @@ def process_and_save_state( "RSV": "RSV", } - data_to_save = duckdb.sql( - f""" - SELECT report_date, reference_date, SUM(value) AS ED_admissions, - CASE WHEN reference_date <= '{last_training_date}' - THEN 'train' - ELSE 'test' END AS data_type - FROM nssp_data - WHERE disease = '{disease_map[disease]}' AND metric = 'count_ed_visits' - AND geo_value = '{state_abb}' - and reference_date >= '{first_training_date}' - GROUP BY report_date, reference_date - ORDER BY report_date, reference_date - """ + data_to_save = ( + nssp_data.filter( + (pl.col("disease") == disease_map[disease]) + & (pl.col("metric") == "count_ed_visits") + & (pl.col("geo_value") == state_abb) + & (pl.col("reference_date") >= first_training_date) + ) + .group_by(["report_date", "reference_date"]) + .agg(pl.col("value").sum().alias("ED_admissions")) + .with_columns( + pl.when(pl.col("reference_date") <= last_training_date) + .then(pl.lit("train")) + .otherwise(pl.lit("test")) + .alias("data_type") + ) + .sort(["report_date", "reference_date"]) ) - data_to_save_pl = data_to_save.pl() - state_pop = ( state_pop_df.filter(pl.col("abb") == state_abb) .get_column("population") @@ -90,22 +91,18 @@ def process_and_save_state( .to_list()[0] ) - last_actual_training_date = ( - data_to_save_pl.filter(pl.col("data_type") == "train") - .get_column("reference_date") - .max() - ) - - right_truncation_offset = (report_date - last_actual_training_date).days + right_truncation_offset = (report_date - last_training_date).days train_ed_admissions = ( - data_to_save_pl.filter(pl.col("data_type") == "train") + data_to_save.filter(pl.col("data_type") == "train") + .collect() .get_column("ED_admissions") .to_list() ) test_ed_admissions = ( - data_to_save_pl.filter(pl.col("data_type") == "test") + data_to_save.filter(pl.col("data_type") == "test") + .collect() .get_column("ED_admissions") .to_list() ) @@ -122,13 +119,12 @@ def process_and_save_state( state_dir = os.path.join(model_data_dir, state_abb) os.makedirs(state_dir, exist_ok=True) + if logger is not None: logger.info(f"Saving {state_abb} to {state_dir}") - data_to_save.to_csv(str(pathlib.Path(state_dir, "data.csv"))) + data_to_save.sink_csv(Path(state_dir, "data.csv")) - with open( - os.path.join(state_dir, "data_for_model_fit.json"), "w" - ) as json_file: + with open(Path(state_dir, "data_for_model_fit.json"), "w") as json_file: json.dump(data_for_model_fit, json_file) @@ -146,7 +142,7 @@ def main( if report_date == "latest": report_date = max( - f.stem for f in pathlib.Path(nssp_data_dir).glob("*.parquet") + f.stem for f in Path(nssp_data_dir).glob("*.parquet") ) report_date = datetime.strptime(report_date, "%Y-%m-%d").date() @@ -160,19 +156,21 @@ def main( ) datafile = f"{report_date}.parquet" - nssp_data = duckdb.read_parquet(os.path.join(nssp_data_dir, datafile)) + nssp_data = pl.scan_parquet(Path(nssp_data_dir, datafile)) param_estimates = pl.from_arrow( pq.read_table(os.path.join(param_data_dir, "prod.parquet")) - ) + ) # make this lazy excluded_states = ["GU", "MO", "WY"] + all_states = ( - nssp_data.unique("geo_value") - .filter(f"geo_value NOT IN {excluded_states}") - .order("geo_value") - .pl()["geo_value"] + nssp_data.select(pl.col("geo_value").unique()) + .filter(~pl.col("geo_value").is_in(excluded_states)) + .collect() + .get_column("geo_value") .to_list() ) + all_states.sort() facts = pl.read_csv( "https://raw.githubusercontent.com/k5cents/usa/" @@ -192,7 +190,7 @@ def main( f"{first_training_date}_t_{last_training_date}" ) - model_data_dir = os.path.join(output_data_dir, model_dir_name) + model_data_dir = Path(output_data_dir, model_dir_name) os.makedirs(model_data_dir, exist_ok=True) for state_abb in all_states: diff --git a/pyproject.toml b/pyproject.toml index 398646c8..3c393ec7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ pyyaml = "^6.0.2" jupyter = "^1.0.0" ipykernel = "^6.29.5" polars = "^1.5.0" -duckdb = "^1.1.2" pyarrow = "^17.0.0" From 10d1239f469ed3957b834a2a1b154b662a2bed39 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Mon, 21 Oct 2024 18:42:38 +0000 Subject: [PATCH 13/15] actually remove duckdb import --- nssp_demo/prep_data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index 60b2be27..18e85501 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -5,7 +5,6 @@ from datetime import datetime, timedelta from pathlib import Path -import duckdb import polars as pl import pyarrow.parquet as pq From 9c6cb51cb46b2eb989d643e409d70428da4af819 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Mon, 21 Oct 2024 18:49:24 +0000 Subject: [PATCH 14/15] make param_estimates lazy --- nssp_demo/prep_data.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index 18e85501..991085c8 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -58,6 +58,7 @@ def process_and_save_state( & (pl.col("parameter") == "generation_interval") & (pl.col("end_date").is_null()) # most recent estimate ) + .collect() .get_column("value") .to_list()[0] ) @@ -69,6 +70,7 @@ def process_and_save_state( & (pl.col("parameter") == "delay") & (pl.col("end_date").is_null()) # most recent estimate ) + .collect() .get_column("value") .to_list()[0] ) @@ -86,6 +88,7 @@ def process_and_save_state( .filter( pl.col("reference_date") == pl.col("reference_date").max() ) # estimates nearest the report date + .collect() .get_column("value") .to_list()[0] ) @@ -156,9 +159,7 @@ def main( datafile = f"{report_date}.parquet" nssp_data = pl.scan_parquet(Path(nssp_data_dir, datafile)) - param_estimates = pl.from_arrow( - pq.read_table(os.path.join(param_data_dir, "prod.parquet")) - ) # make this lazy + param_estimates = pl.scan_parquet(Path(param_data_dir, "prod.parquet")) excluded_states = ["GU", "MO", "WY"] From fbca252c0092d4d9735040a7dd349901b28eb96a Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Mon, 21 Oct 2024 18:52:32 +0000 Subject: [PATCH 15/15] use Path for argument typs where appropriate --- nssp_demo/prep_data.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nssp_demo/prep_data.py b/nssp_demo/prep_data.py index 991085c8..fb078094 100644 --- a/nssp_demo/prep_data.py +++ b/nssp_demo/prep_data.py @@ -228,15 +228,15 @@ def main( parser.add_argument( "--nssp-data-dir", - type=str, - default=os.path.join("private_data", "nssp_etl_gold"), + type=Path, + default=Path("private_data", "nssp_etl_gold"), help="Directory in which to look for NSSP input data.", ) parser.add_argument( "--param-data-dir", - type=str, - default=os.path.join("private_data", "prod_param_estimates"), + type=Path, + default=Path("private_data", "prod_param_estimates"), help=( "Directory in which to look for parameter estimates" "such as delay PMFs." @@ -245,8 +245,8 @@ def main( parser.add_argument( "--output-data-dir", - type=str, - default=os.path.join("private_data"), + type=Path, + default="private_data", help="Directory in which to save output data.", )