Skip to content

Commit

Permalink
refactor: step_epi_slide
Browse files Browse the repository at this point in the history
* remove f_name arg and document prefix argument better
* remove clean_f_name
* validate_slide_fun now rejects formula f
* remove warning about optimized slide functions until that PR
* fix tests
* remove try_period and replace with epiprocess internal
* remove slider dependency
* update documentation
  • Loading branch information
dshemetov committed Aug 3, 2024
1 parent 0d7c001 commit 9c97485
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 267 deletions.
3 changes: 1 addition & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ Imports:
ggplot2,
glue,
hardhat (>= 1.3.0),
lubridate,
magrittr,
quantreg,
recipes (>= 1.0.4),
rlang (>= 1.0.0),
slider,
smoothqr,
stats,
tibble,
Expand All @@ -55,6 +53,7 @@ Suggests:
epidatr (>= 1.0.0),
fs,
knitr,
lubridate,
poissonreg,
purrr,
ranger,
Expand Down
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
export(check_enough_train_data)
export(clean_f_name)
export(default_epi_recipe_blueprint)
export(detect_layer)
export(dist_quantiles)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
`...` args intended for `predict.model_fit()`
- `bake.epi_recipe()` will now re-infer the geo and time type in case baking the
steps has changed the appropriate values
- Add a step to produce generic sliding computations over an `epi_df`
- Add `step_epi_slide` to produce generic sliding computations over an `epi_df`
182 changes: 52 additions & 130 deletions R/step_epi_slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#' that will generate one or more new columns of derived data by "sliding"
#' a computation along existing data.
#'
#'
#' @inheritParams step_epi_lag
#' @param .f A function in one of the following formats:
#' 1. An unquoted function name with no arguments, e.g., `mean`
Expand All @@ -20,27 +19,12 @@
#' argument must be named `.x`. A common, though very difficult to debug
#' error is using something like `function(x) mean`. This will not work
#' because it returns the function mean, rather than `mean(x)`
#' @param f_name a character string of at most 20 characters that describes
#' the function. This will be combined with `prefix` and the columns in `...`
#' to name the result using `{prefix}{f_name}_{column}`. By default it will be determined
#' automatically using `clean_f_name()`.
#' @param before,after non-negative integers.
#' How far `before` and `after` each `time_value` should
#' the sliding window extend? Any value provided for either
#' argument must be a single, non-`NA`, non-negative,
#' [integer-compatible][vctrs::vec_cast] number of time steps. Endpoints of
#' the window are inclusive. Common settings:
#' * For trailing/right-aligned windows from `time_value - time_step(k)` to
#' `time_value`, use `before=k, after=0`. This is the most likely use case
#' for the purposes of forecasting.
#' * For center-aligned windows from `time_value - time_step(k)` to
#' `time_value + time_step(k)`, use `before=k, after=k`.
#' * For leading/left-aligned windows from `time_value` to
#' `time_value + time_step(k)`, use `after=k, after=0`.
#' @param before,after the size of the sliding window on the left and the right
#' of the center. Usually non-negative integers for data indexed by date, but
#' more restrictive in other cases (see [epiprocess::epi_slide()] for details).
#' @param prefix A character string that will be prefixed to the new column.
#' Make sure this is unique for every step.
#'
#' You may also pass a [lubridate::period], like `lubridate::weeks(1)` or a
#' character string that is coercible to a [lubridate::period], like
#' `"2 weeks"`.
#' @template step-return
#'
#' @export
Expand All @@ -62,16 +46,14 @@ step_epi_slide <-
after = 0L,
role = "predictor",
prefix = "epi_slide_",
f_name = clean_f_name(.f),
skip = FALSE,
id = rand_id("epi_slide")) {
if (!is_epi_recipe(recipe)) {
rlang::abort("This recipe step can only operate on an `epi_recipe`.")
}
.f <- validate_slide_fun(.f)
arg_is_scalar(before, after)
before <- try_period(before)
after <- try_period(after)
epiprocess:::validate_slide_window_arg(before, attributes(recipe$template)$metadata$time_type)
epiprocess:::validate_slide_window_arg(after, attributes(recipe$template)$metadata$time_type)
arg_is_chr_scalar(role, prefix, id)
arg_is_lgl_scalar(skip)

Expand All @@ -82,7 +64,6 @@ step_epi_slide <-
before = before,
after = after,
.f = .f,
f_name = f_name,
role = role,
trained = FALSE,
prefix = prefix,
Expand All @@ -100,7 +81,6 @@ step_epi_slide_new <-
before,
after,
.f,
f_name,
role,
trained,
prefix,
Expand All @@ -114,7 +94,6 @@ step_epi_slide_new <-
before = before,
after = after,
.f = .f,
f_name = f_name,
role = role,
trained = trained,
prefix = prefix,
Expand All @@ -126,7 +105,6 @@ step_epi_slide_new <-
}



#' @export
prep.step_epi_slide <- function(x, training, info = NULL, ...) {
col_names <- recipes::recipes_eval_select(x$terms, data = training, info = info)
Expand All @@ -138,7 +116,6 @@ prep.step_epi_slide <- function(x, training, info = NULL, ...) {
before = x$before,
after = x$after,
.f = x$.f,
f_name = x$f_name,
role = x$role,
trained = TRUE,
prefix = x$prefix,
Expand All @@ -150,12 +127,11 @@ prep.step_epi_slide <- function(x, training, info = NULL, ...) {
}



#' @export
bake.step_epi_slide <- function(object, new_data, ...) {
recipes::check_new_data(names(object$columns), object, new_data)
col_names <- object$columns
name_prefix <- paste0(object$prefix, object$f_name, "_")
name_prefix <- object$prefix
newnames <- glue::glue("{name_prefix}{col_names}")
## ensure no name clashes
new_data_names <- colnames(new_data)
Expand All @@ -170,65 +146,70 @@ bake.step_epi_slide <- function(object, new_data, ...) {
class = "epipredict__step__name_collision_error"
)
}
if (any(vapply(c(mean, sum), \(x) identical(x, object$.f), logical(1L)))) {
cli_warn(
c("There is an optimized version of both mean and sum. See `step_epi_slide_mean`, `step_epi_slide_sum`, or `step_epi_slide_opt`."),
class = "epipredict__step_epi_slide__optimized_version"
)
}
# TODO: Uncomment this whenever we make the optimized versions available.
# if (any(vapply(c(mean, sum), \(x) identical(x, object$.f), logical(1L)))) {
# cli_warn(
# c(
# "There is an optimized version of both mean and sum. See `step_epi_slide_mean`, `step_epi_slide_sum`,
# or `step_epi_slide_opt`."
# ),
# class = "epipredict__step_epi_slide__optimized_version"
# )
# }
epi_slide_wrapper(
new_data,
object$before,
object$after,
object$columns,
c(object$.f),
object$f_name,
object$keys[-1],
object$prefix
)
}
#' wrapper to handle epi_slide particulars


#' Wrapper to handle epi_slide particulars
#'
#' @description
#' This should simplify somewhat in the future when we can run `epi_slide` on
#' columns. Surprisingly, lapply is several orders of magnitude faster than
#' using roughly equivalent tidy select style.
#'
#' @param fns vector of functions, even if it's length 1.
#' @param group_keys the keys to group by. likely `epi_keys[-1]` (to remove time_value)
#'
#' @importFrom tidyr crossing
#' @importFrom dplyr bind_cols group_by ungroup
#' @importFrom epiprocess epi_slide
#' @keywords internal
epi_slide_wrapper <- function(new_data, before, after, columns, fns, fn_names, group_keys, name_prefix) {
cols_fns <- tidyr::crossing(col_name = columns, fn_name = fn_names, fn = fns)
epi_slide_wrapper <- function(new_data, before, after, columns, fns, group_keys, name_prefix) {
cols_fns <- tidyr::crossing(col_name = columns, fn = fns)
# Iterate over the rows of cols_fns. For each row number, we will output a
# transformed column. The first result returns all the original columns along
# with the new column. The rest just return the new column.
seq_len(nrow(cols_fns)) %>%
lapply( # iterate over the rows of cols_fns
# takes in the row number, outputs the transformed column
function(comp_i) {
# extract values from the row
col_name <- cols_fns[[comp_i, "col_name"]]
fn_name <- cols_fns[[comp_i, "fn_name"]]
fn <- cols_fns[[comp_i, "fn"]][[1L]]
result_name <- paste(name_prefix, fn_name, col_name, sep = "_")
result <- new_data %>%
group_by(across(all_of(group_keys))) %>%
epi_slide(
before = before,
after = after,
new_col_name = result_name,
f = function(slice, geo_key, ref_time_value) {
fn(slice[[col_name]])
}
) %>%
ungroup()
# the first result needs to include all of the original columns
if (comp_i == 1L) {
result
} else {
# everything else just needs that column transformed
result[result_name]
}
lapply(function(comp_i) {
col_name <- cols_fns[[comp_i, "col_name"]]
fn <- cols_fns[[comp_i, "fn"]][[1L]]
result_name <- paste(name_prefix, col_name, sep = "_")
result <- new_data %>%
group_by(across(all_of(group_keys))) %>%
epi_slide(
before = before,
after = after,
new_col_name = result_name,
f = function(slice, geo_key, ref_time_value) {
fn(slice[[col_name]])
}
) %>%
ungroup()

if (comp_i == 1L) {
result
} else {
result[result_name]
}
) %>%
}) %>%
bind_cols()
}

Expand All @@ -238,81 +219,22 @@ print.step_epi_slide <- function(x, width = max(20, options()$width - 30), ...)
print_epi_step(
x$columns, x$terms, x$trained,
title = "Calculating epi_slide for ",
conjunction = "with", extra_text = x$f_name
conjunction = "with", extra_text = x$prefix
)
invisible(x)
}

#' Create short function names
#'
#' @param .f a function, character string, or lambda. For example, `mean`,
#' `"mean"`, `~ mean(.x)` or `\(x) mean(x, na.rm = TRUE)`.
#' @param max_length integer determining how long names can be
#'
#' @return a character string of length at most `max_length` that
#' (partially) describes the function.
#' @export
#'
#' @examples
#' clean_f_name(mean)
#' clean_f_name("mean")
#' clean_f_name(~ mean(.x, na.rm = TRUE))
#' clean_f_name(\(x) mean(x, na.rm = TRUE))
#' clean_f_name(function(x) mean(x, na.rm = TRUE, trim = 0.2357862))
clean_f_name <- function(.f, max_length = 20L) {
if (rlang::is_formula(.f, scoped = TRUE)) {
f_name <- rlang::f_name(.f)
} else if (rlang::is_character(.f)) {
f_name <- .f
} else if (rlang::is_function(.f)) {
f_name <- as.character(substitute(.f))
if (length(f_name) > 1L) {
f_name <- f_name[3]
if (nchar(f_name) > max_length - 5L) {
f_name <- paste0(substr(f_name, 1L, max(max_length - 8L, 5L)), "...")
}
f_name <- paste0("[ ]{", f_name, "}")
}
}
if (nchar(f_name) > max_length) {
f_name <- paste0(substr(f_name, 1L, max_length - 3L), "...")
}
f_name
}


validate_slide_fun <- function(.f) {
if (rlang::quo(.f) %>% rlang::quo_is_missing()) {
cli_abort("In, `step_epi_slide()`, `.f` may not be missing.")
}
if (rlang::is_formula(.f, scoped = TRUE)) {
if (!is.null(rlang::f_lhs(.f))) {
cli_abort("In, `step_epi_slide()`, `.f` must be a one-sided formula.")
}
cli_abort("In, `step_epi_slide()`, `.f` cannot be a formula.")
} else if (rlang::is_character(.f)) {
.f <- rlang::as_function(.f)
} else if (!rlang::is_function(.f)) {
cli_abort("In, `step_epi_slide()`, `.f` must be a function.")
}
.f
}

try_period <- function(x) {
err <- is.na(x)
if (!err) {
if (is.numeric(x)) {
err <- !rlang::is_integerish(x) || x < 0
} else {
x <- lubridate::as.period(x)
err <- is.na(x)
}
}
if (err) {
cli_abort(paste(
"The value supplied to `before` or `after` must be a non-negative integer",
"a {.cls lubridate::period} or a character scalar that can be coerced",
'as a {.cls lubridate::period}, e.g., `"1 week"`.'
), )
}
x
}
28 changes: 0 additions & 28 deletions man/clean_f_name.Rd

This file was deleted.

3 changes: 1 addition & 2 deletions man/epi_slide_wrapper.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9c97485

Please sign in to comment.