Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plot.check_predictions for Stan models #336

Merged
merged 3 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: see
Title: Model Visualisation Toolbox for 'easystats' and 'ggplot2'
Version: 0.8.3.3
Version: 0.8.3.4
Authors@R:
c(person(given = "Daniel",
family = "Lüdecke",
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
downstream plot-functions (i.e., `plot()` for `check_model()` passes arguments
to change geom sizes to the underlying plot-functions).

* `plot()` for `check_predictions()` now supports Bayesian regression models from
*brms* and *rstanarm*.

# see 0.8.3

## Major changes
Expand Down
90 changes: 90 additions & 0 deletions R/plot.check_predictions.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
#' @export
data_plot.performance_pp_check <- function(x, type = "density", ...) {
# for data from "bayesplot::pp_check()", data is already in shape
if (isTRUE(attributes(x)$is_stan) && type != "density") {
class(x) <- c("data_plot", "see_performance_pp_check", "data.frame")
attr(x, "info") <- list(
xlab = attr(x, "response_name"),
ylab = ifelse(identical(type, "density"), "Density", "Counts"),
title = "Posterior Predictive Check",
check_range = attr(x, "check_range"),
bandwidth = attr(x, "bandwidth"),
model_info = attr(x, "model_info")
)
return(x)
}

columns <- colnames(x)
dataplot <- stats::reshape(
x,
Expand Down Expand Up @@ -53,7 +67,7 @@
#'
#' @return A ggplot2-object.
#'
#' @seealso See also the vignette about [`check_model()`](https://easystats.github.io/performance/articles/check_model.html).

Check warning on line 70 in R/plot.check_predictions.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/plot.check_predictions.R,line=70,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 125 characters.

Check warning on line 70 in R/plot.check_predictions.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/plot.check_predictions.R,line=70,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 125 characters.
#'
#' @examples
#' library(performance)
Expand Down Expand Up @@ -88,8 +102,9 @@
orig_x <- x
check_range <- isTRUE(attributes(x)$check_range)
plot_type <- attributes(x)$type
is_stan <- attributes(x)$is_stan

if (missing(type) && !is.null(plot_type) && plot_type %in% c("density", "discrete_dots", "discrete_interval", "discrete_both")) {

Check warning on line 107 in R/plot.check_predictions.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/plot.check_predictions.R,line=107,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 131 characters.

Check warning on line 107 in R/plot.check_predictions.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/plot.check_predictions.R,line=107,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 131 characters.
type <- plot_type
} else {
type <- match.arg(type)
Expand All @@ -111,6 +126,7 @@
size_axis_title = size_axis_title,
type = type,
x_limits = x_limits,
is_stan = is_stan,
...
)

Expand Down Expand Up @@ -143,6 +159,7 @@
orig_x <- x
check_range <- isTRUE(attributes(x)$check_range)
plot_type <- attributes(x)$type
is_stan <- attributes(x)$is_stan

if (missing(type) && !is.null(plot_type) && plot_type %in% c("density", "discrete_dots", "discrete_interval", "discrete_both")) { # nolint
type <- plot_type
Expand All @@ -166,6 +183,7 @@
colors = colors,
type = type,
x_limits = x_limits,
is_stan = is_stan,
...
)

Expand All @@ -187,12 +205,19 @@
base_size = 10,
size_axis_title = 10,
size_title = 12,
colors,

Check warning on line 208 in R/plot.check_predictions.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/plot.check_predictions.R,line=208,col=28,[function_argument_linter] Arguments without defaults should come before arguments with defaults.

Check warning on line 208 in R/plot.check_predictions.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/plot.check_predictions.R,line=208,col=28,[function_argument_linter] Arguments without defaults should come before arguments with defaults.
type = "density",
x_limits = NULL,
is_stan = NULL,
...) {
info <- attr(x, "info")

# discrete plot type from "bayesplot::pp_check()" returns a different data
# structure, so we need to handle it differently
if (isTRUE(is_stan) && type != "density") {
return(.plot_check_predictions_stan_dots(x, colors, info, size_line, size_point, line_alpha, ...))
}

# default bandwidth, for smooting
bandwidth <- info$bandwidth
if (is.null(bandwidth)) {
Expand Down Expand Up @@ -450,6 +475,71 @@
}


.plot_check_predictions_stan_dots <- function(x,
colors,
info,
size_line,
size_point,
line_alpha,
...) {
# make sure we have a factor, so "table()" generates frequencies for all levels
# for each group - we need tables of same size to bind data frames
x$Group[x$Group == "y"] <- "Observed data"
x$Group[x$Group == "Mean"] <- "Model-predicted data"

# sanity check, remove NA rows
x <- x[!is.na(x$Count), ]

p <- ggplot2::ggplot() +
ggplot2::geom_pointrange(
data = x[x$Group == "Model-predicted data", ],
mapping = ggplot2::aes(
x = .data$x,
y = .data$Count,
ymin = .data$CI_low,
ymax = .data$CI_high,
color = .data$Group
),
position = ggplot2::position_nudge(x = 0.2),
size = 0.4 * size_point,
linewidth = size_line,
stroke = 0,
shape = 16
) +
ggplot2::geom_point(
data = x[x$Group == "Observed data", ],
mapping = ggplot2::aes(
x = .data$x,
y = .data$Count,
color = .data$Group
),
size = 1.5 * size_point,
stroke = 0,
shape = 16
) +
ggplot2::scale_y_continuous() +
ggplot2::scale_color_manual(values = c(
"Observed data" = colors[1],
"Model-predicted data" = colors[2]
)) +
ggplot2::labs(
x = info$xlab,
y = info$ylab,
color = "",
size = "",
alpha = "",
title = "Posterior Predictive Check",
subtitle = "Model-predicted intervals should include observed data points"
) +
ggplot2::guides(
color = ggplot2::guide_legend(reverse = TRUE),
size = ggplot2::guide_legend(reverse = TRUE)
)

return(p)

Check warning on line 539 in R/plot.check_predictions.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/plot.check_predictions.R,line=539,col=3,[return_linter] Use implicit return behavior; explicit return() is not needed.
}


.plot_pp_check_range <- function(x,
size_bar = 0.7,
colors = unname(social_colors(c("green", "blue")))) {
Expand Down
Loading