Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Right Truncation in NSSP Demo #17

Merged
merged 36 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
84344a3
git ignore private data
damonbayer Sep 13, 2024
22837f6
update rprofile for VAP dev
damonbayer Sep 13, 2024
fb8d192
bump renv.lock
damonbayer Sep 13, 2024
ff22a96
add priors
damonbayer Sep 13, 2024
49892a8
prep_data checkin
damonbayer Sep 13, 2024
f74ac9f
activate.R whitespace
damonbayer Sep 13, 2024
f556ba8
add generation_interval_pmf
damonbayer Sep 14, 2024
df08773
fix names of pmfs
damonbayer Sep 14, 2024
83607e8
save test data in json
damonbayer Sep 16, 2024
5a2454b
Merge branch 'dmb_nssp_demo' of https://github.com/CDCgov/pyrenew-cov…
damonbayer Sep 16, 2024
7ef5b0a
move priors
damonbayer Sep 16, 2024
5a9b232
fit_model
damonbayer Sep 16, 2024
672a800
add post_process.R
damonbayer Sep 16, 2024
7fa5c1b
update for correct data source
damonbayer Sep 16, 2024
4776e74
re-arrange set_host_device_count
damonbayer Sep 16, 2024
b444712
prep all data
damonbayer Sep 16, 2024
742161b
use argparse
damonbayer Sep 16, 2024
0fc7f03
Merge branch 'dmb_nssp_demo' of https://github.com/CDCgov/pyrenew-cov…
damonbayer Sep 16, 2024
4eccaed
fit all models
damonbayer Sep 16, 2024
6352323
create all figures
damonbayer Sep 17, 2024
3ba78a9
Merge branch 'dmb_nssp_demo' of https://github.com/CDCgov/pyrenew-cov…
damonbayer Sep 17, 2024
7fa03f6
revise fit all models script
damonbayer Sep 17, 2024
b66207b
remove unused post processing code
damonbayer Sep 17, 2024
bb48570
use mid step
damonbayer Sep 17, 2024
3d398ac
first pass at right_truncation_offset
damonbayer Sep 19, 2024
5083fc1
right truncation working
damonbayer Sep 19, 2024
a4eed48
correct for new fit data
damonbayer Sep 20, 2024
b515fea
Merge branch 'main' into dmb_right_truncation
dylanhmorris Sep 27, 2024
3328e14
simplify call for throwaway RV
damonbayer Oct 1, 2024
87186c0
separate nowcast observations and eventual observations
damonbayer Oct 1, 2024
c7b1076
make t_pre_init 0
damonbayer Oct 1, 2024
a3cc849
incorporate right runcation into hosp_only_ww_model
damonbayer Oct 1, 2024
12ebd79
do 4 week forecasts, regardless of available data
damonbayer Oct 1, 2024
fe855fc
better name for combined figure
damonbayer Oct 1, 2024
707bf63
fit more recent data
damonbayer Oct 2, 2024
6893aa1
plot all forecasts, even if data isn't available
damonbayer Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion nssp_demo/fit_all_models.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/bin/bash

# Base directory containing subdirectories
BASE_DIR="private_data/r_2024-09-10_f_2024-03-13_l_2024-09-09_t_2024-08-14/"

BASE_DIR="private_data/r_2024-10-01_f_2024-04-03_l_2024-09-30_t_2024-09-25/"


# Iterate over each subdirectory in the base directory
for SUBDIR in "$BASE_DIR"/*/; do
Expand Down
15 changes: 10 additions & 5 deletions nssp_demo/fit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@
model_data["data_observed_hospital_admissions"]
)
state_pop = jnp.array(model_data["state_pop"])
n_forecast_points = len(model_data["test_ed_admissions"])

right_truncation_pmf_rv = DeterministicVariable(
"right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"])
)

right_truncation_offset = model_data["right_truncation_offset"]
n_forecast_points = 28
my_model = hosp_only_ww_model(
state_pop=state_pop,
i0_first_obs_n_rv=i0_first_obs_n_rv,
Expand All @@ -79,16 +84,16 @@
autoreg_rt_rv=autoreg_rt_rv,
eta_sd_rv=eta_sd_rv, # sd of random walk for ar process,
generation_interval_pmf_rv=generation_interval_pmf_rv,
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
infection_feedback_strength_rv=inf_feedback_strength_rv,
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
p_hosp_mean_rv=p_hosp_mean_rv,
p_hosp_w_sd_rv=p_hosp_w_sd_rv,
autoreg_p_hosp_rv=autoreg_p_hosp_rv,
hosp_wday_effect_rv=hosp_wday_effect_rv,
phi_rv=phi_rv,
inf_to_hosp_rv=inf_to_hosp_rv,
phi_rv=phi_rv,
right_truncation_pmf_rv=right_truncation_pmf_rv,
n_initialization_points=uot,
i0_t_offset=0,
)


Expand All @@ -97,6 +102,7 @@
num_samples=500,
rng_key=jax.random.key(200),
data_observed_hospital_admissions=data_observed_hospital_admissions,
right_truncation_offset=right_truncation_offset,
mcmc_args=dict(num_chains=n_chains, progress_bar=True),
nuts_args=dict(find_heuristic_step_size=True),
)
Expand Down Expand Up @@ -126,7 +132,6 @@

idata = idata.sel(chain=chains_to_keep)


plotting.plot_predictive(idata)

idata.to_dataframe().to_csv(
Expand Down
18 changes: 13 additions & 5 deletions nssp_demo/post_process.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ make_forecast_fig <- function(model_dir) {
pyrenew_samples$posterior_predictive %>%
gather_draws(observed_hospital_admissions[time]) %>%
median_qi(.width = c(0.5, 0.8, 0.95)) %>%
mutate(date = dat$date[time + 1])
mutate(date = min(dat$date) + time)



Expand Down Expand Up @@ -106,11 +106,13 @@ make_forecast_fig <- function(model_dir) {
forecast_plot
}

base_dir <- path(

base_dir <- path(here(
"nssp_demo",
"private_data",
"r_2024-09-10_f_2024-03-13_l_2024-09-09_t_2024-08-14"
)
"r_2024-10-01_f_2024-04-03_l_2024-09-30_t_2024-09-25"
))


forecast_fig_tbl <-
tibble(base_model_dir = dir_ls(base_dir)) %>%
Expand All @@ -133,7 +135,13 @@ pwalk(
)

str_c(forecast_fig_tbl$figure_path, collapse = " ") %>%
str_c(path(base_dir, "all_forecasts", ext = "pdf"), sep = " ") %>%
str_c(
path(base_dir,
glue("{path_file(base_dir)}_all_forecasts"),
ext = "pdf"
),
sep = " "
) %>%
system2("pdfunite", args = .)

setdiff(usa::state.abb, path_file(forecast_fig_tbl$base_model_dir))
12 changes: 8 additions & 4 deletions nssp_demo/prep_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ prep_and_save_data <- function(report_date,

actual_first_date <- min(dat$prepped_date$date)
actual_last_date <- max(dat$prepped_date$date)

dat$data_for_model_fit$right_truncation_offset <- as.integer(
as_date(report_date) -
as_date(last_training_date)
)
# could be off by 1


# Create folders
Expand Down Expand Up @@ -142,13 +146,13 @@ prep_and_save_data <- function(report_date,
}

walk(
usa::state.abb,
setdiff(usa::state.abb, "PR"),
\(x) {
prep_and_save_data(
report_date = "2024-09-10",
report_date = "2024-10-01",
min_reference_date = "2000-01-01",
max_reference_date = "3000-01-01",
last_training_date = "2024-08-14",
last_training_date = "2024-09-25",
state_abb = x
)
}
Expand Down
50 changes: 37 additions & 13 deletions pyrenew_covid_wastewater/hosp_only_ww_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,14 @@ def __init__(
hosp_wday_effect_rv,
inf_to_hosp_rv,
phi_rv,
right_truncation_pmf_rv, # when unnamed deterministic variables are allowed, we could default this to 1.
n_initialization_points,
i0_t_offset,
): # numpydoc ignore=GL08
self.infection_initialization_process = InfectionInitializationProcess(
"I0_initialization",
i0_first_obs_n_rv,
InitializeInfectionsExponentialGrowth(
n_initialization_points,
initialization_rate_rv,
t_pre_init=i0_t_offset,
n_initialization_points, initialization_rate_rv, t_pre_init=0
),
)

Expand All @@ -65,6 +63,9 @@ def __init__(
differencing_order=1,
)

self.right_truncation_cdf_rv = TransformedVariable(
"right_truncation_cdf", right_truncation_pmf_rv, jnp.cumsum
)
self.autoreg_rt_rv = autoreg_rt_rv
self.eta_sd_rv = eta_sd_rv
self.log_r_mu_intercept_rv = log_r_mu_intercept_rv
Expand All @@ -84,7 +85,10 @@ def validate(self): # numpydoc ignore=GL08
return None

def sample(
self, n_datapoints=None, data_observed_hospital_admissions=None
self,
n_datapoints=None,
data_observed_hospital_admissions=None,
right_truncation_offset=None,
): # numpydoc ignore=GL08
if n_datapoints is None and data_observed_hospital_admissions is None:
raise ValueError(
Expand All @@ -109,11 +113,10 @@ def sample(
eta_sd = self.eta_sd_rv()
autoreg_rt = self.autoreg_rt_rv()
log_r_mu_intercept = self.log_r_mu_intercept_rv()
rt_init_rate_of_change_rv = DistributionalVariable(
rt_init_rate_of_change = DistributionalVariable(
"rt_init_rate_of_change",
dist.Normal(0, eta_sd / jnp.sqrt(1 - jnp.pow(autoreg_rt, 2))),
)
rt_init_rate_of_change = rt_init_rate_of_change_rv()
)()

log_rtu_weekly = self.ar_diff(
n=n_weeks_post_init,
Expand Down Expand Up @@ -190,19 +193,38 @@ def sample(
)[-n_datapoints:]
)

latent_hospital_admissions = (
latent_hospital_admissions_final = (
potential_latent_hospital_admissions
* ihr
* hosp_wday_effect
* self.state_pop
)

if right_truncation_offset is not None:
prop_already_reported_tail = jnp.flip(
self.right_truncation_cdf_rv()[right_truncation_offset:]
)
n_points_to_prepend = (
n_datapoints - prop_already_reported_tail.shape[0]
)
prop_already_reported = jnp.pad(
prop_already_reported_tail,
(n_points_to_prepend, 0),
mode="constant",
constant_values=(1, 0),
)
latent_hospital_admissions_now = (
latent_hospital_admissions_final * prop_already_reported
)
else:
latent_hospital_admissions_now = latent_hospital_admissions_final

hospital_admission_obs_rv = NegativeBinomialObservation(
"observed_hospital_admissions", concentration_rv=self.phi_rv
)

observed_hospital_admissions = hospital_admission_obs_rv(
mu=latent_hospital_admissions,
mu=latent_hospital_admissions_now,
obs=data_observed_hospital_admissions,
)

Expand Down Expand Up @@ -339,7 +361,9 @@ def create_hosp_only_ww_model_from_stan_data(stan_data_file):
state_pop = stan_data["state_pop"]

data_observed_hospital_admissions = jnp.array(stan_data["hosp"])

right_truncation_pmf_rv = DeterministicVariable(
"right_truncation_pmf", jnp.array(1)
)
my_model = hosp_only_ww_model(
state_pop=state_pop,
i0_first_obs_n_rv=i0_first_obs_n_rv,
Expand All @@ -354,10 +378,10 @@ def create_hosp_only_ww_model_from_stan_data(stan_data_file):
p_hosp_w_sd_rv=p_hosp_w_sd_rv,
autoreg_p_hosp_rv=autoreg_p_hosp_rv,
hosp_wday_effect_rv=hosp_wday_effect_rv,
phi_rv=phi_rv,
inf_to_hosp_rv=inf_to_hosp_rv,
phi_rv=phi_rv,
right_truncation_pmf_rv=right_truncation_pmf_rv,
n_initialization_points=uot,
i0_t_offset=0,
)

return my_model, data_observed_hospital_admissions
Loading