Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arx_* without latency adjustment #343

Closed
wants to merge 52 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
99f26c6
fix warnings and empty tests
dsweber2 Mar 16, 2024
9c9651d
first draft of extend_ahead
dsweber2 Mar 18, 2024
bc687d5
extend_ahead version bump and news
dsweber2 May 3, 2024
80153c5
styler has opinions
dsweber2 Mar 18, 2024
a2ed2e9
separate step version
dsweber2 Mar 29, 2024
dd857ef
styler
dsweber2 Mar 29, 2024
3060021
tests for utils-latency and accompanying fixes
dsweber2 Apr 1, 2024
b4d66e1
adding stringr
dsweber2 Apr 1, 2024
ef02b29
old snapshots, select prefers `all_of` for vectors
dsweber2 Apr 1, 2024
63e6878
local renv way out of date
dsweber2 Apr 1, 2024
0961644
pkgdown needs @keywords internal
dsweber2 Apr 1, 2024
bcce194
passes local tests after updating
dsweber2 Apr 2, 2024
cdd8aed
back to skipping some population_scaling tests
dsweber2 Apr 2, 2024
3c28475
step_adjust_latency works on tests
dsweber2 Apr 25, 2024
df48298
spurious lifecycle addition removed
dsweber2 May 3, 2024
a20dcf1
fixing RMDcheck remote
dsweber2 May 3, 2024
8d37492
nothing but `rlang::abort` -> `cli::cli_abort`s
dsweber2 May 3, 2024
0541948
smaller suggestions and styling
dsweber2 May 3, 2024
40e86ba
smaller suggestions: local tests passing again
dsweber2 May 6, 2024
36328de
moving shift detection earlier,dropping string*dep
dsweber2 May 8, 2024
9abae59
+purrr, styling
dsweber2 May 8, 2024
db71207
glue -> glue::glue
dsweber2 May 8, 2024
2868cbd
fix get_latent_column_tibble docs
dsweber2 May 8, 2024
a4a33db
step_adjust_latency arg docs
dsweber2 May 8, 2024
2f0305e
rec formatting things, dropping `purrr`
dsweber2 May 13, 2024
6c3ba35
glue->paste, dropping zoo
dsweber2 May 14, 2024
a6b0d3f
Detecting required/forbidden steps beforehand
dsweber2 May 16, 2024
1ada3d0
minor rebase woes
dsweber2 May 17, 2024
1b6a0af
tests for utils-latency and accompanying fixes
dsweber2 Apr 1, 2024
b9bed37
adding stringr
dsweber2 Apr 1, 2024
9c914e8
nothing but `rlang::abort` -> `cli::cli_abort`s
dsweber2 May 3, 2024
6411c76
moving shift detection earlier,dropping string*dep
dsweber2 May 8, 2024
655b141
+purrr, styling
dsweber2 May 8, 2024
fea65c4
rec formatting things, dropping `purrr`
dsweber2 May 13, 2024
1a84212
initial layer adjustments
dsweber2 May 15, 2024
b9189fb
namespace and doc fixes
dsweber2 May 17, 2024
5fdc3e4
full rebase fixes
dsweber2 May 17, 2024
2b339a2
adding latency adjusting to arx_forecaster
dsweber2 May 17, 2024
7b6f933
arx_classifier more or less free
dsweber2 May 17, 2024
2808f6a
formatting and snapshots
dsweber2 May 17, 2024
2becf68
updated man pages
dsweber2 May 22, 2024
a170ec1
group_by options to get the max_time_value
dsweber2 May 24, 2024
c5d3c9d
PR review recs
dsweber2 May 29, 2024
b0239e8
typo in multiline pipe replacement
dsweber2 May 29, 2024
240583c
happy styler
dsweber2 Jun 3, 2024
795abeb
various requested changes, check passes
dsweber2 Jun 14, 2024
28db575
style fix
dsweber2 Jun 14, 2024
c5136b3
inheritParams, correct print, test adjust subset
dsweber2 Jun 14, 2024
cf8fed6
space
dsweber2 Jun 14, 2024
f36f6fa
Merge pull request #334 from cmu-delphi/adjustAheadLayerAdditions
dsweber2 Jun 17, 2024
c1ca6d3
print fix and tests
dsweber2 Jun 22, 2024
8820c00
arx_* without latency adjustment
dsweber2 Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ S3method(autoplot,canned_epipred)
S3method(autoplot,epi_workflow)
S3method(bake,check_enough_train_data)
S3method(bake,epi_recipe)
S3method(bake,step_adjust_latency)
S3method(bake,step_epi_ahead)
S3method(bake,step_epi_lag)
S3method(bake,step_growth_rate)
Expand Down Expand Up @@ -60,6 +61,7 @@ S3method(predict,epi_workflow)
S3method(predict,flatline)
S3method(prep,check_enough_train_data)
S3method(prep,epi_recipe)
S3method(prep,step_adjust_latency)
S3method(prep,step_epi_ahead)
S3method(prep,step_epi_lag)
S3method(prep,step_growth_rate)
Expand Down Expand Up @@ -88,6 +90,7 @@ S3method(print,layer_quantile_distn)
S3method(print,layer_residual_quantiles)
S3method(print,layer_threshold)
S3method(print,layer_unnest)
S3method(print,step_adjust_latency)
S3method(print,step_epi_ahead)
S3method(print,step_epi_lag)
S3method(print,step_growth_rate)
Expand Down Expand Up @@ -191,6 +194,7 @@ export(remove_frosting)
export(remove_model)
export(slather)
export(smooth_quantile_reg)
export(step_adjust_latency)
export(step_epi_ahead)
export(step_epi_lag)
export(step_epi_naomit)
Expand Down Expand Up @@ -218,30 +222,39 @@ importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_scalar)
importFrom(cli,cli_abort)
importFrom(dplyr,"%>%")
importFrom(dplyr,across)
importFrom(dplyr,all_of)
importFrom(dplyr,group_by)
importFrom(dplyr,join_by)
importFrom(dplyr,left_join)
importFrom(dplyr,n)
importFrom(dplyr,pull)
importFrom(dplyr,rowwise)
importFrom(dplyr,select)
importFrom(dplyr,summarise)
importFrom(dplyr,tibble)
importFrom(dplyr,ungroup)
importFrom(epiprocess,growth_rate)
importFrom(generics,augment)
importFrom(generics,fit)
importFrom(generics,forecast)
importFrom(ggplot2,autoplot)
importFrom(glue,glue)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(magrittr,"%>%")
importFrom(quantreg,rq)
importFrom(recipes,bake)
importFrom(recipes,detect_step)
importFrom(recipes,prep)
importFrom(rlang,"!!")
importFrom(rlang,"%@%")
importFrom(rlang,"%||%")
importFrom(rlang,":=")
importFrom(rlang,abort)
importFrom(rlang,as_function)
importFrom(rlang,caller_env)
importFrom(rlang,enquos)
importFrom(rlang,global_env)
importFrom(rlang,is_null)
importFrom(rlang,set_names)
Expand All @@ -258,6 +271,7 @@ importFrom(stats,quantile)
importFrom(stats,residuals)
importFrom(tibble,tibble)
importFrom(tidyr,drop_na)
importFrom(tidyr,unnest)
importFrom(vctrs,as_list_of)
importFrom(vctrs,field)
importFrom(vctrs,new_rcrd)
Expand All @@ -267,3 +281,4 @@ importFrom(vctrs,vec_data)
importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
importFrom(vctrs,vec_recycle_common)
importFrom(workflows,extract_preprocessor)
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# epipredict (development)

Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicate PR's.
# epipredict 0.2

- add `latency_adjustment` as an option for `add_epi_ahead`, which adjusts the `ahead` so that the prediction is `ahead` relative to the `as_of` date for the `epi_data`, rather than relative to the last day of data.

# epipredict 0.1

Expand Down
4 changes: 2 additions & 2 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ arx_class_epi_workflow <- function(
args_list = arx_class_args_list()) {
validate_forecaster_inputs(epi_data, outcome, predictors)
if (!inherits(args_list, c("arx_class", "alist"))) {
rlang::abort("args_list was not created using `arx_class_args_list().")
cli::cli_abort("`args_list` was not created using `arx_class_args_list().`")
}
if (!(is.null(trainer) || is_classification(trainer))) {
rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.")
cli::cli_abort("`trainer` must be a `{.pkg parsnip}` model of mode 'classification'.")
}
lags <- arx_lags_validator(predictors, args_list$lags)

Expand Down
39 changes: 17 additions & 22 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,22 +129,16 @@ arx_fcast_epi_workflow <- function(
r <- r %>%
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training) %>%
{
if (!is.null(args_list$check_enough_data_n)) {
check_enough_train_data(
.,
all_predictors(),
!!outcome,
n = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
} else {
.
}
}

step_training_window(n_recent = args_list$n_training)
if (!is.null(args_list$check_enough_data_n)) {
r <- r %>% check_enough_train_data(
all_predictors(),
!!outcome,
n = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
}
forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)

Expand All @@ -158,19 +152,20 @@ arx_fcast_epi_workflow <- function(
))
args_list$quantile_levels <- quantile_levels
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
f <- layer_quantile_distn(f, quantile_levels = quantile_levels) %>%
f <- f %>%
layer_quantile_distn(quantile_levels = quantile_levels) %>%
layer_point_from_distn()
} else {
f <- layer_residual_quantiles(
f,
f <- f %>% layer_residual_quantiles(
quantile_levels = args_list$quantile_levels,
symmetrize = args_list$symmetrize,
by_key = args_list$quantile_by_key
)
}
f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>%
f <- f %>%
layer_add_forecast_date(forecast_date = forecast_date) %>%
layer_add_target_date(target_date = target_date)
if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred"))
if (args_list$nonneg) f <- f %>% layer_threshold(dplyr::starts_with(".pred"))

epi_workflow(r, trainer, f)
}
Expand Down Expand Up @@ -316,7 +311,7 @@ compare_quantile_args <- function(alist, tlist) {
if (setequal(alist, tlist)) {
return(sort(unique(alist)))
}
rlang::abort(c(
cli::cli_abort(c(
"You have specified different, non-default, quantiles in the trainier and `arx_args` options.",
i = "Please only specify quantiles in one location."
))
Expand Down
2 changes: 1 addition & 1 deletion R/compat-recipes.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ inline_check <- function(x) {
funs <- fun_calls(x)
funs <- funs[!(funs %in% c("~", "+", "-"))]
if (length(funs) > 0) {
rlang::abort(paste0(
cli::cli_abort(paste0(
"No in-line functions should be used here; ",
"use steps to define baking actions."
))
Expand Down
2 changes: 1 addition & 1 deletion R/epi_check_training_set.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ epi_check_training_set <- function(x, rec) {
if (!is.null(old_ok)) {
if (all(old_ok %in% colnames(x))) { # case 1
if (!all(old_ok %in% new_ok)) {
cli::cli_warn(c(
cli::cli_warn(paste(
"The recipe specifies additional keys. Because these are available,",
"they are being added to the metadata of the training data."
))
Expand Down
20 changes: 9 additions & 11 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ epi_recipe.epi_df <-
function(x, formula = NULL, ..., vars = NULL, roles = NULL) {
if (!is.null(formula)) {
if (!is.null(vars)) {
rlang::abort(
cli::cli_abort(
paste0(
"This `vars` specification will be ignored ",
"when a formula is used"
)
)
}
if (!is.null(roles)) {
rlang::abort(
cli::cli_abort(
paste0(
"This `roles` specification will be ignored ",
"when a formula is used"
Expand All @@ -80,10 +80,10 @@ epi_recipe.epi_df <-
}
if (is.null(vars)) vars <- colnames(x)
if (any(table(vars) > 1)) {
rlang::abort("`vars` should have unique members")
cli::cli_abort("`vars` should have unique members")
}
if (any(!(vars %in% colnames(x)))) {
rlang::abort("1 or more elements of `vars` are not in the data")
cli::cli_abort("1 or more elements of `vars` are not in the data")
}

keys <- epi_keys(x) # we know x is an epi_df
Expand All @@ -94,7 +94,7 @@ epi_recipe.epi_df <-
## Check and add roles when available
if (!is.null(roles)) {
if (length(roles) != length(vars)) {
rlang::abort(c(
cli::cli_abort(paste(
"The number of roles should be the same as the number of ",
"variables."
))
Expand Down Expand Up @@ -140,7 +140,6 @@ epi_recipe.epi_df <-


#' @rdname epi_recipe
#' @importFrom rlang abort
#' @export
epi_recipe.formula <- function(formula, data, ...) {
# we ensure that there's only 1 row in the template
Expand All @@ -152,7 +151,7 @@ epi_recipe.formula <- function(formula, data, ...) {

f_funcs <- recipes:::fun_calls(formula)
if (any(f_funcs == "-")) {
abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.")
cli::cli_abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.")
}

# Check for other in-line functions
Expand Down Expand Up @@ -432,7 +431,7 @@ prep.epi_recipe <- function(
x, training = NULL, fresh = FALSE, verbose = FALSE,
retain = TRUE, log_changes = FALSE, strings_as_factors = TRUE, ...) {
if (is.null(training)) {
cli::cli_warn(c(
cli::cli_warn(paste(
"!" = "No training data was supplied to {.fn prep}.",
"!" = "Unlike a {.cls recipe}, an {.cls epi_recipe} does not ",
"!" = "store the full template data in the object.",
Expand All @@ -457,7 +456,7 @@ prep.epi_recipe <- function(
}
skippers <- map_lgl(x$steps, recipes:::is_skipable)
if (any(skippers) & !retain) {
cli::cli_warn(c(
cli::cli_warn(paste(
"Since some operations have `skip = TRUE`, using ",
"`retain = TRUE` will allow those steps results to ",
"be accessible."
Expand All @@ -475,7 +474,7 @@ prep.epi_recipe <- function(
"You cannot `prep()` a tuneable recipe. Argument(s) with `tune()`: ",
arg, ". Do you want to use a tuning function such as `tune_grid()`?"
)
rlang::abort(msg)
cli::cli_abort(msg)
}
note <- paste("oper", i, gsub("_", " ", class(x$steps[[i]])[1]))
if (!x$steps[[i]]$trained | fresh) {
Expand Down Expand Up @@ -578,7 +577,6 @@ bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") {
new_data
}


kill_levels <- function(x, keys) {
for (i in which(names(x) %in% keys)) x[[i]] <- list(values = NA, ordered = NA)
x
Expand Down
82 changes: 52 additions & 30 deletions R/epi_shift.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,65 @@
#'
#' This is a lower-level function. As such it performs no error checking.
#'
#' @param x Data frame. Variables to shift
#' @param shifts List. Each list element is a vector of shifts.
#' Negative values produce leads. The list should have the same
#' length as the number of columns in `x`.
#' @param time_value Vector. Same length as `x` giving time stamps.
#' @param keys Data frame, vector, or `NULL`. Additional grouping vars.
#' @param out_name Chr. The output list will use this as a prefix.
#' @param x Data frame.
#' @param shift_val a single integer. Negative values produce leads.
#' @param newname the name for the newly shifted column
#' @param key_cols vector, or `NULL`. Additional grouping vars.
#'
#' @keywords internal
#'
#' @return a list of tibbles
epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") {
if (!is.data.frame(x)) x <- data.frame(x)
if (is.null(keys)) keys <- rep("empty", nrow(x))
p_in <- ncol(x)
out_list <- tibble::tibble(i = 1:p_in, shift = shifts) %>%
tidyr::unchop(shift) %>% # what is chop
dplyr::mutate(name = paste0(out_name, 1:nrow(.))) %>%
# One list element for each shifted feature
pmap(function(i, shift, name) {
tibble(keys,
time_value = time_value + shift, # Shift back
!!name := x[[i]]
)
})
if (is.data.frame(keys)) {
common_names <- c(names(keys), "time_value")
} else {
common_names <- c("keys", "time_value")
}

reduce(out_list, dplyr::full_join, by = common_names)
}

epi_shift_single <- function(x, col, shift_val, newname, key_cols) {
x %>%
dplyr::select(tidyselect::all_of(c(key_cols, col))) %>%
dplyr::mutate(time_value = time_value + shift_val) %>%
dplyr::rename(!!newname := {{ col }})
}

#' lags move columns forward to bring the past up to today, while aheads drag
#' the future back to today
#' @keywords internal
get_sign <- function(object) {
if (object$prefix == "lag_") {
return(1)
} else {
return(-1)
}
}

#' backend for both `bake.step_epi_ahead` and `bake.step_epi_lag`, performs the
#' checks missing in `epi_shift_single`
#' @keywords internal
add_shifted_columns <- function(new_data, object, amount) {
sign_shift <- get_sign(object)
grid <- tidyr::expand_grid(col = object$columns, amount = amount) %>%
dplyr::mutate(
newname = glue::glue("{object$prefix}{amount}_{col}"),
shift_val = sign_shift * amount,
amount = NULL
)

## ensure no name clashes
new_data_names <- colnames(new_data)
intersection <- new_data_names %in% grid$newname
if (any(intersection)) {
cli::cli_abort(
paste0(
"Name collision occured in `", class(object)[1],
"`. The following variable names already exists: ",
paste0(new_data_names[intersection], collapse = ", "),
"."
)
)
}
ok <- object$keys
shifted <- reduce(
pmap(grid, epi_shift_single, x = new_data, key_cols = ok),
dplyr::full_join,
by = ok
)
dplyr::full_join(new_data, shifted, by = ok) %>%
dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>%
dplyr::arrange(time_value) %>%
dplyr::ungroup()
}
6 changes: 3 additions & 3 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ augment.epi_workflow <- function(x, new_data, ...) {
if (epiprocess::is_epi_df(predictions)) {
join_by <- epi_keys(predictions)
} else {
rlang::abort(
cli::cli_abort(
c(
"Cannot determine how to join new_data with the predictions.",
"Try converting new_data to an epi_df with `as_epi_df(new_data)`."
"Cannot determine how to join `new_data` with the `predictions`.",
"Try converting `new_data` to an {.cls epi_df} with `as_epi_df(new_data)`."
)
)
}
Expand Down
Loading
Loading