Skip to content

Commit

Permalink
moving locf to step_adjust_ahead instead of get_test_data
Browse files Browse the repository at this point in the history
  • Loading branch information
dsweber2 committed Jul 3, 2024
1 parent 13c70a1 commit ec9a2e3
Show file tree
Hide file tree
Showing 20 changed files with 223 additions and 245 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Imports:
quantreg,
recipes (>= 1.0.4),
rlang (>= 1.0.0),
purrr,
smoothqr,
stats,
tibble,
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ importFrom(dplyr,"%>%")
importFrom(dplyr,across)
importFrom(dplyr,all_of)
importFrom(dplyr,group_by)
importFrom(dplyr,group_by_at)
importFrom(dplyr,join_by)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
Expand Down Expand Up @@ -274,6 +275,7 @@ importFrom(stats,residuals)
importFrom(tibble,tibble)
importFrom(tidyr,drop_na)
importFrom(tidyr,expand_grid)
importFrom(tidyr,fill)
importFrom(tidyr,unnest)
importFrom(vctrs,as_list_of)
importFrom(vctrs,field)
Expand Down
6 changes: 0 additions & 6 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ arx_classifier <- function(
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
preds <- forecast(
wf,
fill_locf = is.null(args_list$adjust_latency),
n_recent = args_list$nafill_buffer,
forecast_date = forecast_date
) %>%
tibble::as_tibble() %>%
dplyr::select(-time_value)
Expand Down Expand Up @@ -292,7 +289,6 @@ arx_class_args_list <- function(
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
log_scale = FALSE,
additional_gr_args = list(),
nafill_buffer = Inf,
check_enough_data_n = NULL,
check_enough_data_epi_keys = NULL,
...) {
Expand All @@ -310,7 +306,6 @@ arx_class_args_list <- function(
arg_is_lgl(log_scale)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
if (!is.list(additional_gr_args)) {
cli::cli_abort(
c("`additional_gr_args` must be a {.cls list}.",
Expand Down Expand Up @@ -352,7 +347,6 @@ arx_class_args_list <- function(
method,
log_scale,
additional_gr_args,
nafill_buffer,
check_enough_data_n,
check_enough_data_epi_keys
),
Expand Down
19 changes: 1 addition & 18 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,7 @@ arx_forecaster <- function(
wf <- arx_fcast_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
wf <- generics::fit(wf, epi_data)

preds <- forecast(
wf,
fill_locf = is.null(args_list$adjust_latency),
n_recent = args_list$nafill_buffer,
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
) %>%
preds <- forecast(wf) %>%
tibble::as_tibble() %>%
dplyr::select(-time_value)

Expand Down Expand Up @@ -251,15 +246,6 @@ arx_fcast_epi_workflow <- function(
#' `character(0)` performs no grouping. This argument only applies when
#' residual quantiles are used. It is not applicable with
#' `trainer = quantile_reg()`, for example.
#' @param nafill_buffer At predict time, recent values of the training data
#' are used to create a forecast. However, these can be `NA` due to, e.g.,
#' data latency issues. By default, any missing values will get filled with
#' less recent data. Setting this value to `NULL` will result in 1 extra
#' recent row (beyond those required for lag creation) to be used. Note that
#' we require at least `min(lags)` rows of recent data per `geo_value` to
#' create a prediction. For this reason, setting `nafill_buffer < min(lags)`
#' will be treated as _additional_ allowed recent data rather than the
#' total amount of recent data to examine.
#' @param check_enough_data_n Integer. A lower limit for the number of rows per
#' epi_key that are required for training. If `NULL`, this check is ignored.
#' @param check_enough_data_epi_keys Character vector. A character vector of
Expand All @@ -286,7 +272,6 @@ arx_args_list <- function(
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = character(0L),
nafill_buffer = Inf,
check_enough_data_n = NULL,
check_enough_data_epi_keys = NULL,
...) {
Expand All @@ -304,7 +289,6 @@ arx_args_list <- function(
arg_is_probabilities(quantile_levels, allow_null = TRUE)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
arg_is_pos(check_enough_data_n, allow_null = TRUE)
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)

Expand All @@ -331,7 +315,6 @@ arx_args_list <- function(
nonneg,
max_lags,
quantile_by_key,
nafill_buffer,
check_enough_data_n,
check_enough_data_epi_keys
),
Expand Down
9 changes: 2 additions & 7 deletions R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ cdc_baseline_forecaster <- function(


latest <- get_test_data(
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
forecast_date
)
epi_recipe(epi_data), epi_data)

f <- frosting() %>%
layer_predict() %>%
Expand Down Expand Up @@ -169,7 +167,6 @@ cdc_baseline_args_list <- function(
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = "geo_value",
nafill_buffer = Inf,
...) {
rlang::check_dots_empty()
arg_is_scalar(n_training, nsims, data_frequency)
Expand All @@ -183,7 +180,6 @@ cdc_baseline_args_list <- function(
arg_is_probabilities(quantile_levels, allow_null = TRUE)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)

structure(
enlist(
Expand All @@ -195,8 +191,7 @@ cdc_baseline_args_list <- function(
nsims,
symmetrize,
nonneg,
quantile_by_key,
nafill_buffer
quantile_by_key
),
class = c("cdc_baseline_fcast", "alist")
)
Expand Down
6 changes: 1 addition & 5 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,9 @@ forecast.epi_workflow <- function(object, ..., fill_locf = FALSE, n_recent = NUL
))
}
}

test_data <- get_test_data(
hardhat::extract_preprocessor(object),
object$original_data,
fill_locf = fill_locf,
n_recent = n_recent %||% Inf,
forecast_date = forecast_date %||% frosting_fd %||% max(object$original_data$time_value)
object$original_data
)

predict(object, new_data = test_data)
Expand Down
9 changes: 2 additions & 7 deletions R/flatline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ flatline_forecaster <- function(
wf <- generics::fit(wf, epi_data)
preds <- suppressWarnings(forecast(
wf,
fill_locf = TRUE,
n_recent = args_list$nafill_buffer,
forecast_date = forecast_date
fill_locf = TRUE
)) %>%
tibble::as_tibble() %>%
dplyr::select(-time_value)
Expand Down Expand Up @@ -116,7 +114,6 @@ flatline_args_list <- function(
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = character(0L),
nafill_buffer = Inf,
...) {
rlang::check_dots_empty()
arg_is_scalar(ahead, n_training)
Expand All @@ -128,7 +125,6 @@ flatline_args_list <- function(
arg_is_probabilities(quantile_levels, allow_null = TRUE)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)

if (!is.null(forecast_date) && !is.null(target_date)) {
if (forecast_date + ahead != target_date) {
Expand All @@ -148,8 +144,7 @@ flatline_args_list <- function(
quantile_levels,
symmetrize,
nonneg,
quantile_by_key,
nafill_buffer
quantile_by_key
),
class = c("flat_fcast", "alist")
)
Expand Down
103 changes: 7 additions & 96 deletions R/get_test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@
#' @param recipe A recipe object.
#' @param x An epi_df. The typical usage is to
#' pass the same data as that used for fitting the recipe.
#' @param fill_locf Logical. Should we use `locf` to fill in missing data?
#' @param n_recent Integer or NULL. If filling missing data with `locf = TRUE`,
#' how far back are we willing to tolerate missing data? Larger values allow
#' more filling. The default `NULL` will determine this from the
#' the `recipe`. For example, suppose `n_recent = 3`, then if the
#' 3 most recent observations in any `geo_value` are all `NA`’s, we won’t be
#' able to fill anything, and an error message will be thrown. (See details.)
#' @param forecast_date By default, this is set to the maximum
#' `time_value` in `x`. But if there is data latency such that recent `NA`'s
#' should be filled, this may be _after_ the last available `time_value`.
Expand All @@ -45,18 +38,8 @@

get_test_data <- function(
recipe,
x,
fill_locf = FALSE,
n_recent = NULL,
forecast_date = max(x$time_value)) {
x) {
if (!is_epi_df(x)) cli::cli_abort("`x` must be an `epi_df`.")
arg_is_lgl(fill_locf)
arg_is_scalar(fill_locf)
arg_is_scalar(n_recent, allow_null = TRUE)
if (!is.null(n_recent) && is.finite(n_recent)) {
arg_is_pos_int(n_recent, allow_null = TRUE)
}
if (!is.null(n_recent)) n_recent <- abs(n_recent) # in case they passed -Inf

check <- hardhat::check_column_names(x, colnames(recipe$template))
if (!check$ok) {
Expand All @@ -66,106 +49,34 @@ get_test_data <- function(
))
}

if (class(forecast_date) != class(x$time_value)) {
cli::cli_abort("`forecast_date` must be the same class as `x$time_value`.")
}


if (forecast_date < max(x$time_value)) {
cli::cli_abort("`forecast_date` must be no earlier than `max(x$time_value)`")
}

min_lags <- min(map_dbl(recipe$steps, ~ min(.x$lag %||% Inf)), Inf)
max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0)), 0)
max_horizon <- max(map_dbl(recipe$steps, ~ max(.x$horizon %||% 0)), 0)
min_required <- max_lags + max_horizon
if (is.null(n_recent)) n_recent <- min_required + 1 # one extra for filling
if (n_recent <= min_required) n_recent <- min_required + n_recent
keep <- max_lags + max_horizon

# CHECK: Error out if insufficient training data
# Probably needs a fix based on the time_type of the epi_df
avail_recent <- diff(range(x$time_value))
if (avail_recent < min_required) {
if (avail_recent < keep) {
cli::cli_abort(c(
"You supplied insufficient recent data for this recipe. ",
"!" = "You need at least {min_required} days of data,",
"!" = "but `x` contains only {avail_recent}."
))
}

max_time_value <- x %>% na.omit %>% pull(time_value) %>% max
x <- arrange(x, time_value)
groups <- kill_time_value(epi_keys(recipe))

# If we skip NA completion, we remove undesirably early time values
# Happens globally, over all groups
keep <- max(n_recent, min_required + 1)
x <- dplyr::filter(x, forecast_date - time_value <= keep)

# Pad with explicit missing values up to and including the forecast_date
# x is grouped here
x <- pad_to_end(x, groups, forecast_date) %>%
epiprocess::group_by(dplyr::across(dplyr::all_of(groups)))
x <- dplyr::filter(x, max_time_value - time_value <= keep)

# If all(lags > 0), then we get rid of recent data
if (min_lags > 0 && min_lags < Inf) {
x <- dplyr::filter(x, forecast_date - time_value >= min_lags)
x <- dplyr::filter(x, max_time_value - time_value >= min_lags)
}

# Now, fill forward missing data if requested
if (fill_locf) {
cannot_be_used <- x %>%
dplyr::filter(forecast_date - time_value <= n_recent) %>%
dplyr::mutate(fillers = forecast_date - time_value > min_required) %>%
dplyr::summarise(
dplyr::across(
-tidyselect::any_of(epi_keys(recipe)),
~ all(is.na(.x[fillers])) & is.na(head(.x[!fillers], 1))
),
.groups = "drop"
) %>%
dplyr::select(-fillers) %>%
dplyr::summarise(dplyr::across(
-tidyselect::any_of(epi_keys(recipe)), ~ any(.x)
)) %>%
unlist()
if (any(cannot_be_used)) {
bad_vars <- names(cannot_be_used)[cannot_be_used]
if (recipes::is_trained(recipe)) {
cli::cli_abort(c(
"The variables {.var {bad_vars}} have too many recent missing",
`!` = "values to be filled automatically. ",
i = "You should either choose `n_recent` larger than its current ",
i = "value {n_recent}, or perform NA imputation manually, perhaps with ",
i = "{.code recipes::step_impute_*()} or with {.code tidyr::fill()}."
))
}
}
x <- tidyr::fill(x, !time_value)
}

dplyr::filter(x, forecast_date - time_value <= min_required) %>%
dplyr::filter(x, max_time_value - time_value <= keep) %>%
epiprocess::ungroup()
}

pad_to_end <- function(x, groups, end_date) {
itval <- epiprocess:::guess_period(c(x$time_value, end_date), "time_value")
completed_time_values <- x %>%
dplyr::group_by(dplyr::across(tidyselect::all_of(groups))) %>%
dplyr::summarise(
time_value = rlang::list2(
time_value = Seq(max(time_value) + itval, end_date, itval)
)
) %>%
unnest("time_value") %>%
mutate(time_value = vctrs::vec_cast(time_value, x$time_value))

dplyr::bind_rows(x, completed_time_values) %>%
dplyr::arrange(dplyr::across(tidyselect::all_of(c("time_value", groups))))
}

Seq <- function(from, to, by) {
if (from > to) {
return(NULL)
}
seq(from = from, to = to, by = by)
}
Loading

0 comments on commit ec9a2e3

Please sign in to comment.