Skip to content

Commit

Permalink
support col_names as tidyselect
Browse files Browse the repository at this point in the history
  • Loading branch information
nmdefries committed Mar 23, 2024
1 parent c6ee7f9 commit 56bed8c
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 61 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ importFrom(rlang,"!!")
importFrom(rlang,.data)
importFrom(rlang,.env)
importFrom(rlang,arg_match)
importFrom(rlang,as_label)
importFrom(rlang,caller_arg)
importFrom(rlang,caller_env)
importFrom(rlang,enquo)
Expand All @@ -146,6 +147,7 @@ importFrom(rlang,is_missing)
importFrom(rlang,is_quosure)
importFrom(rlang,missing_arg)
importFrom(rlang,new_function)
importFrom(rlang,quo_get_expr)
importFrom(rlang,quo_is_missing)
importFrom(rlang,sym)
importFrom(rlang,syms)
Expand Down
46 changes: 32 additions & 14 deletions R/slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
#' @param x The `epi_df` object under consideration, [grouped][dplyr::group_by]
#' or ungrouped. If ungrouped, all data in `x` will be treated as part of a
#' single data group.
#' @param col_names A character vector of the names of one or more columns for
#' which to calculate the rolling mean.
#' @param col_names A single tidyselection or a tidyselection vector of the
#' names of one or more columns for which to calculate the rolling mean.
#' @param ... Additional arguments to pass to `data.table::frollmean`, for
#' example, `na.rm` and `algo`. `data.table::frollmean` is automatically
#' passed the data `x` to operate on, the window size `n`, and the alignment
Expand Down Expand Up @@ -473,7 +473,8 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
#' leading window was intended, but the `after` argument was forgotten or
#' misspelled.)
#'
#' @importFrom dplyr bind_rows mutate %>% arrange tibble
#' @importFrom dplyr bind_rows mutate %>% arrange tibble select
#' @importFrom rlang enquo quo_get_expr as_label
#' @importFrom purrr map
#' @importFrom data.table frollmean
#' @importFrom lubridate as.period
Expand All @@ -484,7 +485,7 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
#' # slide a 7-day trailing average formula on cases
#' jhu_csse_daily_subset %>%
#' group_by(geo_value) %>%
#' epi_slide_mean("cases", new_col_names = "cases_7dav", names_sep = NULL, before = 6) %>%
#' epi_slide_mean(cases, new_col_names = "cases_7dav", names_sep = NULL, before = 6) %>%
#' # Remove a nonessential var. to ensure new col is printed
#' dplyr::select(geo_value, time_value, cases, cases_7dav) %>%
#' ungroup()
Expand All @@ -493,7 +494,7 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
#' # and accuracy, and to allow partially-missing windows.
#' jhu_csse_daily_subset %>%
#' group_by(geo_value) %>%
#' epi_slide_mean("cases",
#' epi_slide_mean(cases,
#' new_col_names = "cases_7dav", names_sep = NULL, before = 6,
#' # `frollmean` options
#' na.rm = TRUE, algo = "exact", hasNA = TRUE
Expand All @@ -504,23 +505,23 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values,
#' # slide a 7-day leading average
#' jhu_csse_daily_subset %>%
#' group_by(geo_value) %>%
#' epi_slide_mean("cases", new_col_names = "cases_7dav", names_sep = NULL, after = 6) %>%
#' epi_slide_mean(cases, new_col_names = "cases_7dav", names_sep = NULL, after = 6) %>%
#' # Remove a nonessential var. to ensure new col is printed
#' dplyr::select(geo_value, time_value, cases, cases_7dav) %>%
#' ungroup()
#'
#' # slide a 7-day centre-aligned average
#' jhu_csse_daily_subset %>%
#' group_by(geo_value) %>%
#' epi_slide_mean("cases", new_col_names = "cases_7dav", names_sep = NULL, before = 3, after = 3) %>%
#' epi_slide_mean(cases, new_col_names = "cases_7dav", names_sep = NULL, before = 3, after = 3) %>%
#' # Remove a nonessential var. to ensure new col is printed
#' dplyr::select(geo_value, time_value, cases, cases_7dav) %>%
#' ungroup()
#'
#' # slide a 14-day centre-aligned average
#' jhu_csse_daily_subset %>%
#' group_by(geo_value) %>%
#' epi_slide_mean("cases", new_col_names = "cases_14dav", names_sep = NULL, before = 6, after = 7) %>%
#' epi_slide_mean(cases, new_col_names = "cases_14dav", names_sep = NULL, before = 6, after = 7) %>%
#' # Remove a nonessential var. to ensure new col is printed
#' dplyr::select(geo_value, time_value, cases, cases_14dav) %>%
#' ungroup()
Expand Down Expand Up @@ -604,29 +605,46 @@ epi_slide_mean <- function(x, col_names, ..., before, after, ref_time_values,
# `before` and `after` params.
m <- before + after + 1L

col_names_quo <- enquo(col_names)
col_names_chr <- as.character(rlang::quo_get_expr(col_names_quo))
if (startsWith(rlang::as_label(col_names_quo), "c(")) {
# List or vector of col names. We need to drop the first element since it
# will be either "c" (if built as a vector) or "list" (if built as a
# list).
col_names_chr <- col_names_chr[-1]
} else if (startsWith(rlang::as_label(col_names_quo), "list(")) {
cli_abort(
"`col_names` must be a single tidy column name or a vector
(`c()`) of tidy column names",
class = "epiprocess__epi_slide_mean__col_names_in_list",
epiprocess__col_names = col_names_chr
)
}
# If single column name, do nothing.

if (is.null(names_sep)) {
if (length(new_col_names) != length(col_names)) {
if (length(new_col_names) != length(col_names_chr)) {
cli_abort(
c(
"`new_col_names` must be the same length as `col_names` when
`names_sep` is NULL to avoid duplicate output column names."
),
class = "epiprocess__epi_slide_mean__col_names_length_mismatch",
epiprocess__new_col_names = new_col_names,
epiprocess__col_names = col_names
epiprocess__col_names = col_names_chr
)
}
result_col_names <- new_col_names
} else {
if (length(new_col_names) != 1L && length(new_col_names) != length(col_names)) {
if (length(new_col_names) != 1L && length(new_col_names) != length(col_names_chr)) {
cli_abort(
"`new_col_names` must be either length 1 or the same length as `col_names`.",
class = "epiprocess__epi_slide_mean__col_names_length_mismatch_and_not_one",
epiprocess__new_col_names = new_col_names,
epiprocess__col_names = col_names
epiprocess__col_names = col_names_chr
)
}
result_col_names <- paste(new_col_names, col_names, sep = names_sep)
result_col_names <- paste(new_col_names, col_names_chr, sep = names_sep)
}

slide_one_grp <- function(.data_group, .group_key, ...) {
Expand Down Expand Up @@ -675,7 +693,7 @@ epi_slide_mean <- function(x, col_names, ..., before, after, ref_time_values,
}

roll_output <- data.table::frollmean(
x = .data_group[, col_names], n = m, align = "right", ...
x = select(.data_group, {{ col_names }}), n = m, align = "right", ...
)

if (after >= 1) {
Expand Down
15 changes: 8 additions & 7 deletions man/epi_slide_mean.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 56bed8c

Please sign in to comment.