Skip to content

Commit

Permalink
Improve plot() for check_predictions() (#291)
Browse files Browse the repository at this point in the history
* Improve plot() for check_predictions()
Fixes #290

* Improve plot() for check_predictions()
Fixes #290

* allow type arg

* minor

* fix y axis label

* add example

* example

* fix some issues

* add discrete_options

* subtitle

* set subtitle

* no lollipop

* return value

* add namespace

* fix issues
  • Loading branch information
strengejacke authored May 25, 2023
1 parent b5606c7 commit 9fcdf21
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 30 deletions.
2 changes: 1 addition & 1 deletion R/data_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
#' x <- equivalence_test(model, verbose = FALSE)
#' plot(x)
#' @export
data_plot <- function(x, data = NULL, ...) {
data_plot <- function(x, ...) {
UseMethod("data_plot")
}

Expand Down
3 changes: 3 additions & 0 deletions R/plot.check_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
plot.see_check_model <- function(x,
style = theme_lucid,
colors = NULL,
type = c("density", "discrete_dots", "discrete_interval", "discrete_both"),
...) {
p <- list()

Expand All @@ -21,6 +22,7 @@ plot.see_check_model <- function(x,
model_info <- attr(x, "model_info")
overdisp_type <- attr(x, "overdisp_type")

type <- match.arg(type)

# set default values for arguments ------

Expand Down Expand Up @@ -59,6 +61,7 @@ plot.see_check_model <- function(x,
p$PP_CHECK <- plot.see_performance_pp_check(
x$PP_CHECK,
style = style,
type = type,
check_model = TRUE,
adjust_legend = TRUE,
colors = colors[1:2]
Expand Down
228 changes: 205 additions & 23 deletions R/plot.performance_pp_check.R → R/plot.check_predictions.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' @export
data_plot.performance_pp_check <- function(x, ...) {
data_plot.performance_pp_check <- function(x, type = "density", ...) {
columns <- colnames(x)
dataplot <- stats::reshape(
x,
Expand All @@ -21,10 +21,11 @@ data_plot.performance_pp_check <- function(x, ...) {

attr(dataplot, "info") <- list(
"xlab" = attr(x, "response_name"),
"ylab" = "Density",
"ylab" = ifelse(identical(type, "density"), "Density", "Counts"),
"title" = "Posterior Predictive Check",
"check_range" = attr(x, "check_range"),
"bandwidth" = attr(x, "bandwidth")
"bandwidth" = attr(x, "bandwidth"),
"model_info" = attr(x, "model_info")
)

class(dataplot) <- unique(c("data_plot", "see_performance_pp_check", class(dataplot)))
Expand All @@ -40,6 +41,9 @@ data_plot.performance_pp_check <- function(x, ...) {
#'
#' @param line_alpha Numeric value specifying alpha of lines indicating `yrep`.
#' @param style A ggplot2-theme.
#' @param type Plot type for the posterior predictive checks plot. Can be `"density"`
#' (default), `"discrete_dots"`, `"discrete_interval"` or `"discrete_both"` (the
#' `discrete_*` options are only for models with binary, integer or ordinal outcomes).
#' @inheritParams data_plot
#' @inheritParams plot.see_check_normality
#' @inheritParams plot.see_parameters_distribution
Expand All @@ -48,23 +52,36 @@ data_plot.performance_pp_check <- function(x, ...) {
#'
#' @examplesIf require("performance")
#' model <- lm(Sepal.Length ~ Species * Petal.Width + Petal.Length, data = iris)
#' check_posterior_predictions(model)
#' check_predictions(model)
#'
#' # dot-plot style for count-models
#' d <- iris
#' d$poisson_var <- rpois(150, 1)
#' model <- glm(
#' poisson_var ~ Species + Petal.Length + Petal.Width,
#' data = d,
#' family = poisson()
#' )
#' out <- check_predictions(model)
#' plot(out, type = "discrete_dots")
#' @export
print.see_performance_pp_check <- function(x,
size_line = 0.5,
line_alpha = 0.15,
size_bar = 0.7,
style = theme_lucid,
colors = unname(social_colors(c("green", "blue"))),
type = c("density", "discrete_dots", "discrete_interval", "discrete_both"),
...) {
orig_x <- x
check_range <- isTRUE(attributes(x)$check_range)
type <- match.arg(type)

if (!inherits(x, "data_plot")) {
x <- data_plot(x)
x <- data_plot(x, type)
}

p1 <- .plot_pp_check(x, size_line, line_alpha, theme_style = style, colors = colors, ...)
p1 <- .plot_pp_check(x, size_line, line_alpha, theme_style = style, colors = colors, type = type, ...)

if (isTRUE(check_range)) {
p2 <- .plot_pp_check_range(orig_x, size_bar, colors = colors)
Expand All @@ -85,15 +102,17 @@ plot.see_performance_pp_check <- function(x,
size_bar = 0.7,
style = theme_lucid,
colors = unname(social_colors(c("green", "blue"))),
type = c("density", "discrete_dots", "discrete_interval", "discrete_both"),
...) {
orig_x <- x
check_range <- isTRUE(attributes(x)$check_range)
type <- match.arg(type)

if (!inherits(x, "data_plot")) {
x <- data_plot(x)
x <- data_plot(x, type)
}

p1 <- .plot_pp_check(x, size_line, line_alpha, theme_style = style, colors = colors, ...)
p1 <- .plot_pp_check(x, size_line, line_alpha, theme_style = style, colors = colors, type = type, ...)

if (isTRUE(check_range)) {
p2 <- .plot_pp_check_range(orig_x, size_bar, colors = colors)
Expand All @@ -105,7 +124,7 @@ plot.see_performance_pp_check <- function(x,



.plot_pp_check <- function(x, size_line, line_alpha, theme_style, colors, ...) {
.plot_pp_check <- function(x, size_line, line_alpha, theme_style, colors, type = "density", ...) {
info <- attr(x, "info")

# default bandwidth, for smooting
Expand All @@ -114,7 +133,46 @@ plot.see_performance_pp_check <- function(x,
bandwidth <- "nrd"
}

out <- ggplot2::ggplot(x) +
minfo <- info$model_info
suggest_dots <- (minfo$is_bernoulli || minfo$is_count || minfo$is_ordinal || minfo$is_categorical)

if (!is.null(type) && type %in% c("discrete_dots", "discrete_interval", "discrete_both") && suggest_dots) {
out <- .plot_check_predictions_dots(x, colors, info, size_line, line_alpha, type, ...)
} else {
if (suggest_dots) {
insight::format_alert(
"The model has an integer or a categorical response variable.",
"It is recommended to switch to a dot-plot style, e.g. `plot(check_model(model), type = \"discrete_dots\"`."
)
}
# denity plot - for models that have no binary or count/ordinal outcome
out <- .plot_check_predictions_density(x, colors, info, size_line, line_alpha, bandwidth, ...)
}


dots <- list(...)
if (isTRUE(dots[["check_model"]])) {
out <- out + theme_style(
base_size = 10,
plot.title.space = 3,
axis.title.space = 5
)
}

if (isTRUE(dots[["adjust_legend"]]) || isTRUE(info$check_range)) {
out <- out + ggplot2::theme(
legend.position = "bottom",
legend.margin = ggplot2::margin(0, 0, 0, 0),
legend.box.margin = ggplot2::margin(-5, -5, -5, -5)
)
}

out
}


.plot_check_predictions_density <- function(x, colors, info, size_line, line_alpha, bandwidth, ...) {
ggplot2::ggplot(x) +
ggplot2::stat_density(
mapping = ggplot2::aes(
x = .data$values,
Expand Down Expand Up @@ -159,26 +217,150 @@ plot.see_performance_pp_check <- function(x,
color = ggplot2::guide_legend(reverse = TRUE),
size = ggplot2::guide_legend(reverse = TRUE)
)
}


dots <- list(...)
if (isTRUE(dots[["check_model"]])) {
out <- out + theme_style(
base_size = 10,
plot.title.space = 3,
axis.title.space = 5
)
.plot_check_predictions_dots <- function(x, colors, info, size_line, line_alpha, type = "discrete_dots", ...) {
# make sure we have a factor, so "table()" generates frequencies for all levels
# for each group - we need tables of same size to bind data frames
x$values <- as.factor(x$values)
x <- stats::aggregate(x["values"], list(grp = x$grp), table)
x <- cbind(data.frame(key = "Model-predicted data", stringsAsFactors = FALSE), x)
x <- cbind(x[1:2], as.data.frame(x[[3]]))
x$key[nrow(x)] <- "Observed data"
x <- datawizard::data_to_long(x, select = -1:-2, names_to = "x", values_to = "count")
if (insight::n_unique(x$x) > 8) {
x$x <- datawizard::to_numeric(x$x)
}

if (isTRUE(dots[["adjust_legend"]]) || isTRUE(info$check_range)) {
out <- out + ggplot2::theme(
legend.position = "bottom",
legend.margin = ggplot2::margin(0, 0, 0, 0),
legend.box.margin = ggplot2::margin(-5, -5, -5, -5)
p1 <- p2 <- NULL

if (!is.null(type) && type %in% c("discrete_interval", "discrete_both")) {
centrality_dispersion <- function(i) {
c(
count = stats::median(i, na.rm = TRUE),
unlist(bayestestR::ci(i)[c("CI_low", "CI_high")])
)
}
x_errorbars <- stats::aggregate(x["count"], list(x$x), centrality_dispersion)
x_errorbars <- cbind(x_errorbars[1], as.data.frame(x_errorbars[[2]]))
colnames(x_errorbars) <- c("x", "count", "CI_low", "CI_high")
x_errorbars <- cbind(
data.frame(key = "Model-predicted data", stringsAsFactors = FALSE),
x_errorbars
)

x_tmp <- x[x$key == "Observed data", ]
x_tmp$CI_low <- NA
x_tmp$CI_high <- NA
x_tmp$grp <- NULL

x_errorbars <- rbind(x_errorbars, x_tmp)
p1 <- ggplot2::ggplot() + ggplot2::geom_pointrange(
data = x_errorbars[x_errorbars$key == "Model-predicted data", ],
mapping = ggplot2::aes(
x = .data$x,
y = .data$count,
ymin = .data$CI_low,
ymax = .data$CI_high,
color = .data$key
),
position = ggplot2::position_nudge(x = 0.2),
size = 1.5 * size_line,
linewidth = 1.5 * size_line,
stroke = 0,
shape = 16
) +
ggplot2::geom_point(
data = x_errorbars[x_errorbars$key == "Observed data", ],
mapping = ggplot2::aes(
x = .data$x,
y = .data$count,
color = .data$key
),
size = 6 * size_line,
stroke = 0,
shape = 16
)
}

out
if (!is.null(type) && type %in% c("discrete_dots", "discrete_both")) {
if (is.null(p1)) {
p2 <- ggplot2::ggplot()
} else {
p2 <- p1
}
p2 <- p2 + ggplot2::geom_point(
data = x[x$key == "Model-predicted data", ],
mapping = ggplot2::aes(
x = .data$x,
y = .data$count,
group = .data$grp,
color = .data$key
),
alpha = line_alpha,
position = ggplot2::position_jitter(width = 0.1, height = 0.02),
size = 4 * size_line,
stroke = 0,
shape = 16
) +
# for legend
ggplot2::geom_point(
data = x[x$key == "Observed data", ],
mapping = ggplot2::aes(
x = .data$x,
y = .data$count,
group = .data$grp,
color = .data$key
),
size = 4 * size_line
) +
ggplot2::geom_point(
data = x[x$key == "Observed data", ],
mapping = ggplot2::aes(
x = .data$x,
y = .data$count
),
size = 6 * size_line,
shape = 21,
colour = "white",
fill = colors[1]
)
}

if (is.null(p2)) {
p <- p1
} else {
p <- p2
}

if (type == "discrete_interval") {
subtitle <- "Model-predicted intervals should include observed data points"
} else {
subtitle <- "Model-predicted points should be close to observed data points"
}

p <- p +
ggplot2::scale_y_continuous() +
ggplot2::scale_color_manual(values = c(
"Observed data" = colors[1],
"Model-predicted data" = colors[2]
)) +
ggplot2::labs(
x = info$xlab,
y = info$ylab,
color = "",
size = "",
alpha = "",
title = "Posterior Predictive Check",
subtitle = subtitle
) +
ggplot2::guides(
color = ggplot2::guide_legend(reverse = TRUE),
size = ggplot2::guide_legend(reverse = TRUE)
)

return(p)
}


Expand Down
1 change: 1 addition & 0 deletions R/plot.compare_performance.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#' @rdname data_plot
#' @export
data_plot.compare_performance <- function(x, data = NULL, ...) {
x$Model <- sprintf("%s (%s)", x$Name, x$Model)
Expand Down
11 changes: 7 additions & 4 deletions man/data_plot.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9fcdf21

Please sign in to comment.