diff --git a/R/data_plot.R b/R/data_plot.R index 13f4730ef..11827629a 100644 --- a/R/data_plot.R +++ b/R/data_plot.R @@ -84,7 +84,7 @@ #' x <- equivalence_test(model, verbose = FALSE) #' plot(x) #' @export -data_plot <- function(x, data = NULL, ...) { +data_plot <- function(x, ...) { UseMethod("data_plot") } diff --git a/R/plot.check_model.R b/R/plot.check_model.R index 70a1ccde6..8f3b4e87c 100644 --- a/R/plot.check_model.R +++ b/R/plot.check_model.R @@ -3,6 +3,7 @@ plot.see_check_model <- function(x, style = theme_lucid, colors = NULL, + type = c("density", "discrete_dots", "discrete_interval", "discrete_both"), ...) { p <- list() @@ -21,6 +22,7 @@ plot.see_check_model <- function(x, model_info <- attr(x, "model_info") overdisp_type <- attr(x, "overdisp_type") + type <- match.arg(type) # set default values for arguments ------ @@ -59,6 +61,7 @@ plot.see_check_model <- function(x, p$PP_CHECK <- plot.see_performance_pp_check( x$PP_CHECK, style = style, + type = type, check_model = TRUE, adjust_legend = TRUE, colors = colors[1:2] diff --git a/R/plot.performance_pp_check.R b/R/plot.check_predictions.R similarity index 51% rename from R/plot.performance_pp_check.R rename to R/plot.check_predictions.R index e4796650c..98922e473 100644 --- a/R/plot.performance_pp_check.R +++ b/R/plot.check_predictions.R @@ -1,5 +1,5 @@ #' @export -data_plot.performance_pp_check <- function(x, ...) { +data_plot.performance_pp_check <- function(x, type = "density", ...) { columns <- colnames(x) dataplot <- stats::reshape( x, @@ -21,10 +21,11 @@ data_plot.performance_pp_check <- function(x, ...) { attr(dataplot, "info") <- list( "xlab" = attr(x, "response_name"), - "ylab" = "Density", + "ylab" = ifelse(identical(type, "density"), "Density", "Counts"), "title" = "Posterior Predictive Check", "check_range" = attr(x, "check_range"), - "bandwidth" = attr(x, "bandwidth") + "bandwidth" = attr(x, "bandwidth"), + "model_info" = attr(x, "model_info") ) class(dataplot) <- unique(c("data_plot", "see_performance_pp_check", class(dataplot))) @@ -40,6 +41,9 @@ data_plot.performance_pp_check <- function(x, ...) { #' #' @param line_alpha Numeric value specifying alpha of lines indicating `yrep`. #' @param style A ggplot2-theme. +#' @param type Plot type for the posterior predictive checks plot. Can be `"density"` +#' (default), `"discrete_dots"`, `"discrete_interval"` or `"discrete_both"` (the +#' `discrete_*` options are only for models with binary, integer or ordinal outcomes). #' @inheritParams data_plot #' @inheritParams plot.see_check_normality #' @inheritParams plot.see_parameters_distribution @@ -48,7 +52,18 @@ data_plot.performance_pp_check <- function(x, ...) { #' #' @examplesIf require("performance") #' model <- lm(Sepal.Length ~ Species * Petal.Width + Petal.Length, data = iris) -#' check_posterior_predictions(model) +#' check_predictions(model) +#' +#' # dot-plot style for count-models +#' d <- iris +#' d$poisson_var <- rpois(150, 1) +#' model <- glm( +#' poisson_var ~ Species + Petal.Length + Petal.Width, +#' data = d, +#' family = poisson() +#' ) +#' out <- check_predictions(model) +#' plot(out, type = "discrete_dots") #' @export print.see_performance_pp_check <- function(x, size_line = 0.5, @@ -56,15 +71,17 @@ print.see_performance_pp_check <- function(x, size_bar = 0.7, style = theme_lucid, colors = unname(social_colors(c("green", "blue"))), + type = c("density", "discrete_dots", "discrete_interval", "discrete_both"), ...) { orig_x <- x check_range <- isTRUE(attributes(x)$check_range) + type <- match.arg(type) if (!inherits(x, "data_plot")) { - x <- data_plot(x) + x <- data_plot(x, type) } - p1 <- .plot_pp_check(x, size_line, line_alpha, theme_style = style, colors = colors, ...) + p1 <- .plot_pp_check(x, size_line, line_alpha, theme_style = style, colors = colors, type = type, ...) if (isTRUE(check_range)) { p2 <- .plot_pp_check_range(orig_x, size_bar, colors = colors) @@ -85,15 +102,17 @@ plot.see_performance_pp_check <- function(x, size_bar = 0.7, style = theme_lucid, colors = unname(social_colors(c("green", "blue"))), + type = c("density", "discrete_dots", "discrete_interval", "discrete_both"), ...) { orig_x <- x check_range <- isTRUE(attributes(x)$check_range) + type <- match.arg(type) if (!inherits(x, "data_plot")) { - x <- data_plot(x) + x <- data_plot(x, type) } - p1 <- .plot_pp_check(x, size_line, line_alpha, theme_style = style, colors = colors, ...) + p1 <- .plot_pp_check(x, size_line, line_alpha, theme_style = style, colors = colors, type = type, ...) if (isTRUE(check_range)) { p2 <- .plot_pp_check_range(orig_x, size_bar, colors = colors) @@ -105,7 +124,7 @@ plot.see_performance_pp_check <- function(x, -.plot_pp_check <- function(x, size_line, line_alpha, theme_style, colors, ...) { +.plot_pp_check <- function(x, size_line, line_alpha, theme_style, colors, type = "density", ...) { info <- attr(x, "info") # default bandwidth, for smooting @@ -114,7 +133,46 @@ plot.see_performance_pp_check <- function(x, bandwidth <- "nrd" } - out <- ggplot2::ggplot(x) + + minfo <- info$model_info + suggest_dots <- (minfo$is_bernoulli || minfo$is_count || minfo$is_ordinal || minfo$is_categorical) + + if (!is.null(type) && type %in% c("discrete_dots", "discrete_interval", "discrete_both") && suggest_dots) { + out <- .plot_check_predictions_dots(x, colors, info, size_line, line_alpha, type, ...) + } else { + if (suggest_dots) { + insight::format_alert( + "The model has an integer or a categorical response variable.", + "It is recommended to switch to a dot-plot style, e.g. `plot(check_model(model), type = \"discrete_dots\"`." + ) + } + # denity plot - for models that have no binary or count/ordinal outcome + out <- .plot_check_predictions_density(x, colors, info, size_line, line_alpha, bandwidth, ...) + } + + + dots <- list(...) + if (isTRUE(dots[["check_model"]])) { + out <- out + theme_style( + base_size = 10, + plot.title.space = 3, + axis.title.space = 5 + ) + } + + if (isTRUE(dots[["adjust_legend"]]) || isTRUE(info$check_range)) { + out <- out + ggplot2::theme( + legend.position = "bottom", + legend.margin = ggplot2::margin(0, 0, 0, 0), + legend.box.margin = ggplot2::margin(-5, -5, -5, -5) + ) + } + + out +} + + +.plot_check_predictions_density <- function(x, colors, info, size_line, line_alpha, bandwidth, ...) { + ggplot2::ggplot(x) + ggplot2::stat_density( mapping = ggplot2::aes( x = .data$values, @@ -159,26 +217,150 @@ plot.see_performance_pp_check <- function(x, color = ggplot2::guide_legend(reverse = TRUE), size = ggplot2::guide_legend(reverse = TRUE) ) +} - dots <- list(...) - if (isTRUE(dots[["check_model"]])) { - out <- out + theme_style( - base_size = 10, - plot.title.space = 3, - axis.title.space = 5 - ) +.plot_check_predictions_dots <- function(x, colors, info, size_line, line_alpha, type = "discrete_dots", ...) { + # make sure we have a factor, so "table()" generates frequencies for all levels + # for each group - we need tables of same size to bind data frames + x$values <- as.factor(x$values) + x <- stats::aggregate(x["values"], list(grp = x$grp), table) + x <- cbind(data.frame(key = "Model-predicted data", stringsAsFactors = FALSE), x) + x <- cbind(x[1:2], as.data.frame(x[[3]])) + x$key[nrow(x)] <- "Observed data" + x <- datawizard::data_to_long(x, select = -1:-2, names_to = "x", values_to = "count") + if (insight::n_unique(x$x) > 8) { + x$x <- datawizard::to_numeric(x$x) } - if (isTRUE(dots[["adjust_legend"]]) || isTRUE(info$check_range)) { - out <- out + ggplot2::theme( - legend.position = "bottom", - legend.margin = ggplot2::margin(0, 0, 0, 0), - legend.box.margin = ggplot2::margin(-5, -5, -5, -5) + p1 <- p2 <- NULL + + if (!is.null(type) && type %in% c("discrete_interval", "discrete_both")) { + centrality_dispersion <- function(i) { + c( + count = stats::median(i, na.rm = TRUE), + unlist(bayestestR::ci(i)[c("CI_low", "CI_high")]) + ) + } + x_errorbars <- stats::aggregate(x["count"], list(x$x), centrality_dispersion) + x_errorbars <- cbind(x_errorbars[1], as.data.frame(x_errorbars[[2]])) + colnames(x_errorbars) <- c("x", "count", "CI_low", "CI_high") + x_errorbars <- cbind( + data.frame(key = "Model-predicted data", stringsAsFactors = FALSE), + x_errorbars ) + + x_tmp <- x[x$key == "Observed data", ] + x_tmp$CI_low <- NA + x_tmp$CI_high <- NA + x_tmp$grp <- NULL + + x_errorbars <- rbind(x_errorbars, x_tmp) + p1 <- ggplot2::ggplot() + ggplot2::geom_pointrange( + data = x_errorbars[x_errorbars$key == "Model-predicted data", ], + mapping = ggplot2::aes( + x = .data$x, + y = .data$count, + ymin = .data$CI_low, + ymax = .data$CI_high, + color = .data$key + ), + position = ggplot2::position_nudge(x = 0.2), + size = 1.5 * size_line, + linewidth = 1.5 * size_line, + stroke = 0, + shape = 16 + ) + + ggplot2::geom_point( + data = x_errorbars[x_errorbars$key == "Observed data", ], + mapping = ggplot2::aes( + x = .data$x, + y = .data$count, + color = .data$key + ), + size = 6 * size_line, + stroke = 0, + shape = 16 + ) } - out + if (!is.null(type) && type %in% c("discrete_dots", "discrete_both")) { + if (is.null(p1)) { + p2 <- ggplot2::ggplot() + } else { + p2 <- p1 + } + p2 <- p2 + ggplot2::geom_point( + data = x[x$key == "Model-predicted data", ], + mapping = ggplot2::aes( + x = .data$x, + y = .data$count, + group = .data$grp, + color = .data$key + ), + alpha = line_alpha, + position = ggplot2::position_jitter(width = 0.1, height = 0.02), + size = 4 * size_line, + stroke = 0, + shape = 16 + ) + + # for legend + ggplot2::geom_point( + data = x[x$key == "Observed data", ], + mapping = ggplot2::aes( + x = .data$x, + y = .data$count, + group = .data$grp, + color = .data$key + ), + size = 4 * size_line + ) + + ggplot2::geom_point( + data = x[x$key == "Observed data", ], + mapping = ggplot2::aes( + x = .data$x, + y = .data$count + ), + size = 6 * size_line, + shape = 21, + colour = "white", + fill = colors[1] + ) + } + + if (is.null(p2)) { + p <- p1 + } else { + p <- p2 + } + + if (type == "discrete_interval") { + subtitle <- "Model-predicted intervals should include observed data points" + } else { + subtitle <- "Model-predicted points should be close to observed data points" + } + + p <- p + + ggplot2::scale_y_continuous() + + ggplot2::scale_color_manual(values = c( + "Observed data" = colors[1], + "Model-predicted data" = colors[2] + )) + + ggplot2::labs( + x = info$xlab, + y = info$ylab, + color = "", + size = "", + alpha = "", + title = "Posterior Predictive Check", + subtitle = subtitle + ) + + ggplot2::guides( + color = ggplot2::guide_legend(reverse = TRUE), + size = ggplot2::guide_legend(reverse = TRUE) + ) + + return(p) } diff --git a/R/plot.compare_performance.R b/R/plot.compare_performance.R index 61094dab5..85848f011 100644 --- a/R/plot.compare_performance.R +++ b/R/plot.compare_performance.R @@ -1,3 +1,4 @@ +#' @rdname data_plot #' @export data_plot.compare_performance <- function(x, data = NULL, ...) { x$Model <- sprintf("%s (%s)", x$Name, x$Model) diff --git a/man/data_plot.Rd b/man/data_plot.Rd index 725361f3f..4daef23d7 100644 --- a/man/data_plot.Rd +++ b/man/data_plot.Rd @@ -1,18 +1,21 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data_plot.R +% Please edit documentation in R/data_plot.R, R/plot.compare_performance.R \name{data_plot} \alias{data_plot} +\alias{data_plot.compare_performance} \title{Prepare objects for plotting or plot objects} \usage{ -data_plot(x, data = NULL, ...) +data_plot(x, ...) + +\method{data_plot}{compare_performance}(x, data = NULL, ...) } \arguments{ \item{x}{An object.} +\item{...}{Arguments passed to or from other methods.} + \item{data}{The original data used to create this object. Can be a statistical model.} - -\item{...}{Arguments passed to or from other methods.} } \description{ \code{data_plot()} extracts and transforms an object for plotting, diff --git a/man/print.see_performance_pp_check.Rd b/man/print.see_performance_pp_check.Rd index 2d4bc0a5d..9ab399a9d 100644 --- a/man/print.see_performance_pp_check.Rd +++ b/man/print.see_performance_pp_check.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.performance_pp_check.R +% Please edit documentation in R/plot.check_predictions.R \name{print.see_performance_pp_check} \alias{print.see_performance_pp_check} \alias{plot.see_performance_pp_check} @@ -12,6 +12,7 @@ size_bar = 0.7, style = theme_lucid, colors = unname(social_colors(c("green", "blue"))), + type = c("density", "discrete_dots", "discrete_interval", "discrete_both"), ... ) @@ -22,6 +23,7 @@ size_bar = 0.7, style = theme_lucid, colors = unname(social_colors(c("green", "blue"))), + type = c("density", "discrete_dots", "discrete_interval", "discrete_both"), ... ) } @@ -39,6 +41,10 @@ \item{colors}{Character vector of length two, indicating the colors (in hex-format) for points and line.} +\item{type}{Plot type for the posterior predictive checks plot. Can be \code{"density"} +(default), \code{"discrete_dots"}, \code{"discrete_interval"} or \code{"discrete_both"} (the +\verb{discrete_*} options are only for models with binary, integer or ordinal outcomes).} + \item{...}{Arguments passed to or from other methods.} } \value{ @@ -50,6 +56,17 @@ The \code{plot()} method for the \code{performance::check_predictions()} functio \examples{ \dontshow{if (require("performance")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} model <- lm(Sepal.Length ~ Species * Petal.Width + Petal.Length, data = iris) -check_posterior_predictions(model) +check_predictions(model) + +# dot-plot style for count-models +d <- iris +d$poisson_var <- rpois(150, 1) +model <- glm( + poisson_var ~ Species + Petal.Length + Petal.Width, + data = d, + family = poisson() +) +out <- check_predictions(model) +plot(out, type = "discrete_dots") \dontshow{\}) # examplesIf} }