Skip to content

Commit

Permalink
feature issue #1549
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 19, 2024
1 parent c0eb374 commit f5932e2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
* Add priorsense support via `create_priorsense_data.brmsfit`
thanks to Noa Kallioinen. (#1354)
* Vectorize censored log likelihoods in the Stan code when possible. (#1657)
* Force Stan to activate threading without altering the Stan code
via argument `force` of function `threading`. (#1549)

### Bug Fixes

Expand Down
33 changes: 23 additions & 10 deletions R/backends.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ compile_model <- function(model, backend, ...) {
if (silent < 2) {
message("Compiling Stan program...")
}
if (use_threading(threads)) {
if (use_threading(threads, force = TRUE)) {
if (utils::packageVersion("rstan") >= "2.26") {
threads_per_chain_def <- rstan::rstan_options("threads_per_chain")
on.exit(rstan::rstan_options(threads_per_chain = threads_per_chain_def))
Expand Down Expand Up @@ -100,7 +100,7 @@ compile_model <- function(model, backend, ...) {
# if (cmdstanr::cmdstan_version() >= "2.29.0") {
# .canonicalize_stan_model(args$stan_file, overwrite_file = TRUE)
# }
if (use_threading(threads)) {
if (use_threading(threads, force = TRUE)) {
args$cpp_options$stan_threads <- TRUE
}
if (use_opencl(opencl)) {
Expand Down Expand Up @@ -147,7 +147,7 @@ fit_model <- function(model, backend, ...) {
seed, control, silent, future, ...) {

# some input checks and housekeeping
if (use_threading(threads)) {
if (use_threading(threads, force = TRUE)) {
if (utils::packageVersion("rstan") >= "2.26") {
threads_per_chain_def <- rstan::rstan_options("threads_per_chain")
on.exit(rstan::rstan_options(threads_per_chain = threads_per_chain_def))
Expand Down Expand Up @@ -265,6 +265,7 @@ fit_model <- function(model, backend, ...) {
if (silent < 2) {
message("Start sampling")
}
use_threading <- use_threading(threads, force = TRUE)
if (algorithm %in% c("sampling", "fixed_param")) {
c(args) <- nlist(
iter_sampling = iter - warmup,
Expand All @@ -275,7 +276,7 @@ fit_model <- function(model, backend, ...) {
show_exceptions = silent == 0,
fixed_param = algorithm == "fixed_param"
)
if (use_threading(threads)) {
if (use_threading) {
args$threads_per_chain <- threads$threads
}
if (future) {
Expand Down Expand Up @@ -304,17 +305,17 @@ fit_model <- function(model, backend, ...) {
}
} else if (algorithm %in% c("fullrank", "meanfield")) {
c(args) <- nlist(iter, algorithm)
if (use_threading(threads)) {
if (use_threading) {
args$threads <- threads$threads
}
out <- do_call(model$variational, args)
} else if (algorithm %in% c("pathfinder")) {
if (use_threading(threads)) {
if (use_threading) {
args$num_threads <- threads$threads
}
out <- do_call(model$pathfinder, args)
} else if (algorithm %in% c("laplace")) {
if (use_threading(threads)) {
if (use_threading) {
args$threads <- threads$threads
}
out <- do_call(model$laplace, args)
Expand Down Expand Up @@ -487,6 +488,10 @@ require_backend <- function(backend, x) {
#' \code{reduce_sum}? Defaults to \code{FALSE}. Setting it to \code{TRUE}
#' is required to achieve exact reproducibility of the model results
#' (if the random seed is set as well).
#' @param force Logical. Defaults to \code{FALSE}. If \code{TRUE}, this will
#' force the Stan model to compile with threading enabled without altering the
#' Stan code generated by brms. This can be useful if your own custom Stan
#' functions use threading internally.
#'
#' @return A \code{brmsthreads} object which can be passed to the
#' \code{threads} argument of \code{brm} and related functions.
Expand Down Expand Up @@ -515,7 +520,8 @@ require_backend <- function(backend, x) {
#' }
#'
#' @export
threading <- function(threads = NULL, grainsize = NULL, static = FALSE) {
threading <- function(threads = NULL, grainsize = NULL, static = FALSE,
force = FALSE) {
out <- list(threads = NULL, grainsize = NULL)
class(out) <- "brmsthreads"
if (!is.null(threads)) {
Expand All @@ -533,6 +539,7 @@ threading <- function(threads = NULL, grainsize = NULL, static = FALSE) {
out$grainsize <- grainsize
}
out$static <- as_one_logical(static)
out$force <- as_one_logical(force)
out
}

Expand All @@ -555,8 +562,14 @@ validate_threads <- function(threads) {
}

# is threading activated?
use_threading <- function(threads) {
isTRUE(validate_threads(threads)$threads > 0)
use_threading <- function(threads, force = FALSE) {
threads <- validate_threads(threads)
out <- isTRUE(threads$threads > 0)
if (!force) {
# Stan code will only be altered in non-forced mode
out <- out && !isTRUE(threads$force)
}
out
}

#' GPU support in Stan via OpenCL
Expand Down
7 changes: 6 additions & 1 deletion man/threading.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit f5932e2

Please sign in to comment.