Skip to content

Commit

Permalink
Forecasting Interface (#241)
Browse files Browse the repository at this point in the history
* add arviz plot for posterior predictive

* broken simplerandomwalk

* slight simplification

* downgrade jax and jaxlib

* fix scan in simplerandomwalkprocess

* add ppc_plot to hosp_model

* test prior/posterior_predictive plots

* Damon check in

* add plot_ppc, plot_lm

* fix tests

* formatting

* more formatting

* fix plot_lm output issue

* suggestion from code review

Co-authored-by: Damon Bayer <[email protected]>

* Update example_with_datasets.qmd

* change plot_lm kwargs

* add details on plot_ppc, fig labels and titles

* add figure descriptions

* code review suggestions

Co-authored-by: Damon Bayer <[email protected]>

* remove link from code

* formatting

* simplify models

* rename hospital_admissions_model

* fix bad merges

* fix hospital_admissions_model.qmd

* fix most tests

* fix remaining tests

* rename test variable

* adjust padding handling

* rename internal variables

* formatting

* correcting merge errors

* remove print debug statements and unused variables

* fix test

* rename for clarity

* cleanup variable names

* fix figure

* delete example_with_datasets

* remove unused code

* rename init in test_random_walk

* clarify forecasting

* update numpyro and jax

* update numpyro and jax

* enforce minimum version of numpyro and jax

* try to enforce numpyro version again

* add test_forecast

* Update docs/source/tutorials/hospital_admissions_model.qmd

Co-authored-by: Dylan H. Morris <[email protected]>

---------

Co-authored-by: sbidari <[email protected]>
Co-authored-by: Subekshya Bidari <[email protected]>
Co-authored-by: Dylan H. Morris <[email protected]>
  • Loading branch information
4 people authored Jul 12, 2024
1 parent 1795426 commit d0b93be
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 111 deletions.
6 changes: 3 additions & 3 deletions docs/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ package-mode = false
[tool.poetry.dependencies]
python = "^3.12"
sphinx = "^7.2.6"
jax = "^0.4.25"
jaxlib = "^0.4.25"
numpyro = "^0.15.0"
jax = ">=0.4.30"
jaxlib = ">=0.4.30"
numpyro = ">=0.15.1"
sphinxcontrib-mermaid = "^0.9.2"
polars = "^0.20.16"
matplotlib = "^3.8.3"
Expand Down
8 changes: 6 additions & 2 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ model1.run(
num_samples=1000,
data_observed_infections=sim_data.observed_infections,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2, chain_method="sequential"),
mcmc_args=dict(
progress_bar=False, num_chains=2, chain_method="sequential"
),
)
```

Expand Down Expand Up @@ -335,7 +337,9 @@ az.plot_hdi(
)
# Add mean of the posterior to the figure
mean_latent_infection = np.mean(idata.posterior["all_latent_infections"], axis=1)
mean_latent_infection = np.mean(
idata.posterior["all_latent_infections"], axis=1
)
axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean")
axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
Expand Down
103 changes: 67 additions & 36 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ Let's take a look at the daily prevalence of hospital admissions.
# | fig-cap: Daily hospital admissions from the simulated data
import matplotlib.pyplot as plt
daily_hosp_admits = dat["daily_hosp_admits"].to_numpy()
# Rotating the x-axis labels, and only showing ~10 labels
ax = plt.gca()
ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=10))
ax.xaxis.set_tick_params(rotation=45)
plt.plot(dat["date"].to_numpy(), dat["daily_hosp_admits"].to_numpy())
plt.plot(dat["date"].to_numpy(), daily_hosp_admits)
plt.xlabel("Date")
plt.ylabel("Admissions")
plt.show()
Expand Down Expand Up @@ -147,7 +148,10 @@ The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to ho
```{python}
# | label: initializing-rest-of-model
from pyrenew import model, process, observation, metaclass, transformation
from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsExponentialGrowth
from pyrenew.latent import (
InfectionSeedingProcess,
SeedInfectionsExponentialGrowth,
)
# Infection process
latent_inf = latent.Infections()
Expand Down Expand Up @@ -237,9 +241,11 @@ npro.set_host_device_count(jax.local_device_count())
hosp_model.run(
num_samples=1000,
num_warmup=1000,
data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(),
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2, chain_method="sequential"),
mcmc_args=dict(
progress_bar=False, num_chains=2, chain_method="sequential"
),
)
```

Expand All @@ -254,7 +260,7 @@ out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
dat["daily_hosp_admits"].to_numpy().astype(float),
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
Expand All @@ -270,10 +276,10 @@ import arviz as az
idata = az.from_numpyro(
hosp_model.mcmc,
posterior_predictive=hosp_model.posterior_predictive(
n_timepoints_to_simulate=len(dat["daily_hosp_admits"])
n_timepoints_to_simulate=len(daily_hosp_admits)
),
prior=hosp_model.prior_predictive(
n_timepoints_to_simulate=len(dat["daily_hosp_admits"]),
n_timepoints_to_simulate=len(daily_hosp_admits),
numpyro_predictive_args={"num_samples": 1000},
),
)
Expand Down Expand Up @@ -308,22 +314,15 @@ We can use the padding argument to solve the overestimation of hospital admissio

```{python}
# | label: model-fit-padding
days_to_impute = 21
# Add 21 Nas to the beginning of dat_w_padding
dat_w_padding = np.pad(
dat["daily_hosp_admits"].to_numpy().astype(float),
(days_to_impute, 0),
constant_values=np.nan,
)
pad_size = 21
hosp_model.run(
num_samples=1000,
num_warmup=1000,
data_observed_hosp_admissions=dat_w_padding,
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
padding=days_to_impute, # Padding the model
padding=pad_size, # Padding the model
)
```

Expand All @@ -336,7 +335,9 @@ out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan
daily_hosp_admits.astype(float),
(gen_int_array.size + pad_size, 0),
constant_values=np.nan,
),
)
```
Expand Down Expand Up @@ -407,7 +408,9 @@ We can look at individual draws from the posterior distribution of latent infect
```{python}
# | label: fig-output-infections-with-padding
# | fig-cap: Latent infections
out2 = hosp_model.plot_posterior(var="all_latent_infections", ylab="Latent Infections")
out2 = hosp_model.plot_posterior(
var="all_latent_infections", ylab="Latent Infections"
)
```

We can also look at credible intervals for the posterior distribution of latent infections:
Expand Down Expand Up @@ -440,7 +443,9 @@ az.plot_hdi(
)
# Add mean of the posterior to the figure
mean_latent_infection = np.mean(idata.posterior["all_latent_infections"], axis=1)
mean_latent_infection = np.mean(
idata.posterior["all_latent_infections"], axis=1
)
axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean")
axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
Expand Down Expand Up @@ -520,10 +525,10 @@ Running the model (with the same padding as before):
hosp_model_weekday.run(
num_samples=2000,
num_warmup=2000,
data_observed_hosp_admissions=dat_w_padding,
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
padding=days_to_impute,
padding=pad_size,
)
```

Expand All @@ -535,21 +540,29 @@ And plotting the results:
out = hosp_model_weekday.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan),
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size + pad_size, 0),
constant_values=np.nan,
),
)
```

We will use ArviZ to visualize the posterior/prior predictive distributions.
We will use ArviZ to visualize the posterior and prior predictive distributions.
By increasing `n_timepoints_to_simulate`, we can perform forecasting using the posterior predictive distribution.

```{python}
# | label: posterior-predictive-distribution
n_forecast_points = 28
idata_weekday = az.from_numpyro(
hosp_model_weekday.mcmc,
posterior_predictive=hosp_model_weekday.posterior_predictive(
n_timepoints_to_simulate=len(dat_w_padding)
n_timepoints_to_simulate=len(daily_hosp_admits) + n_forecast_points,
padding=pad_size,
),
prior=hosp_model_weekday.prior_predictive(
n_timepoints_to_simulate=len(dat_w_padding),
n_timepoints_to_simulate=len(daily_hosp_admits),
padding=pad_size,
numpyro_predictive_args={"num_samples": 1000},
),
)
Expand All @@ -569,7 +582,9 @@ def compute_eti(dataset, eti_prob):
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
idata_weekday.prior_predictive["negbinom_rv_dim_0"],
idata_weekday.prior_predictive["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.9),
color="C0",
smooth=False,
Expand All @@ -578,7 +593,9 @@ az.plot_hdi(
)
az.plot_hdi(
idata_weekday.prior_predictive["negbinom_rv_dim_0"],
idata_weekday.prior_predictive["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.5),
color="C0",
smooth=False,
Expand All @@ -587,7 +604,9 @@ az.plot_hdi(
)
plt.scatter(
idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute,
idata_weekday.observed_data["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
idata_weekday.observed_data["negbinom_rv"],
color="black",
)
Expand All @@ -599,23 +618,31 @@ plt.yscale("log")
plt.show()
```

And now we plot the posterior predictive distributions:
And now we plot the posterior predictive distributions with a `{python} n_forecast_points`-day-ahead forecast:
```{python}
# | label: fig-output-posterior-predictive
# | fig-cap: Posterior Predictive Infections
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
idata_weekday.posterior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.9),
idata_weekday.posterior_predictive["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
hdi_data=compute_eti(
idata_weekday.posterior_predictive["negbinom_rv"], 0.9
),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
idata_weekday.posterior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.5),
idata_weekday.posterior_predictive["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
hdi_data=compute_eti(
idata_weekday.posterior_predictive["negbinom_rv"], 0.5
),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
Expand All @@ -628,13 +655,17 @@ mean_latent_infection = np.mean(
)
plt.plot(
idata_weekday.posterior_predictive["negbinom_rv_dim_0"],
idata_weekday.posterior_predictive["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
mean_latent_infection[0],
color="C0",
label="Mean",
)
plt.scatter(
idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute,
idata_weekday.observed_data["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
idata_weekday.observed_data["negbinom_rv"],
color="black",
)
Expand Down
5 changes: 3 additions & 2 deletions model/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ exclude = [{path = "datasets/*.rds"}]

[tool.poetry.dependencies]
python = "^3.12"
numpyro = "^0.15.0"
jax = "^0.4.25"
numpyro = ">=0.15.1"
jax = ">=0.4.30"
jaxlib = ">=0.4.30"
numpy = "^1.26.4"
polars = "^0.20.16"
pillow = "^10.3.0" # See #56 on CDCgov/multisignal-epi-inference
Expand Down
50 changes: 16 additions & 34 deletions model/src/pyrenew/model/admissionsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

from typing import NamedTuple

import jax.numpy as jnp
import pyrenew.arrayutils as au
from jax.typing import ArrayLike
from pyrenew.deterministic import NullObservation
from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype
from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel

Expand Down Expand Up @@ -99,6 +98,9 @@ def __init__(
)

self.latent_hosp_admissions_rv = latent_hosp_admissions_rv
if hosp_admission_obs_process_rv is None:
hosp_admission_obs_process_rv = NullObservation()

self.hosp_admission_obs_process_rv = hosp_admission_obs_process_rv

@staticmethod
Expand Down Expand Up @@ -178,13 +180,13 @@ def sample(
"Cannot pass both n_timepoints_to_simulate and data_observed_hosp_admissions."
)
elif n_timepoints_to_simulate is None:
n_timepoints = len(data_observed_hosp_admissions)
n_datapoints = len(data_observed_hosp_admissions)
else:
n_timepoints = n_timepoints_to_simulate
n_datapoints = n_timepoints_to_simulate

# Getting the initial quantities from the basic model
basic_model = self.basic_renewal.sample(
n_timepoints_to_simulate=n_timepoints,
n_timepoints_to_simulate=n_datapoints,
data_observed_infections=None,
padding=padding,
**kwargs,
Expand All @@ -199,35 +201,15 @@ def sample(
latent_infections=basic_model.latent_infections,
**kwargs,
)
i0_size = len(latent_hosp_admissions) - n_timepoints
if self.hosp_admission_obs_process_rv is None:
observed_hosp_admissions = None
else:
if data_observed_hosp_admissions is None:
(
observed_hosp_admissions,
*_,
) = self.hosp_admission_obs_process_rv.sample(
mu=latent_hosp_admissions[i0_size + padding :],
obs=data_observed_hosp_admissions,
**kwargs,
)
else:
data_observed_hosp_admissions = au.pad_x_to_match_y(
data_observed_hosp_admissions,
latent_hosp_admissions,
jnp.nan,
pad_direction="start",
)

(
observed_hosp_admissions,
*_,
) = self.hosp_admission_obs_process_rv.sample(
mu=latent_hosp_admissions[i0_size + padding :],
obs=data_observed_hosp_admissions[i0_size + padding :],
**kwargs,
)

(
observed_hosp_admissions,
*_,
) = self.hosp_admission_obs_process_rv.sample(
mu=latent_hosp_admissions[-n_datapoints:],
obs=data_observed_hosp_admissions,
**kwargs,
)

return HospModelSample(
Rt=basic_model.Rt,
Expand Down
Loading

0 comments on commit d0b93be

Please sign in to comment.