Skip to content

Commit

Permalink
edits
Browse files Browse the repository at this point in the history
  • Loading branch information
DominiqueMakowski committed May 1, 2024
1 parent 08b374d commit 7b3ed6e
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 56 deletions.
124 changes: 74 additions & 50 deletions R/report.compare.loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#' Automatically report the results of Bayesian model comparison using the `loo` package.
#'
#' @param x An object of class [brms::loo_compare].
#' @param index type if index to report - expected log pointwise predictive
#' density (ELPD) or information criteria (IC).
#' @param include_IC Whether to include the information criteria (IC).
#' @param include_ENP Whether to include the effective number of parameters (ENP).
#' @param ... Additional arguments (not used for now).
#'
#' @examplesIf require("brms", quietly = TRUE)
Expand All @@ -13,24 +13,33 @@
#'
#' m1 <- brms::brm(mpg ~ qsec, data = mtcars)
#' m2 <- brms::brm(mpg ~ qsec + drat, data = mtcars)
#' m3 <- brms::brm(mpg ~ qsec + drat + wt, data = mtcars)
#'
#' x <- brms::loo_compare(brms::add_criterion(m1, "loo"),
#' x <- brms::loo_compare(
#' brms::add_criterion(m1, "loo"),
#' brms::add_criterion(m2, "loo"),
#' model_names = c("m1", "m2")
#' brms::add_criterion(m3, "loo"),
#' model_names = c("m1", "m2", "m3")
#' )
#' report(x)
#' report(x, include_IC = FALSE)
#' report(x, include_ENP = TRUE)
#' }
#'
#' @details
#' The rule of thumb is that the models are "very similar" if |elpd_diff| (the
#' absolute value of elpd_diff) is less than 4 (Sivula, Magnusson and Vehtari, 2020).
#' If superior to 4, then one can use the SE to obtain a standardized difference
#' (Z-diff) and interpret it as such, assuming that the difference is normally
#' distributed.
#' distributed. The corresponding p-value is then calculated as `2 * pnorm(-abs(Z-diff))`.
#' However, note that if the raw ELPD difference is small (less than 4), it doesn't
#' make much sense to rely on its standardized value: it is not very useful to
#' conclude that a model is much better than another if both models make very
#' similar predictions.
#'
#' @return Objects of class [report_text()].
#' @export
report.compare.loo <- function(x, index = c("ELPD", "IC"), ...) {
report.compare.loo <- function(x, include_IC = TRUE, include_ENP = FALSE, ...) {
# nolint start
# https://stats.stackexchange.com/questions/608881/how-to-interpret-elpd-diff-of-bayesian-loo-estimate-in-bayesian-logistic-regress
# nolint end
Expand All @@ -40,72 +49,87 @@ report.compare.loo <- function(x, index = c("ELPD", "IC"), ...) {
# The difference in expected log predictive density (elpd) between each model
# and the best model as well as the standard error of this difference (assuming
# the difference is approximately normal).
index <- match.arg(index)
x <- as.data.frame(x)

# The values in the first row are 0s because the models are ordered from best to worst according to their elpd.
x <- as.data.frame(x)
modnames <- rownames(x)

elpd_diff <- x[["elpd_diff"]]
se_elpd_diff <- x[["se_diff"]]
ic_diff <- -2 * elpd_diff

z_elpd_diff <- elpd_diff / x[["se_diff"]]
z_elpd_diff <- elpd_diff / se_elpd_diff
p_elpd_diff <- 2 * pnorm(-abs(z_elpd_diff))
z_ic_diff <- -z_elpd_diff

if ("looic" %in% colnames(x)) {
type <- "LOO"
ENP <- x[["p_loo"]]
elpd <- x[["elpd_loo"]]
enp <- x[["p_loo"]]
index_label <- "ELPD-LOO"
ic <- x[["looic"]]
index_ic <- "LOOIC"
} else {
type <- "WAIC"
ENP <- x[["p_waic"]]
elpd <- x[["elpd_waic"]]
enp <- x[["p_waic"]]
index_label <- "ELPD-WAIC"
ic <- x[["waic"]]
index_ic <- "WAIC"
}

if (index == "ELPD") {
index_label <- sprintf("Expected Log Predictive Density (ELPD-%s)", type)
} else if (type == "LOO") {
index_label <- "Leave-One-Out CV Information Criterion (LOOIC)"
} else {
index_label <- "Widely Applicable Information Criterion (WAIC)"
}
# TODO: The above indices-computation and name-matching should be implemented
# in a parameters.compare.loo() function which would be run here.

out_text <- sprintf(
paste(
"The difference in predictive accuracy, as index by %s, suggests that '%s' ",
"is the best model (effective number of parameters (ENP) = %.2f), followed by"
# Starting text -----
text <- sprintf(

Check warning on line 83 in R/report.compare.loo.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/report.compare.loo.R,line=83,col=3,[object_overwrite_linter] 'text' is an exported object from package 'graphics'. Avoid re-using such symbols.
paste0(
"The difference in predictive accuracy, as indexed by Expected Log ",
"Predictive Density (%s), suggests that '%s' is the best model ("
),
index_label, modnames[1], ENP[1]
index_label, modnames[1]
)

if (index == "ELPD") {
other_texts <- sprintf(
"'%s' (diff = %.2f, ENP = %.2f, z-diff = %.2f)",
modnames[-1],
elpd_diff[-1],
ENP[-1],
z_elpd_diff[-1]
)
if(all(c(include_IC, include_ENP) == FALSE)) {

Check warning on line 90 in R/report.compare.loo.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/report.compare.loo.R,line=90,col=5,[spaces_left_parentheses_linter] Place a space before left parenthesis, except in a function call.

Check warning on line 90 in R/report.compare.loo.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/report.compare.loo.R,line=90,col=10,[redundant_equals_linter] Using == on a logical vector is redundant. Well-named logical vectors can be used directly in filtering. For data.table's `i` argument, wrap the column name in (), like `DT[(is_treatment)]`.
text <- sprintf(paste0(text, "ELPD = %.2f)"), elpd[1])

Check warning on line 91 in R/report.compare.loo.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/report.compare.loo.R,line=91,col=5,[object_overwrite_linter] 'text' is an exported object from package 'graphics'. Avoid re-using such symbols.
} else {
other_texts <- sprintf(
"'%s' (diff = %.2f, ENP = %.2f, z-diff = %.2f)",
modnames[-1],
ic_diff[-1],
ENP[-1],
z_ic_diff[-1]
)
if(include_IC) {

Check warning on line 93 in R/report.compare.loo.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/report.compare.loo.R,line=93,col=7,[spaces_left_parentheses_linter] Place a space before left parenthesis, except in a function call.
text <- sprintf(paste0(text, "%s = %.2f"), index_ic, ic[1])

Check warning on line 94 in R/report.compare.loo.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/report.compare.loo.R,line=94,col=7,[object_overwrite_linter] 'text' is an exported object from package 'graphics'. Avoid re-using such symbols.
}
if(include_ENP) {

Check warning on line 96 in R/report.compare.loo.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/report.compare.loo.R,line=96,col=7,[spaces_left_parentheses_linter] Place a space before left parenthesis, except in a function call.
if(include_IC) {

Check warning on line 97 in R/report.compare.loo.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/report.compare.loo.R,line=97,col=9,[spaces_left_parentheses_linter] Place a space before left parenthesis, except in a function call.
text <- sprintf(paste0(text, ", ENP = %.2f)"), enp[1])

Check warning on line 98 in R/report.compare.loo.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/report.compare.loo.R,line=98,col=9,[object_overwrite_linter] 'text' is an exported object from package 'graphics'. Avoid re-using such symbols.
} else {
text <- sprintf(paste0(text, "ENP = %.2f)"), enp[1])

Check warning on line 100 in R/report.compare.loo.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/report.compare.loo.R,line=100,col=9,[object_overwrite_linter] 'text' is an exported object from package 'graphics'. Avoid re-using such symbols.
}
} else {
text <- paste0(text, ")")
}
}

sep <- "."
nothermods <- length(other_texts)
if (nothermods > 1L) {
if (nothermods == 2L) {
sep <- c(" and ", sep)
# Other models ---
text_models <- sprintf("'%s' (diff-ELPD = %.2f +- %.2f, %s",
modnames[-1],
elpd_diff[-1],
se_elpd_diff[-1],
insight::format_p(p_elpd_diff[-1]))

if(all(c(include_IC, include_ENP) == FALSE)) {
text_models <- paste0(text_models, ")")
} else {
if(include_IC) {
text_models <- sprintf(paste0(text_models, ", %s = %.2f"), index_ic, ic[-1])
}
if(include_ENP) {
if(include_IC) {
text_models <- sprintf(paste0(text_models, ", ENP = %.2f)"), enp[-1])
} else {
text_models <- sprintf(paste0(text_models, "ENP = %.2f)"), enp[-1])
}
} else {
sep <- c(rep(", ", length = nothermods - 2), ", and ", sep)
text_models <- sprintf(paste0(text_models, ")"))
}
}

other_texts <- paste0(other_texts, sep, collapse = "")

out_text <- paste(out_text, other_texts, collapse = "")
text <- paste0(text, ", followed by ", datawizard::text_concatenate( text_models))
class(text) <- c("report_text", class(text))
out_text
text
}
20 changes: 14 additions & 6 deletions man/report.compare.loo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7b3ed6e

Please sign in to comment.