Skip to content

Commit

Permalink
Merge branch 'master' into re-predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 16, 2024
2 parents d913c5f + 5bb6531 commit 285adf2
Show file tree
Hide file tree
Showing 59 changed files with 520 additions and 292 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.21.7
Date: 2024-07-19
Version: 2.21.9
Date: 2024-09-16
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "[email protected]",
role = c("aut", "cre")),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ export(read_csv_as_stanfit)
export(recompile_model)
export(reloo)
export(rename_pars)
export(resp_bhaz)
export(resp_cat)
export(resp_cens)
export(resp_dec)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

### New Features

* Support stratified `cox` models via the new addition term `bhaz`. (#1489)
* Support futures for parallelization in the `cmdstanr` backend. (#1684)
* Add method `loo_epred` thanks to Aki Vehtari. (#1641)
* Add priorsense support via `create_priorsense_data.brmsfit` thanks to Noa Kallioinen. (#1354)

Expand Down
50 changes: 41 additions & 9 deletions R/backends.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ fit_model <- function(model, backend, ...) {
} else if (is.character(init) && !init %in% c("random", "0")) {
init <- get(init, mode = "function", envir = parent.frame())
}
future <- future && algorithm %in% "sampling"
args <- nlist(
object = model, data = sdata, iter, seed,
init = init, pars = exclude, include = FALSE
Expand All @@ -187,7 +188,7 @@ fit_model <- function(model, backend, ...) {
warning2("Argument 'cores' is ignored when using 'future'.")
}
args$chains <- 1L
futures <- fits <- vector("list", chains)
out <- futures <- vector("list", chains)
for (i in seq_len(chains)) {
args$chain_id <- i
if (is.list(init)) {
Expand All @@ -200,10 +201,10 @@ fit_model <- function(model, backend, ...) {
)
}
for (i in seq_len(chains)) {
fits[[i]] <- future::value(futures[[i]])
out[[i]] <- future::value(futures[[i]])
}
out <- rstan::sflist2stanfit(fits)
rm(futures, fits)
out <- rstan::sflist2stanfit(out)
rm(futures)
} else {
c(args) <- nlist(chains, cores)
out <- do_call(rstan::sampling, args)
Expand Down Expand Up @@ -239,9 +240,7 @@ fit_model <- function(model, backend, ...) {
} else if (is_equal(init, "0")) {
init <- 0
}
if (future) {
stop2("Argument 'future' is not supported by backend 'cmdstanr'.")
}
future <- future && algorithm %in% "sampling"
args <- nlist(data = sdata, seed, init)
if (use_opencl(opencl)) {
args$opencl_ids <- opencl$ids
Expand Down Expand Up @@ -279,7 +278,30 @@ fit_model <- function(model, backend, ...) {
if (use_threading(threads)) {
args$threads_per_chain <- threads$threads
}
out <- do_call(model$sample, args)
if (future) {
if (cores > 1L) {
warning2("Argument 'cores' is ignored when using 'future'.")
}
args$chains <- 1L
out <- futures <- vector("list", chains)
for (i in seq_len(chains)) {
args$chain_ids <- i
if (is.list(init)) {
args$init <- init[i]
}
futures[[i]] <- future::future(
brms::do_call(model$sample, args),
packages = "cmdstanr",
seed = TRUE
)
}
for (i in seq_len(chains)) {
out[[i]] <- future::value(futures[[i]])
}
rm(futures)
} else {
out <- do_call(model$sample, args)
}
} else if (algorithm %in% c("fullrank", "meanfield")) {
c(args) <- nlist(iter, algorithm)
if (use_threading(threads)) {
Expand All @@ -300,8 +322,18 @@ fit_model <- function(model, backend, ...) {
stop2("Algorithm '", algorithm, "' is not supported.")
}

if (future) {
# 'out' is a list of fitted models
output_files <- ulapply(out, function(x) x$output_files())
stan_variables <- out[[1]]$metadata()$stan_variables
} else {
# 'out' is a single fitted model
output_files <- out$output_files()
stan_variables <- out$metadata()$stan_variables
}

out <- read_csv_as_stanfit(
out$output_files(), variables = out$metadata()$stan_variables,
output_files, variables = stan_variables,
model = model, exclude = exclude, algorithm = algorithm
)

Expand Down
1 change: 1 addition & 0 deletions R/brm.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
#' variational inference with independent normal distributions,
#' \code{"fullrank"} for variational inference with a multivariate normal
#' distribution, \code{"pathfinder"} for the pathfinder algorithm,
#' \code{"laplace"} for the laplace approximation,
#' or \code{"fixed_param"} for sampling from fixed parameter
#' values. Can be set globally for the current \R session via the
#' \code{"brms.algorithm"} option (see \code{\link{options}}).
Expand Down
18 changes: 12 additions & 6 deletions R/brmsformula.R
Original file line number Diff line number Diff line change
Expand Up @@ -1296,12 +1296,6 @@ validate_formula.brmsformula <- function(
out$family$thres <- extract_thres_names(out, data)
out$family$cats <- extract_cat_names(out, data)
}
if (is.mixfamily(out$family)) {
# every mixture family needs to know about response categories
for (i in seq_along(out$family$mix)) {
out$family$mix[[i]]$thres <- out$family$thres
}
}
}
conv_cats_dpars <- conv_cats_dpars(out$family)
if (conv_cats_dpars && !is.null(data)) {
Expand Down Expand Up @@ -1337,6 +1331,18 @@ validate_formula.brmsformula <- function(
out$family$dpars <- union(dp_dpars, out$family$dpars)
}
}
if (is_cox(out$family) && !is.null(data)) {
# for easy access of baseline hazards
out$family$bhaz <- extract_bhaz(out, data)
}
if (is.mixfamily(out$family)) {
# every mixture family needs to know about additional response information
for (i in seq_along(out$family$mix)) {
for (term in c("cats", "thres", "bhaz")) {
out$family$mix[[i]][[term]] <- out$family[[term]]
}
}
}

# incorporate deprecated arguments
require_threshold <- is_ordinal(out$family) && is.null(out$family$threshold)
Expand Down
6 changes: 3 additions & 3 deletions R/brmsframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ brmsframe.brmsterms <- function(x, data, frame = NULL, basis = NULL, ...) {
# this must be a multivariate model
stopifnot(is.list(frame))
x$frame <- frame
x$frame$re <- subset(x$frame$re, resp = x$resp)
x$frame$re <- subset2(x$frame$re, resp = x$resp)
}
data <- subset_data(data, x)
x$frame$resp <- frame_resp(x, data = data)
Expand Down Expand Up @@ -418,8 +418,8 @@ frame_basis_bhaz <- function(x, data, ...) {
if (is_cox(x$family)) {
# compute basis matrix of the baseline hazard for the Cox model
y <- model.response(model.frame(x$respform, data, na.action = na.pass))
out$basis_matrix <- bhaz_basis_matrix(y, args = x$family$bhaz)
args <- family_info(x, "bhaz")$args
out$basis_matrix <- bhaz_basis_matrix(y, args = args)
}
out
}

4 changes: 4 additions & 0 deletions R/data-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,10 @@ validate_newdata <- function(
new_levels <- get_levels(bterms, data = newdata)
for (g in names(old_levels)) {
unknown_levels <- setdiff(new_levels[[g]], old_levels[[g]])
# NA is not found by get_levels but still behaves like a new level (#1652)
if (anyNA(newdata[[g]])) {
c(unknown_levels) <- NA
}
if (length(unknown_levels)) {
unknown_levels <- collapse_comma(unknown_levels)
stop2(
Expand Down
61 changes: 50 additions & 11 deletions R/data-response.R
Original file line number Diff line number Diff line change
Expand Up @@ -469,14 +469,34 @@ data_bhaz <- function(bframe, data, data2, prior) {
return(out)
}
y <- bframe$frame$resp$values
args <- bframe$family$bhaz
bhaz <- family_info(bframe, "bhaz")
bs <- bframe$basis$bhaz$basis_matrix
out$Zbhaz <- bhaz_basis_matrix(y, args, basis = bs)
out$Zcbhaz <- bhaz_basis_matrix(y, args, integrate = TRUE, basis = bs)
out$Zbhaz <- bhaz_basis_matrix(y, bhaz$args, basis = bs)
out$Zcbhaz <- bhaz_basis_matrix(y, bhaz$args, integrate = TRUE, basis = bs)
out$Kbhaz <- NCOL(out$Zbhaz)
sbhaz_prior <- subset2(prior, class = "sbhaz", resp = bframe$resp)
con_sbhaz <- eval_dirichlet(sbhaz_prior$prior, out$Kbhaz, data2)
out$con_sbhaz <- as.array(con_sbhaz)
groups <- bhaz$groups
if (!is.null(groups)) {
out$ngrbhaz <- length(groups)
gr <- get_ad_values(bframe, "bhaz", "gr", data)
gr <- factor(rename(gr), levels = groups)
out$Jgrbhaz <- match(gr, groups)
out$con_sbhaz <- matrix(nrow = out$ngrbhaz, ncol = out$Kbhaz)
sbhaz_prior <- subset2(prior, class = "sbhaz", resp = bframe$resp)
sbhaz_prior_global <- subset2(sbhaz_prior, group = "")
con_sbhaz_global <- eval_dirichlet(sbhaz_prior_global$prior, out$Kbhaz, data2)
for (k in seq_along(groups)) {
sbhaz_prior_group <- subset2(sbhaz_prior, group = groups[k])
if (nzchar(sbhaz_prior_group$prior)) {
out$con_sbhaz[k, ] <- eval_dirichlet(sbhaz_prior_group$prior, out$Kbhaz, data2)
} else {
out$con_sbhaz[k, ] <- con_sbhaz_global
}
}
} else {
sbhaz_prior <- subset2(prior, class = "sbhaz", resp = bframe$resp)
con_sbhaz <- eval_dirichlet(sbhaz_prior$prior, out$Kbhaz, data2)
out$con_sbhaz <- as.array(con_sbhaz)
}
out
}

Expand All @@ -502,9 +522,6 @@ bhaz_basis_matrix <- function(y, args = list(), integrate = FALSE,
}
stopifnot(is.list(args))
args$x <- y
if (!is.null(args$intercept)) {
args$intercept <- as_one_logical(args$intercept)
}
if (is.null(args$Boundary.knots)) {
# avoid 'knots' outside 'Boundary.knots' error (#1143)
# we also need a smaller lower boundary knot to avoid lp = -Inf
Expand All @@ -524,6 +541,29 @@ bhaz_basis_matrix <- function(y, args = list(), integrate = FALSE,
out
}

# extract baseline hazard information from data for storage in the model family
# @return a named list with elements:
# args: arguments that can be passed to bhaz_basis_matrix
# groups: optional names of the groups for which to stratify
extract_bhaz <- function(x, data) {
stopifnot(is.brmsformula(x) || is.brmsterms(x), is_cox(x))
if (is.null(x$adforms)) {
x$adforms <- terms_ad(x$formula, x$family)
}
out <- list()
if (is.null(x$adforms$bhaz)) {
# bhaz is an optional addition term so defaults need to be listed here too
out$args <- list(df = 5, intercept = TRUE)
} else {
out$args <- eval_rhs(x$adforms$bhaz)$flags
gr <- get_ad_values(x, "bhaz", "gr", data)
if (!is.null(gr)) {
out$groups <- rename(levels(factor(gr)))
}
}
out
}

# extract names of response categories
# @param x a brmsterms object or one that can be coerced to it
# @param data user specified data
Expand All @@ -550,7 +590,6 @@ extract_cat_names <- function(x, data) {
# @return a data.frame with columns 'thres' and 'group'
extract_thres_names <- function(x, data) {
stopifnot(is.brmsformula(x) || is.brmsterms(x), has_thres(x))

if (is.null(x$adforms)) {
x$adforms <- terms_ad(x$formula, x$family)
}
Expand Down Expand Up @@ -609,7 +648,7 @@ extract_thres_names <- function(x, data) {
data.frame(thres, group, stringsAsFactors = FALSE)
}

# extract threshold names from the response values
# extract number of thresholds from the response values
# @param formula with the response on the LHS
# @param data a data.frame from which to extract responses
# @param extra_cat is the first category an extra (hurdle) category?
Expand Down
45 changes: 19 additions & 26 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
#' category is used as the reference. If \code{NA}, all categories will be
#' predicted, which requires strong priors or carefully specified predictor
#' terms in order to lead to an identified model.
#' @param bhaz Currently for experimental purposes only.
#'
#' @details
#' Below, we list common use cases for the different families.
Expand Down Expand Up @@ -199,7 +198,7 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log",
link_alpha = "identity",
link_quantile = "logit",
threshold = "flexible",
refcat = NULL, bhaz = NULL) {
refcat = NULL) {
slink <- substitute(link)
.brmsfamily(
family, link = link, slink = slink,
Expand All @@ -212,8 +211,7 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log",
link_ndt = link_ndt, link_bias = link_bias,
link_alpha = link_alpha, link_xi = link_xi,
link_quantile = link_quantile,
threshold = threshold, refcat = refcat,
bhaz = bhaz
threshold = threshold, refcat = refcat
)
}

Expand All @@ -227,7 +225,7 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log",
# @return an object of 'brmsfamily' which inherits from 'family'
.brmsfamily <- function(family, link = NULL, slink = link,
threshold = "flexible",
refcat = NULL, bhaz = NULL, ...) {
refcat = NULL, ...) {
family <- tolower(as_one_character(family))
aux_links <- list(...)
pattern <- c("^normal$", "^zi_", "^hu_")
Expand Down Expand Up @@ -300,23 +298,6 @@ brmsfamily <- function(family, link = NULL, link_sigma = "log",
out$refcat <- as_one_character(refcat, allow_na = allow_na_ref)
}
}
if (is_cox(out$family)) {
if (!is.null(bhaz)) {
if (!is.list(bhaz)) {
stop2("'bhaz' should be a list.")
}
out$bhaz <- bhaz
} else {
out$bhaz <- list()
}
# set default arguments
if (is.null(out$bhaz$df)) {
out$bhaz$df <- 5L
}
if (is.null(out$bhaz$intercept)) {
out$bhaz$intercept <- TRUE
}
}
out
}

Expand Down Expand Up @@ -475,8 +456,8 @@ combine_family_info <- function(x, y, ...) {
clb <- !any(ulapply(x[, 1], isFALSE))
cub <- !any(ulapply(x[, 2], isFALSE))
x <- c(clb, cub)
} else if (y == "thres") {
# thresholds are the same across mixture components
} else if (y %in% c("thres", "bhaz")) {
# same across mixture components
x <- x[[1]]
}
x
Expand Down Expand Up @@ -687,9 +668,9 @@ zero_inflated_asym_laplace <- function(link = "identity", link_sigma = "log",

#' @rdname brmsfamily
#' @export
cox <- function(link = "log", bhaz = NULL) {
cox <- function(link = "log") {
slink <- substitute(link)
.brmsfamily("cox", link = link, bhaz = bhaz)
.brmsfamily("cox", link = link)
}

#' @rdname brmsfamily
Expand Down Expand Up @@ -1750,6 +1731,18 @@ has_thres_groups <- function(family) {
any(nzchar(groups))
}

# get group names of baseline hazard groups
get_bhaz_groups <- function(family) {
bhaz <- family_info(family, "bhaz")
unique(bhaz$groups)
}

# has the model group specific baseline hazards?
has_bhaz_groups <- function(family) {
groups <- get_bhaz_groups(family)
any(nzchar(groups))
}

has_ndt <- function(family) {
"ndt" %in% dpar_class(family_info(family, "dpars"))
}
Expand Down
Loading

0 comments on commit 285adf2

Please sign in to comment.