Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forecasting Interface #241

Merged
merged 55 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
f58f33c
add arviz plot for posterior predictive
sbidari Jul 1, 2024
da8db93
broken simplerandomwalk
damonbayer Jul 2, 2024
3c10b3e
slight simplification
damonbayer Jul 2, 2024
797b1d1
downgrade jax and jaxlib
damonbayer Jul 3, 2024
20fb2a4
fix scan in simplerandomwalkprocess
damonbayer Jul 3, 2024
df7ac7b
add ppc_plot to hosp_model
sbidari Jul 3, 2024
c3ec082
test prior/posterior_predictive plots
sbidari Jul 3, 2024
3cb7755
Merge branch 'main' into demonstrate-use-of-predictive-distributions-…
damonbayer Jul 3, 2024
f13c04d
Damon check in
damonbayer Jul 3, 2024
c7a596a
add plot_ppc, plot_lm
sbidari Jul 5, 2024
5fc2704
fix tests
damonbayer Jul 8, 2024
4f0e74f
formatting
damonbayer Jul 8, 2024
21dd93d
more formatting
damonbayer Jul 8, 2024
76932df
fix plot_lm output issue
sbidari Jul 8, 2024
e4525ed
suggestion from code review
sbidari Jul 8, 2024
320cfbe
Update example_with_datasets.qmd
damonbayer Jul 8, 2024
3b2a3a0
change plot_lm kwargs
sbidari Jul 8, 2024
ce58119
add details on plot_ppc, fig labels and titles
sbidari Jul 8, 2024
d691a63
resolve merge conflict
sbidari Jul 8, 2024
7954f11
add figure descriptions
sbidari Jul 8, 2024
05583a0
code review suggestions
sbidari Jul 8, 2024
034a624
remove link from code
sbidari Jul 8, 2024
47fbf8d
formatting
damonbayer Jul 8, 2024
4b5a1bc
Merge branch 'main' into dmb_forecast
damonbayer Jul 8, 2024
c744dee
simplify models
damonbayer Jul 8, 2024
12bb942
demonstrate forecasting
damonbayer Jul 8, 2024
cf6ff45
rename hospital_admissions_model
damonbayer Jul 10, 2024
50dd053
Merge branch 'main' into dmb_forecast
damonbayer Jul 10, 2024
19ad4c5
fix bad merges
damonbayer Jul 10, 2024
21f8e77
fix hospital_admissions_model.qmd
damonbayer Jul 10, 2024
0c4c906
fix most tests
damonbayer Jul 10, 2024
4644208
fix remaining tests
damonbayer Jul 10, 2024
b8cb7e8
rename test variable
damonbayer Jul 10, 2024
9817561
adjust padding handling
damonbayer Jul 10, 2024
9c48df9
rename internal variables
damonbayer Jul 10, 2024
e430d67
Merge branch 'main' into dmb_forecast
damonbayer Jul 10, 2024
7b392d3
formatting
damonbayer Jul 10, 2024
6d543f7
Merge branch 'main' into dmb_forecast
damonbayer Jul 10, 2024
30a9a31
correcting merge errors
damonbayer Jul 10, 2024
4cf685a
remove print debug statements and unused variables
damonbayer Jul 10, 2024
45d9950
fix test
damonbayer Jul 10, 2024
b3bff64
rename for clarity
damonbayer Jul 10, 2024
f710e8d
cleanup variable names
damonbayer Jul 11, 2024
5dff94b
fix figure
damonbayer Jul 11, 2024
b0df76e
delete example_with_datasets
damonbayer Jul 11, 2024
188cb0f
Merge branch 'main' into dmb_forecast
damonbayer Jul 11, 2024
c943309
remove unused code
damonbayer Jul 11, 2024
bddc731
rename init in test_random_walk
damonbayer Jul 11, 2024
a84ca20
clarify forecasting
damonbayer Jul 11, 2024
85c62c1
update numpyro and jax
damonbayer Jul 11, 2024
b360093
update numpyro and jax
damonbayer Jul 11, 2024
a7c4673
enforce minimum version of numpyro and jax
damonbayer Jul 11, 2024
f6b34ab
try to enforce numpyro version again
damonbayer Jul 11, 2024
3342f24
add test_forecast
damonbayer Jul 11, 2024
45879b1
Update docs/source/tutorials/hospital_admissions_model.qmd
damonbayer Jul 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ package-mode = false
[tool.poetry.dependencies]
python = "^3.12"
sphinx = "^7.2.6"
jax = "^0.4.25"
jaxlib = "^0.4.25"
jax = "0.4.29"
jaxlib = "^0.4.29"
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
numpyro = "^0.15.0"
sphinxcontrib-mermaid = "^0.9.2"
polars = "^0.20.16"
Expand Down
8 changes: 6 additions & 2 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ model1.run(
num_samples=1000,
data_observed_infections=sim_data.observed_infections,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2, chain_method="sequential"),
mcmc_args=dict(
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
progress_bar=False, num_chains=2, chain_method="sequential"
),
)
```

Expand Down Expand Up @@ -335,7 +337,9 @@ az.plot_hdi(
)

# Add mean of the posterior to the figure
mean_latent_infection = np.mean(idata.posterior["all_latent_infections"], axis=1)
mean_latent_infection = np.mean(
idata.posterior["all_latent_infections"], axis=1
)
axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean")
axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
Expand Down
99 changes: 65 additions & 34 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ Let's take a look at the daily prevalence of hospital admissions.
# | fig-cap: Daily hospital admissions from the simulated data
import matplotlib.pyplot as plt

daily_hosp_admits = dat["daily_hosp_admits"].to_numpy()
# Rotating the x-axis labels, and only showing ~10 labels
ax = plt.gca()
ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=10))
ax.xaxis.set_tick_params(rotation=45)
plt.plot(dat["date"].to_numpy(), dat["daily_hosp_admits"].to_numpy())
plt.plot(dat["date"].to_numpy(), daily_hosp_admits)
plt.xlabel("Date")
plt.ylabel("Admissions")
plt.show()
Expand Down Expand Up @@ -147,7 +148,10 @@ The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to ho
```{python}
# | label: initializing-rest-of-model
from pyrenew import model, process, observation, metaclass, transformation
from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsExponentialGrowth
from pyrenew.latent import (
InfectionSeedingProcess,
SeedInfectionsExponentialGrowth,
)

# Infection process
latent_inf = latent.Infections()
Expand Down Expand Up @@ -237,9 +241,11 @@ npro.set_host_device_count(jax.local_device_count())
hosp_model.run(
num_samples=1000,
num_warmup=1000,
data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(),
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2, chain_method="sequential"),
mcmc_args=dict(
progress_bar=False, num_chains=2, chain_method="sequential"
),
)
```

Expand All @@ -254,7 +260,7 @@ out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
dat["daily_hosp_admits"].to_numpy().astype(float),
daily_hosp_admits.astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
Expand All @@ -270,10 +276,10 @@ import arviz as az
idata = az.from_numpyro(
hosp_model.mcmc,
posterior_predictive=hosp_model.posterior_predictive(
n_timepoints_to_simulate=len(dat["daily_hosp_admits"])
n_timepoints_to_simulate=len(daily_hosp_admits)
),
prior=hosp_model.prior_predictive(
n_timepoints_to_simulate=len(dat["daily_hosp_admits"]),
n_timepoints_to_simulate=len(daily_hosp_admits),
numpyro_predictive_args={"num_samples": 1000},
),
)
Expand Down Expand Up @@ -308,22 +314,15 @@ We can use the padding argument to solve the overestimation of hospital admissio

```{python}
# | label: model-fit-padding
days_to_impute = 21

# Add 21 Nas to the beginning of dat_w_padding
dat_w_padding = np.pad(
dat["daily_hosp_admits"].to_numpy().astype(float),
(days_to_impute, 0),
constant_values=np.nan,
)
pad_size = 21

hosp_model.run(
num_samples=1000,
num_warmup=1000,
data_observed_hosp_admissions=dat_w_padding,
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
padding=days_to_impute, # Padding the model
padding=pad_size, # Padding the model
)
```

Expand All @@ -336,7 +335,9 @@ out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan
daily_hosp_admits.astype(float),
(gen_int_array.size + pad_size, 0),
constant_values=np.nan,
),
)
```
Expand Down Expand Up @@ -407,7 +408,9 @@ We can look at individual draws from the posterior distribution of latent infect
```{python}
# | label: fig-output-infections-with-padding
# | fig-cap: Latent infections
out2 = hosp_model.plot_posterior(var="all_latent_infections", ylab="Latent Infections")
out2 = hosp_model.plot_posterior(
var="all_latent_infections", ylab="Latent Infections"
)
```

We can also look at credible intervals for the posterior distribution of latent infections:
Expand Down Expand Up @@ -440,7 +443,9 @@ az.plot_hdi(
)

# Add mean of the posterior to the figure
mean_latent_infection = np.mean(idata.posterior["all_latent_infections"], axis=1)
mean_latent_infection = np.mean(
idata.posterior["all_latent_infections"], axis=1
)
axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean")
axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
Expand Down Expand Up @@ -520,10 +525,10 @@ Running the model (with the same padding as before):
hosp_model_weekday.run(
num_samples=2000,
num_warmup=2000,
data_observed_hosp_admissions=dat_w_padding,
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
padding=days_to_impute,
padding=pad_size,
)
```

Expand All @@ -535,21 +540,29 @@ And plotting the results:
out = hosp_model_weekday.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan),
obs_signal=np.pad(
daily_hosp_admits.astype(float),
(gen_int_array.size + pad_size, 0),
constant_values=np.nan,
),
)
```

We will use ArviZ to visualize the posterior/prior predictive distributions.
By increasing `n_timepoints_to_simulate`, we can perform forecasting.
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved

```{python}
# | label: posterior-predictive-distribution
n_forecast_points = 28
idata_weekday = az.from_numpyro(
hosp_model_weekday.mcmc,
posterior_predictive=hosp_model_weekday.posterior_predictive(
n_timepoints_to_simulate=len(dat_w_padding)
n_timepoints_to_simulate=len(daily_hosp_admits) + n_forecast_points,
padding=pad_size,
),
prior=hosp_model_weekday.prior_predictive(
n_timepoints_to_simulate=len(dat_w_padding),
n_timepoints_to_simulate=len(daily_hosp_admits),
padding=pad_size,
numpyro_predictive_args={"num_samples": 1000},
),
)
Expand All @@ -569,7 +582,9 @@ def compute_eti(dataset, eti_prob):

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
idata_weekday.prior_predictive["negbinom_rv_dim_0"],
idata_weekday.prior_predictive["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.9),
color="C0",
smooth=False,
Expand All @@ -578,7 +593,9 @@ az.plot_hdi(
)

az.plot_hdi(
idata_weekday.prior_predictive["negbinom_rv_dim_0"],
idata_weekday.prior_predictive["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.5),
color="C0",
smooth=False,
Expand All @@ -587,7 +604,9 @@ az.plot_hdi(
)

plt.scatter(
idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute,
idata_weekday.observed_data["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
idata_weekday.observed_data["negbinom_rv"],
color="black",
)
Expand All @@ -605,17 +624,25 @@ And now we plot the posterior predictive distributions:
# | fig-cap: Posterior Predictive Infections
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
idata_weekday.posterior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.9),
idata_weekday.posterior_predictive["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
hdi_data=compute_eti(
idata_weekday.posterior_predictive["negbinom_rv"], 0.9
),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)

az.plot_hdi(
idata_weekday.posterior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.5),
idata_weekday.posterior_predictive["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
hdi_data=compute_eti(
idata_weekday.posterior_predictive["negbinom_rv"], 0.5
),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
Expand All @@ -628,13 +655,17 @@ mean_latent_infection = np.mean(
)

plt.plot(
idata_weekday.posterior_predictive["negbinom_rv_dim_0"],
idata_weekday.posterior_predictive["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
mean_latent_infection[0],
color="C0",
label="Mean",
)
plt.scatter(
idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute,
idata_weekday.observed_data["negbinom_rv_dim_0"]
+ pad_size
+ gen_int.size(),
idata_weekday.observed_data["negbinom_rv"],
color="black",
)
Expand Down
3 changes: 2 additions & 1 deletion model/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ exclude = [{path = "datasets/*.rds"}]
[tool.poetry.dependencies]
python = "^3.12"
numpyro = "^0.15.0"
jax = "^0.4.25"
jax = "0.4.29"
jaxlib = "0.4.29"
numpy = "^1.26.4"
polars = "^0.20.16"
pillow = "^10.3.0" # See #56 on CDCgov/multisignal-epi-inference
Expand Down
50 changes: 16 additions & 34 deletions model/src/pyrenew/model/admissionsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

from typing import NamedTuple

import jax.numpy as jnp
import pyrenew.arrayutils as au
from jax.typing import ArrayLike
from pyrenew.deterministic import NullObservation
from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype
from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel

Expand Down Expand Up @@ -99,6 +98,9 @@ def __init__(
)

self.latent_hosp_admissions_rv = latent_hosp_admissions_rv
if hosp_admission_obs_process_rv is None:
hosp_admission_obs_process_rv = NullObservation()

self.hosp_admission_obs_process_rv = hosp_admission_obs_process_rv

@staticmethod
Expand Down Expand Up @@ -178,13 +180,13 @@ def sample(
"Cannot pass both n_timepoints_to_simulate and data_observed_hosp_admissions."
)
elif n_timepoints_to_simulate is None:
n_timepoints = len(data_observed_hosp_admissions)
n_datapoints = len(data_observed_hosp_admissions)
else:
n_timepoints = n_timepoints_to_simulate
n_datapoints = n_timepoints_to_simulate

# Getting the initial quantities from the basic model
basic_model = self.basic_renewal.sample(
n_timepoints_to_simulate=n_timepoints,
n_timepoints_to_simulate=n_datapoints,
data_observed_infections=None,
padding=padding,
**kwargs,
Expand All @@ -199,35 +201,15 @@ def sample(
latent_infections=basic_model.latent_infections,
**kwargs,
)
i0_size = len(latent_hosp_admissions) - n_timepoints
if self.hosp_admission_obs_process_rv is None:
observed_hosp_admissions = None
else:
if data_observed_hosp_admissions is None:
(
observed_hosp_admissions,
*_,
) = self.hosp_admission_obs_process_rv.sample(
mu=latent_hosp_admissions[i0_size + padding :],
obs=data_observed_hosp_admissions,
**kwargs,
)
else:
data_observed_hosp_admissions = au.pad_x_to_match_y(
data_observed_hosp_admissions,
latent_hosp_admissions,
jnp.nan,
pad_direction="start",
)

(
observed_hosp_admissions,
*_,
) = self.hosp_admission_obs_process_rv.sample(
mu=latent_hosp_admissions[i0_size + padding :],
obs=data_observed_hosp_admissions[i0_size + padding :],
**kwargs,
)

(
observed_hosp_admissions,
*_,
) = self.hosp_admission_obs_process_rv.sample(
mu=latent_hosp_admissions[-n_datapoints:],
obs=data_observed_hosp_admissions,
**kwargs,
)

return HospModelSample(
Rt=basic_model.Rt,
Expand Down
7 changes: 2 additions & 5 deletions model/src/pyrenew/model/rtinfectionsrenewalmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ def sample(
"Cannot pass both n_timepoints_to_simulate and data_observed_infections."
)
elif n_timepoints_to_simulate is None:
n_timepoints = len(data_observed_infections)
n_timepoints = len(data_observed_infections) + padding
else:
n_timepoints = n_timepoints_to_simulate
n_timepoints = n_timepoints_to_simulate + padding
# Sampling from Rt (possibly with a given Rt, depending on
# the Rt_process (RandomVariable) object.)
Rt, *_ = self.Rt_process_rv.sample(
Expand All @@ -210,9 +210,6 @@ def sample(
**kwargs,
)

if data_observed_infections is not None:
data_observed_infections = data_observed_infections[padding:]

observed_infections, *_ = self.infection_obs_process_rv.sample(
mu=post_seed_latent_infections[padding:],
obs=data_observed_infections,
Expand Down
Loading