Skip to content

Commit

Permalink
Merge pull request #1643 from avehtari/loo_epred
Browse files Browse the repository at this point in the history
add loo_epred
  • Loading branch information
paul-buerkner authored May 27, 2024
2 parents 7174bea + ea557cf commit 652a7c0
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 42 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.21.3
Date: 2024-05-09
Version: 2.21.4
Date: 2024-05-27
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "[email protected]",
role = c("aut", "cre")),
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ S3method(log_posterior,brmsfit)
S3method(loo,brmsfit)
S3method(loo_R2,brmsfit)
S3method(loo_compare,brmsfit)
S3method(loo_epred,brmsfit)
S3method(loo_linpred,brmsfit)
S3method(loo_model_weights,brmsfit)
S3method(loo_moment_match,brmsfit)
Expand Down Expand Up @@ -471,6 +472,7 @@ export(lognormal)
export(loo)
export(loo_R2)
export(loo_compare)
export(loo_epred)
export(loo_linpred)
export(loo_model_weights)
export(loo_moment_match)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# brms 2.21.0++

### New Features

* Add method `loo_epred` thanks to Aki Vehtari. (#1641)

### Bug Fixes

* Fix a bug that led to partially duplicated Stan code in multilevel terms
Expand All @@ -10,6 +14,8 @@ thanks to Henrik Singmann. (#1651)
* Refactor some of the internal code base to avoid evaluating
many data-dependent quantities several times. (#1653)
* Make argument `loo` optional in `loo_moment_match`.
* Change the output format of `loo_predict` and `loo_linpred` to be
more consistent with other post-processing functions.

# brms 2.21.0

Expand Down
93 changes: 69 additions & 24 deletions R/loo_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' These functions are wrappers around the \code{\link[loo]{E_loo}}
#' function of the \pkg{loo} package.
#'
#' @aliases loo_predict loo_linpred loo_predictive_interval
#' @aliases loo_predict loo_epred loo_linpred loo_predictive_interval
#'
#' @param object An object of class \code{brmsfit}.
#' @param type The statistic to be computed on the results.
Expand All @@ -19,22 +19,21 @@
#' internally, which may be time consuming for models fit to very large datasets.
#' @param ... Optional arguments passed to the underlying methods that is
#' \code{\link[brms:log_lik.brmsfit]{log_lik}}, as well as
#' \code{\link[brms:posterior_predict.brmsfit]{posterior_predict}} or
#' \code{\link[brms:posterior_predict.brmsfit]{posterior_predict}},
#' \code{\link[brms:posterior_epred.brmsfit]{posterior_epred}} or
#' \code{\link[brms:posterior_linpred.brmsfit]{posterior_linpred}}.
#' @inheritParams posterior_predict.brmsfit
#'
#' @return \code{loo_predict} and \code{loo_linpred} return a vector with one
#' element per observation. The only exception is if \code{type = "quantile"}
#' and \code{length(probs) >= 2}, in which case a separate vector for each
#' element of \code{probs} is computed and they are returned in a matrix with
#' \code{length(probs)} rows and one column per observation.
#' @return \code{loo_predict}, \code{loo_epred}, \code{loo_linpred}, and
#' \code{loo_predictive_interval} all return a matrix with one row per
#' observation and one column per summary statistic as specified by
#' arguments \code{type} and \code{probs}. In multivariate or categorical models
#' a third dimension is added to represent the response variables or categories,
#' respectively.
#'
#' \code{loo_predictive_interval} returns a matrix with one row per
#' observation and two columns.
#' \code{loo_predictive_interval(..., prob = p)} is equivalent to
#' \code{loo_predict(..., type = "quantile", probs = c(a, 1-a))} with
#' \code{a = (1 - p)/2}, except it transposes the result and adds informative
#' column names.
#' \code{a = (1 - p)/2}.
#'
#' @examples
#' \dontrun{
Expand All @@ -52,6 +51,7 @@
#' psis <- loo::psis(-log_lik(fit), cores = 2)
#' loo_predictive_interval(fit, prob = 0.8, psis_object = psis)
#' loo_predict(fit, type = "var", psis_object = psis)
#' loo_epred(fit, type = "var", psis_object = psis)
#' }
#'
#' @method loo_predict brmsfit
Expand All @@ -62,13 +62,37 @@ loo_predict.brmsfit <- function(object, type = c("mean", "var", "quantile"),
probs = 0.5, psis_object = NULL, resp = NULL,
...) {
type <- match.arg(type)
stopifnot_resp(object, resp)
if (is.null(psis_object)) {
message("Running PSIS to compute weights")
psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...)
}
preds <- posterior_predict(object, resp = resp, ...)
loo::E_loo(preds, psis_object, type = type, probs = probs)$value
E_loo_value(preds, psis_object, type = type, probs = probs)
}

# #' @importFrom rstantools loo_epred
#' @rdname loo_predict.brmsfit
#' @method loo_epred brmsfit
#' @export loo_epred
#' @export
loo_epred.brmsfit <- function(object, type = c("mean", "var", "quantile"),
probs = 0.5, psis_object = NULL, resp = NULL,
...) {
type <- match.arg(type)
# stopifnot_resp(object, resp)
if (is.null(psis_object)) {
message("Running PSIS to compute weights")
psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...)
}
preds <- posterior_epred(object, resp = resp, ...)
E_loo_value(preds, psis_object, type = type, probs = probs)
}

#' @rdname loo_predict.brmsfit
#' @export
loo_epred <- function(object, ...) {
# TODO: remove this generic once it is available in rstantools
UseMethod("loo_epred")
}

#' @rdname loo_predict.brmsfit
Expand All @@ -80,18 +104,12 @@ loo_linpred.brmsfit <- function(object, type = c("mean", "var", "quantile"),
probs = 0.5, psis_object = NULL, resp = NULL,
...) {
type <- match.arg(type)
stopifnot_resp(object, resp)
family <- family(object, resp = resp)
if (is_ordinal(family) || is_categorical(family)) {
stop2("Method 'loo_linpred' is not implemented ",
"for categorical or ordinal models")
}
if (is.null(psis_object)) {
message("Running PSIS to compute weights")
psis_object <- compute_loo(object, criterion = "psis", resp = resp, ...)
}
preds <- posterior_linpred(object, resp = resp, ...)
loo::E_loo(preds, psis_object, type = type, probs = probs)$value
E_loo_value(preds, psis_object, type = type, probs = probs)
}

#' @rdname loo_predict.brmsfit
Expand All @@ -106,13 +124,40 @@ loo_predictive_interval.brmsfit <- function(object, prob = 0.9,
}
alpha <- (1 - prob) / 2
probs <- c(alpha, 1 - alpha)
labs <- paste0(100 * probs, "%")
intervals <- loo_predict(
object, type = "quantile", probs = probs,
psis_object = psis_object, ...
)
rownames(intervals) <- labs
t(intervals)
intervals
}

# convenient wrapper around loo::E_loo
E_loo_value <- function(x, psis_object, type = "mean", probs = 0.5) {
.E_loo_value <- function(x) {
y <- loo::E_loo(x, psis_object, type = type, probs = probs)$value
# loo::E_loo has output dimensions inconsistent with brms conventions
# ensure that observations are stored as rows and summaries as columns
if (is.matrix(y) && ncol(x) == ncol(y)) {
y <- t(y)
} else if (is.vector(y)) {
# create a matrix with one column representing the summary statistic
y <- matrix(y)
}
# ensure names consistent with the posterior package
labs <- type
if (labs == "quantile") {
labs <- paste0("q", probs * 100)
}
colnames(y) <- labs
return(y)
}
if (length(dim(x)) == 3) {
out <- apply(x, 3, .E_loo_value, simplify = FALSE)
out <- abind::abind(out, rev.along = 0)
} else {
out <- .E_loo_value(x)
}
out
}

#' Compute a LOO-adjusted R-squared for regression models
Expand Down Expand Up @@ -204,7 +249,7 @@ loo_R2.brmsfit <- function(object, resp = NULL, summary = TRUE,
ypredloo <- loo::E_loo(ypred, psis_object, log_ratios = -ll)$value
err_loo <- ypredloo - y

# simulated dirichlet weights
# simulated Dirichlet weights
S <- nrow(ypred)
N <- ncol(ypred)
exp_draws <- matrix(rexp(S * N, rate = 1), nrow = S, ncol = N)
Expand Down
33 changes: 23 additions & 10 deletions man/loo_predict.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 34 additions & 6 deletions tests/testthat/tests.brmsfit-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -479,21 +479,47 @@ test_that("loo_R2 has reasonable outputs", {
expect_equal(dim(R2), c(ndraws(fit1), 1))
})

test_that("loo_epred has reasonable outputs", {
skip_on_cran()

llp <- SW(loo_epred(fit1))
expect_equal(nrow(llp), nobs(fit1))

newdata <- data.frame(
Age = 0, visit = c("a", "b"), Trt = 0,
count = 20, patient = 1, Exp = 2, volume = 0
)
llp <- SW(loo_epred(
fit1, newdata = newdata,
type = "quantile", probs = c(0.25, 0.75),
allow_new_levels = TRUE
))
expect_equal(dim(llp), c(nrow(newdata), 2))

llp <- SW(loo_epred(fit4))
expect_equal(nrow(llp), nobs(fit4))
expect_equal(dim(llp)[3], 4)
})

test_that("loo_linpred has reasonable outputs", {
skip_on_cran()

llp <- SW(loo_linpred(fit1))
expect_equal(length(llp), nobs(fit1))
expect_error(loo_linpred(fit4), "Method 'loo_linpred'")
expect_equal(nrow(llp), nobs(fit1))

llp <- SW(loo_linpred(fit4))
expect_equal(nrow(llp), nobs(fit4))
expect_equal(dim(llp)[3], 3)

llp <- SW(loo_linpred(fit2, scale = "response", type = "var"))
expect_equal(length(llp), nobs(fit2))
expect_equal(nrow(llp), nobs(fit2))
})

test_that("loo_predict has reasonable outputs", {
skip_on_cran()

llp <- SW(loo_predict(fit1))
expect_equal(length(llp), nobs(fit1))
expect_equal(nrow(llp), nobs(fit1))

newdata <- data.frame(
Age = 0, visit = c("a", "b"), Trt = 0,
Expand All @@ -504,9 +530,11 @@ test_that("loo_predict has reasonable outputs", {
type = "quantile", probs = c(0.25, 0.75),
allow_new_levels = TRUE
))
expect_equal(dim(llp), c(2, nrow(newdata)))
expect_equal(dim(llp), c(nrow(newdata), 2))

llp <- SW(loo_predict(fit4))
expect_equal(length(llp), nobs(fit4))
expect_equal(nrow(llp), nobs(fit4))
expect_equal(length(dim(llp)), 2)
})

test_that("loo_predictive_interval has reasonable outputs", {
Expand Down

0 comments on commit 652a7c0

Please sign in to comment.