From aefc06c23825e71e19cbd216ced37d29dfd35208 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Wed, 4 Sep 2024 12:29:11 -0500 Subject: [PATCH] rebase fixes round 2 --- NAMESPACE | 1 + R/step_adjust_latency.R | 8 ++++---- R/step_epi_shift.R | 23 ++++------------------- R/utils-latency.R | 5 +++-- man/step_epi_shift.Rd | 1 - tests/testthat/_snaps/snapshots.md | 30 ++++++++++++++---------------- 6 files changed, 26 insertions(+), 42 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 6bfb12fdb..543114967 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -309,6 +309,7 @@ importFrom(stats,residuals) importFrom(tibble,as_tibble) importFrom(tibble,tibble) importFrom(tidyr,crossing) +importFrom(tidyr,drop_na) importFrom(tidyr,expand_grid) importFrom(tidyr,fill) importFrom(tidyr,unnest) diff --git a/R/step_adjust_latency.R b/R/step_adjust_latency.R index 11e1d84fd..bb03f0237 100644 --- a/R/step_adjust_latency.R +++ b/R/step_adjust_latency.R @@ -139,7 +139,7 @@ then the previous `step_epi_lag`s won't work with modified data.", } if (length(fixed_latency > 1)) { template <- recipe$template - data_names <- names(template)[!names(template) %in% epi_keys(template)] + data_names <- names(template)[!names(template) %in% key_colnames(template)] wrong_names <- names(fixed_latency)[!names(fixed_latency) %in% data_names] if (length(wrong_names) > 0) { cli::cli_abort("{.val fixed_latency} contains names not in the template dataset: {wrong_names}", class = "epipredict__step_adjust_latency__undefined_names_error") @@ -173,7 +173,7 @@ then the previous `step_epi_lag`s won't work with modified data.", latency = fixed_latency, latency_table = NULL, default = default, - keys = epi_keys(recipe), + keys = key_colnames(recipe), columns = columns, skip = skip, id = id @@ -213,7 +213,7 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) { latency <- x$latency forecast_date <- x$forecast_date %||% set_forecast_date(training, info, x$epi_keys_checked, latency) # construct the latency table - latency_table <- names(training)[!names(training) %in% epi_keys(training)] %>% + latency_table <- names(training)[!names(training) %in% key_colnames(training)] %>% tibble(col_name = .) if (length(recipes_eval_select(x$terms, training, info)) > 0) { latency_table <- latency_table %>% filter(col_name %in% @@ -299,7 +299,7 @@ bake.step_adjust_latency <- function(object, new_data, ...) { return(new_data) } else if (object$method == "locf") { # locf doesn't need to mess with the metadata at all, it just forward-fills the requested columns - rel_keys <- setdiff(epi_keys(new_data), "time_value") + rel_keys <- setdiff(key_colnames(new_data), "time_value") object$forecast_date unnamed_columns <- object$columns %>% unname() new_data %>% diff --git a/R/step_epi_shift.R b/R/step_epi_shift.R index bbb0fc93d..c1b0f51bc 100644 --- a/R/step_epi_shift.R +++ b/R/step_epi_shift.R @@ -19,10 +19,6 @@ #' be the lag or lead for each value in the vector. Lag integers must be #' nonnegative, while ahead integers must be positive. #' @param prefix A character string that will be prefixed to the new column. -#' @param latency_adjustment a character. Determines the method by which the forecast handles data that doesn't extend to the day the forecast is made. The options are: -#' - `"extend_ahead"`: actually forecasts from the last date. E.g. if there are 3 days of latency for a 4 day ahead forecast, the ahead used in practice is actually 7. -#' - `"locf"`: carries forward the last observed value up to the forecast date. -#' - `"extend_lags"`: per `epi_key` and `predictor`, adjusts the lag so that the shortest lag at predict time is #' @param default Determines what fills empty rows #' left by leading/lagging (defaults to NA). #' @param skip A logical. Should the step be skipped when the @@ -71,12 +67,6 @@ step_epi_lag <- } arg_is_nonneg_int(lag) arg_is_chr_scalar(prefix, id) - if (!is.null(columns)) { - cli::cli_abort(c( - "The `columns` argument must be `NULL`.", - i = "Use `tidyselect` methods to choose columns to lag." - )) - } recipes::add_step( recipe, step_epi_lag_new( @@ -87,8 +77,7 @@ step_epi_lag <- prefix = prefix, default = default, keys = key_colnames(recipe), - columns = columns, - latency_adjustment = latency_adjustment, + columns = NULL, skip = skip, id = id ) @@ -107,7 +96,6 @@ step_epi_ahead <- role = "outcome", prefix = "ahead_", default = NA, - columns = NULL, skip = FALSE, id = rand_id("epi_ahead")) { if (!is_epi_recipe(recipe)) { @@ -121,7 +109,7 @@ step_epi_ahead <- )) } arg_is_nonneg_int(ahead) - arg_is_chr_scalar(prefix, id, latency_adjustment) + arg_is_chr_scalar(prefix, id) recipes::add_step( recipe, step_epi_ahead_new( @@ -132,8 +120,7 @@ step_epi_ahead <- prefix = prefix, default = default, keys = key_colnames(recipe), - latency_adjustment = latency_adjustment, - columns = columns, + columns = NULL, skip = skip, id = id ) @@ -143,7 +130,7 @@ step_epi_ahead <- step_epi_lag_new <- function(terms, role, trained, lag, prefix, default, keys, - latency_adjustment, columns, skip, id) { + columns, skip, id) { recipes::step( subclass = "epi_lag", terms = terms, @@ -189,7 +176,6 @@ prep.step_epi_lag <- function(x, training, info = NULL, ...) { prefix = x$prefix, default = x$default, keys = x$keys, - latency_adjustment = x$latency_adjustment, columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id @@ -206,7 +192,6 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) { prefix = x$prefix, default = x$default, keys = x$keys, - latency_adjustment = x$latency_adjustment, columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id diff --git a/R/utils-latency.R b/R/utils-latency.R index 8356ae1e2..5498e3ede 100644 --- a/R/utils-latency.R +++ b/R/utils-latency.R @@ -32,6 +32,7 @@ construct_shift_tibble <- function(terms_used, recipe, rel_step_type, shift_name #' Extract the as_of for the forecast date, and make sure there's nothing very off about it. #' @keywords internal #' @importFrom dplyr select +#' @importFrom tidyr drop_na set_forecast_date <- function(new_data, info, epi_keys_checked, latency) { original_columns <- info %>% filter(source == "original") %>% @@ -166,14 +167,14 @@ fill_locf <- function(x, forecast_date) { dplyr::mutate(fillers = forecast_date - time_value > keep) %>% dplyr::summarise( dplyr::across( - -tidyselect::any_of(epi_keys(recipe)), + -tidyselect::any_of(key_colnames(recipe)), ~ all(is.na(.x[fillers])) & is.na(head(.x[!fillers], 1)) ), .groups = "drop" ) %>% dplyr::select(-fillers) %>% dplyr::summarise(dplyr::across( - -tidyselect::any_of(epi_keys(recipe)), ~ any(.x) + -tidyselect::any_of(key_colnames(recipe)), ~ any(.x) )) %>% unlist() x <- tidyr::fill(x, !time_value) diff --git a/man/step_epi_shift.Rd b/man/step_epi_shift.Rd index f0f7f2a2f..30ac05d16 100644 --- a/man/step_epi_shift.Rd +++ b/man/step_epi_shift.Rd @@ -23,7 +23,6 @@ step_epi_ahead( role = "outcome", prefix = "ahead_", default = NA, - columns = NULL, skip = FALSE, id = rand_id("epi_ahead") ) diff --git a/tests/testthat/_snaps/snapshots.md b/tests/testthat/_snaps/snapshots.md index fb11026dd..6c439dbd8 100644 --- a/tests/testthat/_snaps/snapshots.md +++ b/tests/testthat/_snaps/snapshots.md @@ -1036,26 +1036,24 @@ structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" ), .pred = c(0.303244704017743, 0.531332853311082, 0.588827944685979, 0.988690249216229, 0.794801997001639, 0.306895457225321), .pred_distn = structure(list( - structure(list(values = c("5%" = 0.136509784083987, "95%" = 0.469979623951498 + structure(list(values = c(0.136509784083987, 0.469979623951498 ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c("5%" = 0.364597933377326, "95%" = 0.698067773244837 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c("5%" = 0.422093024752224, "95%" = 0.755562864619735 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c("5%" = 0.821955329282474, "95%" = 1.15542516914998 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.364597933377326, 0.698067773244837), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.422093024752224, + 0.755562864619735), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c("5%" = 0.628067077067883, "95%" = 0.961536916935394 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + values = c(0.821955329282474, 1.15542516914998), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.628067077067883, + 0.961536916935394), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c("5%" = 0.140160537291566, "95%" = 0.473630377159077 - ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", - "vctrs_vctr", "list")), forecast_date = structure(c(18997, 18997, - 18997, 18997, 18997, 18997), class = "Date"), target_date = structure(c(18998, + values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", + "list")), forecast_date = structure(c(18997, 18997, 18997, 18997, + 18997, 18997), class = "Date"), target_date = structure(c(18998, 18998, 18998, 18998, 18998, 18998), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame"))