diff --git a/nssp_demo/fit_all_models.sh b/nssp_demo/fit_all_models.sh index 4c4be11b..b5a6ceb1 100644 --- a/nssp_demo/fit_all_models.sh +++ b/nssp_demo/fit_all_models.sh @@ -1,7 +1,9 @@ #!/bin/bash # Base directory containing subdirectories -BASE_DIR="private_data/r_2024-09-10_f_2024-03-13_l_2024-09-09_t_2024-08-14/" + +BASE_DIR="private_data/r_2024-10-01_f_2024-04-03_l_2024-09-30_t_2024-09-25/" + # Iterate over each subdirectory in the base directory for SUBDIR in "$BASE_DIR"/*/; do diff --git a/nssp_demo/fit_model.py b/nssp_demo/fit_model.py index 7c93dcfb..eea2cf08 100644 --- a/nssp_demo/fit_model.py +++ b/nssp_demo/fit_model.py @@ -69,8 +69,13 @@ model_data["data_observed_hospital_admissions"] ) state_pop = jnp.array(model_data["state_pop"]) -n_forecast_points = len(model_data["test_ed_admissions"]) +right_truncation_pmf_rv = DeterministicVariable( + "right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"]) +) + +right_truncation_offset = model_data["right_truncation_offset"] +n_forecast_points = 28 my_model = hosp_only_ww_model( state_pop=state_pop, i0_first_obs_n_rv=i0_first_obs_n_rv, @@ -79,16 +84,16 @@ autoreg_rt_rv=autoreg_rt_rv, eta_sd_rv=eta_sd_rv, # sd of random walk for ar process, generation_interval_pmf_rv=generation_interval_pmf_rv, - infection_feedback_pmf_rv=infection_feedback_pmf_rv, infection_feedback_strength_rv=inf_feedback_strength_rv, + infection_feedback_pmf_rv=infection_feedback_pmf_rv, p_hosp_mean_rv=p_hosp_mean_rv, p_hosp_w_sd_rv=p_hosp_w_sd_rv, autoreg_p_hosp_rv=autoreg_p_hosp_rv, hosp_wday_effect_rv=hosp_wday_effect_rv, - phi_rv=phi_rv, inf_to_hosp_rv=inf_to_hosp_rv, + phi_rv=phi_rv, + right_truncation_pmf_rv=right_truncation_pmf_rv, n_initialization_points=uot, - i0_t_offset=0, ) @@ -97,6 +102,7 @@ num_samples=500, rng_key=jax.random.key(200), data_observed_hospital_admissions=data_observed_hospital_admissions, + right_truncation_offset=right_truncation_offset, mcmc_args=dict(num_chains=n_chains, progress_bar=True), nuts_args=dict(find_heuristic_step_size=True), ) @@ -126,7 +132,6 @@ idata = idata.sel(chain=chains_to_keep) - plotting.plot_predictive(idata) idata.to_dataframe().to_csv( diff --git a/nssp_demo/post_process.R b/nssp_demo/post_process.R index 0632e0c7..7eae6ef4 100644 --- a/nssp_demo/post_process.R +++ b/nssp_demo/post_process.R @@ -63,7 +63,7 @@ make_forecast_fig <- function(model_dir) { pyrenew_samples$posterior_predictive %>% gather_draws(observed_hospital_admissions[time]) %>% median_qi(.width = c(0.5, 0.8, 0.95)) %>% - mutate(date = dat$date[time + 1]) + mutate(date = min(dat$date) + time) @@ -106,11 +106,13 @@ make_forecast_fig <- function(model_dir) { forecast_plot } -base_dir <- path( + +base_dir <- path(here( "nssp_demo", "private_data", - "r_2024-09-10_f_2024-03-13_l_2024-09-09_t_2024-08-14" -) + "r_2024-10-01_f_2024-04-03_l_2024-09-30_t_2024-09-25" +)) + forecast_fig_tbl <- tibble(base_model_dir = dir_ls(base_dir)) %>% @@ -133,7 +135,13 @@ pwalk( ) str_c(forecast_fig_tbl$figure_path, collapse = " ") %>% - str_c(path(base_dir, "all_forecasts", ext = "pdf"), sep = " ") %>% + str_c( + path(base_dir, + glue("{path_file(base_dir)}_all_forecasts"), + ext = "pdf" + ), + sep = " " + ) %>% system2("pdfunite", args = .) setdiff(usa::state.abb, path_file(forecast_fig_tbl$base_model_dir)) diff --git a/nssp_demo/prep_data.R b/nssp_demo/prep_data.R index ebdd2659..d923c7e1 100644 --- a/nssp_demo/prep_data.R +++ b/nssp_demo/prep_data.R @@ -113,7 +113,11 @@ prep_and_save_data <- function(report_date, actual_first_date <- min(dat$prepped_date$date) actual_last_date <- max(dat$prepped_date$date) - + dat$data_for_model_fit$right_truncation_offset <- as.integer( + as_date(report_date) - + as_date(last_training_date) + ) + # could be off by 1 # Create folders @@ -142,13 +146,13 @@ prep_and_save_data <- function(report_date, } walk( - usa::state.abb, + setdiff(usa::state.abb, "PR"), \(x) { prep_and_save_data( - report_date = "2024-09-10", + report_date = "2024-10-01", min_reference_date = "2000-01-01", max_reference_date = "3000-01-01", - last_training_date = "2024-08-14", + last_training_date = "2024-09-25", state_abb = x ) } diff --git a/pyrenew_covid_wastewater/hosp_only_ww_model.py b/pyrenew_covid_wastewater/hosp_only_ww_model.py index 8bb2afe0..db4496cc 100644 --- a/pyrenew_covid_wastewater/hosp_only_ww_model.py +++ b/pyrenew_covid_wastewater/hosp_only_ww_model.py @@ -39,16 +39,14 @@ def __init__( hosp_wday_effect_rv, inf_to_hosp_rv, phi_rv, + right_truncation_pmf_rv, # when unnamed deterministic variables are allowed, we could default this to 1. n_initialization_points, - i0_t_offset, ): # numpydoc ignore=GL08 self.infection_initialization_process = InfectionInitializationProcess( "I0_initialization", i0_first_obs_n_rv, InitializeInfectionsExponentialGrowth( - n_initialization_points, - initialization_rate_rv, - t_pre_init=i0_t_offset, + n_initialization_points, initialization_rate_rv, t_pre_init=0 ), ) @@ -65,6 +63,9 @@ def __init__( differencing_order=1, ) + self.right_truncation_cdf_rv = TransformedVariable( + "right_truncation_cdf", right_truncation_pmf_rv, jnp.cumsum + ) self.autoreg_rt_rv = autoreg_rt_rv self.eta_sd_rv = eta_sd_rv self.log_r_mu_intercept_rv = log_r_mu_intercept_rv @@ -84,7 +85,10 @@ def validate(self): # numpydoc ignore=GL08 return None def sample( - self, n_datapoints=None, data_observed_hospital_admissions=None + self, + n_datapoints=None, + data_observed_hospital_admissions=None, + right_truncation_offset=None, ): # numpydoc ignore=GL08 if n_datapoints is None and data_observed_hospital_admissions is None: raise ValueError( @@ -109,11 +113,10 @@ def sample( eta_sd = self.eta_sd_rv() autoreg_rt = self.autoreg_rt_rv() log_r_mu_intercept = self.log_r_mu_intercept_rv() - rt_init_rate_of_change_rv = DistributionalVariable( + rt_init_rate_of_change = DistributionalVariable( "rt_init_rate_of_change", dist.Normal(0, eta_sd / jnp.sqrt(1 - jnp.pow(autoreg_rt, 2))), - ) - rt_init_rate_of_change = rt_init_rate_of_change_rv() + )() log_rtu_weekly = self.ar_diff( n=n_weeks_post_init, @@ -190,19 +193,38 @@ def sample( )[-n_datapoints:] ) - latent_hospital_admissions = ( + latent_hospital_admissions_final = ( potential_latent_hospital_admissions * ihr * hosp_wday_effect * self.state_pop ) + if right_truncation_offset is not None: + prop_already_reported_tail = jnp.flip( + self.right_truncation_cdf_rv()[right_truncation_offset:] + ) + n_points_to_prepend = ( + n_datapoints - prop_already_reported_tail.shape[0] + ) + prop_already_reported = jnp.pad( + prop_already_reported_tail, + (n_points_to_prepend, 0), + mode="constant", + constant_values=(1, 0), + ) + latent_hospital_admissions_now = ( + latent_hospital_admissions_final * prop_already_reported + ) + else: + latent_hospital_admissions_now = latent_hospital_admissions_final + hospital_admission_obs_rv = NegativeBinomialObservation( "observed_hospital_admissions", concentration_rv=self.phi_rv ) observed_hospital_admissions = hospital_admission_obs_rv( - mu=latent_hospital_admissions, + mu=latent_hospital_admissions_now, obs=data_observed_hospital_admissions, ) @@ -339,7 +361,9 @@ def create_hosp_only_ww_model_from_stan_data(stan_data_file): state_pop = stan_data["state_pop"] data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) - + right_truncation_pmf_rv = DeterministicVariable( + "right_truncation_pmf", jnp.array(1) + ) my_model = hosp_only_ww_model( state_pop=state_pop, i0_first_obs_n_rv=i0_first_obs_n_rv, @@ -354,10 +378,10 @@ def create_hosp_only_ww_model_from_stan_data(stan_data_file): p_hosp_w_sd_rv=p_hosp_w_sd_rv, autoreg_p_hosp_rv=autoreg_p_hosp_rv, hosp_wday_effect_rv=hosp_wday_effect_rv, - phi_rv=phi_rv, inf_to_hosp_rv=inf_to_hosp_rv, + phi_rv=phi_rv, + right_truncation_pmf_rv=right_truncation_pmf_rv, n_initialization_points=uot, - i0_t_offset=0, ) return my_model, data_observed_hospital_admissions