Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

174 cmdstanr sample args #175

Merged
merged 63 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
65312a5
add all non depracted cmdstanr::sample args to mcmc options
kaitejohnson Sep 19, 2024
4ae40f3
set show messages to false in the test data mcmc options
kaitejohnson Sep 19, 2024
d65a5cc
update test data
kaitejohnson Sep 19, 2024
10711ed
add documentation for refresh
kaitejohnson Sep 19, 2024
4ae2e9d
update tests to expect all args to mcmc_options
kaitejohnson Sep 19, 2024
7543c2b
replace very verbose handling of passing cmdstanr::sample with a do.call
kaitejohnson Sep 19, 2024
35754b5
replace very verbose handling of passing cmdstanr::sample with a do.call
kaitejohnson Sep 19, 2024
221304c
update docs
kaitejohnson Sep 19, 2024
dcb5dc0
fix test
kaitejohnson Sep 19, 2024
702db94
try to suppress ESS warning
kaitejohnson Sep 19, 2024
7d4e01f
Update tests/testthat/test_ww_model.R
kaitejohnson Sep 19, 2024
4831258
set fit_opts as an empty list
kaitejohnson Sep 19, 2024
c780542
replace test with checkmate
kaitejohnson Sep 19, 2024
a25cac1
Merge branch '174-cmdstanr-sample-args' of https://github.com/CDCgov/…
kaitejohnson Sep 19, 2024
8447704
add wrapper function to silence warnings on mcmc tests
kaitejohnson Sep 19, 2024
5dc640c
Update R/wwinference.R
kaitejohnson Sep 22, 2024
2dba97b
remove extra args from get mcmc options
kaitejohnson Sep 22, 2024
3dc8b95
fix example and rerun documentation
kaitejohnson Sep 22, 2024
3c76895
add checkmates expect names to only allow cmdstanr$ sample args to be…
kaitejohnson Sep 22, 2024
f7d7d39
make get_mcmc_options an internal function
kaitejohnson Sep 22, 2024
ccc7792
modify wwinference call to pass a list not function to fit_opts
kaitejohnson Sep 22, 2024
0a9c49b
fix call to get number of chains
kaitejohnson Sep 22, 2024
9965832
update package data
kaitejohnson Sep 22, 2024
47ec753
modify language in vignette to no longer reference get_mcmc_options
kaitejohnson Sep 22, 2024
388a81c
Update R/wwinference.R
kaitejohnson Sep 22, 2024
6e41735
update documentation
kaitejohnson Sep 22, 2024
be2aaa3
run precommit
kaitejohnson Sep 22, 2024
fc635a2
remove example for function thts not exported
kaitejohnson Sep 22, 2024
1932126
fix test
kaitejohnson Sep 22, 2024
4af34d6
modify test data and rerun
kaitejohnson Sep 22, 2024
94900a6
Update R/wwinference.R
kaitejohnson Sep 23, 2024
53f4772
Update R/wwinference.R
kaitejohnson Sep 23, 2024
811c4f3
Update R/wwinference.R
kaitejohnson Sep 23, 2024
fea1582
Update R/wwinference.R
kaitejohnson Sep 23, 2024
d02c83c
Update R/wwinference.R
kaitejohnson Sep 23, 2024
8aeddc0
Update R/wwinference.R
kaitejohnson Sep 23, 2024
abac574
Update R/wwinference.R
kaitejohnson Sep 23, 2024
68db82c
Update R/wwinference.R
kaitejohnson Sep 23, 2024
5bff3df
Update R/wwinference.R
kaitejohnson Sep 23, 2024
77c7b73
Update R/wwinference.R
kaitejohnson Sep 23, 2024
1608f23
Update R/wwinference.R
kaitejohnson Sep 23, 2024
3bc1ac1
Update R/wwinference.R
kaitejohnson Sep 23, 2024
a255fc2
Update R/wwinference.R
kaitejohnson Sep 23, 2024
f202eb0
Update R/wwinference.R
kaitejohnson Sep 23, 2024
c2c65a1
export function and remove args that are already defaults
kaitejohnson Sep 23, 2024
f684518
fix tests
kaitejohnson Sep 23, 2024
c594a28
fix test, add formalArgs to imports
kaitejohnson Sep 23, 2024
d7d77e7
remove parallel chains and set to default, document things
kaitejohnson Sep 23, 2024
43571c1
Update R/wwinference.R
kaitejohnson Sep 24, 2024
718ebcf
Update R/wwinference.R
kaitejohnson Sep 24, 2024
606eff3
update docs
kaitejohnson Sep 24, 2024
becd549
run precommit and add new test
kaitejohnson Sep 24, 2024
dc67274
add reg exp to test
kaitejohnson Sep 24, 2024
d94d580
add setting mc cores to 4
kaitejohnson Sep 24, 2024
b0ffc23
add the chains back so that we use them in inits
kaitejohnson Sep 24, 2024
d70c609
update test data
kaitejohnson Sep 24, 2024
acd0c4b
Merge branch 'main' into 174-cmdstanr-sample-args
dylanhmorris Sep 24, 2024
6d7a305
remove methods formal args from package
kaitejohnson Sep 24, 2024
492d793
Merge branch '174-cmdstanr-sample-args' of https://github.com/CDCgov/…
kaitejohnson Sep 24, 2024
ee27de0
try updating testing data
kaitejohnson Sep 25, 2024
a6e78b0
Update R/wwinference.R
kaitejohnson Sep 26, 2024
1489d7c
Update R/wwinference.R
kaitejohnson Sep 26, 2024
cfe8b0a
fix documentation
kaitejohnson Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified R/sysdata.rda
Binary file not shown.
111 changes: 67 additions & 44 deletions R/wwinference.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@
#' `get_model_spec()`. The default here pertains to the `forecast_date` in the
#' example data provided by the package, but this should be specified by the
#' user based on the date they are producing a forecast
#' @param fit_opts The fit options, which in this case default to the
#' MCMC parameters as defined using `get_mcmc_options()`. This includes
#' the following arguments, which are passed to
#' [`$sample()`][cmdstanr::model-method-sample]:
#' the number of chains, the number of warmup
#' and sampling iterations, the maximum tree depth, the average acceptance
#' probability, and the stan PRNG seed
#' @param fit_opts MCMC fitting options, as a list of keys and values.
#' These are passed as keyword arguments to
#' [`compiled_model$sample()`][cmdstanr::model-method-sample].
#' Where no option is specified, [wwinference()] will fall back first on a
#' package-specific default value given by [get_mcmc_options()], if one exists.
#' If no package-specific default exists, [wwinference()] will fall back on
#' the default value defined in [`$sample()`][cmdstanr::model-method-sample].
kaitejohnson marked this conversation as resolved.
Show resolved Hide resolved
#' See the documentation for [`$sample()`][cmdstanr::model-method-sample] for
#' details on available options.
#' @param generate_initial_values Boolean indicating whether or not to specify
#' the initialization of the sampler, default is `TRUE`, meaning that
#' initialization lists will be generated and passed as the `init` argument
Expand Down Expand Up @@ -124,24 +126,27 @@
#' calibration_time <- 90
#' forecast_horizon <- 28
#' include_ww <- 1
#' ww_fit <- wwinference(input_ww_data,
#' input_count_data,
#'
#' ww_fit <- wwinference(
#' ww_data = input_ww_data,
#' count_data = input_count_data,
#' forecast_date = forecast_date,
#' calibration_time = calibration_time,
#' forecast_horizon = forecast_horizon,
#' model_spec = get_model_spec(
#' forecast_date = forecast_date,
#' calibration_time = calibration_time,
#' forecast_horizon = forecast_horizon,
#' generation_interval = generation_interval,
#' inf_to_count_delay = inf_to_coutn_delay,
#' inf_to_count_delay = inf_to_count_delay,
#' infection_feedback_pmf = infection_feedback_pmf,
#' params = params
#' ),
#' fit_opts = get_mcmc_options(
#' fit_opts = list(
#' iter_warmup = 250,
#' iter_sampling = 250,
#' n_chains = 2
#' chains = 2
#' )
#' )
#' }
#'
#' @rdname wwinference
#' @aliases wwinference_fit
wwinference <- function(ww_data,
Expand All @@ -150,7 +155,7 @@ wwinference <- function(ww_data,
calibration_time = 90,
forecast_horizon = 28,
model_spec = get_model_spec(),
fit_opts = get_mcmc_options(),
fit_opts = list(),
generate_initial_values = TRUE,
initial_values_seed = NULL,
compiled_model = compile_model()) {
Expand All @@ -160,6 +165,18 @@ wwinference <- function(ww_data,
)
}

fit_opts_use <- get_mcmc_options() # get defaults
# this overwrites defaults with all and only the values the user sets in
# `fit_opts`
fit_opts_use[names(fit_opts)] <- fit_opts

# Check that the fit options passed to wwinference are valid cmdstanr::sample
# arguments
checkmate::assert_names(names(fit_opts),
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
subset.of = formalArgs(compiled_model$sample)
)


# Check that data is compatible with specifications
assert_no_dates_after_max(ww_data$date, forecast_date)
assert_no_dates_after_max(count_data$date, forecast_date)
Expand Down Expand Up @@ -204,7 +221,7 @@ wwinference <- function(ww_data,
if (generate_initial_values) {
withr::with_seed(initial_values_seed, {
init_lists <- lapply(
1:fit_opts$n_chains,
1:fit_opts_use$chains,
\(x) {
get_inits_for_one_chain(stan_data_list)
}
Expand All @@ -220,7 +237,7 @@ wwinference <- function(ww_data,
fit <- safe_fit_model(
compiled_model = compiled_model,
stan_data_list = stan_data_list,
fit_opts = fit_opts,
fit_opts = fit_opts_use,
init_lists = init_lists
)

Expand Down Expand Up @@ -329,15 +346,18 @@ fit_model <- function(compiled_model,
stan_data_list,
fit_opts,
init_lists) {
fit <- compiled_model$sample(
data = stan_data_list,
init = init_lists,
seed = fit_opts$seed,
iter_sampling = fit_opts$iter_sampling,
iter_warmup = fit_opts$iter_warmup,
max_treedepth = fit_opts$max_treedepth,
chains = fit_opts$n_chains,
parallel_chains = fit_opts$n_chains
args_for_stan_sampling <-
c(
list(
data = stan_data_list,
init = init_lists
),
fit_opts
)

fit <- do.call(
compiled_model$sample,
args_for_stan_sampling
)

return(fit)
Expand All @@ -348,42 +368,45 @@ fit_model <- function(compiled_model,
#'
#' @description
#' This function returns a list of MCMC settings to pass to the
#' `cmdstanr::sample()` function to fit the model. The default settings are
#' specified for production-level runs, consider adjusting to optimize
#' for speed while iterating.
#' [`$sample()`][cmdstanr::model-method-sample] function to fit the model.
#' The default settings are specified for production-level runs.
#' All input arguments to [`$sample()`][cmdstanr::model-method-sample]
#' are configurable by the user. See
#' [`$sample()`][cmdstanr::model-method-sample] documentation
#' for details of the available arguments.
#'
#'
#' @param iter_warmup integer indicating the number of warm-up iterations,
#' default is `750`
#' default is `750`.
#' @param iter_sampling integer indicating the number of sampling iterations,
#' default is `500`
#' @param n_chains integer indicating the number of MCMC chains to run, default
#' is `4`
#' @param seed set of integers indicating the random seed of the stan sampler,
#' default is NULL
#' default is `500`.
#' @param seed integer, A seed for the (P)RNG to pass to CmdStan. In the case
#' of multi-chain sampling the single seed will automatically be augmented by
#' the the run (chain) ID so that each chain uses a different seed.
#' Default is `NULL`.
#' @param chains integer indicating the number of MCMC chains to run, default
#' is `4`.
#' @param adapt_delta float between 0 and 1 indicating the average acceptance
#' probability, default is `0.95`
#' probability, default is `0.95`.
#' @param max_treedepth integer indicating the maximum tree depth of the
#' sampler, default is 12
#' sampler, default is 12.
#'
#' @return a list of mcmc settings with the values given by the function
#' @return A list of MCMC settings with the values given by the function.
#' arguments
#' @export
#'
#' @examples
#' mcmc_settings <- get_mcmc_options()
#' @export
get_mcmc_options <- function(
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
iter_warmup = 750,
iter_sampling = 500,
n_chains = 4,
seed = NULL,
chains = 4,
adapt_delta = 0.95,
max_treedepth = 12) {
mcmc_settings <- list(
iter_warmup = iter_warmup,
iter_sampling = iter_sampling,
n_chains = n_chains,
seed = seed,
chains = chains,
adapt_delta = adapt_delta,
max_treedepth = max_treedepth
)
Expand Down
9 changes: 5 additions & 4 deletions data-raw/test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ model_spec <- wwinference::get_model_spec(
params = params
)

mcmc_options <- wwinference::get_mcmc_options(
seed = 55,
mcmc_options <- list(
seed = 5,
iter_warmup = 25,
iter_sampling = 25,
n_chains = 1
chains = 1,
show_messages = FALSE
)

generate_initial_values <- TRUE
Expand All @@ -66,7 +67,7 @@ model_test_data <- list(
generate_initial_values = generate_initial_values
)

withr::with_seed(5, {
withr::with_seed(55, {
fit <- do.call(
wwinference::wwinference,
model_test_data
Expand Down
34 changes: 18 additions & 16 deletions man/get_mcmc_options.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 21 additions & 16 deletions man/wwinference.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions tests/testthat/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,10 @@ diff_ar1_from_z_scores_alt <- function(x0, ar, sd, z, stationary = FALSE) {

return(x)
}

silent_wwinference <- function(...) {
utils::capture.output(
fit <- suppressMessages(wwinference(...))
)
return(fit)
}
9 changes: 7 additions & 2 deletions tests/testthat/test_ww_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ test_that("Test the wastewater inference model on simulated data.", {
#######
# run model briefly on the simulated data
#######
withr::with_seed(5, {

# This seed sets the initial values seed. Must be the same as the one used
# in generating the test data.
# model_test_data contains the seed that gets passed to stan
withr::with_seed(55, {
fit <- do.call(
wwinference::wwinference,
silent_wwinference,
model_test_data
)
})


params <- model_test_data$model_spec$params
obs_last_draw <- posterior::subset_draws(fit$fit$result$draws(),
draw = 25
Expand Down
Loading