diff --git a/NEWS.md b/NEWS.md index 378905346..0c5c79876 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,6 +9,8 @@ * Add priorsense support via `create_priorsense_data.brmsfit` thanks to Noa Kallioinen. (#1354) * Vectorize censored log likelihoods in the Stan code when possible. (#1657) +* Force Stan to activate threading without altering the Stan code +via argument `force` of function `threading`. (#1549) ### Bug Fixes diff --git a/R/backends.R b/R/backends.R index 3000383f6..b95cbd45c 100644 --- a/R/backends.R +++ b/R/backends.R @@ -70,7 +70,7 @@ compile_model <- function(model, backend, ...) { if (silent < 2) { message("Compiling Stan program...") } - if (use_threading(threads)) { + if (use_threading(threads, force = TRUE)) { if (utils::packageVersion("rstan") >= "2.26") { threads_per_chain_def <- rstan::rstan_options("threads_per_chain") on.exit(rstan::rstan_options(threads_per_chain = threads_per_chain_def)) @@ -100,7 +100,7 @@ compile_model <- function(model, backend, ...) { # if (cmdstanr::cmdstan_version() >= "2.29.0") { # .canonicalize_stan_model(args$stan_file, overwrite_file = TRUE) # } - if (use_threading(threads)) { + if (use_threading(threads, force = TRUE)) { args$cpp_options$stan_threads <- TRUE } if (use_opencl(opencl)) { @@ -147,7 +147,7 @@ fit_model <- function(model, backend, ...) { seed, control, silent, future, ...) { # some input checks and housekeeping - if (use_threading(threads)) { + if (use_threading(threads, force = TRUE)) { if (utils::packageVersion("rstan") >= "2.26") { threads_per_chain_def <- rstan::rstan_options("threads_per_chain") on.exit(rstan::rstan_options(threads_per_chain = threads_per_chain_def)) @@ -265,6 +265,7 @@ fit_model <- function(model, backend, ...) { if (silent < 2) { message("Start sampling") } + use_threading <- use_threading(threads, force = TRUE) if (algorithm %in% c("sampling", "fixed_param")) { c(args) <- nlist( iter_sampling = iter - warmup, @@ -275,7 +276,7 @@ fit_model <- function(model, backend, ...) { show_exceptions = silent == 0, fixed_param = algorithm == "fixed_param" ) - if (use_threading(threads)) { + if (use_threading) { args$threads_per_chain <- threads$threads } if (future) { @@ -304,17 +305,17 @@ fit_model <- function(model, backend, ...) { } } else if (algorithm %in% c("fullrank", "meanfield")) { c(args) <- nlist(iter, algorithm) - if (use_threading(threads)) { + if (use_threading) { args$threads <- threads$threads } out <- do_call(model$variational, args) } else if (algorithm %in% c("pathfinder")) { - if (use_threading(threads)) { + if (use_threading) { args$num_threads <- threads$threads } out <- do_call(model$pathfinder, args) } else if (algorithm %in% c("laplace")) { - if (use_threading(threads)) { + if (use_threading) { args$threads <- threads$threads } out <- do_call(model$laplace, args) @@ -487,6 +488,10 @@ require_backend <- function(backend, x) { #' \code{reduce_sum}? Defaults to \code{FALSE}. Setting it to \code{TRUE} #' is required to achieve exact reproducibility of the model results #' (if the random seed is set as well). +#' @param force Logical. Defaults to \code{FALSE}. If \code{TRUE}, this will +#' force the Stan model to compile with threading enabled without altering the +#' Stan code generated by brms. This can be useful if your own custom Stan +#' functions use threading internally. #' #' @return A \code{brmsthreads} object which can be passed to the #' \code{threads} argument of \code{brm} and related functions. @@ -515,7 +520,8 @@ require_backend <- function(backend, x) { #' } #' #' @export -threading <- function(threads = NULL, grainsize = NULL, static = FALSE) { +threading <- function(threads = NULL, grainsize = NULL, static = FALSE, + force = FALSE) { out <- list(threads = NULL, grainsize = NULL) class(out) <- "brmsthreads" if (!is.null(threads)) { @@ -533,6 +539,7 @@ threading <- function(threads = NULL, grainsize = NULL, static = FALSE) { out$grainsize <- grainsize } out$static <- as_one_logical(static) + out$force <- as_one_logical(force) out } @@ -555,8 +562,14 @@ validate_threads <- function(threads) { } # is threading activated? -use_threading <- function(threads) { - isTRUE(validate_threads(threads)$threads > 0) +use_threading <- function(threads, force = FALSE) { + threads <- validate_threads(threads) + out <- isTRUE(threads$threads > 0) + if (!force) { + # Stan code will only be altered in non-forced mode + out <- out && !isTRUE(threads$force) + } + out } #' GPU support in Stan via OpenCL diff --git a/man/threading.Rd b/man/threading.Rd index 15e2cc614..b16614f26 100644 --- a/man/threading.Rd +++ b/man/threading.Rd @@ -4,7 +4,7 @@ \alias{threading} \title{Threading in Stan} \usage{ -threading(threads = NULL, grainsize = NULL, static = FALSE) +threading(threads = NULL, grainsize = NULL, static = FALSE, force = FALSE) } \arguments{ \item{threads}{Number of threads to use in within-chain parallelization.} @@ -19,6 +19,11 @@ default is experimental and may change in the future without prior notice.} \code{reduce_sum}? Defaults to \code{FALSE}. Setting it to \code{TRUE} is required to achieve exact reproducibility of the model results (if the random seed is set as well).} + +\item{force}{Logical. Defaults to \code{FALSE}. If \code{TRUE}, this will +force the Stan model to compile with threading enabled without altering the +Stan code generated by brms. This can be useful if your own custom Stan +functions use threading internally.} } \value{ A \code{brmsthreads} object which can be passed to the