Skip to content

Commit

Permalink
rebase fixes, error classes, unskip latency tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dsweber2 committed Jul 30, 2024
1 parent a28ad82 commit 986b657
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 116 deletions.
3 changes: 2 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ importFrom(quantreg,rq)
importFrom(recipes,bake)
importFrom(recipes,detect_step)
importFrom(recipes,prep)
importFrom(rlang,"!!!")
importFrom(recipes,recipes_eval_select)
importFrom(rlang,"!!!")
importFrom(rlang,"!!")
importFrom(rlang,"%@%")
importFrom(rlang,"%||%")
Expand All @@ -262,6 +262,7 @@ importFrom(rlang,caller_env)
importFrom(rlang,enquos)
importFrom(rlang,global_env)
importFrom(rlang,inject)
importFrom(rlang,is_empty)
importFrom(rlang,is_logical)
importFrom(rlang,is_null)
importFrom(rlang,is_true)
Expand Down
3 changes: 2 additions & 1 deletion R/canned-epipred.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ print.canned_epipred <- function(x, name, ...) {
purrr::map("columns") %>%
reduce(c)
latency_per_base_col <- latency_step$latency_table %>%
filter(col_name %in% valid_columns) %>% mutate(latency = abs(latency))
filter(col_name %in% valid_columns) %>%
mutate(latency = abs(latency))
if (latency_step$method != "locf" && nrow(latency_per_base_col) > 1) {
intro_text <- glue::glue("{type_str} adjusted per column: ")
} else if (latency_step$method != "locf") {
Expand Down
3 changes: 2 additions & 1 deletion R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ cdc_baseline_forecaster <- function(


latest <- get_test_data(
epi_recipe(epi_data), epi_data)
epi_recipe(epi_data), epi_data
)

f <- frosting() %>%
layer_predict() %>%
Expand Down
7 changes: 4 additions & 3 deletions R/epi_shift.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,17 @@ add_shifted_columns <- function(new_data, object, amount) {
shift_sign_lat <- attributes(new_data)$metadata$shift_sign
if (!is.null(latency_table) &&
shift_sign_lat == sign_shift) {
#TODO this doesn't work on lags of transforms
# TODO this doesn't work on lags of transforms
rel_latency <- latency_table %>% filter(col_name %in% object$columns)
} else {
rel_latency <- tibble(col_name = object$columns, latency = 0L)
}
grid <- expand_grid(col = object$columns, amount = sign_shift *amount) %>%
grid <- expand_grid(col = object$columns, amount = sign_shift * amount) %>%
left_join(rel_latency, by = join_by(col == col_name), ) %>%
tidyr::replace_na(list(latency = 0)) %>%
dplyr::mutate(
shift_val = amount + latency) %>%
shift_val = amount + latency
) %>%
mutate(
newname = glue::glue("{object$prefix}{abs(shift_val)}_{col}"), # name is always positive
amount = NULL,
Expand Down
18 changes: 5 additions & 13 deletions R/get_test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,9 @@
#' used if growth rate calculations are requested by the recipe. This is
#' calculated internally.
#'
#' It also optionally fills missing values
#' using the last-observation-carried-forward (LOCF) method. If this
#' is not possible (say because there would be only `NA`'s in some location),
#' it will produce an error suggesting alternative options to handle missing
#' values with more advanced techniques.
#'
#' @param recipe A recipe object.
#' @param x An epi_df. The typical usage is to
#' pass the same data as that used for fitting the recipe.
#' @param forecast_date By default, this is set to the maximum
#' `time_value` in `x`. But if there is data latency such that recent `NA`'s
#' should be filled, this may be _after_ the last available `time_value`.
#'
#' @return An object of the same type as `x` with columns `geo_value`, `time_value`, any additional
#' keys, as well other variables in the original dataset.
Expand All @@ -36,9 +27,7 @@
#' @importFrom rlang %@%
#' @export

get_test_data <- function(
recipe,
x) {
get_test_data <- function(recipe, x) {
if (!is_epi_df(x)) cli::cli_abort("`x` must be an `epi_df`.")

check <- hardhat::check_column_names(x, colnames(recipe$template))
Expand All @@ -64,7 +53,10 @@ get_test_data <- function(
"!" = "but `x` contains only {avail_recent}."
))
}
max_time_value <- x %>% na.omit %>% pull(time_value) %>% max
max_time_value <- x %>%
na.omit() %>%
pull(time_value) %>%
max()
x <- arrange(x, time_value)
groups <- kill_time_value(epi_keys(recipe))

Expand Down
26 changes: 15 additions & 11 deletions R/step_adjust_latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
#' jhu_fit
#'
#' @importFrom recipes detect_step
#' @importFrom rlang enquos
#' @importFrom rlang enquos is_empty
step_adjust_latency <-
function(recipe,
...,
Expand All @@ -106,39 +106,43 @@ step_adjust_latency <-
id = recipes::rand_id("adjust_latency")) {
arg_is_chr_scalar(id, method)
if (!is_epi_recipe(recipe)) {
cli::cli_abort("This recipe step can only operate on an {.cls epi_recipe}.")
cli::cli_abort("This recipe step can only operate on an {.cls epi_recipe}.", class = "epipredict__step_adjust_latency__epi_recipe_only")
}
if (!is.null(columns)) {
cli::cli_abort(c("The `columns` argument must be `NULL`.",
i = "Use `tidyselect` methods to choose columns to lag."
))
), class = "epipredict__step_adjust_latency__cols_not_null")
}
if ((method == "extend_ahead") && (detect_step(recipe, "epi_ahead"))) {
cli::cli_warn(
"If `method` is {.val extend_ahead}, then the previous `step_epi_ahead` won't be modified."
"If `method` is {.val extend_ahead}, then the previous `step_epi_ahead` won't be modified.",
class = "epipredict__step_adjust_latency__misordered_step_warning"
)
} else if ((method == "extend_lags") && detect_step(recipe, "epi_lag")) {
cli::cli_warn(
"If `method` is {.val extend_lags} or {.val locf},
then the previous `step_epi_lag`s won't work with modified data."
then the previous `step_epi_lag`s won't work with modified data.",
class = "epipredict__step_adjust_latency__misordered_step_warning"
)
} else if ((method == "locf") && (length(recipe$steps) > 0)) {
cli::cli_warn("There are steps before `step_adjust_latency`. With the method {.val locf}, it is recommended to include this step before any others")
cli::cli_warn("There are steps before `step_adjust_latency`. With the method {.val locf}, it is recommended to include this step before any others",
class = "epipredict__step_adjust_latency__misordered_step_warning"
)
}
if (detect_step(recipe, "naomit")) {
cli::cli_abort("adjust_latency needs to occur before any `NA` removal,
as columns may be moved around")
as columns may be moved around", class = "epipredict__step_adjust_latency__post_NA_error")
}
if (!is.null(fixed_latency) && !is.null(fixed_forecast_date)) {
cli::cli_abort("Only one of `fixed_latency` and `fixed_forecast_date`
can be non-`NULL` at a time!")
can be non-`NULL` at a time!", class = "epipredict__step_adjust_latency__too_many_args_error")
}
if (length(fixed_latency > 1)) {
template <- recipe$template
data_names <- names(template)[!names(template) %in% epi_keys(template)]
wrong_names <- names(fixed_latency)[!names(fixed_latency) %in% data_names]
if (length(wrong_names) > 0) {
cli::cli_abort("{.val fixed_latency} contains names not in the template dataset: {wrong_names}")
cli::cli_abort("{.val fixed_latency} contains names not in the template dataset: {wrong_names}", class = "epipredict__step_adjust_latency__undefined_names_error")
}
}

Expand Down Expand Up @@ -258,8 +262,8 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
"{time_type}."
),
"i" = "latency: {latency_table$latency[[i_latency]]}",
"i" = "`max_time` = {max_time} -> `forecast_date` = {forecast_date}"
))
"i" = "`max_time` = {max(training$time_value)} -> `forecast_date` = {forecast_date}"
), class = "epipredict__prep.step_latency__very_large_latency")
}

step_adjust_latency_new(
Expand Down
52 changes: 32 additions & 20 deletions R/utils-latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ set_forecast_date <- function(new_data, info, epi_keys_checked, latency) {
pull(variable)
# make sure that there's enough column names
if (length(original_columns) < 3) {
cli::cli_abort(glue::glue(
"The original columns of `time_value`, ",
"`geo_value` and at least one signal. The current colums are \n",
paste(capture.output(object$info), collapse = "\n\n")
))
cli::cli_abort(
glue::glue(
"The original columns of `time_value`, ",
"`geo_value` and at least one signal. The current colums are \n",
paste(capture.output(object$info), collapse = "\n\n")
),
class = "epipredict__set_forecast_date__too_few_data_columns"
)
}
# the source data determines the actual time_values
# these are the non-na time_values;
Expand All @@ -65,25 +68,34 @@ set_forecast_date <- function(new_data, info, epi_keys_checked, latency) {
}
# make sure the as_of is sane
if (!inherits(forecast_date, class(max_time)) & !inherits(forecast_date, "POSIXt")) {
cli::cli_abort(paste(
"the data matrix `forecast_date` value is {forecast_date}, ",
"and not a valid `time_type` with type ",
"matching `time_value`'s type of ",
"{class(max_time)}."
))
cli::cli_abort(
paste(
"the data matrix `forecast_date` value is {forecast_date}, ",
"and not a valid `time_type` with type ",
"matching `time_value`'s type of ",
"{class(max_time)}."
),
class = "epipredict__set_forecast_date__wrong_time_value_type_error"
)
}
if (is.null(forecast_date) || is.na(forecast_date)) {
cli::cli_warn(paste(
"epi_data's `forecast_date` was {forecast_date}, setting to ",
"the latest time value, {max_time}."
))
cli::cli_warn(
paste(
"epi_data's `forecast_date` was {forecast_date}, setting to ",
"the latest time value, {max_time}."
),
class = "epipredict__set_forecast_date__max_time_warning"
)
forecast_date <- max_time
} else if (forecast_date < max_time) {
cli::cli_abort(paste(
"`forecast_date` ({(forecast_date)}) is before the most ",
"recent data ({max_time}). Remove before ",
"predicting."
))
cli::cli_abort(
paste(
"`forecast_date` ({(forecast_date)}) is before the most ",
"recent data ({max_time}). Remove before ",
"predicting."
),
class = "epipredict__set_forecast_date__misordered_forecast_date_error"
)
}
# TODO cover the rest of the possible types for as_of and max_time...
if (inherits(max_time, "Date")) {
Expand Down
10 changes: 0 additions & 10 deletions man/get_test_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-snapshots.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ test_that("arx_forecaster snapshots", {

test_that("arx_forecaster output format snapshots", {
jhu <- case_death_rate_subset %>%
dplyr::filter(time_value >= as.Date("2021-12-01"))
dplyr::filter(time_value >= as.Date("2021-12-01"))
out1 <- arx_forecaster(
jhu, "death_rate",
c("case_rate", "death_rate")
Expand Down
Loading

0 comments on commit 986b657

Please sign in to comment.