From f58f33c8709d40019b11da36e679d98d21700807 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 1 Jul 2024 18:51:09 -0400 Subject: [PATCH 01/47] add arviz plot for posterior predictive --- model/docs/example_with_datasets.qmd | 59 ++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index fc0321a6..01729416 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -501,3 +501,62 @@ out = hosp_model_weekday.plot_posterior( ), ) ``` +We can use ArviZ to visualize the posterior predictive distributions. +```{python} +# | label: posterior-predictive-distribution +idata_weekday = az.from_numpyro( + hosp_model_weekday.mcmc, + posterior_predictive=hosp_model_weekday.posterior_predictive(n_timepoints_to_simulate=1) +) +``` +Below we plot the posterior predictive distributions using equal tailed bayesian credible intervals: + +```{python} +# | label: fig-output-posterior-predictive +# | fig-cap: Posterior Predictive Infections + + +def compute_eti(dataset, eti_prob): + eti_bdry = dataset.quantile( + ((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw") + ) + return eti_bdry.values.T + + +fig, axes = plt.subplots(figsize=(6, 5)) +az.plot_hdi( + idata_weekday.posterior_predictive["negbinom_rv_dim_0"], + hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.9), + color="C0", + smooth=False, + fill_kwargs={"alpha": 0.3}, + ax=axes, +) + +az.plot_hdi( + idata_weekday.posterior_predictive["negbinom_rv_dim_0"], + hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.5), + color="C0", + smooth=False, + fill_kwargs={"alpha": 0.6}, + ax=axes, +) + +# Add mean of the posterior to the figure +mean_latent_infection = np.mean( + idata_weekday.posterior_predictive["negbinom_rv"], axis=1 +) + +plt.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") +plt.scatter( + idata_weekday.observed_data["negbinom_rv_dim_0"] + + gen_int_array.size + + days_to_impute, + idata_weekday.observed_data["negbinom_rv"], + color="black", +) +axes.legend() +axes.set_title("Posterior Predictive Infections", fontsize=10) +axes.set_xlabel("Time", fontsize=10) +axes.set_ylabel("Observed Infections", fontsize=10); +``` From da8db9399f710f4aca5758b2f50a390eb2a82884 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Tue, 2 Jul 2024 16:59:10 -0500 Subject: [PATCH 02/47] broken simplerandomwalk --- model/src/pyrenew/process/simplerandomwalk.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index d2c233b3..eb7af67e 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -4,6 +4,7 @@ import jax.numpy as jnp import numpyro as npro import numpyro.distributions as dist +from numpyro.contrib.control_flow import scan from pyrenew.metaclass import RandomVariable @@ -62,12 +63,23 @@ def sample( if init is None: init = npro.sample(name + "_init", self.error_distribution) - diffs = npro.sample( - name + "_diffs", - self.error_distribution.expand((n_timepoints - 1,)), - ) - return (init + jnp.cumsum(jnp.pad(diffs, [1, 0], constant_values=0)),) + # diffs = npro.sample( + # name + "_diffs", + # self.error_distribution.expand((n_timepoints - 1,)), + # ) + + # return (init + jnp.cumsum(jnp.pad(diffs, [1, 0], constant_values=0)),) + + def transition(x_prev, _): # numpydoc ignore=GL08 + diff = npro.sample(name + "_diffs", self.error_distribution) + x_curr = x_prev + diff + return x_curr, x_curr + + # Error occurs when I call scan: + _, x = scan(transition, init=init, xs=None, length=n_timepoints - 1) + + return (jnp.hstack([jnp.atleast_1d(init), x]),) @staticmethod def validate(): From 3c10b3e5118724f6d01739d7e2014682dcb4b9b8 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Tue, 2 Jul 2024 17:08:36 -0500 Subject: [PATCH 03/47] slight simplification --- model/src/pyrenew/process/simplerandomwalk.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index eb7af67e..0ae2fd0a 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -71,7 +71,8 @@ def sample( # return (init + jnp.cumsum(jnp.pad(diffs, [1, 0], constant_values=0)),) - def transition(x_prev, _): # numpydoc ignore=GL08 + def transition(x_prev, _): + # numpydoc ignore=GL08 diff = npro.sample(name + "_diffs", self.error_distribution) x_curr = x_prev + diff return x_curr, x_curr @@ -79,7 +80,7 @@ def transition(x_prev, _): # numpydoc ignore=GL08 # Error occurs when I call scan: _, x = scan(transition, init=init, xs=None, length=n_timepoints - 1) - return (jnp.hstack([jnp.atleast_1d(init), x]),) + return (jnp.hstack([init, x]),) @staticmethod def validate(): From 797b1d12a285650d512aec3f59a552352f669bd1 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 3 Jul 2024 08:48:59 -0500 Subject: [PATCH 04/47] downgrade jax and jaxlib --- docs/pyproject.toml | 4 ++-- model/pyproject.toml | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/pyproject.toml b/docs/pyproject.toml index 1efb8f2e..81ffd5a2 100644 --- a/docs/pyproject.toml +++ b/docs/pyproject.toml @@ -9,8 +9,8 @@ package-mode = false [tool.poetry.dependencies] python = "^3.12" sphinx = "^7.2.6" -jax = "^0.4.25" -jaxlib = "^0.4.25" +jax = "0.4.29" +jaxlib = "^0.4.29" numpyro = "^0.15.0" sphinxcontrib-mermaid = "^0.9.2" polars = "^0.20.16" diff --git a/model/pyproject.toml b/model/pyproject.toml index 63917f66..f7656dd8 100755 --- a/model/pyproject.toml +++ b/model/pyproject.toml @@ -12,7 +12,8 @@ exclude = [{path = "datasets/*.rds"}] [tool.poetry.dependencies] python = "^3.12" numpyro = "^0.15.0" -jax = "^0.4.25" +jax = "0.4.29" +jaxlib = "0.4.29" numpy = "^1.26.4" polars = "^0.20.16" pillow = "^10.3.0" # See #56 on CDCgov/multisignal-epi-inference From 20fb2a4d432ea4d6caf6021fba394341397c5ecb Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 3 Jul 2024 08:49:12 -0500 Subject: [PATCH 05/47] fix scan in simplerandomwalkprocess --- model/src/pyrenew/process/simplerandomwalk.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index 0ae2fd0a..c18bec67 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -64,21 +64,17 @@ def sample( if init is None: init = npro.sample(name + "_init", self.error_distribution) - # diffs = npro.sample( - # name + "_diffs", - # self.error_distribution.expand((n_timepoints - 1,)), - # ) - - # return (init + jnp.cumsum(jnp.pad(diffs, [1, 0], constant_values=0)),) - def transition(x_prev, _): # numpydoc ignore=GL08 diff = npro.sample(name + "_diffs", self.error_distribution) x_curr = x_prev + diff return x_curr, x_curr - # Error occurs when I call scan: - _, x = scan(transition, init=init, xs=None, length=n_timepoints - 1) + _, x = scan( + transition, + init=init, + xs=jnp.arange(n_timepoints - 1), + ) return (jnp.hstack([init, x]),) From df7ac7b92665bad9f5b4eea94575ce3de51439fa Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 3 Jul 2024 14:12:30 -0400 Subject: [PATCH 06/47] add ppc_plot to hosp_model --- model/docs/example_with_datasets.qmd | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index 01729416..b2b7edaa 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -560,3 +560,13 @@ axes.set_title("Posterior Predictive Infections", fontsize=10) axes.set_xlabel("Time", fontsize=10) axes.set_ylabel("Observed Infections", fontsize=10); ``` + +We can also use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data. +```{python} +# | label: fig-output-posterior-predictive-check-plot +# | fig-cap: Posterior Predictive Infections +az.plot_ppc( + data=idata_weekday, + kind="cumulative", # As observed data is a timeseries, use 'cumulative' rather than 'kde' +) +``` From c3ec0829472ef50ed5b34dccebe93d5cf1a3743c Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 3 Jul 2024 14:30:06 -0400 Subject: [PATCH 07/47] test prior/posterior_predictive plots --- model/docs/example_with_datasets.qmd | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index b2b7edaa..d6db7150 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -238,8 +238,29 @@ hosp_model.run( num_warmup=2000, data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), rng_key=jax.random.PRNGKey(54), - mcmc_args=dict(progress_bar=False, num_chains=2), + mcmc_args=dict(progress_bar=False, num_chains=1), +) +``` +Test `posterior_predictive` and `prior_predictive` methods +```{python} +import arviz as az + +idata = az.from_numpyro(hosp_model.mcmc, + posterior_predictive=hosp_model.posterior_predictive(data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy().astype(float))) + +az.plot_lm("negbinom_rv", + idata=idata, +); +``` +The posterior predictive samples are simply the observed data. Looking at prior predictive + +```{python} +prior_pred = hosp_model.prior_predictive( + data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy().astype(float), + numpyro_predictive_args={"num_samples": 100}, ) +# Not the best way but can still see prior_predictive is the same as observed hospital data +prior_pred["negbinom_rv"] ``` We can use the `plot_posterior` method to visualize the results[^capture]: From f13c04d2157cd91d294396247f458553067e0ab2 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 3 Jul 2024 14:09:39 -0500 Subject: [PATCH 08/47] Damon check in --- model/docs/example_with_datasets.qmd | 4 +++- model/src/pyrenew/model/admissionsmodel.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index d6db7150..bc9c85b9 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -527,7 +527,9 @@ We can use ArviZ to visualize the posterior predictive distributions. # | label: posterior-predictive-distribution idata_weekday = az.from_numpyro( hosp_model_weekday.mcmc, - posterior_predictive=hosp_model_weekday.posterior_predictive(n_timepoints_to_simulate=1) + posterior_predictive=hosp_model_weekday.posterior_predictive( + n_timepoints_to_simulate=len(dat_w_padding) + ), ) ``` Below we plot the posterior predictive distributions using equal tailed bayesian credible intervals: diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index 5db372f8..a190bcfc 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -208,7 +208,7 @@ def sample( observed_hosp_admissions, *_, ) = self.hosp_admission_obs_process_rv.sample( - mu=latent_hosp_admissions, + mu=latent_hosp_admissions[i0_size + padding :], obs=data_observed_hosp_admissions, **kwargs, ) From c7a596a6a853a2e973ded6b8bf9abeb2417678ae Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 5 Jul 2024 14:21:14 -0400 Subject: [PATCH 09/47] add plot_ppc, plot_lm --- model/docs/example_with_datasets.qmd | 97 ++++++++++++---------------- 1 file changed, 41 insertions(+), 56 deletions(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index bc9c85b9..11273ba5 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -227,42 +227,18 @@ plt.show() ## Fitting the model We can fit the model to the data. We will use the `run` method of the model object: - - ```{python} # | label: model-fit import jax hosp_model.run( - num_samples=2000, - num_warmup=2000, + num_samples=1000, + num_warmup=1000, data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), rng_key=jax.random.PRNGKey(54), - mcmc_args=dict(progress_bar=False, num_chains=1), -) -``` -Test `posterior_predictive` and `prior_predictive` methods -```{python} -import arviz as az - -idata = az.from_numpyro(hosp_model.mcmc, - posterior_predictive=hosp_model.posterior_predictive(data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy().astype(float))) - -az.plot_lm("negbinom_rv", - idata=idata, -); -``` -The posterior predictive samples are simply the observed data. Looking at prior predictive - -```{python} -prior_pred = hosp_model.prior_predictive( - data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy().astype(float), - numpyro_predictive_args={"num_samples": 100}, + mcmc_args=dict(progress_bar=False, num_chains=2), ) -# Not the best way but can still see prior_predictive is the same as observed hospital data -prior_pred["negbinom_rv"] ``` - We can use the `plot_posterior` method to visualize the results[^capture]: [^capture]: The output is captured to avoid `quarto` from displaying the output twice. @@ -280,6 +256,38 @@ out = hosp_model.plot_posterior( ), ) ``` +Test `posterior_predictive` and `prior_predictive` methods +```{python} +# | label: demonstrate us of predictive methods +import arviz as az + +idata = az.from_numpyro( + hosp_model.mcmc, + posterior_predictive=hosp_model.posterior_predictive( + n_timepoints_to_simulate=len(dat["daily_hosp_admits"]) + ), + prior=hosp_model.prior_predictive( + n_timepoints_to_simulate=len(dat["daily_hosp_admits"]), + numpyro_predictive_args={"num_samples": 1000}, + ), +) + +az.plot_lm( + "negbinom_rv", + idata=idata, +) +``` +We can also use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data. +```{python} +# | label: fig-output-posterior-predictive-check-plot +# | fig-cap: Posterior Predictive Infections +az.plot_ppc( + idata, + kind="scatter", + coords={"negbinom_rv_dim_0": [0, 1]}, + flatten=[], +); +``` The first half of the model is not looking good. The reason is that the infection to hospitalization interval PMF makes it unlikely to observe admissions from the beginning. The following section shows how to fix this. @@ -298,17 +306,15 @@ dat_w_padding = np.pad( constant_values=np.nan, ) - hosp_model.run( - num_samples=2000, - num_warmup=2000, + num_samples=1000, + num_warmup=1000, data_observed_hosp_admissions=dat_w_padding, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False, num_chains=2), padding=days_to_impute, # Padding the model ) ``` - And plotting the results: ```{python} @@ -317,17 +323,13 @@ And plotting the results: out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", - obs_signal=np.pad( - dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan - ), + obs_signal=np.pad(dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan), ) ``` We can use [ArviZ](https://www.arviz.org/) to visualize the results. Let's start by converting the fitted model to Arviz InferenceData object: ```{python} # | label: convert-inferenceData -import arviz as az - idata = az.from_numpyro(hosp_model.mcmc) ``` We obtain the summary of model diagnostics and print the diagnostics for `latent_hospital_admissions[1]` @@ -525,20 +527,14 @@ out = hosp_model_weekday.plot_posterior( We can use ArviZ to visualize the posterior predictive distributions. ```{python} # | label: posterior-predictive-distribution -idata_weekday = az.from_numpyro( - hosp_model_weekday.mcmc, - posterior_predictive=hosp_model_weekday.posterior_predictive( - n_timepoints_to_simulate=len(dat_w_padding) - ), -) +idata_weekday = az.from_numpyro(hosp_model_weekday.mcmc, +posterior_predictive=hosp_model_weekday.posterior_predictive(n_timepoints_to_simulate=len(dat_w_padding))) ``` Below we plot the posterior predictive distributions using equal tailed bayesian credible intervals: ```{python} # | label: fig-output-posterior-predictive # | fig-cap: Posterior Predictive Infections - - def compute_eti(dataset, eti_prob): eti_bdry = dataset.quantile( ((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw") @@ -570,10 +566,9 @@ mean_latent_infection = np.mean( idata_weekday.posterior_predictive["negbinom_rv"], axis=1 ) -plt.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") +plt.plot(idata_weekday.posterior_predictive["negbinom_rv_dim_0"], mean_latent_infection[0], color="C0", label="Mean") plt.scatter( idata_weekday.observed_data["negbinom_rv_dim_0"] - + gen_int_array.size + days_to_impute, idata_weekday.observed_data["negbinom_rv"], color="black", @@ -583,13 +578,3 @@ axes.set_title("Posterior Predictive Infections", fontsize=10) axes.set_xlabel("Time", fontsize=10) axes.set_ylabel("Observed Infections", fontsize=10); ``` - -We can also use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data. -```{python} -# | label: fig-output-posterior-predictive-check-plot -# | fig-cap: Posterior Predictive Infections -az.plot_ppc( - data=idata_weekday, - kind="cumulative", # As observed data is a timeseries, use 'cumulative' rather than 'kde' -) -``` From 5fc270490438da6a4efc6bdffb2882b13502f520 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Mon, 8 Jul 2024 09:40:00 -0500 Subject: [PATCH 10/47] fix tests --- model/src/test/test_model_hospitalizations.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index 01c21134..195004b7 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -572,10 +572,11 @@ def test_model_hosp_with_obs_model_weekday_phosp(): with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model1.sample(n_timepoints_to_simulate=n_obs_to_generate) + pad_size = 5 obs = jnp.hstack( [ - jnp.repeat(jnp.nan, 5), - model1_samp.observed_hosp_admissions[5 + gen_int.size() :], + jnp.repeat(jnp.nan, pad_size), + model1_samp.observed_hosp_admissions[pad_size:], ] ) # Running with padding @@ -584,7 +585,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): num_samples=500, rng_key=jr.key(272), data_observed_hosp_admissions=obs, - padding=5, + padding=pad_size, ) inf = model1.spread_draws(["latent_hospital_admissions"]) From 4f0e74f4e7d8790e1f6e22e464c6349afcce81f3 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Mon, 8 Jul 2024 09:41:16 -0500 Subject: [PATCH 11/47] formatting --- model/docs/example_with_datasets.qmd | 35 +++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index 11273ba5..355668f3 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -323,7 +323,9 @@ And plotting the results: out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", - obs_signal=np.pad(dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan), + obs_signal=np.pad( + dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan + ), ) ``` @@ -332,7 +334,9 @@ We can use [ArviZ](https://www.arviz.org/) to visualize the results. Let's start # | label: convert-inferenceData idata = az.from_numpyro(hosp_model.mcmc) ``` + We obtain the summary of model diagnostics and print the diagnostics for `latent_hospital_admissions[1]` + ```{python} # | label: diagnostics # | warning: false @@ -524,12 +528,19 @@ out = hosp_model_weekday.plot_posterior( ), ) ``` + We can use ArviZ to visualize the posterior predictive distributions. + ```{python} # | label: posterior-predictive-distribution -idata_weekday = az.from_numpyro(hosp_model_weekday.mcmc, -posterior_predictive=hosp_model_weekday.posterior_predictive(n_timepoints_to_simulate=len(dat_w_padding))) +idata_weekday = az.from_numpyro( + hosp_model_weekday.mcmc, + posterior_predictive=hosp_model_weekday.posterior_predictive( + n_timepoints_to_simulate=len(dat_w_padding) + ), +) ``` + Below we plot the posterior predictive distributions using equal tailed bayesian credible intervals: ```{python} @@ -545,7 +556,9 @@ def compute_eti(dataset, eti_prob): fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( idata_weekday.posterior_predictive["negbinom_rv_dim_0"], - hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.9), + hdi_data=compute_eti( + idata_weekday.posterior_predictive["negbinom_rv"], 0.9 + ), color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, @@ -554,7 +567,9 @@ az.plot_hdi( az.plot_hdi( idata_weekday.posterior_predictive["negbinom_rv_dim_0"], - hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.5), + hdi_data=compute_eti( + idata_weekday.posterior_predictive["negbinom_rv"], 0.5 + ), color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, @@ -566,10 +581,14 @@ mean_latent_infection = np.mean( idata_weekday.posterior_predictive["negbinom_rv"], axis=1 ) -plt.plot(idata_weekday.posterior_predictive["negbinom_rv_dim_0"], mean_latent_infection[0], color="C0", label="Mean") +plt.plot( + idata_weekday.posterior_predictive["negbinom_rv_dim_0"], + mean_latent_infection[0], + color="C0", + label="Mean", +) plt.scatter( - idata_weekday.observed_data["negbinom_rv_dim_0"] - + days_to_impute, + idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute, idata_weekday.observed_data["negbinom_rv"], color="black", ) From 21dd93daf564e5c80ac060e20c706ee00794d85c Mon Sep 17 00:00:00 2001 From: damonbayer Date: Mon, 8 Jul 2024 09:48:51 -0500 Subject: [PATCH 12/47] more formatting --- model/docs/example_with_datasets.qmd | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index 355668f3..1a46cedc 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -227,6 +227,7 @@ plt.show() ## Fitting the model We can fit the model to the data. We will use the `run` method of the model object: + ```{python} # | label: model-fit import jax @@ -239,6 +240,7 @@ hosp_model.run( mcmc_args=dict(progress_bar=False, num_chains=2), ) ``` + We can use the `plot_posterior` method to visualize the results[^capture]: [^capture]: The output is captured to avoid `quarto` from displaying the output twice. @@ -256,7 +258,9 @@ out = hosp_model.plot_posterior( ), ) ``` + Test `posterior_predictive` and `prior_predictive` methods + ```{python} # | label: demonstrate us of predictive methods import arviz as az @@ -277,7 +281,9 @@ az.plot_lm( idata=idata, ) ``` + We can also use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data. + ```{python} # | label: fig-output-posterior-predictive-check-plot # | fig-cap: Posterior Predictive Infections @@ -315,6 +321,7 @@ hosp_model.run( padding=days_to_impute, # Padding the model ) ``` + And plotting the results: ```{python} @@ -330,6 +337,7 @@ out = hosp_model.plot_posterior( ``` We can use [ArviZ](https://www.arviz.org/) to visualize the results. Let's start by converting the fitted model to Arviz InferenceData object: + ```{python} # | label: convert-inferenceData idata = az.from_numpyro(hosp_model.mcmc) @@ -397,6 +405,7 @@ out2 = hosp_model.plot_posterior( var="all_latent_infections", ylab="Latent Infections" ) ``` + and the distribution of latent infections ```{python} From 76932df1cd4f5eeb1e39b563f6a967d394e2ccdd Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 8 Jul 2024 11:22:49 -0400 Subject: [PATCH 13/47] fix plot_lm output issue --- model/docs/example_with_datasets.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index 1a46cedc..9ce25f1d 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -279,7 +279,7 @@ idata = az.from_numpyro( az.plot_lm( "negbinom_rv", idata=idata, -) +); ``` We can also use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data. From e4525ed3c22fb1568a0e3165604fdeb6cad85a24 Mon Sep 17 00:00:00 2001 From: Subekshya Bidari <37636707+sbidari@users.noreply.github.com> Date: Mon, 8 Jul 2024 11:36:46 -0400 Subject: [PATCH 14/47] suggestion from code review Co-authored-by: Damon Bayer --- model/docs/example_with_datasets.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index 9ce25f1d..31fa7ab5 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -282,7 +282,7 @@ az.plot_lm( ); ``` -We can also use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data. +We can also use the `plot_ppc` method to compare the posterior predictive samples with the observed data. ```{python} # | label: fig-output-posterior-predictive-check-plot From 320cfbe6f6f2e68a0957ab0915f08ad0cdb616d5 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Mon, 8 Jul 2024 11:22:33 -0500 Subject: [PATCH 15/47] Update example_with_datasets.qmd --- model/docs/example_with_datasets.qmd | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index 31fa7ab5..704e3a55 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -262,7 +262,8 @@ out = hosp_model.plot_posterior( Test `posterior_predictive` and `prior_predictive` methods ```{python} -# | label: demonstrate us of predictive methods +# | label: fig-demonstrate-use-of-predictive-methods +# | fig-cap: Hospital Admissions posterior distribution with plot_lm import arviz as az idata = az.from_numpyro( From 3b2a3a0cef1b1ec410795c1544c032547ad65f00 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 8 Jul 2024 12:53:35 -0400 Subject: [PATCH 16/47] change plot_lm kwargs --- model/docs/example_with_datasets.qmd | 3 +++ 1 file changed, 3 insertions(+) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index 9ce25f1d..28ed7f07 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -279,6 +279,9 @@ idata = az.from_numpyro( az.plot_lm( "negbinom_rv", idata=idata, + kind_pp="hdi", + y_kwargs={"color": "black"}, + y_hat_fill_kwargs={"color": "C0"}, ); ``` From ce58119995df08abb7811e0018d4b70a61a8a215 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 8 Jul 2024 14:03:11 -0400 Subject: [PATCH 17/47] add details on plot_ppc, fig labels and titles --- model/docs/example_with_datasets.qmd | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index 28ed7f07..3b3cf04f 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -276,26 +276,36 @@ idata = az.from_numpyro( ), ) +fig, ax = plt.subplots() az.plot_lm( "negbinom_rv", idata=idata, kind_pp="hdi", y_kwargs={"color": "black"}, y_hat_fill_kwargs={"color": "C0"}, -); + axes=ax, +) + +ax.set_title("Posterior Predictive Plot") +ax.set_ylabel("Hospital Admissions") +ax.set_xlabel("Days") +plt.show() ``` -We can also use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data. +We can assess model fit by comparing the model prediction with the observed data. Below we use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data for the first two time period. This allows us to check how well the model predictions approximate observed data. ```{python} # | label: fig-output-posterior-predictive-check-plot # | fig-cap: Posterior Predictive Infections +fig, ax = plt.subplots(1, 2, figsize=(9, 5)) az.plot_ppc( - idata, - kind="scatter", - coords={"negbinom_rv_dim_0": [0, 1]}, - flatten=[], -); + idata, kind="scatter", coords={"negbinom_rv_dim_0": [0, 1]}, flatten=[], ax=ax +) + +fig.suptitle("Posterior Predictive Check Plots") +ax[0].set_xlabel("Hospital Admissions[0]") +ax[1].set_xlabel("Hospital Admissions[1]") +plt.show() ``` The first half of the model is not looking good. The reason is that the infection to hospitalization interval PMF makes it unlikely to observe admissions from the beginning. The following section shows how to fix this. From 7954f116e32fff066121dd41329af00cd4b71e05 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 8 Jul 2024 14:39:14 -0400 Subject: [PATCH 18/47] add figure descriptions --- model/docs/example_with_datasets.qmd | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index fadd5b3d..f72d1b0c 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -259,11 +259,10 @@ out = hosp_model.plot_posterior( ) ``` -Test `posterior_predictive` and `prior_predictive` methods +Test `posterior_predictive` and `prior_predictive` methods to generate posterior and prior predictive samples. ```{python} -# | label: fig-demonstrate-use-of-predictive-methods -# | fig-cap: Hospital Admissions posterior distribution with plot_lm +# | label: demonstrate-use-of-predictive-methods import arviz as az idata = az.from_numpyro( @@ -276,7 +275,11 @@ idata = az.from_numpyro( numpyro_predictive_args={"num_samples": 1000}, ), ) - +``` +We will use `plot_lm` method from ArviZ to plot the posterior predictive distribution and observed data below: +```{python} +# | label: fig-posterior-predictive +# | fig-cap: Hospital Admissions posterior distribution with plot_lm fig, ax = plt.subplots() az.plot_lm( "negbinom_rv", From 05583a0e408afbcf17adfbd8c555936135543826 Mon Sep 17 00:00:00 2001 From: Subekshya Bidari <37636707+sbidari@users.noreply.github.com> Date: Mon, 8 Jul 2024 14:40:43 -0400 Subject: [PATCH 19/47] code review suggestions Co-authored-by: Damon Bayer --- model/docs/example_with_datasets.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index f72d1b0c..ec0b14bc 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -296,7 +296,7 @@ ax.set_xlabel("Days") plt.show() ``` -We can assess model fit by comparing the model prediction with the observed data. Below we use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data for the first two time period. This allows us to check how well the model predictions approximate observed data. +We can assess model fit by comparing the model prediction with the observed data. Below we use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data for the first two time steps. This allows us to check how well the model predictions approximate observed data. ```{python} # | label: fig-output-posterior-predictive-check-plot From 034a624c7980b993eb14762ee90b733617c8ef51 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 8 Jul 2024 14:50:47 -0400 Subject: [PATCH 20/47] remove link from code --- model/docs/example_with_datasets.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index ec0b14bc..911c478a 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -296,7 +296,7 @@ ax.set_xlabel("Days") plt.show() ``` -We can assess model fit by comparing the model prediction with the observed data. Below we use the [`plot_ppc`](https://python.arviz.org/en/latest/api/generated/arviz.plot_ppc.html) to compare the posterior predictive samples with the observed data for the first two time steps. This allows us to check how well the model predictions approximate observed data. +We can assess model fit by comparing the model prediction with the observed data. Below we use the `plot_ppc` to compare the posterior predictive samples with the observed data for the first two time steps. This allows us to check how well the model predictions approximate observed data. ```{python} # | label: fig-output-posterior-predictive-check-plot From 47fbf8d109656da0bd47ab5f20de0397dc62c450 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Mon, 8 Jul 2024 14:29:30 -0500 Subject: [PATCH 21/47] formatting --- model/docs/example_with_datasets.qmd | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/model/docs/example_with_datasets.qmd b/model/docs/example_with_datasets.qmd index 911c478a..81321b09 100644 --- a/model/docs/example_with_datasets.qmd +++ b/model/docs/example_with_datasets.qmd @@ -303,7 +303,11 @@ We can assess model fit by comparing the model prediction with the observed data # | fig-cap: Posterior Predictive Infections fig, ax = plt.subplots(1, 2, figsize=(9, 5)) az.plot_ppc( - idata, kind="scatter", coords={"negbinom_rv_dim_0": [0, 1]}, flatten=[], ax=ax + idata, + kind="scatter", + coords={"negbinom_rv_dim_0": [0, 1]}, + flatten=[], + ax=ax, ) fig.suptitle("Posterior Predictive Check Plots") @@ -347,7 +351,9 @@ And plotting the results: out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", - obs_signal=np.pad(dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan), + obs_signal=np.pad( + dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan + ), ) ``` @@ -451,7 +457,9 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_latent_infection = np.mean(idata.posterior["all_latent_infections"], axis=1) +mean_latent_infection = np.mean( + idata.posterior["all_latent_infections"], axis=1 +) axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") axes.legend() axes.set_title("Posterior Latent Infections", fontsize=10) From c744deef71a3a6f809db4f22c3a75c468d9c7758 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Mon, 8 Jul 2024 16:03:46 -0500 Subject: [PATCH 22/47] simplify models --- model/src/pyrenew/model/admissionsmodel.py | 37 +++++-------------- .../pyrenew/model/rtinfectionsrenewalmodel.py | 3 -- 2 files changed, 9 insertions(+), 31 deletions(-) diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index 5db372f8..c406b883 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -5,8 +5,6 @@ from typing import NamedTuple -import jax.numpy as jnp -import pyrenew.arrayutils as au from jax.typing import ArrayLike from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel @@ -202,32 +200,15 @@ def sample( i0_size = len(latent_hosp_admissions) - n_timepoints if self.hosp_admission_obs_process_rv is None: observed_hosp_admissions = None - else: - if data_observed_hosp_admissions is None: - ( - observed_hosp_admissions, - *_, - ) = self.hosp_admission_obs_process_rv.sample( - mu=latent_hosp_admissions, - obs=data_observed_hosp_admissions, - **kwargs, - ) - else: - data_observed_hosp_admissions = au.pad_x_to_match_y( - data_observed_hosp_admissions, - latent_hosp_admissions, - jnp.nan, - pad_direction="start", - ) - - ( - observed_hosp_admissions, - *_, - ) = self.hosp_admission_obs_process_rv.sample( - mu=latent_hosp_admissions[i0_size + padding :], - obs=data_observed_hosp_admissions[i0_size + padding :], - **kwargs, - ) + + ( + observed_hosp_admissions, + *_, + ) = self.hosp_admission_obs_process_rv.sample( + mu=latent_hosp_admissions[i0_size + padding :], + obs=data_observed_hosp_admissions, + **kwargs, + ) return HospModelSample( Rt=basic_model.Rt, diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 2b5498f8..08a3f75b 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -210,9 +210,6 @@ def sample( **kwargs, ) - if data_observed_infections is not None: - data_observed_infections = data_observed_infections[padding:] - observed_infections, *_ = self.infection_obs_process_rv.sample( mu=post_seed_latent_infections[padding:], obs=data_observed_infections, From cf6ff45ccb7f2525086f9ab816df3a844c440d5d Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 11:20:53 -0500 Subject: [PATCH 23/47] rename hospital_admissions_model --- ...{example_with_datasets.rst => hospital_admissions_model.rst} | 2 +- ...{example_with_datasets.qmd => hospital_admissions_model.qmd} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename docs/source/tutorials/{example_with_datasets.rst => hospital_admissions_model.rst} (80%) rename model/docs/{example_with_datasets.qmd => hospital_admissions_model.qmd} (100%) diff --git a/docs/source/tutorials/example_with_datasets.rst b/docs/source/tutorials/hospital_admissions_model.rst similarity index 80% rename from docs/source/tutorials/example_with_datasets.rst rename to docs/source/tutorials/hospital_admissions_model.rst index 377d7b1d..3323aee6 100644 --- a/docs/source/tutorials/example_with_datasets.rst +++ b/docs/source/tutorials/hospital_admissions_model.rst @@ -2,4 +2,4 @@ .. Please do not edit this file directly. .. This file is just a placeholder. .. For the source file, see: -.. +.. diff --git a/model/docs/example_with_datasets.qmd b/model/docs/hospital_admissions_model.qmd similarity index 100% rename from model/docs/example_with_datasets.qmd rename to model/docs/hospital_admissions_model.qmd From 19ad4c5088681e8fdf9522a5a92b1493d462e7c2 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 11:37:21 -0500 Subject: [PATCH 24/47] fix bad merges --- docs/source/tutorials/basic_renewal_model.rst | 4 ---- docs/source/tutorials/hospital_admissions_model.rst | 4 ---- 2 files changed, 8 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.rst b/docs/source/tutorials/basic_renewal_model.rst index d9e6afaa..7b6407d8 100644 --- a/docs/source/tutorials/basic_renewal_model.rst +++ b/docs/source/tutorials/basic_renewal_model.rst @@ -2,8 +2,4 @@ .. Please do not edit this file directly. .. This file is just a placeholder. .. For the source file, see: -<<<<<<<< HEAD:docs/source/tutorials/hospital_admissions_model.rst -.. -======== .. ->>>>>>>> main:docs/source/tutorials/basic_renewal_model.rst diff --git a/docs/source/tutorials/hospital_admissions_model.rst b/docs/source/tutorials/hospital_admissions_model.rst index d9e6afaa..3323aee6 100644 --- a/docs/source/tutorials/hospital_admissions_model.rst +++ b/docs/source/tutorials/hospital_admissions_model.rst @@ -2,8 +2,4 @@ .. Please do not edit this file directly. .. This file is just a placeholder. .. For the source file, see: -<<<<<<<< HEAD:docs/source/tutorials/hospital_admissions_model.rst .. -======== -.. ->>>>>>>> main:docs/source/tutorials/basic_renewal_model.rst From 21f8e7765ca0a9d0498dd3f5a396aac60df166a6 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 12:06:30 -0500 Subject: [PATCH 25/47] fix hospital_admissions_model.qmd --- model/docs/hospital_admissions_model.qmd | 1 - 1 file changed, 1 deletion(-) diff --git a/model/docs/hospital_admissions_model.qmd b/model/docs/hospital_admissions_model.qmd index 387e7bf6..cd6b7c0e 100644 --- a/model/docs/hospital_admissions_model.qmd +++ b/model/docs/hospital_admissions_model.qmd @@ -555,7 +555,6 @@ idata_weekday = az.from_numpyro( n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()) + days_to_impute + 28 - ) ), prior=hosp_model_weekday.prior_predictive( n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()) From 0c4c906cfd49d4d297976a68fc39faefa13ab175 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 13:58:02 -0500 Subject: [PATCH 26/47] fix most tests --- model/src/pyrenew/model/admissionsmodel.py | 4 ++++ .../pyrenew/model/rtinfectionsrenewalmodel.py | 2 +- model/src/test/test_model_basic_renewal.py | 11 +++++------ model/src/test/test_model_hospitalizations.py | 18 +++++++----------- model/src/test/test_random_walk.py | 4 ++-- 5 files changed, 19 insertions(+), 20 deletions(-) diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index f3883e53..582d21ff 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -6,6 +6,7 @@ from typing import NamedTuple from jax.typing import ArrayLike +from pyrenew.deterministic import NullObservation from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel @@ -97,6 +98,9 @@ def __init__( ) self.latent_hosp_admissions_rv = latent_hosp_admissions_rv + if hosp_admission_obs_process_rv is None: + hosp_admission_obs_process_rv = NullObservation() + self.hosp_admission_obs_process_rv = hosp_admission_obs_process_rv @staticmethod diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 08a3f75b..246d0899 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -187,7 +187,7 @@ def sample( "Cannot pass both n_timepoints_to_simulate and data_observed_infections." ) elif n_timepoints_to_simulate is None: - n_timepoints = len(data_observed_infections) + n_timepoints = len(data_observed_infections) + padding else: n_timepoints = n_timepoints_to_simulate # Sampling from Rt (possibly with a given Rt, depending on diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index c8c53ec3..f79a7e4c 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -267,18 +267,17 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 # Sampling and fitting model 1 (with obs infections) np.random.seed(2203) + pad_size = 5 with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - model1_samp = model1.sample(n_timepoints_to_simulate=30) - - new_obs = jnp.hstack( - [jnp.repeat(jnp.nan, 5), model1_samp.observed_infections[5:]], - ) + model1_samp = model1.sample( + n_timepoints_to_simulate=30, padding=pad_size + ) model1.run( num_warmup=500, num_samples=500, rng_key=jr.key(22), - data_observed_infections=new_obs, + data_observed_infections=model1_samp.observed_infections, padding=5, ) diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index d6d5c023..69c5d2d8 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -253,13 +253,13 @@ def test_model_hosp_no_obs_model(): with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model0_samp = model0.sample(n_timepoints_to_simulate=30) - model0.observation_process = NullObservation() + model0.hosp_admission_obs_process_rv = NullObservation() np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model0.sample(n_timepoints_to_simulate=30) - np.testing.assert_array_equal(model0_samp.Rt, model1_samp.Rt) + np.testing.assert_array_almost_equal(model0_samp.Rt, model1_samp.Rt) np.testing.assert_array_equal( model0_samp.latent_infections, model1_samp.latent_infections ) @@ -572,23 +572,19 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) # Sampling and fitting model 0 (with no obs for infections) + pad_size = 5 np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - model1_samp = model1.sample(n_timepoints_to_simulate=n_obs_to_generate) + model1_samp = model1.sample( + n_timepoints_to_simulate=n_obs_to_generate, padding=pad_size + ) - pad_size = 5 - obs = jnp.hstack( - [ - jnp.repeat(jnp.nan, pad_size), - model1_samp.observed_hosp_admissions[pad_size:], - ] - ) # Running with padding model1.run( num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=obs, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, padding=pad_size, ) diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index 9f1335e1..3e766b02 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -16,7 +16,7 @@ def test_rw_can_be_sampled(): with numpyro.handlers.seed(rng_seed=62): # can sample with and without inits - ans0 = rw_normal.sample(3532, init=jnp.array([50.0])) + ans0 = rw_normal.sample(3532, init=50.0) ans1 = rw_normal.sample(5023) # check that the samples are of the right shape @@ -35,7 +35,7 @@ def test_rw_samples_correctly_distributed(): [0, 2.253, -3.2521, 1052, 1e-6], [1, 0.025, 3, 1, 0.02] ): rw_normal = SimpleRandomWalkProcess(dist.Normal(step_mean, step_sd)) - init_arr = jnp.array([532.0]) + init_arr = 532.0 with numpyro.handlers.seed(rng_seed=62): samples, *_ = rw_normal.sample(n_samples, init=init_arr) From 464420839f31a2b5ecda15b50961e0e40867609b Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 14:08:45 -0500 Subject: [PATCH 27/47] fix remaining tests --- model/src/test/test_random_key.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index f173012c..a2342552 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -94,6 +94,9 @@ def test_rng_keys_produce_correct_samples(): # set up base models for testing models = [create_test_model() for _ in range(5)] n_timepoints_to_simulate = [30] * len(models) + n_timepoints_posterior_predictive = [ + x + models[0].gen_int_rv.size() for x in n_timepoints_to_simulate + ] # sample only a single model and use that model's samples # as the observed_infections for the rest of the models with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): @@ -115,7 +118,9 @@ def test_rng_keys_produce_correct_samples(): posterior_predictive_list = [ posterior_predictive_test_model(*elt) - for elt in list(zip(models, n_timepoints_to_simulate, rng_keys)) + for elt in list( + zip(models, n_timepoints_posterior_predictive, rng_keys) + ) ] # using same rng_key should get same run samples assert_array_equal( From b8cb7e876fb1e4aea23d0ed558c4843a2dd922f8 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 14:27:34 -0500 Subject: [PATCH 28/47] rename test variable --- model/src/test/test_random_walk.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index 3e766b02..9f8c6839 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -35,9 +35,9 @@ def test_rw_samples_correctly_distributed(): [0, 2.253, -3.2521, 1052, 1e-6], [1, 0.025, 3, 1, 0.02] ): rw_normal = SimpleRandomWalkProcess(dist.Normal(step_mean, step_sd)) - init_arr = 532.0 + init = 532.0 with numpyro.handlers.seed(rng_seed=62): - samples, *_ = rw_normal.sample(n_samples, init=init_arr) + samples, *_ = rw_normal.sample(n_samples, init=init) # Checking the shape assert samples.shape == (n_samples,) @@ -60,4 +60,4 @@ def test_rw_samples_correctly_distributed(): assert jnp.abs(jnp.log(jnp.std(diffs) / step_sd)) < jnp.log(1.1) # first value should be the init value - assert_almost_equal(samples[0], init_arr) + assert_almost_equal(samples[0], init) From 98175619b6b5941357eec166a953df67a60ec795 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 16:08:38 -0500 Subject: [PATCH 29/47] adjust padding handling --- model/docs/hospital_admissions_model.qmd | 37 +++++++++++++------ model/src/pyrenew/model/admissionsmodel.py | 14 +++++-- .../pyrenew/model/rtinfectionsrenewalmodel.py | 2 +- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/model/docs/hospital_admissions_model.qmd b/model/docs/hospital_admissions_model.qmd index cd6b7c0e..c56c5c6d 100644 --- a/model/docs/hospital_admissions_model.qmd +++ b/model/docs/hospital_admissions_model.qmd @@ -552,13 +552,12 @@ By increasing `n_timepoints_to_simulate`, we can perform forecasting. idata_weekday = az.from_numpyro( hosp_model_weekday.mcmc, posterior_predictive=hosp_model_weekday.posterior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()) - + days_to_impute - + 28 + n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()) + 28, + padding=days_to_impute, ), prior=hosp_model_weekday.prior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()) - + days_to_impute, + n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()), + padding=days_to_impute, numpyro_predictive_args={"num_samples": 1000}, ), ) @@ -578,7 +577,9 @@ def compute_eti(dataset, eti_prob): fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata_weekday.prior_predictive["negbinom_rv_dim_0"], + idata_weekday.prior_predictive["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.9), color="C0", smooth=False, @@ -587,7 +588,9 @@ az.plot_hdi( ) az.plot_hdi( - idata_weekday.prior_predictive["negbinom_rv_dim_0"], + idata_weekday.prior_predictive["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.5), color="C0", smooth=False, @@ -596,7 +599,9 @@ az.plot_hdi( ) plt.scatter( - idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute, + idata_weekday.observed_data["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), idata_weekday.observed_data["negbinom_rv"], color="black", ) @@ -614,7 +619,9 @@ And now we plot the posterior predictive distributions: # | fig-cap: Posterior Predictive Infections fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"], + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), hdi_data=compute_eti( idata_weekday.posterior_predictive["negbinom_rv"], 0.9 ), @@ -625,7 +632,9 @@ az.plot_hdi( ) az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"], + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), hdi_data=compute_eti( idata_weekday.posterior_predictive["negbinom_rv"], 0.5 ), @@ -641,13 +650,17 @@ mean_latent_infection = np.mean( ) plt.plot( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"], + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), mean_latent_infection[0], color="C0", label="Mean", ) plt.scatter( - idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute, + idata_weekday.observed_data["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), idata_weekday.observed_data["negbinom_rv"], color="black", ) diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index 582d21ff..fe8fcb00 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -182,11 +182,11 @@ def sample( elif n_timepoints_to_simulate is None: n_timepoints = len(data_observed_hosp_admissions) + padding else: - n_timepoints = n_timepoints_to_simulate + n_timepoints = n_timepoints_to_simulate + padding # Getting the initial quantities from the basic model basic_model = self.basic_renewal.sample( - n_timepoints_to_simulate=n_timepoints, + n_timepoints_to_simulate=n_timepoints - padding, data_observed_infections=None, padding=padding, **kwargs, @@ -202,6 +202,14 @@ def sample( **kwargs, ) i0_size = len(latent_hosp_admissions) - n_timepoints + print(f"len(latent_hosp_admissions): {len(latent_hosp_admissions)}") + print(f"n_timepoints: {n_timepoints}") + print(f"padding: {padding}") + print(f"i0_size: {i0_size}") + if data_observed_hosp_admissions is not None: + print( + f"len(data_observed_hosp_admissions): {len(data_observed_hosp_admissions)}" + ) if self.hosp_admission_obs_process_rv is None: observed_hosp_admissions = None @@ -209,7 +217,7 @@ def sample( observed_hosp_admissions, *_, ) = self.hosp_admission_obs_process_rv.sample( - mu=latent_hosp_admissions[i0_size + padding :], + mu=latent_hosp_admissions[-(n_timepoints - padding) :], obs=data_observed_hosp_admissions, **kwargs, ) diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 246d0899..afbdb8d0 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -189,7 +189,7 @@ def sample( elif n_timepoints_to_simulate is None: n_timepoints = len(data_observed_infections) + padding else: - n_timepoints = n_timepoints_to_simulate + n_timepoints = n_timepoints_to_simulate + padding # Sampling from Rt (possibly with a given Rt, depending on # the Rt_process (RandomVariable) object.) Rt, *_ = self.Rt_process_rv.sample( From 9c48df9510a5356beda670e6672bfb62a6bceab5 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 16:17:12 -0500 Subject: [PATCH 30/47] rename internal variables --- model/src/pyrenew/model/admissionsmodel.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index fe8fcb00..d8ff8763 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -180,13 +180,15 @@ def sample( "Cannot pass both n_timepoints_to_simulate and data_observed_hosp_admissions." ) elif n_timepoints_to_simulate is None: - n_timepoints = len(data_observed_hosp_admissions) + padding + n_datapoints = len(data_observed_hosp_admissions) else: - n_timepoints = n_timepoints_to_simulate + padding + n_datapoints = n_timepoints_to_simulate + + n_timepoints = n_datapoints + padding # Getting the initial quantities from the basic model basic_model = self.basic_renewal.sample( - n_timepoints_to_simulate=n_timepoints - padding, + n_timepoints_to_simulate=n_datapoints, data_observed_infections=None, padding=padding, **kwargs, @@ -217,7 +219,7 @@ def sample( observed_hosp_admissions, *_, ) = self.hosp_admission_obs_process_rv.sample( - mu=latent_hosp_admissions[-(n_timepoints - padding) :], + mu=latent_hosp_admissions[-n_datapoints:], obs=data_observed_hosp_admissions, **kwargs, ) From 7b392d3cd236d7059f9d299ce706a5f90330a4d7 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 16:24:35 -0500 Subject: [PATCH 31/47] formatting --- model/docs/basic_renewal_model.qmd | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/model/docs/basic_renewal_model.qmd b/model/docs/basic_renewal_model.qmd index 3fc88342..e2433c99 100644 --- a/model/docs/basic_renewal_model.qmd +++ b/model/docs/basic_renewal_model.qmd @@ -214,7 +214,9 @@ model1.run( num_samples=1000, data_observed_infections=sim_data.observed_infections, rng_key=jax.random.PRNGKey(54), - mcmc_args=dict(progress_bar=False, num_chains=2, chain_method="sequential"), + mcmc_args=dict( + progress_bar=False, num_chains=2, chain_method="sequential" + ), ) ``` @@ -335,7 +337,9 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_latent_infection = np.mean(idata.posterior["all_latent_infections"], axis=1) +mean_latent_infection = np.mean( + idata.posterior["all_latent_infections"], axis=1 +) axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") axes.legend() axes.set_title("Posterior Latent Infections", fontsize=10) From 30a9a311f2e70e92f2887455139432d97e8bf1f2 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 17:26:16 -0500 Subject: [PATCH 32/47] correcting merge errors --- .../tutorials/hospital_admissions_model.qmd | 76 +++++++++++++------ 1 file changed, 51 insertions(+), 25 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index f15e1c5b..c56c5c6d 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -237,9 +237,13 @@ npro.set_host_device_count(jax.local_device_count()) hosp_model.run( num_samples=1000, num_warmup=1000, - data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), + data_observed_hosp_admissions=dat["daily_hosp_admits"] + .to_numpy() + .astype(float), rng_key=jax.random.PRNGKey(54), - mcmc_args=dict(progress_bar=False, num_chains=2, chain_method="sequential"), + mcmc_args=dict( + progress_bar=False, num_chains=2, chain_method="sequential" + ), ) ``` @@ -310,17 +314,10 @@ We can use the padding argument to solve the overestimation of hospital admissio # | label: model-fit-padding days_to_impute = 21 -# Add 21 Nas to the beginning of dat_w_padding -dat_w_padding = np.pad( - dat["daily_hosp_admits"].to_numpy().astype(float), - (days_to_impute, 0), - constant_values=np.nan, -) - hosp_model.run( num_samples=1000, num_warmup=1000, - data_observed_hosp_admissions=dat_w_padding, + data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False, num_chains=2), padding=days_to_impute, # Padding the model @@ -336,7 +333,9 @@ out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", obs_signal=np.pad( - dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan + dat["daily_hosp_admits"].to_numpy().astype(float), + (gen_int_array.size + days_to_impute, 0), + constant_values=np.nan, ), ) ``` @@ -440,7 +439,9 @@ az.plot_hdi( ) # Add mean of the posterior to the figure -mean_latent_infection = np.mean(idata.posterior["all_latent_infections"], axis=1) +mean_latent_infection = np.mean( + idata.posterior["all_latent_infections"], axis=1 +) axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") axes.legend() axes.set_title("Posterior Latent Infections", fontsize=10) @@ -520,7 +521,7 @@ Running the model (with the same padding as before): hosp_model_weekday.run( num_samples=2000, num_warmup=2000, - data_observed_hosp_admissions=dat_w_padding, + data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), padding=days_to_impute, @@ -535,21 +536,28 @@ And plotting the results: out = hosp_model_weekday.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", - obs_signal=np.pad(dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan), + obs_signal=np.pad( + dat["daily_hosp_admits"].to_numpy().astype(float), + (gen_int_array.size + days_to_impute, 0), + constant_values=np.nan, + ), ) ``` We will use ArviZ to visualize the posterior/prior predictive distributions. +By increasing `n_timepoints_to_simulate`, we can perform forecasting. ```{python} # | label: posterior-predictive-distribution idata_weekday = az.from_numpyro( hosp_model_weekday.mcmc, posterior_predictive=hosp_model_weekday.posterior_predictive( - n_timepoints_to_simulate=len(dat_w_padding) + n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()) + 28, + padding=days_to_impute, ), prior=hosp_model_weekday.prior_predictive( - n_timepoints_to_simulate=len(dat_w_padding), + n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()), + padding=days_to_impute, numpyro_predictive_args={"num_samples": 1000}, ), ) @@ -569,7 +577,9 @@ def compute_eti(dataset, eti_prob): fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata_weekday.prior_predictive["negbinom_rv_dim_0"], + idata_weekday.prior_predictive["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.9), color="C0", smooth=False, @@ -578,7 +588,9 @@ az.plot_hdi( ) az.plot_hdi( - idata_weekday.prior_predictive["negbinom_rv_dim_0"], + idata_weekday.prior_predictive["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.5), color="C0", smooth=False, @@ -587,7 +599,9 @@ az.plot_hdi( ) plt.scatter( - idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute, + idata_weekday.observed_data["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), idata_weekday.observed_data["negbinom_rv"], color="black", ) @@ -605,8 +619,12 @@ And now we plot the posterior predictive distributions: # | fig-cap: Posterior Predictive Infections fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"], - hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.9), + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), + hdi_data=compute_eti( + idata_weekday.posterior_predictive["negbinom_rv"], 0.9 + ), color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, @@ -614,8 +632,12 @@ az.plot_hdi( ) az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"], - hdi_data=compute_eti(idata_weekday.posterior_predictive["negbinom_rv"], 0.5), + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), + hdi_data=compute_eti( + idata_weekday.posterior_predictive["negbinom_rv"], 0.5 + ), color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, @@ -628,13 +650,17 @@ mean_latent_infection = np.mean( ) plt.plot( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"], + idata_weekday.posterior_predictive["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), mean_latent_infection[0], color="C0", label="Mean", ) plt.scatter( - idata_weekday.observed_data["negbinom_rv_dim_0"] + days_to_impute, + idata_weekday.observed_data["negbinom_rv_dim_0"] + + days_to_impute + + gen_int.size(), idata_weekday.observed_data["negbinom_rv"], color="black", ) From 4cf685a629462700d08d26b283d17d4431dc419f Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 17:30:30 -0500 Subject: [PATCH 33/47] remove print debug statements and unused variables --- model/src/pyrenew/model/admissionsmodel.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index d8ff8763..f3ef32e7 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -184,8 +184,6 @@ def sample( else: n_datapoints = n_timepoints_to_simulate - n_timepoints = n_datapoints + padding - # Getting the initial quantities from the basic model basic_model = self.basic_renewal.sample( n_timepoints_to_simulate=n_datapoints, @@ -203,15 +201,6 @@ def sample( latent_infections=basic_model.latent_infections, **kwargs, ) - i0_size = len(latent_hosp_admissions) - n_timepoints - print(f"len(latent_hosp_admissions): {len(latent_hosp_admissions)}") - print(f"n_timepoints: {n_timepoints}") - print(f"padding: {padding}") - print(f"i0_size: {i0_size}") - if data_observed_hosp_admissions is not None: - print( - f"len(data_observed_hosp_admissions): {len(data_observed_hosp_admissions)}" - ) if self.hosp_admission_obs_process_rv is None: observed_hosp_admissions = None From 45d9950296c3dc6398fe3a0ecb4f7cd2dafc0f29 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 17:49:36 -0500 Subject: [PATCH 34/47] fix test --- model/src/test/test_model_hospitalizations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index 69c5d2d8..4e48d163 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -540,13 +540,13 @@ def test_model_hosp_with_obs_model_weekday_phosp(): weekday = weekday / weekday.sum() weekday = jnp.tile(weekday, 10) # weekday = weekday[:n_obs_to_generate] - weekday = weekday[:34] + weekday = weekday[:39] weekday = DeterministicVariable(weekday, name="weekday") hosp_report_prob_dist = jnp.array([0.9, 0.8, 0.7, 0.7, 0.6, 0.4]) hosp_report_prob_dist = jnp.tile(hosp_report_prob_dist, 10) - hosp_report_prob_dist = hosp_report_prob_dist[:34] + hosp_report_prob_dist = hosp_report_prob_dist[:39] hosp_report_prob_dist = hosp_report_prob_dist / hosp_report_prob_dist.sum() hosp_report_prob_dist = DeterministicVariable( From b3bff64170c510ea5b66d96c83be953b78aa5c97 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 17:53:36 -0500 Subject: [PATCH 35/47] rename for clarity --- model/src/test/test_model_hospitalizations.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index 4e48d163..dc2b090a 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -493,6 +493,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): jnp.array([0.25, 0.25, 0.25, 0.25]), name="gen_int" ) n_obs_to_generate = 30 + pad_size = 5 I0 = InfectionSeedingProcess( "I0_seeding", @@ -536,17 +537,17 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) # Other random components + total_length = n_obs_to_generate + pad_size + gen_int.size() weekday = jnp.array([1, 1, 1, 1, 2, 2]) weekday = weekday / weekday.sum() weekday = jnp.tile(weekday, 10) - # weekday = weekday[:n_obs_to_generate] - weekday = weekday[:39] + weekday = weekday[:total_length] weekday = DeterministicVariable(weekday, name="weekday") hosp_report_prob_dist = jnp.array([0.9, 0.8, 0.7, 0.7, 0.6, 0.4]) hosp_report_prob_dist = jnp.tile(hosp_report_prob_dist, 10) - hosp_report_prob_dist = hosp_report_prob_dist[:39] + hosp_report_prob_dist = hosp_report_prob_dist[:total_length] hosp_report_prob_dist = hosp_report_prob_dist / hosp_report_prob_dist.sum() hosp_report_prob_dist = DeterministicVariable( @@ -572,7 +573,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ) # Sampling and fitting model 0 (with no obs for infections) - pad_size = 5 + np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model1.sample( From f710e8d0295980cf5bcb0d13953c785a48991cfc Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 21:38:50 -0500 Subject: [PATCH 36/47] cleanup variable names --- .../tutorials/hospital_admissions_model.qmd | 65 ++++++++++--------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index c56c5c6d..649740a4 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -81,11 +81,12 @@ Let's take a look at the daily prevalence of hospital admissions. # | fig-cap: Daily hospital admissions from the simulated data import matplotlib.pyplot as plt +daily_hosp_admits = dat["daily_hosp_admits"].to_numpy() # Rotating the x-axis labels, and only showing ~10 labels ax = plt.gca() ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=10)) ax.xaxis.set_tick_params(rotation=45) -plt.plot(dat["date"].to_numpy(), dat["daily_hosp_admits"].to_numpy()) +plt.plot(dat["date"].to_numpy(), daily_hosp_admits) plt.xlabel("Date") plt.ylabel("Admissions") plt.show() @@ -147,7 +148,10 @@ The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to ho ```{python} # | label: initializing-rest-of-model from pyrenew import model, process, observation, metaclass, transformation -from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsExponentialGrowth +from pyrenew.latent import ( + InfectionSeedingProcess, + SeedInfectionsExponentialGrowth, +) # Infection process latent_inf = latent.Infections() @@ -237,9 +241,7 @@ npro.set_host_device_count(jax.local_device_count()) hosp_model.run( num_samples=1000, num_warmup=1000, - data_observed_hosp_admissions=dat["daily_hosp_admits"] - .to_numpy() - .astype(float), + data_observed_hosp_admissions=daily_hosp_admits, rng_key=jax.random.PRNGKey(54), mcmc_args=dict( progress_bar=False, num_chains=2, chain_method="sequential" @@ -258,7 +260,7 @@ out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", obs_signal=np.pad( - dat["daily_hosp_admits"].to_numpy().astype(float), + daily_hosp_admits.astype(float), (gen_int_array.size, 0), constant_values=np.nan, ), @@ -274,10 +276,10 @@ import arviz as az idata = az.from_numpyro( hosp_model.mcmc, posterior_predictive=hosp_model.posterior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"]) + n_timepoints_to_simulate=len(daily_hosp_admits) ), prior=hosp_model.prior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"]), + n_timepoints_to_simulate=len(daily_hosp_admits), numpyro_predictive_args={"num_samples": 1000}, ), ) @@ -312,15 +314,15 @@ We can use the padding argument to solve the overestimation of hospital admissio ```{python} # | label: model-fit-padding -days_to_impute = 21 +pad_size = 21 hosp_model.run( num_samples=1000, num_warmup=1000, - data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), + data_observed_hosp_admissions=daily_hosp_admits, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False, num_chains=2), - padding=days_to_impute, # Padding the model + padding=pad_size, # Padding the model ) ``` @@ -333,8 +335,8 @@ out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", obs_signal=np.pad( - dat["daily_hosp_admits"].to_numpy().astype(float), - (gen_int_array.size + days_to_impute, 0), + daily_hosp_admits.astype(float), + (gen_int_array.size + pad_size, 0), constant_values=np.nan, ), ) @@ -406,7 +408,9 @@ We can look at individual draws from the posterior distribution of latent infect ```{python} # | label: fig-output-infections-with-padding # | fig-cap: Latent infections -out2 = hosp_model.plot_posterior(var="all_latent_infections", ylab="Latent Infections") +out2 = hosp_model.plot_posterior( + var="all_latent_infections", ylab="Latent Infections" +) ``` We can also look at credible intervals for the posterior distribution of latent infections: @@ -521,10 +525,10 @@ Running the model (with the same padding as before): hosp_model_weekday.run( num_samples=2000, num_warmup=2000, - data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), + data_observed_hosp_admissions=daily_hosp_admits, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), - padding=days_to_impute, + padding=pad_size, ) ``` @@ -537,8 +541,8 @@ out = hosp_model_weekday.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", obs_signal=np.pad( - dat["daily_hosp_admits"].to_numpy().astype(float), - (gen_int_array.size + days_to_impute, 0), + daily_hosp_admits.astype(float), + (gen_int_array.size + pad_size, 0), constant_values=np.nan, ), ) @@ -549,15 +553,16 @@ By increasing `n_timepoints_to_simulate`, we can perform forecasting. ```{python} # | label: posterior-predictive-distribution +n_forecast_points = 28 idata_weekday = az.from_numpyro( hosp_model_weekday.mcmc, posterior_predictive=hosp_model_weekday.posterior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()) + 28, - padding=days_to_impute, + n_timepoints_to_simulate=len(daily_hosp_admits) + n_forecast_points, + padding=pad_size, ), prior=hosp_model_weekday.prior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()), - padding=days_to_impute, + n_timepoints_to_simulate=len(daily_hosp_admits), + padding=pad_size, numpyro_predictive_args={"num_samples": 1000}, ), ) @@ -577,9 +582,7 @@ def compute_eti(dataset, eti_prob): fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata_weekday.prior_predictive["negbinom_rv_dim_0"] - + days_to_impute - + gen_int.size(), + idata_weekday.prior_predictive["negbinom_rv_dim_0"] + +gen_int.size(), hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.9), color="C0", smooth=False, @@ -589,7 +592,7 @@ az.plot_hdi( az.plot_hdi( idata_weekday.prior_predictive["negbinom_rv_dim_0"] - + days_to_impute + + pad_size + gen_int.size(), hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.5), color="C0", @@ -600,7 +603,7 @@ az.plot_hdi( plt.scatter( idata_weekday.observed_data["negbinom_rv_dim_0"] - + days_to_impute + + pad_size + gen_int.size(), idata_weekday.observed_data["negbinom_rv"], color="black", @@ -620,7 +623,7 @@ And now we plot the posterior predictive distributions: fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( idata_weekday.posterior_predictive["negbinom_rv_dim_0"] - + days_to_impute + + pad_size + gen_int.size(), hdi_data=compute_eti( idata_weekday.posterior_predictive["negbinom_rv"], 0.9 @@ -633,7 +636,7 @@ az.plot_hdi( az.plot_hdi( idata_weekday.posterior_predictive["negbinom_rv_dim_0"] - + days_to_impute + + pad_size + gen_int.size(), hdi_data=compute_eti( idata_weekday.posterior_predictive["negbinom_rv"], 0.5 @@ -651,7 +654,7 @@ mean_latent_infection = np.mean( plt.plot( idata_weekday.posterior_predictive["negbinom_rv_dim_0"] - + days_to_impute + + pad_size + gen_int.size(), mean_latent_infection[0], color="C0", @@ -659,7 +662,7 @@ plt.plot( ) plt.scatter( idata_weekday.observed_data["negbinom_rv_dim_0"] - + days_to_impute + + pad_size + gen_int.size(), idata_weekday.observed_data["negbinom_rv"], color="black", From 5dff94b69ee8a5ddd83c1b9a892a57ba5d9cc8b2 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 21:42:02 -0500 Subject: [PATCH 37/47] fix figure --- docs/source/tutorials/hospital_admissions_model.qmd | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 649740a4..85aa5cb7 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -582,7 +582,9 @@ def compute_eti(dataset, eti_prob): fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata_weekday.prior_predictive["negbinom_rv_dim_0"] + +gen_int.size(), + idata_weekday.prior_predictive["negbinom_rv_dim_0"] + + pad_size + + gen_int.size(), hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.9), color="C0", smooth=False, From b0df76e272216503cbfb36ac88971cc49a4e5ea8 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 10 Jul 2024 21:45:45 -0500 Subject: [PATCH 38/47] delete example_with_datasets --- .../tutorials/example_with_datasets.qmd | 672 ------------------ 1 file changed, 672 deletions(-) delete mode 100644 docs/source/tutorials/example_with_datasets.qmd diff --git a/docs/source/tutorials/example_with_datasets.qmd b/docs/source/tutorials/example_with_datasets.qmd deleted file mode 100644 index c56c5c6d..00000000 --- a/docs/source/tutorials/example_with_datasets.qmd +++ /dev/null @@ -1,672 +0,0 @@ ---- -title: Fitting a hospital admissions-only model -format: gfm -engine: jupyter ---- - -This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data. - -## Model definition - -In this section, we provide the formal definition of the model. The hospitalization model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions: - -$$ -h(t) \sim \text{HospDist}\left(H(t)\right) -$$ - -Where $h(t)$ is the observed number of hospital admissions at time $t$, and $H(t)$ is the number of latent hospital admissions at time $t$. The distribution $\text{HospDist}$ is discrete. For this example, we will use a negative binomial distribution: - -$$ -\begin{align*} -h(t) & \sim \text{NegativeBinomial}\left(\text{concentration} = 1, \text{mean} = H(t)\right) \\ -H(t) & = \omega(t) p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau) -\end{align*} -$$ - -Were $d(\tau)$ is the infection to hospitalization interval, $I(t)$ is the number of latent infections at time $t$, $p_\mathrm{hosp}(t)$ is the infection to hospitalization rate, and $\omega(t)$ is the day-of-the-week effect at time $t$; the last section provides an example building such a `RandomVariable`. - -The number of latent hospital admissions at time $t$ is a function of the number of latent infections at time $t$ and the infection to hospitalization rate. The latent infections are modeled as a renewal process: - -$$ -\begin{align*} -I(t) &= R(t) \times \sum_{\tau < t} I(\tau) g(t - \tau) \\ -I(0) &\sim \text{LogNormal}(\mu = \log(80/0.05), \sigma = 1.5) -\end{align*} -$$ - -The reproductive number $R(t)$ is modeled as a random walk process: - -$$ -\begin{align*} -R(t) & = R(t-1) + \epsilon\\ -\log{\epsilon} & \sim \text{Normal}(\mu=0, \sigma=0.1) \\ -R(0) &\sim \text{TruncatedNormal}(\text{loc}=1.2, \text{scale}=0.2, \text{min}=0) -\end{align*} -$$ - - -## Data processing - -We start by loading the data and inspecting the first five rows. - -```{python} -# | label: data-inspect -import polars as pl -from pyrenew import datasets - -dat = datasets.load_wastewater() -dat.head(5) -``` - -The data shows one entry per site, but the way it was simulated, the number of admissions is the same across sites. Thus, we will only keep the first observation per day. - -```{python} -# | label: aggregation -# Keeping the first observation of each date -dat = dat.group_by("date").first().select(["date", "daily_hosp_admits"]) - -# Now, sorting by date -dat = dat.sort("date") - -# Keeping the first 90 days -dat = dat.head(90) - -dat.head(5) -``` - -Let's take a look at the daily prevalence of hospital admissions. - -```{python} -# | label: fig-plot-hospital-admissions -# | fig-cap: Daily hospital admissions from the simulated data -import matplotlib.pyplot as plt - -# Rotating the x-axis labels, and only showing ~10 labels -ax = plt.gca() -ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=10)) -ax.xaxis.set_tick_params(rotation=45) -plt.plot(dat["date"].to_numpy(), dat["daily_hosp_admits"].to_numpy()) -plt.xlabel("Date") -plt.ylabel("Admissions") -plt.show() -``` - -## Building the model - -First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospitalization interval. - -```{python} -# | label: fig-data-extract -# | fig-cap: Generation interval and infection to hospitalization interval -gen_int = datasets.load_generation_interval() -inf_hosp_int = datasets.load_infection_admission_interval() - -# We only need the probability_mass column of each dataset -gen_int_array = gen_int["probability_mass"].to_numpy() -gen_int = gen_int_array -inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy() - -# Taking a pick at the first 5 elements of each -gen_int[:5], inf_hosp_int[:5] - -# Visualizing both quantities side by side -fig, axs = plt.subplots(1, 2) - -axs[0].plot(gen_int) -axs[0].set_title("Generation interval") -axs[1].plot(inf_hosp_int) -axs[1].set_title("Infection to hospitalization interval") -plt.show() -``` - -With these two in hand, we can start building the model. First, we will define the latent hospital admissions: - -```{python} -# | label: latent-hosp -from pyrenew import latent, deterministic, metaclass -import jax.numpy as jnp -import numpyro.distributions as dist - -inf_hosp_int = deterministic.DeterministicPMF( - inf_hosp_int, name="inf_hosp_int" -) - -hosp_rate = metaclass.DistributionalRV( - dist=dist.LogNormal(jnp.log(0.05), 0.1), - name="IHR", -) - -latent_hosp = latent.HospitalAdmissions( - infection_to_admission_interval_rv=inf_hosp_int, - infect_hosp_rate_rv=hosp_rate, -) -``` - -The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospitalization interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospitalization rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospitalization rate. Now, we can define the rest of the other components: - -```{python} -# | label: initializing-rest-of-model -from pyrenew import model, process, observation, metaclass, transformation -from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsExponentialGrowth - -# Infection process -latent_inf = latent.Infections() -I0 = InfectionSeedingProcess( - "I0_seeding", - metaclass.DistributionalRV( - dist=dist.LogNormal(loc=jnp.log(100), scale=0.5), name="I0" - ), - SeedInfectionsExponentialGrowth( - gen_int_array.size, - deterministic.DeterministicVariable(0.5, name="rate"), - ), - t_unit=1, -) - -# Generation interval and Rt -gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int") -rtproc = process.RtRandomWalkProcess( - Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), - Rt_transform=transformation.ExpTransform().inv, - Rt_rw_dist=dist.Normal(0, 0.025), -) - -# The observation model -obs = observation.NegativeBinomialObservation(concentration_prior=1.0) -``` - -Notice all the components are `RandomVariable` instances. We can now build the model: - -```{python} -# | label: init-model -hosp_model = model.HospitalAdmissionsModel( - latent_infections_rv=latent_inf, - latent_hosp_admissions_rv=latent_hosp, - I0_rv=I0, - gen_int_rv=gen_int, - Rt_process_rv=rtproc, - hosp_admission_obs_process_rv=obs, -) -``` - -Let's simulate to check if the model is working: - -```{python} -# | label: simulation -import numpyro as npro -import numpy as np - -timeframe = 120 - -np.random.seed(223) -with npro.handlers.seed(rng_seed=np.random.randint(1, timeframe)): - sim_data = hosp_model.sample(n_timepoints_to_simulate=timeframe) -``` - -```{python} -# | label: fig-basic -# | fig-cap: Rt and Infections -import matplotlib.pyplot as plt - -fig, axs = plt.subplots(1, 2) - -# Rt plot -axs[0].plot(sim_data.Rt) -axs[0].set_ylabel("Rt") - -# Infections plot -axs[1].plot(sim_data.observed_hosp_admissions) -axs[1].set_ylabel("Infections") -axs[1].set_yscale("log") - -fig.suptitle("Basic renewal model") -fig.supxlabel("Time") -plt.tight_layout() -plt.show() -``` - -## Fitting the model - -We can fit the model to the data. We will use the `run` method of the model object: - -```{python} -# | label: model-fit -import jax - -npro.set_host_device_count(jax.local_device_count()) -hosp_model.run( - num_samples=1000, - num_warmup=1000, - data_observed_hosp_admissions=dat["daily_hosp_admits"] - .to_numpy() - .astype(float), - rng_key=jax.random.PRNGKey(54), - mcmc_args=dict( - progress_bar=False, num_chains=2, chain_method="sequential" - ), -) -``` - -We can use the `plot_posterior` method to visualize the results[^capture]: - -[^capture]: The output is captured to avoid `quarto` from displaying the output twice. - -```{python} -# | label: fig-output-hospital-admissions -# | fig-cap: Hospital Admissions posterior distribution -out = hosp_model.plot_posterior( - var="latent_hospital_admissions", - ylab="Hospital Admissions", - obs_signal=np.pad( - dat["daily_hosp_admits"].to_numpy().astype(float), - (gen_int_array.size, 0), - constant_values=np.nan, - ), -) -``` - -Test `posterior_predictive` and `prior_predictive` methods to generate posterior and prior predictive samples. - -```{python} -# | label: demonstrate-use-of-predictive-methods -import arviz as az - -idata = az.from_numpyro( - hosp_model.mcmc, - posterior_predictive=hosp_model.posterior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"]) - ), - prior=hosp_model.prior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"]), - numpyro_predictive_args={"num_samples": 1000}, - ), -) -``` - -We will use `plot_lm` method from ArviZ to plot the posterior predictive distribution and observed data below: - -```{python} -# | label: fig-posterior-predictive -# | fig-cap: Hospital Admissions posterior distribution with plot_lm -fig, ax = plt.subplots() -az.plot_lm( - "negbinom_rv", - idata=idata, - kind_pp="hdi", - y_kwargs={"color": "black"}, - y_hat_fill_kwargs={"color": "C0"}, - axes=ax, -) - -ax.set_title("Posterior Predictive Plot") -ax.set_ylabel("Hospital Admissions") -ax.set_xlabel("Days") -plt.show() -``` - -The first half of the model is not looking good. The reason is that the infection to hospitalization interval PMF makes it unlikely to observe admissions from the beginning. The following section shows how to fix this. - -## Padding the model - -We can use the padding argument to solve the overestimation of hospital admissions in the first half of the model. By setting `padding > 0`, the model then assumes that the first `padding` observations are missing; thus, only observations after `padding` will count towards the likelihood of the model. In practice, the model will extend the estimated Rt and latent infections by `padding` days, given time to adjust to the observed data. The following code will add 21 days of missing data at the beginning of the model and re-estimate it with `padding = 21`: - -```{python} -# | label: model-fit-padding -days_to_impute = 21 - -hosp_model.run( - num_samples=1000, - num_warmup=1000, - data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), - rng_key=jax.random.PRNGKey(54), - mcmc_args=dict(progress_bar=False, num_chains=2), - padding=days_to_impute, # Padding the model -) -``` - -And plotting the results: - -```{python} -# | label: fig-output-admissions-with-padding -# | fig-cap: Hospital Admissions -out = hosp_model.plot_posterior( - var="latent_hospital_admissions", - ylab="Hospital Admissions", - obs_signal=np.pad( - dat["daily_hosp_admits"].to_numpy().astype(float), - (gen_int_array.size + days_to_impute, 0), - constant_values=np.nan, - ), -) -``` - -We can use [ArviZ](https://www.arviz.org/) to visualize the results. Let's start by converting the fitted model to Arviz InferenceData object: - -```{python} -# | label: convert-inferenceData -idata = az.from_numpyro(hosp_model.mcmc) -``` - -We obtain the summary of model diagnostics and print the diagnostics for `latent_hospital_admissions[1]` - -```{python} -# | label: diagnostics -# | warning: false -diagnostic_stats_summary = az.summary( - idata.posterior, - kind="diagnostics", -) - -print(diagnostic_stats_summary.loc["latent_hospital_admissions[1]"]) -``` - -Below we plot 90% and 50% highest density intervals for latent hospital admissions using [plot_hdi](https://python.arviz.org/en/stable/api/generated/arviz.plot_hdi.html): - -```{python} -# | label: fig-output-admission-distribution -# | fig-cap: Hospital Admissions posterior distribution -x_data = idata.posterior["latent_hospital_admissions_dim_0"] -y_data = idata.posterior["latent_hospital_admissions"] - -fig, axes = plt.subplots(figsize=(6, 5)) -az.plot_hdi( - x_data, - y_data, - hdi_prob=0.9, - color="C0", - smooth=False, - fill_kwargs={"alpha": 0.3}, - ax=axes, -) - -az.plot_hdi( - x_data, - y_data, - hdi_prob=0.5, - color="C0", - smooth=False, - fill_kwargs={"alpha": 0.6}, - ax=axes, -) - -# Add mean of the posterior to the figure -mean_latent_hosp_admission = np.mean( - idata.posterior["latent_hospital_admissions"], axis=1 -) -axes.plot(x_data, mean_latent_hosp_admission[0], color="C0", label="Mean") -axes.legend() -axes.set_title("Posterior Hospital Admissions", fontsize=10) -axes.set_xlabel("Time", fontsize=10) -axes.set_ylabel("Hospital Admissions", fontsize=10) -plt.show() -``` - -We can look at individual draws from the posterior distribution of latent infections: - -```{python} -# | label: fig-output-infections-with-padding -# | fig-cap: Latent infections -out2 = hosp_model.plot_posterior(var="all_latent_infections", ylab="Latent Infections") -``` - -We can also look at credible intervals for the posterior distribution of latent infections: - -```{python} -# | label: fig-output-infections-distribution -# | fig-cap: Posterior Latent Infections -x_data = idata.posterior["all_latent_infections_dim_0"] -y_data = idata.posterior["all_latent_infections"] - -fig, axes = plt.subplots(figsize=(6, 5)) -az.plot_hdi( - x_data, - y_data, - hdi_prob=0.9, - color="C0", - smooth=False, - fill_kwargs={"alpha": 0.3}, - ax=axes, -) - -az.plot_hdi( - x_data, - y_data, - hdi_prob=0.5, - color="C0", - smooth=False, - fill_kwargs={"alpha": 0.6}, - ax=axes, -) - -# Add mean of the posterior to the figure -mean_latent_infection = np.mean( - idata.posterior["all_latent_infections"], axis=1 -) -axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") -axes.legend() -axes.set_title("Posterior Latent Infections", fontsize=10) -axes.set_xlabel("Time", fontsize=10) -axes.set_ylabel("Latent Infections", fontsize=10) -plt.show() -``` - -## Round 2: Incorporating day-of-the-week effects - -We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect distribution. To do this, we will create a new instance of `RandomVariable` to model the effect. The class will be based on a truncated normal distribution with a mean of 1.0 and a standard deviation of 0.5. The distribution will be truncated between 0.1 and 10.0. The random variable will be repeated for the number of weeks in the dataset. -Note a similar weekday effect is implemented in its own module, with example code [here](periodic_effects.html). - -```{python} -# | label: weekly-effect -from pyrenew import metaclass -import numpyro as npro - - -class DayOfWeekEffect(metaclass.RandomVariable): - """Day of the week effect""" - - def __init__(self, len: int): - """Initialize the day of the week effect distribution - Parameters - ---------- - len : int - The number of observations - """ - self.nweeks = int(jnp.ceil(len / 7)) - self.len = len - - @staticmethod - def validate(): - return None - - def sample(self, **kwargs): - ans = npro.sample( - name="dayofweek_effect", - fn=npro.distributions.TruncatedNormal( - loc=1.0, scale=0.5, low=0.1, high=10.0 - ), - sample_shape=(7,), - ) - - return jnp.tile(ans, self.nweeks)[: self.len] - - -# Initializing the RV -dayofweek_effect = DayOfWeekEffect(dat.shape[0]) -``` - -Notice that the instance's `nweeks` and `len` members are passed during construction. Trying to compute the number of weeks and the length of the dataset in the `validate` method will raise a `jit` error in `jax` as the shape and size of elements are not known during the validation step, which happens before the model is run. With the new effect, we can rebuild the latent hospitalization model: - -```{python} -# | label: latent-hosp-weekday -latent_hosp_wday_effect = latent.HospitalAdmissions( - infection_to_admission_interval_rv=inf_hosp_int, - infect_hosp_rate_rv=hosp_rate, - day_of_week_effect_rv=dayofweek_effect, -) - -hosp_model_weekday = model.HospitalAdmissionsModel( - latent_infections_rv=latent_inf, - latent_hosp_admissions_rv=latent_hosp_wday_effect, - I0_rv=I0, - gen_int_rv=gen_int, - Rt_process_rv=rtproc, - hosp_admission_obs_process_rv=obs, -) -``` - -Running the model (with the same padding as before): - -```{python} -# | label: model-2-run -hosp_model_weekday.run( - num_samples=2000, - num_warmup=2000, - data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), - rng_key=jax.random.PRNGKey(54), - mcmc_args=dict(progress_bar=False), - padding=days_to_impute, -) -``` - -And plotting the results: - -```{python} -# | label: fig-output-admissions-padding-and-weekday -# | fig-cap: Hospital Admissions posterior distribution -out = hosp_model_weekday.plot_posterior( - var="latent_hospital_admissions", - ylab="Hospital Admissions", - obs_signal=np.pad( - dat["daily_hosp_admits"].to_numpy().astype(float), - (gen_int_array.size + days_to_impute, 0), - constant_values=np.nan, - ), -) -``` - -We will use ArviZ to visualize the posterior/prior predictive distributions. -By increasing `n_timepoints_to_simulate`, we can perform forecasting. - -```{python} -# | label: posterior-predictive-distribution -idata_weekday = az.from_numpyro( - hosp_model_weekday.mcmc, - posterior_predictive=hosp_model_weekday.posterior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()) + 28, - padding=days_to_impute, - ), - prior=hosp_model_weekday.prior_predictive( - n_timepoints_to_simulate=len(dat["daily_hosp_admits"].to_numpy()), - padding=days_to_impute, - numpyro_predictive_args={"num_samples": 1000}, - ), -) -``` - -Below we plot the prior predictive distributions using equal tailed bayesian credible intervals: - -```{python} -# | label: fig-output-prior-predictive -# | fig-cap: Prior Predictive Infections -def compute_eti(dataset, eti_prob): - eti_bdry = dataset.quantile( - ((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw") - ) - return eti_bdry.values.T - - -fig, axes = plt.subplots(figsize=(6, 5)) -az.plot_hdi( - idata_weekday.prior_predictive["negbinom_rv_dim_0"] - + days_to_impute - + gen_int.size(), - hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.9), - color="C0", - smooth=False, - fill_kwargs={"alpha": 0.3}, - ax=axes, -) - -az.plot_hdi( - idata_weekday.prior_predictive["negbinom_rv_dim_0"] - + days_to_impute - + gen_int.size(), - hdi_data=compute_eti(idata_weekday.prior_predictive["negbinom_rv"], 0.5), - color="C0", - smooth=False, - fill_kwargs={"alpha": 0.6}, - ax=axes, -) - -plt.scatter( - idata_weekday.observed_data["negbinom_rv_dim_0"] - + days_to_impute - + gen_int.size(), - idata_weekday.observed_data["negbinom_rv"], - color="black", -) - -axes.set_title("Prior Predictive Infections", fontsize=10) -axes.set_xlabel("Time", fontsize=10) -axes.set_ylabel("Observed Infections", fontsize=10) -plt.yscale("log") -plt.show() -``` - -And now we plot the posterior predictive distributions: -```{python} -# | label: fig-output-posterior-predictive -# | fig-cap: Posterior Predictive Infections -fig, axes = plt.subplots(figsize=(6, 5)) -az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"] - + days_to_impute - + gen_int.size(), - hdi_data=compute_eti( - idata_weekday.posterior_predictive["negbinom_rv"], 0.9 - ), - color="C0", - smooth=False, - fill_kwargs={"alpha": 0.3}, - ax=axes, -) - -az.plot_hdi( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"] - + days_to_impute - + gen_int.size(), - hdi_data=compute_eti( - idata_weekday.posterior_predictive["negbinom_rv"], 0.5 - ), - color="C0", - smooth=False, - fill_kwargs={"alpha": 0.6}, - ax=axes, -) - -# Add mean of the posterior to the figure -mean_latent_infection = np.mean( - idata_weekday.posterior_predictive["negbinom_rv"], axis=1 -) - -plt.plot( - idata_weekday.posterior_predictive["negbinom_rv_dim_0"] - + days_to_impute - + gen_int.size(), - mean_latent_infection[0], - color="C0", - label="Mean", -) -plt.scatter( - idata_weekday.observed_data["negbinom_rv_dim_0"] - + days_to_impute - + gen_int.size(), - idata_weekday.observed_data["negbinom_rv"], - color="black", -) -axes.legend() -axes.set_title("Posterior Predictive Infections", fontsize=10) -axes.set_xlabel("Time", fontsize=10) -axes.set_ylabel("Observed Infections", fontsize=10) -plt.show() -``` From c943309c35e692293f520417e32bd7d999397051 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Thu, 11 Jul 2024 09:59:25 -0500 Subject: [PATCH 39/47] remove unused code --- model/src/pyrenew/model/admissionsmodel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index f3ef32e7..caf5bbdb 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -201,8 +201,6 @@ def sample( latent_infections=basic_model.latent_infections, **kwargs, ) - if self.hosp_admission_obs_process_rv is None: - observed_hosp_admissions = None ( observed_hosp_admissions, From bddc731ac171e679065e0ca01d9f63f1bb487c9e Mon Sep 17 00:00:00 2001 From: damonbayer Date: Thu, 11 Jul 2024 10:30:29 -0500 Subject: [PATCH 40/47] rename init in test_random_walk --- model/src/test/test_random_walk.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index 9f8c6839..bd56d910 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -35,9 +35,9 @@ def test_rw_samples_correctly_distributed(): [0, 2.253, -3.2521, 1052, 1e-6], [1, 0.025, 3, 1, 0.02] ): rw_normal = SimpleRandomWalkProcess(dist.Normal(step_mean, step_sd)) - init = 532.0 + rw_init = 532.0 with numpyro.handlers.seed(rng_seed=62): - samples, *_ = rw_normal.sample(n_samples, init=init) + samples, *_ = rw_normal.sample(n_samples, init=rw_init) # Checking the shape assert samples.shape == (n_samples,) @@ -60,4 +60,4 @@ def test_rw_samples_correctly_distributed(): assert jnp.abs(jnp.log(jnp.std(diffs) / step_sd)) < jnp.log(1.1) # first value should be the init value - assert_almost_equal(samples[0], init) + assert_almost_equal(samples[0], rw_init) From a84ca20e4e1183033df3c65e0c3b061ad2f410e7 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Thu, 11 Jul 2024 15:33:52 -0500 Subject: [PATCH 41/47] clarify forecasting --- docs/source/tutorials/hospital_admissions_model.qmd | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 85aa5cb7..44b4f8bd 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -548,8 +548,8 @@ out = hosp_model_weekday.plot_posterior( ) ``` -We will use ArviZ to visualize the posterior/prior predictive distributions. -By increasing `n_timepoints_to_simulate`, we can perform forecasting. +We will use ArviZ to visualize the posterior and prior predictive distributions. +By increasing `n_timepoints_to_simulate`, we can perform forecasting in the posterior predictive. ```{python} # | label: posterior-predictive-distribution @@ -618,7 +618,7 @@ plt.yscale("log") plt.show() ``` -And now we plot the posterior predictive distributions: +And now we plot the posterior predictive distributions with a `{python} n_forecast_points`-day-ahead forecast: ```{python} # | label: fig-output-posterior-predictive # | fig-cap: Posterior Predictive Infections From 85c62c11165167a7469a5f732aeb7266af4e3a75 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Thu, 11 Jul 2024 18:09:50 -0500 Subject: [PATCH 42/47] update numpyro and jax --- docs/pyproject.toml | 6 +++--- pyproject.toml | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/pyproject.toml b/docs/pyproject.toml index 81ffd5a2..00d54396 100644 --- a/docs/pyproject.toml +++ b/docs/pyproject.toml @@ -9,9 +9,9 @@ package-mode = false [tool.poetry.dependencies] python = "^3.12" sphinx = "^7.2.6" -jax = "0.4.29" -jaxlib = "^0.4.29" -numpyro = "^0.15.0" +jax = "^0.4.30" +jaxlib = "^0.4.30" +numpyro = "^0.15.1" sphinxcontrib-mermaid = "^0.9.2" polars = "^0.20.16" matplotlib = "^3.8.3" diff --git a/pyproject.toml b/pyproject.toml index 27ab0504..396c63ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ packages = [{include = "multisignal_epi_inference"}] [tool.poetry.dependencies] python = "^3.12" +numpyro = "^0.15.1" [tool.poetry.group.dev] optional = true From b360093db34eb9e99392c595f158187188413d26 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Thu, 11 Jul 2024 18:10:45 -0500 Subject: [PATCH 43/47] update numpyro and jax --- model/pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model/pyproject.toml b/model/pyproject.toml index f7656dd8..ad8c429e 100755 --- a/model/pyproject.toml +++ b/model/pyproject.toml @@ -11,9 +11,9 @@ exclude = [{path = "datasets/*.rds"}] [tool.poetry.dependencies] python = "^3.12" -numpyro = "^0.15.0" -jax = "0.4.29" -jaxlib = "0.4.29" +numpyro = "^0.15.1" +jax = "^0.4.30" +jaxlib = "^0.4.30" numpy = "^1.26.4" polars = "^0.20.16" pillow = "^10.3.0" # See #56 on CDCgov/multisignal-epi-inference From a7c4673778b530ef530c63810e6a38d4fb8bd183 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Thu, 11 Jul 2024 18:20:26 -0500 Subject: [PATCH 44/47] enforce minimum version of numpyro and jax --- docs/pyproject.toml | 6 +++--- model/pyproject.toml | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/pyproject.toml b/docs/pyproject.toml index 00d54396..dc67ded7 100644 --- a/docs/pyproject.toml +++ b/docs/pyproject.toml @@ -9,9 +9,9 @@ package-mode = false [tool.poetry.dependencies] python = "^3.12" sphinx = "^7.2.6" -jax = "^0.4.30" -jaxlib = "^0.4.30" -numpyro = "^0.15.1" +jax = ">=0.4.30" +jaxlib = ">=0.4.30" +numpyro = ">=0.15.1" sphinxcontrib-mermaid = "^0.9.2" polars = "^0.20.16" matplotlib = "^3.8.3" diff --git a/model/pyproject.toml b/model/pyproject.toml index ad8c429e..fb164c2e 100755 --- a/model/pyproject.toml +++ b/model/pyproject.toml @@ -11,9 +11,9 @@ exclude = [{path = "datasets/*.rds"}] [tool.poetry.dependencies] python = "^3.12" -numpyro = "^0.15.1" -jax = "^0.4.30" -jaxlib = "^0.4.30" +numpyro = ">=0.15.1" +jax = ">=0.4.30" +jaxlib = ">=0.4.30" numpy = "^1.26.4" polars = "^0.20.16" pillow = "^10.3.0" # See #56 on CDCgov/multisignal-epi-inference From f6b34ab8285c15c0db438a1572e85186506fcdda Mon Sep 17 00:00:00 2001 From: damonbayer Date: Thu, 11 Jul 2024 18:35:16 -0500 Subject: [PATCH 45/47] try to enforce numpyro version again --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 396c63ac..8227e76f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ packages = [{include = "multisignal_epi_inference"}] [tool.poetry.dependencies] python = "^3.12" -numpyro = "^0.15.1" +numpyro = ">=0.15.1" [tool.poetry.group.dev] optional = true From 3342f24f61a698693e1998257f2517476bee584a Mon Sep 17 00:00:00 2001 From: damonbayer Date: Thu, 11 Jul 2024 18:38:28 -0500 Subject: [PATCH 46/47] add test_forecast --- model/src/test/test_forecast.py | 78 +++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 model/src/test/test_forecast.py diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py new file mode 100644 index 00000000..677755c5 --- /dev/null +++ b/model/src/test/test_forecast.py @@ -0,0 +1,78 @@ +# numpydoc ignore=GL08 + +import jax.numpy as jnp +import jax.random as jr +import numpy as np +import numpyro as npro +import numpyro.distributions as dist +import pyrenew.transformation as t +from numpy.testing import assert_array_equal +from pyrenew.deterministic import DeterministicPMF +from pyrenew.latent import ( + Infections, + InfectionSeedingProcess, + SeedInfectionsZeroPad, +) +from pyrenew.metaclass import DistributionalRV +from pyrenew.model import RtInfectionsRenewalModel +from pyrenew.observation import PoissonObservation +from pyrenew.process import RtRandomWalkProcess + + +def test_forecast(): + """Check that forecasts are the right length and match the posterior up until forecast begins.""" + pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25]) + gen_int = DeterministicPMF(pmf_array, name="gen_int") + I0 = InfectionSeedingProcess( + "I0_seeding", + DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), + SeedInfectionsZeroPad(n_timepoints=gen_int.size()), + t_unit=1, + ) + latent_infections = Infections() + observed_infections = PoissonObservation() + rt = RtRandomWalkProcess( + Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), + Rt_transform=t.ExpTransform().inv, + Rt_rw_dist=dist.Normal(0, 0.025), + ) + model = RtInfectionsRenewalModel( + I0_rv=I0, + gen_int_rv=gen_int, + latent_infections_rv=latent_infections, + infection_obs_process_rv=observed_infections, + Rt_process_rv=rt, + ) + + n_timepoints_to_simulate = 30 + n_forecast_points = 10 + with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): + model_sample = model.sample( + n_timepoints_to_simulate=n_timepoints_to_simulate + ) + + model.run( + num_warmup=5, + num_samples=5, + data_observed_infections=model_sample.observed_infections, + rng_key=jr.key(54), + ) + + posterior_predictive_samples = model.posterior_predictive( + n_timepoints_to_simulate=n_timepoints_to_simulate + n_forecast_points, + ) + + # Check the length of the predictive distribution + assert ( + len(posterior_predictive_samples["poisson_rv"][0]) + == n_timepoints_to_simulate + n_forecast_points + ) + + # Check the first elements of the posterior predictive Rt are the same as the + # posterior Rt + assert_array_equal( + model.mcmc.get_samples()["Rt"][0], + posterior_predictive_samples["Rt"][0][ + : len(model.mcmc.get_samples()["Rt"][0]) + ], + ) From 45879b180da87fb6794b07f2b67503eb809e325d Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Thu, 11 Jul 2024 23:34:20 -0500 Subject: [PATCH 47/47] Update docs/source/tutorials/hospital_admissions_model.qmd Co-authored-by: Dylan H. Morris --- docs/source/tutorials/hospital_admissions_model.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 44b4f8bd..f88a884e 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -549,7 +549,7 @@ out = hosp_model_weekday.plot_posterior( ``` We will use ArviZ to visualize the posterior and prior predictive distributions. -By increasing `n_timepoints_to_simulate`, we can perform forecasting in the posterior predictive. +By increasing `n_timepoints_to_simulate`, we can perform forecasting using the posterior predictive distribution. ```{python} # | label: posterior-predictive-distribution