Skip to content

Commit

Permalink
Merge pull request #157 from kylelang/plot_trace_in_loop
Browse files Browse the repository at this point in the history
Adjust the variable parsing for `vrb` argument
  • Loading branch information
hanneoberman authored Jul 26, 2024
2 parents 3210413 + 7fe9028 commit 41222e3
Show file tree
Hide file tree
Showing 15 changed files with 292 additions and 110 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Authors@R: c(
person("Thom", "Volker", role = "ctb", comment = c(ORCID = "0000-0002-2408-7820")),
person("Gerko", "Vink", role = "ctb", comment = c(ORCID = "0000-0001-9767-1924")),
person("Pepijn", "Vink", role = "ctb", comment = c(ORCID = "0000-0001-6960-9904")),
person("Jamie", "Wallis", role = "ctb", comment = c(ORCID = "0000-0003-2765-3813"))
person("Jamie", "Wallis", role = "ctb", comment = c(ORCID = "0000-0003-2765-3813")),
person("Kyle", "Lang", role = "ctb", comment = c(ORCID = "0000-0001-5340-7849"))
)
Description: Enhance a 'mice' imputation workflow with visualizations for
incomplete and/or imputed data. The plotting functions produce
Expand Down Expand Up @@ -46,4 +47,4 @@ Config/testthat/edition: 3
Copyright: 'ggmice' authors
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.2
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## Bug fixes

* Correct labeling of 'exclusion-restriction' variables in `plot_pred()` (#128)
* Correct labeling of 'exclusion-restriction' variables in `plot_pred()` (#128)
* Parsing of `vrb` argument in all `plot_*()` functions: variable name(s) from object in global environment now recognized using `!!` notation (#157)

## Minor changes

Expand Down
14 changes: 5 additions & 9 deletions R/ggmice.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ ggmice <- function(data = NULL,
}
if (length(vrbs) > length(unique(vrbs))) {
cli::cli_abort(
c("The data must have unique column names.",
"x" = "Duplication found in {vrbs[duplicated(vrbs)]}")
c("The data must have unique column names.", "x" = "Duplication found in {vrbs[duplicated(vrbs)]}")
)
}
# extract mapping variables
Expand Down Expand Up @@ -107,8 +106,8 @@ ggmice <- function(data = NULL,
.imp = 0,
.id = rownames(data$data),
data$data
)[!miss_xy,],
data.frame(.where = "imputed", mice::complete(data, action = "long"))[where_xy,]
)[!miss_xy, ],
data.frame(.where = "imputed", mice::complete(data, action = "long"))[where_xy, ]
),
.where = factor(
.where,
Expand Down Expand Up @@ -154,7 +153,6 @@ ggmice <- function(data = NULL,
return(gg)
}


#' Utils function to extract mapping variables
#'
#' @param data Incomplete dataset or mids object.
Expand Down Expand Up @@ -197,11 +195,9 @@ match_mapping <- function(data, vrbs, mapping_in) {
inherits(try(dplyr::mutate(mapping_data,

Check warning on line 195 in R/ggmice.R

View workflow job for this annotation

GitHub Actions / lint

file=R/ggmice.R,line=195,col=8,[indentation_linter] Indentation should be 10 spaces but is 8 spaces.
!!rlang::parse_quo(mapping_text, env = rlang::current_env())),
silent = TRUE)
,
"try-error")) {
, "try-error")) {
cli::cli_abort(
c("Must provide a valid mapping variable.",
"x" = "Mapping variable '{mapping_text}' not found in the data or imputations.")
c("Must provide a valid mapping variable.", "x" = "Mapping variable '{mapping_text}' not found in the data or imputations.")
)
} else {
cli::cli_warn(
Expand Down
55 changes: 32 additions & 23 deletions R/plot_corr.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,22 @@
#' @return An object of class [ggplot2::ggplot].
#'
#' @examples
#' plot_corr(mice::nhanes, label = TRUE)
#' # plot correlations for all columns
#' plot_corr(mice::nhanes)
#'
#' # plot correlations for specific columns by supplying a character vector
#' plot_corr(mice::nhanes, c("chl", "hyp"))
#'
#' # plot correlations for specific columns by supplying unquoted variable names
#' plot_corr(mice::nhanes, c(chl, hyp))
#'
#' # plot correlations for specific columns by passing an object with variable names
#' # from the environment, unquoted with `!!`
#' my_variables <- c("chl", "hyp")
#' plot_corr(mice::nhanes, !!my_variables)
#' # object with variable names must be unquoted with `!!`
#' try(plot_corr(mice::nhanes, my_variables))
#'
#' @export
plot_corr <-
function(data,
Expand All @@ -25,38 +40,33 @@ plot_corr <-
data <- as.data.frame(data)
}
verify_data(data = data, df = TRUE)
vrb <- substitute(vrb)
if (vrb != "all" && length(vrb) < 2) {
vrb <- rlang::enexpr(vrb)
vrb_matched <- match_vrb(vrb, names(data))
if (length(vrb_matched) < 2) {
cli::cli_abort("The number of variables should be two or more to compute correlations.")
}
if (vrb[1] == "all") {
vrb <- names(data)
} else {
data <- dplyr::select(data, {{vrb}})
vrb <- names(data)
}
# check if any column is constant
constants <- apply(data, MARGIN = 2, function(x) {
constants <- apply(data[, vrb_matched], MARGIN = 2, function(x) {
all(is.na(x)) || max(x, na.rm = TRUE) == min(x, na.rm = TRUE)
})
if (any(constants)) {
vrb <- names(data[, !constants])
vrb_matched <- vrb_matched[!constants]
cli::cli_inform(
c(
"No correlations computed for variable(s):",
" " = paste(names(constants[which(constants)]), collapse = ", "),
"x" = "Correlation undefined for constants."
"i" = "Correlations are undefined for constants."
)
)
}

p <- length(vrb)
# compute correlations
p <- length(vrb_matched)
corrs <- data.frame(
vrb = rep(vrb, each = p),
prd = vrb,
vrb = rep(vrb_matched, each = p),
prd = vrb_matched,
corr = matrix(
round(stats::cov2cor(
stats::cov(data.matrix(data[, vrb]), use = "pairwise.complete.obs")
stats::cov(data.matrix(data[, vrb_matched]), use = "pairwise.complete.obs")
), 2),
nrow = p * p,
byrow = TRUE
Expand All @@ -65,6 +75,7 @@ plot_corr <-
if (!diagonal) {
corrs[corrs$vrb == corrs$prd, "corr"] <- NA
}
# create plot
gg <-
ggplot2::ggplot(corrs,
ggplot2::aes(
Expand All @@ -74,8 +85,8 @@ plot_corr <-
fill = .data$corr
)) +
ggplot2::geom_tile(color = "black", alpha = 0.6) +
ggplot2::scale_x_discrete(limits = vrb, position = "top") +
ggplot2::scale_y_discrete(limits = rev(vrb)) +
ggplot2::scale_x_discrete(limits = vrb_matched, position = "top") +
ggplot2::scale_y_discrete(limits = rev(vrb_matched)) +
ggplot2::scale_fill_gradient2(
low = ggplot2::alpha("deepskyblue", 0.6),
mid = "lightyellow",
Expand All @@ -91,13 +102,11 @@ plot_corr <-
y = "Variable to impute",
fill = "Correlation*
",
caption = "*pairwise complete observations"
caption = "*pairwise complete observations"
)
} else {
gg <- gg +
ggplot2::labs(x = "Imputation model predictor",
y = "Variable to impute",
fill = "Correlation")
ggplot2::labs(x = "Imputation model predictor", y = "Variable to impute", fill = "Correlation")
}
if (label) {
gg <-
Expand Down
30 changes: 21 additions & 9 deletions R/plot_flux.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,37 @@
#' @return An object of class [ggplot2::ggplot].
#'
#' @examples
#' # plot flux for all columns
#' plot_flux(mice::nhanes)
#'
#' # plot flux for specific columns by supplying a character vector
#' plot_flux(mice::nhanes, c("chl", "hyp"))
#'
#' # plot flux for specific columns by supplying unquoted variable names
#' plot_flux(mice::nhanes, c(chl, hyp))
#'
#' # plot flux for specific columns by passing an object with variable names
#' # from the environment, unquoted with `!!`
#' my_variables <- c("chl", "hyp")
#' plot_flux(mice::nhanes, !!my_variables)
#' # object with variable names must be unquoted with `!!`
#' try(plot_flux(mice::nhanes, my_variables))
#'
#' @export
plot_flux <-
function(data,
vrb = "all",
label = TRUE,
caption = TRUE) {
verify_data(data, df = TRUE)
vrb <- substitute(vrb)
if (vrb != "all" && length(vrb) < 2) {
vrb <- rlang::enexpr(vrb)
vrb_matched <- match_vrb(vrb, names(data))
if (length(vrb_matched) < 2) {
cli::cli_abort("The number of variables should be two or more to compute flux.")
}
if (vrb[1] == "all") {
vrb <- names(data)
} else {
vrb <- names(dplyr::select(data, {{vrb}}))
}
# plot in and outflux
flx <- mice::flux(data[, vrb])[, c("influx", "outflux")]
# compute flux
flx <- mice::flux(data[, vrb_matched])[, c("influx", "outflux")]
# create plot
gg <-
data.frame(
vrb = rownames(flx),
Expand Down
32 changes: 21 additions & 11 deletions R/plot_pattern.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,22 @@
#' @return An object of class [ggplot2::ggplot].
#'
#' @examples
#' # plot missing data pattern for all columns
#' plot_pattern(mice::nhanes)
#'
#' # plot missing data pattern for specific columns by supplying a character vector
#' plot_pattern(mice::nhanes, c("chl", "hyp"))
#'
#' # plot missing data pattern for specific columns by supplying unquoted variable names
#' plot_pattern(mice::nhanes, c(chl, hyp))
#'
#' # plot missing data pattern for specific columns by passing an object with variable names
#' # from the environment, unquoted with `!!`
#' my_variables <- c("chl", "hyp")
#' plot_pattern(mice::nhanes, !!my_variables)
#' # object with variable names must be unquoted with `!!`
#' try(plot_pattern(mice::nhanes, my_variables))
#'
#' @export
plot_pattern <-
function(data,
Expand All @@ -21,21 +36,16 @@ plot_pattern <-
cluster = NULL,
npat = NULL,
caption = TRUE) {
# input processing
if (is.matrix(data) && ncol(data) > 1) {
data <- as.data.frame(data)
}
verify_data(data, df = TRUE)
vrb <- substitute(vrb)
if (vrb != "all" && length(vrb) < 2) {
vrb <- rlang::enexpr(vrb)
vrb_matched <- match_vrb(vrb, names(data))
if (length(vrb_matched) < 2) {
cli::cli_abort("The number of variables should be two or more to compute missing data patterns.")
}
if (vrb[1] == "all") {
vrb <- names(data)
} else {
vrb <- names(dplyr::select(as.data.frame(data), {{vrb}}))
}
if (".x" %in% vrb || ".y" %in% vrb) {
if (".x" %in% vrb_matched || ".y" %in% vrb_matched) {
cli::cli_abort(
c(
"The variable names '.x' and '.y' are used internally to produce the missing data pattern plot.",
Expand All @@ -44,7 +54,7 @@ plot_pattern <-
)
}
if (!is.null(cluster)) {
if (cluster %nin% names(data[, vrb])) {
if (cluster %nin% names(data[, vrb_matched])) {
cli::cli_abort(
c("Cluster variable not recognized.",
"i" = "Please provide the variable name as a character string.")
Expand All @@ -61,7 +71,7 @@ plot_pattern <-
}

# get missing data pattern
pat <- mice::md.pattern(data[, vrb], plot = FALSE)
pat <- mice::md.pattern(data[, vrb_matched], plot = FALSE)
rows_pat_full <-
(nrow(pat) - 1) # full number of missing data patterns

Expand Down
36 changes: 24 additions & 12 deletions R/plot_pred.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,25 @@
#' @return An object of class `ggplot2::ggplot`.
#'
#' @examples
#' # generate a predictor matrix
#' pred <- mice::quickpred(mice::nhanes)
#'
#' # plot predictor matrix for all columns
#' plot_pred(pred)
#'
#' # plot predictor matrix for specific columns by supplying a character vector
#' plot_pred(pred, c("chl", "hyp"))
#'
#' # plot predictor matrix for specific columns by supplying unquoted variable names
#' plot_pred(pred, c(chl, hyp))
#'
#' # plot predictor matrix for specific columns by passing an object with variable names
#' # from the environment, unquoted with `!!`
#' my_variables <- c("chl", "hyp")
#' plot_pred(pred, !!my_variables)
#' # object with variable names must be unquoted with `!!`
#' try(plot_pred(pred, my_variables))
#'
#' @export
plot_pred <-
function(data,
Expand All @@ -21,7 +38,9 @@ plot_pred <-
square = TRUE,
rotate = FALSE) {
verify_data(data, pred = TRUE)
p <- nrow(data)
vrb <- rlang::enexpr(vrb)
vrb_matched <- match_vrb(vrb, row.names(data))
p <- length(vrb_matched)
if (!is.null(method) && is.character(method)) {
if (length(method) == 1) {
method <- rep(method, p)
Expand All @@ -37,17 +56,10 @@ plot_pred <-
if (!is.character(method) || length(method) != p) {
cli::cli_abort("Method should be NULL or a character string or vector (of length 1 or `ncol(data)`).")
}
vrb <- substitute(vrb)
if (vrb[1] == "all") {
vrb <- names(data)
} else {
vrb <- names(dplyr::select(as.data.frame(data), {{vrb}}))
}
vrbs <- row.names(data)
long <- data.frame(
vrb = 1:p,
prd = rep(vrbs, each = p),
ind = matrix(data, nrow = p * p, byrow = TRUE)
prd = rep(vrb_matched, each = p),
ind = matrix(data[vrb_matched, vrb_matched], nrow = p * p, byrow = TRUE)
) %>% dplyr::mutate(clr = factor(
.data$ind,
levels = c(-3, -2, 0, 1, 2),
Expand All @@ -70,10 +82,10 @@ plot_pred <-
fill = .data$clr
)) +
ggplot2::geom_tile(color = "black", alpha = 0.6) +
ggplot2::scale_x_discrete(limits = vrbs, position = "top") +
ggplot2::scale_x_discrete(limits = vrb_matched, position = "top") +
ggplot2::scale_y_reverse(
breaks = 1:p,
labels = vrbs,
labels = vrb_matched,
sec.axis = ggplot2::dup_axis(labels = method, name = ylabel)
) +
ggplot2::scale_fill_manual(
Expand Down
Loading

0 comments on commit 41222e3

Please sign in to comment.