Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Right Truncation in NSSP Demo #17

Merged
merged 36 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
84344a3
git ignore private data
damonbayer Sep 13, 2024
22837f6
update rprofile for VAP dev
damonbayer Sep 13, 2024
fb8d192
bump renv.lock
damonbayer Sep 13, 2024
ff22a96
add priors
damonbayer Sep 13, 2024
49892a8
prep_data checkin
damonbayer Sep 13, 2024
f74ac9f
activate.R whitespace
damonbayer Sep 13, 2024
f556ba8
add generation_interval_pmf
damonbayer Sep 14, 2024
df08773
fix names of pmfs
damonbayer Sep 14, 2024
83607e8
save test data in json
damonbayer Sep 16, 2024
5a2454b
Merge branch 'dmb_nssp_demo' of https://github.com/CDCgov/pyrenew-cov…
damonbayer Sep 16, 2024
7ef5b0a
move priors
damonbayer Sep 16, 2024
5a9b232
fit_model
damonbayer Sep 16, 2024
672a800
add post_process.R
damonbayer Sep 16, 2024
7fa5c1b
update for correct data source
damonbayer Sep 16, 2024
4776e74
re-arrange set_host_device_count
damonbayer Sep 16, 2024
b444712
prep all data
damonbayer Sep 16, 2024
742161b
use argparse
damonbayer Sep 16, 2024
0fc7f03
Merge branch 'dmb_nssp_demo' of https://github.com/CDCgov/pyrenew-cov…
damonbayer Sep 16, 2024
4eccaed
fit all models
damonbayer Sep 16, 2024
6352323
create all figures
damonbayer Sep 17, 2024
3ba78a9
Merge branch 'dmb_nssp_demo' of https://github.com/CDCgov/pyrenew-cov…
damonbayer Sep 17, 2024
7fa03f6
revise fit all models script
damonbayer Sep 17, 2024
b66207b
remove unused post processing code
damonbayer Sep 17, 2024
bb48570
use mid step
damonbayer Sep 17, 2024
3d398ac
first pass at right_truncation_offset
damonbayer Sep 19, 2024
5083fc1
right truncation working
damonbayer Sep 19, 2024
a4eed48
correct for new fit data
damonbayer Sep 20, 2024
b515fea
Merge branch 'main' into dmb_right_truncation
dylanhmorris Sep 27, 2024
3328e14
simplify call for throwaway RV
damonbayer Oct 1, 2024
87186c0
separate nowcast observations and eventual observations
damonbayer Oct 1, 2024
c7b1076
make t_pre_init 0
damonbayer Oct 1, 2024
a3cc849
incorporate right runcation into hosp_only_ww_model
damonbayer Oct 1, 2024
12ebd79
do 4 week forecasts, regardless of available data
damonbayer Oct 1, 2024
fe855fc
better name for combined figure
damonbayer Oct 1, 2024
707bf63
fit more recent data
damonbayer Oct 2, 2024
6893aa1
plot all forecasts, even if data isn't available
damonbayer Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rprofile
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
source("renv/activate.R")
source("~/.Rprofile")
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,5 @@ poetry.lock

notebooks/*_files/
notebooks/*.md

nssp_demo/private_data/*
11 changes: 11 additions & 0 deletions nssp_demo/fit_all_models.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash

# Base directory containing subdirectories
BASE_DIR="private_data/r_2024-09-19_f_2024-03-22_l_2024-09-18_t_2024-09-15/"

# Iterate over each subdirectory in the base directory
for SUBDIR in "$BASE_DIR"/*/; do
# Run the Python script with the current subdirectory as the model_dir argument
echo "$SUBDIR"
python fit_model.py --model_dir "$SUBDIR"
done
140 changes: 140 additions & 0 deletions nssp_demo/fit_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import argparse
import json
from pathlib import Path

import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
from pyrenew.deterministic import DeterministicVariable

import pyrenew_covid_wastewater.plotting as plotting
from pyrenew_covid_wastewater.hosp_only_ww_right_truncation_model import (
hosp_only_ww_right_truncation_model,
)

n_chains = 4
numpyro.set_host_device_count(n_chains)

# load priors
# have to run this from the right directory
from priors import ( # noqa: E402
autoreg_p_hosp_rv,
autoreg_rt_rv,
eta_sd_rv,
hosp_wday_effect_rv,
i0_first_obs_n_rv,
inf_feedback_strength_rv,
initialization_rate_rv,
log_r_mu_intercept_rv,
p_hosp_mean_rv,
p_hosp_w_sd_rv,
phi_rv,
uot,
)

parser = argparse.ArgumentParser(
description="Fit the hospital-only wastewater model."
)
parser.add_argument(
"--model_dir",
type=str,
required=True,
help="Path to the model directory containing the data.",
)
args = parser.parse_args()
model_dir = Path(args.model_dir)
data_path = model_dir / "data_for_model_fit.json"

with open(
data_path,
"r",
) as file:
model_data = json.load(file)


inf_to_hosp_rv = DeterministicVariable(
"inf_to_hosp", jnp.array(model_data["inf_to_hosp_pmf"])
) # check if off by 1 or reversed

generation_interval_pmf_rv = DeterministicVariable(
"generation_interval_pmf", jnp.array(model_data["generation_interval_pmf"])
) # check if off by 1 or reversed

infection_feedback_pmf_rv = DeterministicVariable(
"infection_feedback_pmf", jnp.array(model_data["generation_interval_pmf"])
) # check if off by 1 or reversed

data_observed_hospital_admissions = jnp.array(
model_data["data_observed_hospital_admissions"]
)
state_pop = jnp.array(model_data["state_pop"])

right_truncation_pmf_rv = DeterministicVariable(
"right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"])
)

right_truncation_offset = model_data["right_truncation_offset"]
n_forecast_points = len(model_data["test_ed_admissions"])
my_model = hosp_only_ww_right_truncation_model(
state_pop=state_pop,
i0_first_obs_n_rv=i0_first_obs_n_rv,
initialization_rate_rv=initialization_rate_rv,
log_r_mu_intercept_rv=log_r_mu_intercept_rv,
autoreg_rt_rv=autoreg_rt_rv,
eta_sd_rv=eta_sd_rv, # sd of random walk for ar process,
generation_interval_pmf_rv=generation_interval_pmf_rv,
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
infection_feedback_strength_rv=inf_feedback_strength_rv,
p_hosp_mean_rv=p_hosp_mean_rv,
p_hosp_w_sd_rv=p_hosp_w_sd_rv,
autoreg_p_hosp_rv=autoreg_p_hosp_rv,
hosp_wday_effect_rv=hosp_wday_effect_rv,
phi_rv=phi_rv,
inf_to_hosp_rv=inf_to_hosp_rv,
right_truncation_pmf_rv=right_truncation_pmf_rv,
n_initialization_points=uot,
)


my_model.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.key(200),
data_observed_hospital_admissions=data_observed_hospital_admissions,
right_truncation_offset=right_truncation_offset,
mcmc_args=dict(num_chains=n_chains, progress_bar=True),
nuts_args=dict(find_heuristic_step_size=True),
)


posterior_predictive = my_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + n_forecast_points
)

idata = az.from_numpyro(
my_model.mcmc,
posterior_predictive=posterior_predictive,
)

chain_ll = (
idata["log_likelihood"]
.mean(dim=["observed_hospital_admissions_dim_0", "draw"])[
"observed_hospital_admissions"
]
.values
)

chains_to_keep = np.arange(n_chains)[
((chain_ll - chain_ll.max()) / chain_ll.max()) < 2
]
# would like to not have to run this

idata = idata.sel(chain=chains_to_keep)

plotting.plot_predictive(idata)

idata.to_dataframe().to_csv(
model_dir / "pyrenew_inference_data.csv", index=False
)
139 changes: 139 additions & 0 deletions nssp_demo/post_process.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
library(tidyverse)
library(tidybayes)
library(fs)
library(cowplot)
library(glue)
library(scales)

theme_set(theme_minimal_grid())

make_forecast_fig <- function(model_dir) {
state_abb <- model_dir %>%
path_split() %>%
pluck(1) %>%
tail(1)


data_path <- path(model_dir, "data", ext = "csv")
posterior_samples_path <- path(model_dir, "pyrenew_inference_data",
ext = "csv"
)


dat <- read_csv(data_path) %>%
arrange(date) %>%
mutate(time = row_number() - 1) %>%
rename(.value = COVID_ED_admissions)

last_training_date <- dat %>%
filter(data_type == "train") %>%
pull(date) %>%
max()
last_data_date <- dat %>%
pull(date) %>%
max()

arviz_split <- function(x) {
x %>%
select(-distribution) %>%
split(f = as.factor(x$distribution))
}

pyrenew_samples <-
read_csv(posterior_samples_path) %>%
rename_with(\(varname) str_remove_all(varname, "\\(|\\)|\\'|(, \\d+)")) |>
rename(
.chain = chain,
.iteration = draw
) |>
mutate(across(c(.chain, .iteration), \(x) as.integer(x + 1))) |>
mutate(
.draw = tidybayes:::draw_from_chain_and_iteration_(.chain, .iteration),
.after = .iteration
) |>
pivot_longer(-starts_with("."),
names_sep = ", ",
names_to = c("distribution", "name")
) |>
arviz_split() |>
map(\(x) pivot_wider(x, names_from = name) |> tidy_draws())


hosp_ci <-
pyrenew_samples$posterior_predictive %>%
gather_draws(observed_hospital_admissions[time]) %>%
median_qi(.width = c(0.5, 0.8, 0.95)) %>%
mutate(date = dat$date[time + 1])



forecast_plot <-
ggplot(mapping = aes(date, .value)) +
geom_lineribbon(
data = hosp_ci,
mapping = aes(ymin = .lower, ymax = .upper),
color = "#08519c", key_glyph = draw_key_rect, step = "mid"
) +
geom_point(mapping = aes(shape = data_type), data = dat) +
scale_y_continuous("Emergency Department Admissions") +
scale_x_date("Date") +
scale_fill_brewer(
name = "Credible Interval Width",
labels = ~ percent(as.numeric(.))
) +
scale_shape_discrete("Data Type", labels = str_to_title) +
geom_vline(xintercept = last_training_date, linetype = "dashed") +
annotate(
geom = "text",
x = last_training_date,
y = -Inf,
label = "Fit Period ←\n",
hjust = "right",
vjust = "bottom"
) +
annotate(
geom = "text",
x = last_training_date,
y = -Inf, label = "→ Forecast Period\n",
hjust = "left",
vjust = "bottom",
) +
ggtitle(glue("NSSP-based forecast for {state_abb}"),
subtitle = glue("as of {last_data_date}")
) +
theme(legend.position = "bottom")

forecast_plot
}

base_dir <- here(path(
"nssp_demo",
"private_data",
"r_2024-09-19_f_2024-03-22_l_2024-09-18_t_2024-09-15/"
))

forecast_fig_tbl <-
tibble(base_model_dir = dir_ls(base_dir)) %>%
filter(
path(base_model_dir, "pyrenew_inference_data", ext = "csv") %>%
file_exists()
) %>%
mutate(forecast_fig = map(base_model_dir, make_forecast_fig)) %>%
mutate(figure_path = path(base_model_dir, "forecast_plot", ext = "pdf"))

pwalk(
forecast_fig_tbl %>% select(forecast_fig, figure_path),
function(forecast_fig, figure_path) {
save_plot(
filename = figure_path,
plot = forecast_fig,
device = cairo_pdf, base_height = 6
)
}
)

str_c(forecast_fig_tbl$figure_path, collapse = " ") %>%
str_c(path(base_dir, "all_forecasts", ext = "pdf"), sep = " ") %>%
system2("pdfunite", args = .)

setdiff(usa::state.abb, path_file(forecast_fig_tbl$base_model_dir))
Loading
Loading