Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Hosp Only Model from cdcgov/wastewater-informed-covid-forecasting #313

Closed
wants to merge 51 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
8007fee
first checkin
damonbayer Jul 23, 2024
7695435
ignore .json
damonbayer Jul 23, 2024
a45cd58
initialization working
damonbayer Jul 24, 2024
e8a3fad
load variables from stan_data
damonbayer Jul 24, 2024
b4eb2e8
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 25, 2024
8940186
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 25, 2024
8519449
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 25, 2024
2315d0f
prep for update from main
damonbayer Jul 26, 2024
179f741
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 26, 2024
4a1cdc6
checkin (broken)
damonbayer Jul 26, 2024
3915fc4
working Rt
damonbayer Jul 29, 2024
428b37f
fix a prior
damonbayer Jul 29, 2024
e5252ab
infection with feedback (works but looks wrong)
damonbayer Jul 29, 2024
b25fc8e
cleanup
damonbayer Jul 29, 2024
3759770
fixing some things
damonbayer Jul 30, 2024
60013ff
check in
damonbayer Jul 30, 2024
a73a85f
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 30, 2024
a162235
fix AR process
damonbayer Jul 30, 2024
c8b71c2
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 30, 2024
9a66578
move n_timepoints to sample arg
damonbayer Jul 31, 2024
6ce532c
fix ihr broadcasting
damonbayer Aug 1, 2024
0ac1006
cleanup initialization
damonbayer Aug 1, 2024
023d91a
day of week effect
damonbayer Aug 1, 2024
b6a39f8
cleanup day of week effect
damonbayer Aug 1, 2024
68f1977
correct dotwe
damonbayer Aug 1, 2024
6e7e0b2
add latent hospitalizations
damonbayer Aug 2, 2024
5cfb946
rename hospitalizations
damonbayer Aug 2, 2024
61b5c2d
data observation model
damonbayer Aug 2, 2024
5be7e9a
rename n_timepoints
damonbayer Aug 2, 2024
0e00066
work with supplied data
damonbayer Aug 2, 2024
e52da09
replace ceil with integer division
damonbayer Aug 2, 2024
f5e5e13
prior predictive
damonbayer Aug 2, 2024
2f3e611
posterior
damonbayer Aug 2, 2024
49f879f
remove json from gitignore
damonbayer Aug 2, 2024
c2d48dd
add stan data
damonbayer Aug 2, 2024
ebfec69
convert to tutorial
damonbayer Aug 2, 2024
468743c
fix document title
damonbayer Aug 2, 2024
c5178a7
try avoiding ipython error
damonbayer Aug 2, 2024
babe31f
Merge branch 'main' into dmb_hosp_only_ww_model
dylanhmorris Aug 7, 2024
0c518b2
more posterior plots
damonbayer Aug 12, 2024
bef40ab
more posterior plots
damonbayer Aug 12, 2024
4cb7ea3
fix t_pre_init
damonbayer Aug 12, 2024
6740123
adjust t_pre_init for experimentation
damonbayer Aug 12, 2024
0fcbe08
save latent infections for debug
damonbayer Aug 12, 2024
5addeeb
save rtu for debugging
damonbayer Aug 12, 2024
22bc13e
more rtu observations
damonbayer Aug 12, 2024
41bb7cd
correctly scale by population size
damonbayer Aug 14, 2024
d4ca1b7
note about population size
damonbayer Aug 14, 2024
08289e2
use tile_unitl_n
damonbayer Aug 14, 2024
153e1eb
more chains
damonbayer Aug 14, 2024
7232238
fix latent -> observed hosp admission offset
damonbayer Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ extend-exclude = [
".gitignore",
".pre-commit-config.yaml",
".github/ISSUE_TEMPLATE/general_issue.md",
"*.json"
]
386 changes: 386 additions & 0 deletions docs/source/tutorials/hosp_only_ww_model.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,386 @@
---
title: "Replicating Hospital Only Model from `cdcgov/wastewater-informed-covid-forecasting`"
format: gfm
engine: jupyter
---

```{python}
# | label: setup
import json

import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.transforms as transforms
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.model import hosp_only_ww_model

numpyro.set_host_device_count(1)
# model crashes if run in parallel
# see https://github.com/pyro-ppl/numpyro/issues/1836
```

## Background

This tutorial provides a demonstration of our reimplementation of "Model 2" from the `wastewater-informed-covid-forecasting` project.
The model is described [here](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/model_definition.md).
Stan code for the model is [here](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/cfaforecastrenewalww/inst/stan/renewal_ww_hosp_site_level_inf_dynamics.stan).

The model we provide is designed to be fully-compatible with the stan_data generated in the that project.
We provide the stan data used in the `toy_data_vignette` [vignette](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/cfaforecastrenewalww/vignettes/toy_data_vignette.Rmd) in the `wastewater-informed-covid-forecasting` project.
The data is available in `scratch/stan_data_hosp_only.json`.
This data was generated by running `scratch/save_from_vignette.R` after running all the cells in the vignette.
This script also saves the posterior samples from the model for comparison to our own model.

## Load Data and create Priors

We begin by loading the stan_data and converting it to priors used in our model.
```{python}
# | label: Load data and create priors
# | code-fold: true
def convert_to_logmean_log_sd(mean, sd):
logmean = np.log(
np.power(mean, 2) / np.sqrt(np.power(sd, 2) + np.power(mean, 2))
)
logsd = np.sqrt(np.log(1 + (np.power(sd, 2) / np.power(mean, 2))))
return logmean, logsd


# Load the JSON file
import builtins

with builtins.open(
"../../../scratch/stan_data_hosp_only.json",
"r",
) as file:
stan_data = json.load(file)

i0_over_n_prior_a = stan_data["i0_over_n_prior_a"][0]
i0_over_n_prior_b = stan_data["i0_over_n_prior_b"][0]
i0_over_n_rv = DistributionalRV(
"i0_over_n_rv", dist.Beta(i0_over_n_prior_a, i0_over_n_prior_b)
)

initial_growth_prior_mean = stan_data["initial_growth_prior_mean"][0]
initial_growth_prior_sd = stan_data["initial_growth_prior_sd"][0]
initialization_rate_rv = DistributionalRV(
"rate",
dist.TruncatedNormal(
loc=initial_growth_prior_mean,
scale=initial_growth_prior_sd,
low=-1,
high=1,
),
)
# could reasonably switch to non-Truncated

r_prior_mean = stan_data["r_prior_mean"][0]
r_prior_sd = stan_data["r_prior_sd"][0]
r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd)
log_r_mu_intercept_rv = DistributionalRV(
"log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd)
)


eta_sd_sd = stan_data["eta_sd_sd"][0]
eta_sd_rv = DistributionalRV(
"eta_sd", dist.TruncatedNormal(0, eta_sd_sd, low=0)
)

autoreg_rt_a = stan_data["autoreg_rt_a"][0]
autoreg_rt_b = stan_data["autoreg_rt_b"][0]
autoreg_rt_rv = DistributionalRV(
"autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b)
)


generation_interval_pmf_rv = DeterministicVariable(
"generation_interval_pmf", jnp.array(stan_data["generation_interval"])
)

infection_feedback_pmf_rv = DeterministicVariable(
"infection_feedback_pmf", jnp.array(stan_data["infection_feedback_pmf"])
)

inf_feedback_prior_logmean = stan_data["inf_feedback_prior_logmean"][0]
inf_feedback_prior_logsd = stan_data["inf_feedback_prior_logsd"][0]
inf_feedback_strength_rv = TransformedRandomVariable(
"inf_feedback",
DistributionalRV(
"inf_feedback_raw",
dist.LogNormal(inf_feedback_prior_logmean, inf_feedback_prior_logsd),
),
transforms=transforms.AffineTransform(loc=0, scale=-1),
)
# Could be reparameterized?

p_hosp_prior_mean = stan_data["p_hosp_prior_mean"][0]
p_hosp_sd_logit = stan_data["p_hosp_sd_logit"][0]

p_hosp_mean_rv = DistributionalRV(
"p_hosp_mean",
dist.Normal(transforms.logit(p_hosp_prior_mean), p_hosp_sd_logit),
) # logit scale

p_hosp_w_sd_sd = stan_data["p_hosp_w_sd_sd"][0]
p_hosp_w_sd_rv = DistributionalRV(
"p_hosp_w_sd_sd", dist.TruncatedNormal(0, p_hosp_w_sd_sd, low=0)
)

autoreg_p_hosp_a = stan_data["autoreg_p_hosp_a"][0]
autoreg_p_hosp_b = stan_data["autoreg_p_hosp_b"][0]
autoreg_p_hosp_rv = DistributionalRV(
"autoreg_p_hosp", dist.Beta(autoreg_p_hosp_a, autoreg_p_hosp_b)
)

# hosp_wday_effect ~ normal(effect_mean, wday_effect_prior_sd);
# wday_effect_prior_mean = stan_data["wday_effect_prior_mean"][0]
# wday_effect_prior_sd = stan_data["wday_effect_prior_sd"][0]
# Instead of the above, use a Dirichlet prior (see https://github.com/CDCgov/ww-inference-model/issues/42)

hosp_wday_effect_rv = TransformedRandomVariable(
"hosp_wday_effect",
DistributionalRV(
"hosp_wday_effect_raw", dist.Dirichlet(concentration=jnp.ones(7))
),
transforms.AffineTransform(loc=0, scale=7),
)

inf_to_hosp_rv = DeterministicVariable(
"inf_to_hosp", jnp.array(stan_data["inf_to_hosp"])
)

inv_sqrt_phi_prior_mean = stan_data["inv_sqrt_phi_prior_mean"][0]
inv_sqrt_phi_prior_sd = stan_data["inv_sqrt_phi_prior_sd"][0]

phi_rv = TransformedRandomVariable(
"phi",
DistributionalRV(
"inv_sqrt_phi",
dist.TruncatedNormal(
loc=inv_sqrt_phi_prior_mean,
scale=inv_sqrt_phi_prior_sd,
low=1 / jnp.sqrt(5000),
),
),
transforms=transforms.PowerTransform(-2),
)

uot = stan_data["uot"][0]
state_pop = stan_data["state_pop"][0]

data_observed_hospital_admissions = jnp.array(stan_data["hosp"])
```

# Simulate from the model

Next, we define the model:

```{python}
# | label: define the model
my_model = hosp_only_ww_model(
state_pop=state_pop,
i0_over_n_rv=i0_over_n_rv,
initialization_rate_rv=initialization_rate_rv,
log_r_mu_intercept_rv=log_r_mu_intercept_rv,
autoreg_rt_rv=autoreg_rt_rv, # ar process
eta_sd_rv=eta_sd_rv, # sd of random walk for ar process,
generation_interval_pmf_rv=generation_interval_pmf_rv,
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
infection_feedback_strength_rv=inf_feedback_strength_rv,
p_hosp_mean_rv=p_hosp_mean_rv,
p_hosp_w_sd_rv=p_hosp_w_sd_rv,
autoreg_p_hosp_rv=autoreg_p_hosp_rv,
hosp_wday_effect_rv=hosp_wday_effect_rv,
phi_rv=phi_rv,
inf_to_hosp_rv=inf_to_hosp_rv,
# n_initialization_points=uot,
n_initialization_points=len(jnp.array(stan_data["inf_to_hosp"])),
# i0_t_offset=-50, # to match stan model
i0_t_offset=0, # a better way of parameterizing
)
```


We check that we can simulate from the prior predictive
```{python}
# | label: prior predictive
# | eval: false
# for some reason the posterior inference crashes if we do the prior predictive first
prior_predictive = my_model.prior_predictive(
n_datapoints=len(data_observed_hospital_admissions),
numpyro_predictive_args={"num_samples": 200},
)
```

# Fit the model

Now we can fit the model to the observed data:
```{python}
# | label: fit the model
my_model.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.key(200),
data_observed_hospital_admissions=data_observed_hospital_admissions,
mcmc_args=dict(num_chains=4),
)
```

Check the posterior predictive:

```{python}
# | label: posterior predictive
posterior_predictive = my_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions)
)
```

Forecasting is broken (dependent on https://github.com/CDCgov/multisignal-epi-inference/issues/328)

```{python}
# | label: posterior forecast
# | eval: false
my_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + 2
)
```


## Prepare for plotting
```{python}
import arviz as az

idata = az.from_numpyro(
my_model.mcmc, posterior_predictive=posterior_predictive
)
```

```{python}
def compute_eti(dataset, eti_prob):
eti_bdry = dataset.quantile(
((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw")
)
return eti_bdry.values.T
```


```{python}
import matplotlib.pyplot as plt


def plot_posterior(name, predictive=False):
if predictive:
posterior_object = idata.posterior_predictive
else:
posterior_object = idata.posterior
x_data = posterior_object[f"{name}_dim_0"]
y_data = posterior_object[name]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.9),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)

az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.5),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)

# Add median of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])

plt.plot(
x_data,
median_ts,
color="C0",
label="Median",
)

axes.legend()
axes.set_title(name, fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel(name, fontsize=10)
return fig
```

## Plot all posteriors
# Do we know why some univariate sites have a dimension and others do not?
```{python}
for key in list(idata.posterior.keys()):
try:
fig = plot_posterior(key)
fig.show()
except Exception as e:
print(f"An error occurred while plotting {key}: {e}")
```

## Posterior predictive hospital admissions
```{python}
import matplotlib.pyplot as plt

x_data = idata.posterior_predictive["observed_hospital_admissions_dim_0"] + uot
y_data = idata.posterior_predictive["observed_hospital_admissions"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.9),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)

az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.5),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)


# Add median of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])

plt.plot(
x_data,
median_ts,
color="C0",
label="Median",
)
plt.scatter(
idata.observed_data["observed_hospital_admissions_dim_0"] + uot,
idata.observed_data["observed_hospital_admissions"],
color="black",
)
axes.legend()
axes.set_title("Posterior Predictive Admissions", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Hospital Admissions", fontsize=10)
plt.show()
```



```{python}
idata.posterior["log_r_mu_intercept_rv"]
az.summary(
idata,
var_names=["log_rt", "periodic_diff_sd", "autoreg_rt_det", "rtu"],
stat_focus="median",
)
```
Why is log_rt 2 dimensional?
5 changes: 5 additions & 0 deletions docs/source/tutorials/hosp_only_ww_model.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.. WARNING
.. Please do not edit this file directly.
.. This file is just a placeholder.
.. For the source file, see:
.. <https://github.com/CDCgov/multisignal-epi-inference/tree/main/docs/source/tutorials/hosp_only_ww_model.qmd>
Loading