diff --git a/DESCRIPTION b/DESCRIPTION index 3d0f7ec9f..56a35698e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Type: Package Package: see Title: Model Visualisation Toolbox for 'easystats' and 'ggplot2' -Version: 0.8.3.3 +Version: 0.8.3.4 Authors@R: c(person(given = "Daniel", family = "Lüdecke", diff --git a/NEWS.md b/NEWS.md index 7a00b2417..565575a7e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,9 @@ downstream plot-functions (i.e., `plot()` for `check_model()` passes arguments to change geom sizes to the underlying plot-functions). +* `plot()` for `check_predictions()` now supports Bayesian regression models from + *brms* and *rstanarm*. + # see 0.8.3 ## Major changes diff --git a/R/plot.check_predictions.R b/R/plot.check_predictions.R index d69f79dca..896c416c3 100644 --- a/R/plot.check_predictions.R +++ b/R/plot.check_predictions.R @@ -1,5 +1,19 @@ #' @export data_plot.performance_pp_check <- function(x, type = "density", ...) { + # for data from "bayesplot::pp_check()", data is already in shape + if (isTRUE(attributes(x)$is_stan) && type != "density") { + class(x) <- c("data_plot", "see_performance_pp_check", "data.frame") + attr(x, "info") <- list( + xlab = attr(x, "response_name"), + ylab = ifelse(identical(type, "density"), "Density", "Counts"), + title = "Posterior Predictive Check", + check_range = attr(x, "check_range"), + bandwidth = attr(x, "bandwidth"), + model_info = attr(x, "model_info") + ) + return(x) + } + columns <- colnames(x) dataplot <- stats::reshape( x, @@ -88,6 +102,7 @@ print.see_performance_pp_check <- function(x, orig_x <- x check_range <- isTRUE(attributes(x)$check_range) plot_type <- attributes(x)$type + is_stan <- attributes(x)$is_stan if (missing(type) && !is.null(plot_type) && plot_type %in% c("density", "discrete_dots", "discrete_interval", "discrete_both")) { type <- plot_type @@ -111,6 +126,7 @@ print.see_performance_pp_check <- function(x, size_axis_title = size_axis_title, type = type, x_limits = x_limits, + is_stan = is_stan, ... ) @@ -143,6 +159,7 @@ plot.see_performance_pp_check <- function(x, orig_x <- x check_range <- isTRUE(attributes(x)$check_range) plot_type <- attributes(x)$type + is_stan <- attributes(x)$is_stan if (missing(type) && !is.null(plot_type) && plot_type %in% c("density", "discrete_dots", "discrete_interval", "discrete_both")) { # nolint type <- plot_type @@ -166,6 +183,7 @@ plot.see_performance_pp_check <- function(x, colors = colors, type = type, x_limits = x_limits, + is_stan = is_stan, ... ) @@ -190,9 +208,16 @@ plot.see_performance_pp_check <- function(x, colors, type = "density", x_limits = NULL, + is_stan = NULL, ...) { info <- attr(x, "info") + # discrete plot type from "bayesplot::pp_check()" returns a different data + # structure, so we need to handle it differently + if (isTRUE(is_stan) && type != "density") { + return(.plot_check_predictions_stan_dots(x, colors, info, size_line, size_point, line_alpha, ...)) + } + # default bandwidth, for smooting bandwidth <- info$bandwidth if (is.null(bandwidth)) { @@ -450,6 +475,71 @@ plot.see_performance_pp_check <- function(x, } +.plot_check_predictions_stan_dots <- function(x, + colors, + info, + size_line, + size_point, + line_alpha, + ...) { + # make sure we have a factor, so "table()" generates frequencies for all levels + # for each group - we need tables of same size to bind data frames + x$Group[x$Group == "y"] <- "Observed data" + x$Group[x$Group == "Mean"] <- "Model-predicted data" + + # sanity check, remove NA rows + x <- x[!is.na(x$Count), ] + + p <- ggplot2::ggplot() + + ggplot2::geom_pointrange( + data = x[x$Group == "Model-predicted data", ], + mapping = ggplot2::aes( + x = .data$x, + y = .data$Count, + ymin = .data$CI_low, + ymax = .data$CI_high, + color = .data$Group + ), + position = ggplot2::position_nudge(x = 0.2), + size = 0.4 * size_point, + linewidth = size_line, + stroke = 0, + shape = 16 + ) + + ggplot2::geom_point( + data = x[x$Group == "Observed data", ], + mapping = ggplot2::aes( + x = .data$x, + y = .data$Count, + color = .data$Group + ), + size = 1.5 * size_point, + stroke = 0, + shape = 16 + ) + + ggplot2::scale_y_continuous() + + ggplot2::scale_color_manual(values = c( + "Observed data" = colors[1], + "Model-predicted data" = colors[2] + )) + + ggplot2::labs( + x = info$xlab, + y = info$ylab, + color = "", + size = "", + alpha = "", + title = "Posterior Predictive Check", + subtitle = "Model-predicted intervals should include observed data points" + ) + + ggplot2::guides( + color = ggplot2::guide_legend(reverse = TRUE), + size = ggplot2::guide_legend(reverse = TRUE) + ) + + return(p) +} + + .plot_pp_check_range <- function(x, size_bar = 0.7, colors = unname(social_colors(c("green", "blue")))) {