Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add S3 methods for sorting & filtering hierarchical tables #2097

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ S3method(plot,tbl_regression)
S3method(plot,tbl_uvregression)
S3method(print,gtsummary)
S3method(print,tbl_split)
S3method(tbl_filter,tbl_hierarchical)
S3method(tbl_regression,brmsfit)
S3method(tbl_regression,crr)
S3method(tbl_regression,default)
Expand All @@ -77,6 +78,7 @@ S3method(tbl_regression,multinom)
S3method(tbl_regression,stanreg)
S3method(tbl_regression,survreg)
S3method(tbl_regression,workflow)
S3method(tbl_sort,tbl_hierarchical)
S3method(tbl_split,gtsummary)
S3method(tbl_survfit,data.frame)
S3method(tbl_survfit,list)
Expand Down Expand Up @@ -210,11 +212,13 @@ export(tbl_butcher)
export(tbl_continuous)
export(tbl_cross)
export(tbl_custom_summary)
export(tbl_filter)
export(tbl_hierarchical)
export(tbl_hierarchical_count)
export(tbl_likert)
export(tbl_merge)
export(tbl_regression)
export(tbl_sort)
export(tbl_split)
export(tbl_stack)
export(tbl_strata)
Expand Down
125 changes: 125 additions & 0 deletions R/filter_tbl_hierarchical.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#' Filter Hierarchical Tables
#'
#' @description `r lifecycle::badge('experimental')`\cr
#'
#' This function is used to filter hierarchical table rows by frequency row sum.
#'
#' @param x (`tbl_hierarchical`, `tbl_hierarchical_count`)\cr
#' A hierarchical gtsummary table of class `'tbl_hierarchical'` or `'tbl_hierarchical_count'`.
#' @param t (scalar `numeric`)\cr
#' Threshold used to determine which rows will be retained.
#' @param gt (scalar `logical`)\cr
#' Whether to filter for row sums greater than `t` or less than `t`. Default is greater than (`gt = TRUE`).
#' @param eq (scalar `logical`)\cr
#' Whether to include the value of `t` in the filtered range, i.e. whether to use exclusive comparators (`>`, `<`) or
#' inclusive comparators (`>=`, `<=`) when filtering. Default is `FALSE`.
#' @param .stat (`string`)\cr
#' Statistic to use to calculate row sums. This statistic must be present in the table for all hierarchy levels.
#' Default is `"n"`.
#' @inheritParams rlang::args_dots_empty
#'
#' @return A `gtsummary` of the same class as `x`.
#'
#' @name filter_tbl_hierarchical
#' @seealso [tbl_sort()]
#'
#' @examplesIf (identical(Sys.getenv("NOT_CRAN"), "true") || identical(Sys.getenv("IN_PKGDOWN"), "true"))
#' ADAE_subset <- cards::ADAE |>
#' dplyr::filter(AETERM %in% unique(cards::ADAE$AETERM)[1:5])
#'
#' tbl <- tbl_hierarchical(
#' data = ADAE_subset,
#' variables = c(SEX, RACE, AETERM),
#' by = TRTA,
#' denominator = cards::ADSL |> mutate(TRTA = ARM),
#' id = USUBJID,
#' overall_row = TRUE
#' )
#'
#' # Example 1 - Row Sums > 10 ------------------
#' tbl_filter(tbl, t = 10)
#'
#' # Example 2 - Row Sums <= 5 ------------------
#' tbl_filter(tbl, t = 10, gt = FALSE, eq = TRUE)
NULL

#' @rdname filter_tbl_hierarchical
#' @export
tbl_filter <- function(x, ...) {
check_not_missing(x)
check_class(x, "gtsummary")

UseMethod("tbl_filter")
}

#' @export
#' @rdname filter_tbl_hierarchical
tbl_filter.tbl_hierarchical <- function(x, t, gt = TRUE, eq = FALSE, .stat = "n", ...) {
set_cli_abort_call()

# process and check inputs ----------------------------------------------------------------------
check_numeric(t)
check_scalar_logical(gt)
check_scalar_logical(eq)
check_string(.stat)

outer_cols <- sapply(
x$table_body |> select(cards::all_ard_groups("names")),
function(x) dplyr::last(unique(stats::na.omit(x)))
)

# get row sums ----------------------------------------------------------------------------------
x <- .append_hierarchy_row_sums(x, .stat)

# keep all summary rows (removed later if no sub-rows are kept)
if (!gt) x$table_body$sum_row[x$table_body$variable %in% outer_cols] <- t - 1

# create and apply filtering expression ---------------------------------------------------------
filt_expr <- paste(
"sum_row",
dplyr::case_when(
gt && eq ~ ">=",
!gt && eq ~ "<=",
!gt ~ "<",
TRUE ~ ">"
),
t
)
x$table_body <- x$table_body |>
dplyr::filter(!!parse_expr(filt_expr))

# remove any summary rows with no sub-rows still present ----------------------------------------
if (!gt) {
for (i in rev(seq_along(outer_cols))) {
gp_empty <- x$table_body |>
dplyr::group_by(across(c(names(outer_cols[1:i]), paste0(names(outer_cols[1:i]), "_level")))) |>
dplyr::summarize(is_empty := dplyr::n() == 1) |>
stats::na.omit()

if (!all(!gp_empty$is_empty)) {
x$table_body <- x$table_body |>
dplyr::left_join(
gp_empty,
by = gp_empty |> select(cards::all_ard_groups()) |> names()
) |>
dplyr::filter(!is_empty | is.na(is_empty)) |>
dplyr::select(-"is_empty")
} else {
break
}
}
if (nrow(x$table_body) > 0) {
cli::cli_inform(
"For readability, all summary rows preceding at least one row that meets the filtering criteria are kept
regardless of whether they meet the filtering criteria themselves.",
.frequency = "once",
.frequency_id = "sum_rows_lt"
)
}
}

x$table_body <- x$table_body |>
dplyr::select(-"sum_row")

x
}
235 changes: 235 additions & 0 deletions R/sort_tbl_hierarchical.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
#' Sort Hierarchical Tables
#'
#' @description `r lifecycle::badge('experimental')`\cr
#'
#' This function is used to sort hierarchical tables. Options for sorting criteria are:
#'
#' 1. Frequency - within each section of the hierarchy table, frequency sums are calculated for each row and rows are
#' ordered accordingly (default).
#' 2. Alphanumeric - rows are ordered alphanumerically by label text. By default, [tbl_hierarchical()] sorts tables
#' in ascending alphanumeric order (i.e. A to Z).
#'
#' @param x (`tbl_hierarchical`, `tbl_hierarchical_count`)\cr
#' A hierarchical gtsummary table of class `'tbl_hierarchical'` or `'tbl_hierarchical_count'`.
#' @param sort (`string`)\cr
#' Specifies sorting to perform. Values must be one of `c("frequency", "alphanumeric")`. Default is `"frequency"`.
#' @param desc (scalar `logical`)\cr
#' Whether to sort rows in ascending or descending order. Default is descending (`TRUE`) when `sort = "frequency"`
#' and ascending (`FALSE`) when `sort = "alphanumeric"`.
#' @param .stat (`string`)\cr
#' Statistic to use to calculate row sums when `sort = "frequency"`. This statistic must be present in the table for
#' all hierarchy levels. Default is `"n"`.
#' @inheritParams rlang::args_dots_empty
#'
#' @return A `gtsummary` of the same class as `x`.
#'
#' @name sort_tbl_hierarchical
#' @seealso [tbl_filter()]
#'
#' @examplesIf (identical(Sys.getenv("NOT_CRAN"), "true") || identical(Sys.getenv("IN_PKGDOWN"), "true"))
#' ADAE_subset <- cards::ADAE |>
#' dplyr::filter(AETERM %in% unique(cards::ADAE$AETERM)[1:5])
#'
#' tbl <- tbl_hierarchical(
#' data = ADAE_subset,
#' variables = c(SEX, RACE, AETERM),
#' by = TRTA,
#' denominator = cards::ADSL |> mutate(TRTA = ARM),
#' id = USUBJID,
#' overall_row = TRUE
#' )
#'
#' # Example 1 - Descending Frequency Sort ------------------
#' tbl_sort(tbl)
#'
#' # Example 2 - Descending Alphanumeric Sort (Z to A) ------
#' tbl_sort(tbl, sort = "alphanumeric", desc = TRUE)
NULL

#' @rdname sort_tbl_hierarchical
#' @export
tbl_sort <- function(x, ...) {
check_not_missing(x)
check_class(x, "gtsummary")

UseMethod("tbl_sort")
}

#' @rdname sort_tbl_hierarchical
#' @export
tbl_sort.tbl_hierarchical <- function(x, sort = "frequency", desc = (sort == "frequency"), .stat = "n", ...) {
set_cli_abort_call()

# process and check inputs ----------------------------------------------------------------------
check_scalar_logical(desc)
check_string(.stat)

if (!sort %in% c("frequency", "alphanumeric")) {
cli::cli_abort(
"The {.arg sort} argument must be either {.val frequency} or {.val alphanumeric}.",
call = get_cli_abort_call()
)
}

overall <- "..ard_hierarchical_overall.." %in% x$table_body$variable
outer_cols <- sapply(
x$table_body |> select(cards::all_ard_groups("names")),
function(x) dplyr::last(unique(stats::na.omit(x)))
)
inner_col <- setdiff(
x$table_body$variable,
x$table_body |> select(cards::all_ard_groups("names")) |> unlist() |> unique()
)

if (sort == "alphanumeric") {
# summary rows remain at the top of each sub-section
rep_str <- if (desc) "zzzz" else " "

# overall row always appears first
if (desc && overall) {
ovrl_row <- x$table_body[1, ]
x$table_body <- x$table_body[-1, ]
}

# sort by label -------------------------------------------------------------------------------
sort_cols <- c(x$table_body |> select(cards::all_ard_groups("levels")) |> names(), "inner_var", "label")

x$table_body <- x$table_body |>
dplyr::rowwise() |>
dplyr::mutate(inner_var = if (!.data$variable %in% inner_col) rep_str else .data$variable) |>
dplyr::ungroup() |>
dplyr::mutate(across(cards::all_ard_groups(), .fns = ~ tidyr::replace_na(., rep_str))) |>
dplyr::arrange(across(all_of(sort_cols), ~ if (desc) dplyr::desc(.x) else .x)) |>
dplyr::mutate(across(cards::all_ard_groups(), .fns = ~ str_replace(., paste0("^", rep_str, "$"), NA))) |>
select(-"inner_var")

if (desc && overall) x$table_body <- dplyr::bind_rows(ovrl_row, x$table_body)
} else {
# get row sums --------------------------------------------------------------------------------
x <- .append_hierarchy_row_sums(x, .stat)

# append outer hierarchy level sums in each row to sort at all levels -------------------------
for (g in names(outer_cols)) {
x$table_body <- x$table_body |> dplyr::group_by(across(all_of(c(g, paste0(g, "_level")))), .add = TRUE)
x$table_body <- x$table_body |>
dplyr::left_join(
x$table_body |>
dplyr::summarize(!!paste0("sum_", g) := dplyr::first(.data$sum_row)),
by = x$table_body |> dplyr::group_vars()
)
}

# summary rows remain at the top of each sub-section
x$table_body <- x$table_body |>
dplyr::ungroup() |>
dplyr::mutate(across(cards::all_ard_groups(), .fns = ~ tidyr::replace_na(., " "))) |>
dplyr::rowwise() |>
dplyr::mutate(inner_var = if (!.data$variable %in% inner_col) " " else .data$variable) |>
dplyr::ungroup()

# sort by row sum -----------------------------------------------------------------------------
sort_cols <- c(rbind(
x$table_body |> select(cards::all_ard_groups("names")) |> names(),
x$table_body |> select(starts_with("sum_group")) |> names(),
x$table_body |> select(cards::all_ard_groups("levels")) |> names()
), "inner_var", "sum_row", "label")

x$table_body <- x$table_body |>
dplyr::arrange(across(all_of(sort_cols), ~ if (is.numeric(.x) && desc) dplyr::desc(.x) else .x)) |>
dplyr::mutate(across(cards::all_ard_groups(), .fns = ~ str_replace(., "^ $", NA))) |>
select(-starts_with("sum_"), -"inner_var")
}

x
}

.append_hierarchy_row_sums <- function(x, .stat) {
cards <- x$cards$tbl_hierarchical

if (!.stat %in% cards$stat_name) {
cli::cli_abort(
"The {.arg .stat} argument is {.val {(.stat)}} but this statistic is not present in {.arg x}. For all valid
statistic options see the {.val stat_name} column of {.code x$cards$tbl_hierarchical}.",
call = get_cli_abort_call()
)
}

by_cols <- if (ncol(x$table_body |> select(starts_with("stat_"))) > 1) c("group1", "group1_level") else NA
outer_cols <- sapply(
x$table_body |> select(cards::all_ard_groups("names")),
function(x) dplyr::last(unique(stats::na.omit(x)))
)

# update logical variable_level entries from overall row to character
cards$variable_level[cards$variable == "..ard_hierarchical_overall.."] <- x$table_body |>
dplyr::filter(.data$variable == "..ard_hierarchical_overall..") |>
dplyr::pull("label") |>
as.list()

# extract row sums ------------------------------------------------------------------------------
cards <- cards |>
dplyr::filter(.data$stat_name == .stat, .data$variable %in% x$table_body$variable) |>
dplyr::group_by(across(c(cards::all_ard_groups(), cards::all_ard_variables(), -all_of(by_cols)))) |>
dplyr::summarise(sum_row = sum(unlist(.data$stat))) |>
dplyr::ungroup() |>
dplyr::rename(label = "variable_level") |>
tidyr::unnest(cols = everything())

# match cards names to x$table_body -------------------------------------------------------------
if (length(by_cols) > 1) {
names(cards)[grep("group", names(cards))] <- x$table_body |>
select(cards::all_ard_groups()) |>
names()
}
cards[cards$variable == "..ard_hierarchical_overall..", 1] <- "..ard_hierarchical_overall.."

# fill in NAs to align cards layout with x$table_body -------------------------------------------
cards <- cards |>
dplyr::rowwise() |>
dplyr::mutate(across(
cards::all_ard_groups(),
~ if (is.na(.x) && !grepl("_level", dplyr::cur_column()) && .data$variable == outer_cols[dplyr::cur_column()]) {
.data$variable
} else if (is.na(.x) && .data$variable %in% outer_cols[gsub("_level", "", dplyr::cur_column())]) {
.data$label
} else {
.x
}
))

# for any variables not in include, calculate group sums ----------------------------------------
if (!all(outer_cols %in% cards$variable)) {
gp_vars <- outer_cols[outer_cols %in% setdiff(outer_cols, cards$variable)]
gp_cols <- names(gp_vars)

cli::cli_inform(
"Not all hierarchy variables present in the table were included in the {.arg include} argument.
These variables ({gp_vars}) do not have event rate data available so the total sum of the event
rates for this hierarchy section will be used instead. To use event rates for all sections of the table,
set {.code include = everything()} when creating your table via {.fun tbl_hierarchical}."
)

for (i in seq_along(gp_cols)) {
cards <- cards |>
dplyr::bind_rows(
cards |>
dplyr::filter(.data$variable != "..ard_hierarchical_overall..") |>
dplyr::group_by(across(c(gp_cols[1:i], paste0(gp_cols[1:i], "_level")))) |>
dplyr::summarize(sum_row = sum(.data$sum_row)) |>
dplyr::mutate(
variable = .data[[gp_cols[i]]],
label = .data[[paste0(gp_cols[i], "_level")]]
)
)
}
}

# append row sums to x$table_body ---------------------------------------------------------------
x$table_body <- x$table_body |>
dplyr::left_join(
cards,
by = c(cards |> select(-"sum_row") |> names())
)

x
}
Loading
Loading