Skip to content

Commit

Permalink
rebase fixes round 2
Browse files Browse the repository at this point in the history
  • Loading branch information
dsweber2 committed Sep 4, 2024
1 parent 74ad5f1 commit aefc06c
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 42 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ importFrom(stats,residuals)
importFrom(tibble,as_tibble)
importFrom(tibble,tibble)
importFrom(tidyr,crossing)
importFrom(tidyr,drop_na)
importFrom(tidyr,expand_grid)
importFrom(tidyr,fill)
importFrom(tidyr,unnest)
Expand Down
8 changes: 4 additions & 4 deletions R/step_adjust_latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ then the previous `step_epi_lag`s won't work with modified data.",
}
if (length(fixed_latency > 1)) {
template <- recipe$template
data_names <- names(template)[!names(template) %in% epi_keys(template)]
data_names <- names(template)[!names(template) %in% key_colnames(template)]
wrong_names <- names(fixed_latency)[!names(fixed_latency) %in% data_names]
if (length(wrong_names) > 0) {
cli::cli_abort("{.val fixed_latency} contains names not in the template dataset: {wrong_names}", class = "epipredict__step_adjust_latency__undefined_names_error")
Expand Down Expand Up @@ -173,7 +173,7 @@ then the previous `step_epi_lag`s won't work with modified data.",
latency = fixed_latency,
latency_table = NULL,
default = default,
keys = epi_keys(recipe),
keys = key_colnames(recipe),
columns = columns,
skip = skip,
id = id
Expand Down Expand Up @@ -213,7 +213,7 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
latency <- x$latency
forecast_date <- x$forecast_date %||% set_forecast_date(training, info, x$epi_keys_checked, latency)
# construct the latency table
latency_table <- names(training)[!names(training) %in% epi_keys(training)] %>%
latency_table <- names(training)[!names(training) %in% key_colnames(training)] %>%
tibble(col_name = .)
if (length(recipes_eval_select(x$terms, training, info)) > 0) {
latency_table <- latency_table %>% filter(col_name %in%
Expand Down Expand Up @@ -299,7 +299,7 @@ bake.step_adjust_latency <- function(object, new_data, ...) {
return(new_data)
} else if (object$method == "locf") {
# locf doesn't need to mess with the metadata at all, it just forward-fills the requested columns
rel_keys <- setdiff(epi_keys(new_data), "time_value")
rel_keys <- setdiff(key_colnames(new_data), "time_value")
object$forecast_date
unnamed_columns <- object$columns %>% unname()
new_data %>%
Expand Down
23 changes: 4 additions & 19 deletions R/step_epi_shift.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@
#' be the lag or lead for each value in the vector. Lag integers must be
#' nonnegative, while ahead integers must be positive.
#' @param prefix A character string that will be prefixed to the new column.
#' @param latency_adjustment a character. Determines the method by which the forecast handles data that doesn't extend to the day the forecast is made. The options are:
#' - `"extend_ahead"`: actually forecasts from the last date. E.g. if there are 3 days of latency for a 4 day ahead forecast, the ahead used in practice is actually 7.
#' - `"locf"`: carries forward the last observed value up to the forecast date.
#' - `"extend_lags"`: per `epi_key` and `predictor`, adjusts the lag so that the shortest lag at predict time is
#' @param default Determines what fills empty rows
#' left by leading/lagging (defaults to NA).
#' @param skip A logical. Should the step be skipped when the
Expand Down Expand Up @@ -71,12 +67,6 @@ step_epi_lag <-
}
arg_is_nonneg_int(lag)
arg_is_chr_scalar(prefix, id)
if (!is.null(columns)) {
cli::cli_abort(c(
"The `columns` argument must be `NULL`.",
i = "Use `tidyselect` methods to choose columns to lag."
))
}
recipes::add_step(
recipe,
step_epi_lag_new(
Expand All @@ -87,8 +77,7 @@ step_epi_lag <-
prefix = prefix,
default = default,
keys = key_colnames(recipe),
columns = columns,
latency_adjustment = latency_adjustment,
columns = NULL,
skip = skip,
id = id
)
Expand All @@ -107,7 +96,6 @@ step_epi_ahead <-
role = "outcome",
prefix = "ahead_",
default = NA,
columns = NULL,
skip = FALSE,
id = rand_id("epi_ahead")) {
if (!is_epi_recipe(recipe)) {
Expand All @@ -121,7 +109,7 @@ step_epi_ahead <-
))
}
arg_is_nonneg_int(ahead)
arg_is_chr_scalar(prefix, id, latency_adjustment)
arg_is_chr_scalar(prefix, id)
recipes::add_step(
recipe,
step_epi_ahead_new(
Expand All @@ -132,8 +120,7 @@ step_epi_ahead <-
prefix = prefix,
default = default,
keys = key_colnames(recipe),
latency_adjustment = latency_adjustment,
columns = columns,
columns = NULL,
skip = skip,
id = id
)
Expand All @@ -143,7 +130,7 @@ step_epi_ahead <-

step_epi_lag_new <-
function(terms, role, trained, lag, prefix, default, keys,
latency_adjustment, columns, skip, id) {
columns, skip, id) {
recipes::step(
subclass = "epi_lag",
terms = terms,
Expand Down Expand Up @@ -189,7 +176,6 @@ prep.step_epi_lag <- function(x, training, info = NULL, ...) {
prefix = x$prefix,
default = x$default,
keys = x$keys,
latency_adjustment = x$latency_adjustment,
columns = recipes::recipes_eval_select(x$terms, training, info),
skip = x$skip,
id = x$id
Expand All @@ -206,7 +192,6 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) {
prefix = x$prefix,
default = x$default,
keys = x$keys,
latency_adjustment = x$latency_adjustment,
columns = recipes::recipes_eval_select(x$terms, training, info),
skip = x$skip,
id = x$id
Expand Down
5 changes: 3 additions & 2 deletions R/utils-latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ construct_shift_tibble <- function(terms_used, recipe, rel_step_type, shift_name
#' Extract the as_of for the forecast date, and make sure there's nothing very off about it.
#' @keywords internal
#' @importFrom dplyr select
#' @importFrom tidyr drop_na
set_forecast_date <- function(new_data, info, epi_keys_checked, latency) {
original_columns <- info %>%
filter(source == "original") %>%
Expand Down Expand Up @@ -166,14 +167,14 @@ fill_locf <- function(x, forecast_date) {
dplyr::mutate(fillers = forecast_date - time_value > keep) %>%
dplyr::summarise(
dplyr::across(
-tidyselect::any_of(epi_keys(recipe)),
-tidyselect::any_of(key_colnames(recipe)),
~ all(is.na(.x[fillers])) & is.na(head(.x[!fillers], 1))
),
.groups = "drop"
) %>%
dplyr::select(-fillers) %>%
dplyr::summarise(dplyr::across(
-tidyselect::any_of(epi_keys(recipe)), ~ any(.x)
-tidyselect::any_of(key_colnames(recipe)), ~ any(.x)
)) %>%
unlist()
x <- tidyr::fill(x, !time_value)
Expand Down
1 change: 0 additions & 1 deletion man/step_epi_shift.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 14 additions & 16 deletions tests/testthat/_snaps/snapshots.md
Original file line number Diff line number Diff line change
Expand Up @@ -1036,26 +1036,24 @@
structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.303244704017743, 0.531332853311082, 0.588827944685979,
0.988690249216229, 0.794801997001639, 0.306895457225321), .pred_distn = structure(list(
structure(list(values = c("5%" = 0.136509784083987, "95%" = 0.469979623951498
structure(list(values = c(0.136509784083987, 0.469979623951498
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c("5%" = 0.364597933377326, "95%" = 0.698067773244837
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c("5%" = 0.422093024752224, "95%" = 0.755562864619735
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c("5%" = 0.821955329282474, "95%" = 1.15542516914998
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
values = c(0.364597933377326, 0.698067773244837), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0.422093024752224,
0.755562864619735), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c("5%" = 0.628067077067883, "95%" = 0.961536916935394
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
values = c(0.821955329282474, 1.15542516914998), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0.628067077067883,
0.961536916935394), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c("5%" = 0.140160537291566, "95%" = 0.473630377159077
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution",
"vctrs_vctr", "list")), forecast_date = structure(c(18997, 18997,
18997, 18997, 18997, 18997), class = "Date"), target_date = structure(c(18998,
values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr"))), class = c("distribution", "vctrs_vctr",
"list")), forecast_date = structure(c(18997, 18997, 18997, 18997,
18997, 18997), class = "Date"), target_date = structure(c(18998,
18998, 18998, 18998, 18998, 18998), class = "Date")), row.names = c(NA,
-6L), class = c("tbl_df", "tbl", "data.frame"))

Expand Down

0 comments on commit aefc06c

Please sign in to comment.