Skip to content

Commit

Permalink
add functions that set up model specification and mcmc options
Browse files Browse the repository at this point in the history
  • Loading branch information
kaitejohnson committed Jul 3, 2024
1 parent d9bd1d7 commit 2a94332
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 67 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export(get_count_data_sizes)
export(get_count_indices)
export(get_count_values)
export(get_ind_m)
export(get_mcmc_options)
export(get_params)
export(get_stan_data)
export(get_subpop_data)
Expand All @@ -23,6 +24,7 @@ export(indicate_ww_exclusions)
export(make_hospital_onset_delay_pmf)
export(make_incubation_period_pmf)
export(make_reporting_delay_pmf)
export(model_spec)
export(preprocess_hosp_data)
export(preprocess_ww_data)
export(simulate_double_censored_pmf)
Expand Down
136 changes: 113 additions & 23 deletions R/wwinference.R
Original file line number Diff line number Diff line change
@@ -1,29 +1,10 @@
wwinference <- function(ww_data,
count_data,
model_spec = list(
forecast_date = "2023-12-06",
calibration_time = 90,
forecast_horizon = 28,
generation_interval =
wwinference::generation_interval,
inf_to_count_delay = wwinference::inf_to_hosp,
infection_feedback_pmf =
wwinference::generation_interval,
params = get_params(
system.file("extdata", "example_params.toml",
package = "wwinference"
)
),
# Default MCMC settings
iter_warmup = 750,
iter_sampling = 500,
n_chains = 4,
seed = 123,
adapt_delta = 0.95,
max_treedepth = 12,
# Default fitting to data
compute_likelihood = 1
model_spec = get_model_spec(
forecast_date =
"2023-12-06"
),
mcmc_options = get_mcmc_options(),
compiled_model = compile_model()) {
# Check that data is compatible with specifications
check_date(ww_data, model_spec$forecast_date)
Expand Down Expand Up @@ -103,3 +84,112 @@ wwinference <- function(ww_data,

return(out)
}

#' Get MCMC options
#'
#' @description
#' This function returns a list of MCMC settings to pass to the
#' `cmdstanr::sample()` function to fit the model. The default settings are
#' specified for production-level runs, consider adjusting to optimize
#' for speed while iterating.
#'
#'
#' @param iter_warmup integer indicating the number of warm-up iterations,
#' default is `750`
#' @param iter_sampling integer indicating the number of sampling iterations,
#' default is `500`
#' @param n_chains integer indicating the number of MCMC chains to run, default
#' is `4`
#' @param seed set of integers indicating the random seed of the stan sampler,
#' default is `123`
#' @param adapt_delta float between 0 and 1 indicating the average acceptance
#' probability, default is `0.95`
#' @param max_treedepth integer indicating the maximum tree depth of the
#' sampler, default is 12
#' @param compute_likelihood integer indicating whether or not to compute the
#' likelihood using the data, default is `1` which will fit the model to the
#' data. If set to 0, the model will sample from the prior only
#'
#' @return a list of mcmc settings with the values given by the function
#' arguments
#' @export
#'
#' @examples
#' mcmc_settings <- get_mcmc_options()
get_mcmc_options <- function(
iter_warmup = 750,
iter_sampling = 500,
n_chains = 4,
seed = 123,
adapt_delta = 0.95,
max_treedepth = 12,
compute_likelihood = 1) {
mcmc_settings <- list(
iter_warmup = iter_warmup,
iter_sampling = iter_sampling,
n_chains = n_chains,
seed = seed,
adapt_delta = adapt_delta,
max_treedepth = max_treedepth,
compute_likelihood = compute_likelihood
)

return(mcmc_settings)
}

#' Get model specificaitons
#' @description
#' This function returns a nested list containing the model specifications
#' in the function arguments. All defaults are set for the case of fitting a
#' post-omicron COVID-19 model with joint inference of hospital admissions
#' and data on wastewater viral concentrations
#'
#'
#' @param forecast_date a character string in ISO8 format (YYYY-MM-DD)
#' indicating the date that the forecast is to be made. Default is
#' @param calibration_time integer indicating the number of days to calibrate
#' the model for, default is `90`
#' @param forecast_horizon integer indicating the number of days, including the
#' forecast date, to produce forecasts for, default is `28`
#' @param generation_interval vector of a simplex (must sum to 1) describing
#' the daily probability of onwards transmission, default is package data
#' provided for the COVID-19 generation interval post-Omicron
#' @param inf_to_count_delay vector of a simplex (must sum to 1) describing the
#' daily probability of transitioning from infection to whatever the count
#' variable is, e.g. hospital admissions or cases. Default corresonds to the
#' delay distribution from COVID-19 infection to hospital admission
#' @param infection_feedback_pmf vector of a simplex (must sum to 1) describing
#' the delay from incident infection to feedback in the transmission dynamics.
#' The default is the COVID-19 generation interval
#' @param params a 1 row dataframe of parameters corresponding to model
#' priors and disease/data specific parameters. Default is for COVID-19 hospital
#' admissions and viral concentrations in wastewater
#'
#' @return a list of model specs to be passed to the `get_stan_data()` function
#' @export
#'
#' @examples
#' model_spec_list <- model_spec(forecast_date = "2023-12-06")
model_spec <- function(
forecast_date,
calibration_time = 90,
forecast_horizon = 28,
generation_interval = wwinference::generation_interval,
inf_to_count_delay = wwinference::inf_to_hosp,
infection_feedback_pmf = wwinference::generation_interval,
params = get_params(
system.file("extdata", "example_params.toml",
package = "wwinference"
)
)) {
model_specs <- list(
forecast_date = forecast_date,
calibration_time = calibration_time,
forecast_horizon = forecast_horizon,
generation_interval = generation_interval,
inf_to_count_delay = inf_to_hosp,
infection_feedback_pmf = infection_feedback_pmf,
params = params
)
return(model_specs)
}
52 changes: 52 additions & 0 deletions man/get_mcmc_options.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 56 additions & 0 deletions man/model_spec.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 11 additions & 44 deletions vignettes/wwinference.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -252,27 +252,9 @@ inf_to_hosp <- wwinference::inf_to_hosp
# Assign infection feedback equal to the generation interval
infection_feedback_pmf <- generation_interval
```
We will pass these to the `model_spec()` function of the `wwinference()` model,
along with the other specified parameters above.

## Combine into a single list called `model_spec`
```{r}
model_spec <- list(
forecast_date = forecast_date,
calibration_time = calibration_time,
forecast_horizon = forecast_horizon,
generation_interval = generation_interval,
inf_to_count_delay = inf_to_hosp,
infection_feedback_pmf = infection_feedback_pmf,
params = params,
iter_warmup = 750,
iter_sampling = 500,
n_chains = 4,
seed = 123,
adapt_delta = 0.95,
max_treedepth = 12,
exclude_ww_outliers = TRUE,
compute_likelihood = 1
)
```

# Precompiling the model
As `wwinference` uses `cmdstan` to fit its models, it is necessary to first
Expand Down Expand Up @@ -301,7 +283,15 @@ fit the model.
fit <- wwinference(
ww_data_to_fit,
hosp_data_preprocessed,
model_spec,
model_spec = get_model_spec(
forecast_date = forecast_date,
calibration_time = calibration_time,
forecast_horizon = forecast_horizon,
generation_interval = generation_interval,
inf_to_count_delay = inf_to_hosp,
infection_feedback_pmf = infection_feedback_pmf
),
mcmc_options = get_mcmc_options(),
model
)
```
Expand All @@ -313,26 +303,3 @@ diagnostic information, the data used for for fitting, and the underlying
nowcasted, and forecasted expected observed hospital admissions and wastewater
concentrations, as well as the latent variables of interest including the site-
level $R(t)$ estimates and the state-level $R(t)$ estimate.

```{r}
stan_data <- get_stan_data(
input_count_data = hosp_data_preprocessed,
input_ww_data = ww_data_preprocessed,
forecast_date = forecast_date,
calibration_time = calibration_time,
forecast_horizon = forecast_horizon,
generation_interval = generation_interval,
inf_to_count_delay = inf_to_hosp,
infection_feedback_pmf = infection_feedback_pmf,
params = params
)
ww_fit_obj <- model$sample(
data = stan_data,
seed = 123,
iter_sampling = 500,
iter_warmup = 750,
max_treedepth = 12,
chains = 4,
parallel_chains = 4
)
```

0 comments on commit 2a94332

Please sign in to comment.