Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Denominator Forecasting #68

Merged
merged 21 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions notebooks/hosp_only_ww_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ We begin by loading the Stan data, converting it the correct inputs for our mode

```{python}
# | label: create model
my_hosp_only_ww_model, data_observed_hospital_admissions = (
my_hosp_only_ww_model, data_observed_disease_hospital_admissions = (
create_hosp_only_ww_model_from_stan_data(
"data/fit_hosp_only/stan_data.json"
)
Expand All @@ -50,7 +50,7 @@ We check that we can simulate from the prior predictive
n_forecast_days = 35

prior_predictive = my_hosp_only_ww_model.prior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days,
n_datapoints=len(data_observed_disease_hospital_admissions) + n_forecast_days,
numpyro_predictive_args={"num_samples": 200},
)
```
Expand All @@ -64,7 +64,7 @@ my_hosp_only_ww_model.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.key(200),
data_observed_hospital_admissions=data_observed_hospital_admissions,
data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions,
mcmc_args=dict(num_chains=4, progress_bar=False),
nuts_args=dict(find_heuristic_step_size=True),
)
Expand All @@ -75,7 +75,7 @@ Create the posterior predictive and forecast:
```{python}
# | label: posterior predictive
posterior_predictive = my_hosp_only_ww_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days
n_datapoints=len(data_observed_disease_hospital_admissions) + n_forecast_days
)
```

Expand Down
10 changes: 7 additions & 3 deletions nssp_demo/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def build_model_from_dir(model_dir):
jnp.array(model_data["generation_interval_pmf"]),
) # check if off by 1 or reversed

data_observed_hospital_admissions = jnp.array(
model_data["data_observed_hospital_admissions"]
data_observed_disease_hospital_admissions = jnp.array(
model_data["data_observed_disease_hospital_admissions"]
)
state_pop = jnp.array(model_data["state_pop"])

Expand Down Expand Up @@ -84,4 +84,8 @@ def build_model_from_dir(model_dir):
n_initialization_points=uot,
)

return my_model, data_observed_hospital_admissions, right_truncation_offset
return (
my_model,
data_observed_disease_hospital_admissions,
right_truncation_offset,
)
10 changes: 6 additions & 4 deletions nssp_demo/fit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
args = parser.parse_args()
model_dir = args.model_dir

my_model, data_observed_hospital_admissions, right_truncation_offset = (
build_model_from_dir(model_dir)
)
(
my_model,
data_observed_disease_hospital_admissions,
right_truncation_offset,
) = build_model_from_dir(model_dir)
my_model.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.key(200),
data_observed_hospital_admissions=data_observed_hospital_admissions,
data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions,
right_truncation_offset=right_truncation_offset,
mcmc_args=dict(num_chains=n_chains, progress_bar=True),
nuts_args=dict(find_heuristic_step_size=True),
Expand Down
13 changes: 8 additions & 5 deletions nssp_demo/generate_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
args = parser.parse_args()
model_dir = args.model_dir
n_forecast_points = args.n_forecast_points
my_model, data_observed_hospital_admissions, right_truncation_offset = (
build_model_from_dir(model_dir)
)
(
my_model,
data_observed_disease_hospital_admissions,
right_truncation_offset,
) = build_model_from_dir(model_dir)

my_model._init_model(1, 1)
fresh_sampler = my_model.mcmc.sampler
Expand All @@ -47,12 +49,13 @@
# "num_samples": my_model.mcmc.num_samples * my_model.mcmc.num_chains,
# "batch_ndims":1
# },
# n_datapoints=len(data_observed_hospital_admissions) + n_forecast_points,
# n_datapoints=len(data_observed_disease_hospital_admissions) + n_forecast_points,
# )
# need to figure out a way to generate these as distinct chains, so that the result of the to_datarame method is more compact

posterior_predictive = my_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + n_forecast_points
n_datapoints=len(data_observed_disease_hospital_admissions)
+ n_forecast_points
)

idata = az.from_numpyro(
Expand Down
Loading
Loading