From e4617d00e3fe52ad1807a15271894279d9bedeb7 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 09:44:10 -0700 Subject: [PATCH 01/16] checks pass --- NAMESPACE | 16 ++++-- NEWS.md | 1 + R/autoplot.R | 6 +- R/cdc_baseline_forecaster.R | 2 +- R/epi_keys.R | 56 ------------------- R/epi_recipe.R | 12 ++-- R/epi_workflow.R | 4 +- R/epipredict-package.R | 2 + R/flatline_forecaster.R | 2 +- R/get_test_data.R | 16 +++--- R/key_colnames.R | 28 ++++++++++ R/layer_cdc_flatline_quantiles.R | 2 +- R/layer_population_scaling.R | 2 +- R/step_epi_shift.R | 6 +- R/step_epi_slide.R | 2 +- R/step_growth_rate.R | 4 +- R/step_lag_difference.R | 4 +- R/step_training_window.R | 2 +- R/utils-misc.R | 4 +- man/epi_keys.Rd | 20 ------- tests/testthat/test-epi_recipe.R | 10 ++-- tests/testthat/test-epi_shift.R | 2 +- .../{test-epi_keys.R => test-key_colnames.R} | 30 ++-------- vignettes/epipredict.Rmd | 2 +- 24 files changed, 87 insertions(+), 148 deletions(-) delete mode 100644 R/epi_keys.R create mode 100644 R/key_colnames.R delete mode 100644 man/epi_keys.Rd rename tests/testthat/{test-epi_keys.R => test-key_colnames.R} (63%) diff --git a/NAMESPACE b/NAMESPACE index 608cea18b..b07321768 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -28,11 +28,6 @@ S3method(bake,step_population_scaling) S3method(bake,step_training_window) S3method(detect_layer,frosting) S3method(detect_layer,workflow) -S3method(epi_keys,data.frame) -S3method(epi_keys,default) -S3method(epi_keys,epi_df) -S3method(epi_keys,epi_workflow) -S3method(epi_keys,recipe) S3method(epi_recipe,default) S3method(epi_recipe,epi_df) S3method(epi_recipe,formula) @@ -55,6 +50,9 @@ S3method(forecast,epi_workflow) S3method(format,dist_quantiles) S3method(is.na,dist_quantiles) S3method(is.na,distribution) +S3method(key_colnames,epi_workflow) +S3method(key_colnames,list) +S3method(key_colnames,recipe) S3method(mean,dist_quantiles) S3method(median,dist_quantiles) S3method(predict,epi_workflow) @@ -154,7 +152,6 @@ export(clean_f_name) export(default_epi_recipe_blueprint) export(detect_layer) export(dist_quantiles) -export(epi_keys) export(epi_recipe) export(epi_recipe_blueprint) export(epi_workflow) @@ -231,9 +228,16 @@ importFrom(checkmate,assert_scalar) importFrom(cli,cli_abort) importFrom(dplyr,across) importFrom(dplyr,all_of) +importFrom(dplyr,any_of) +importFrom(dplyr,arrange) importFrom(dplyr,bind_cols) +importFrom(dplyr,bind_rows) +importFrom(dplyr,filter) importFrom(dplyr,group_by) +importFrom(dplyr,left_join) +importFrom(dplyr,mutate) importFrom(dplyr,n) +importFrom(dplyr,select) importFrom(dplyr,summarise) importFrom(dplyr,ungroup) importFrom(epiprocess,epi_slide) diff --git a/NEWS.md b/NEWS.md index 62095be73..15aa6de29 100644 --- a/NEWS.md +++ b/NEWS.md @@ -56,3 +56,4 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - add functionality to calculate weighted interval scores for `dist_quantiles()` - Add `step_epi_slide` to produce generic sliding computations over an `epi_df` - Add quantile random forests (via `{grf}`) as a parsnip engine +- Replace `epi_keys()` with `epiprocess::key_colnames()`, #352 diff --git a/R/autoplot.R b/R/autoplot.R index 77f04dde7..0d0e48e35 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -125,7 +125,7 @@ autoplot.epi_workflow <- function( if (!is.null(shift)) { edf <- dplyr::mutate(edf, time_value = time_value + shift) } - extra_keys <- setdiff(epi_keys_mold(mold), c("time_value", "geo_value")) + extra_keys <- setdiff(key_colnames(mold), c("time_value", "geo_value")) if (length(extra_keys) == 0L) extra_keys <- NULL edf <- as_epi_df(edf, as_of = object$fit$meta$as_of, @@ -145,7 +145,7 @@ autoplot.epi_workflow <- function( } predictions <- dplyr::rename(predictions, time_value = target_date) } - pred_cols_ok <- hardhat::check_column_names(predictions, epi_keys(edf)) + pred_cols_ok <- hardhat::check_column_names(predictions, key_colnames(edf)) if (!pred_cols_ok$ok) { cli::cli_warn(c( "`predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}.", @@ -165,7 +165,7 @@ autoplot.epi_workflow <- function( ) # Now, prepare matching facets in the predictions - ek <- kill_time_value(epi_keys(edf)) + ek <- kill_time_value(key_colnames(edf)) predictions <- predictions %>% dplyr::mutate( .facets = interaction(!!!rlang::syms(as.list(ek)), sep = "/"), diff --git a/R/cdc_baseline_forecaster.R b/R/cdc_baseline_forecaster.R index d5b74a9c3..31194daae 100644 --- a/R/cdc_baseline_forecaster.R +++ b/R/cdc_baseline_forecaster.R @@ -63,7 +63,7 @@ cdc_baseline_forecaster <- function( if (!inherits(args_list, c("cdc_flat_fcast", "alist"))) { cli_stop("args_list was not created using `cdc_baseline_args_list().") } - keys <- epi_keys(epi_data) + keys <- key_colnames(epi_data) ek <- kill_time_value(keys) outcome <- rlang::sym(outcome) diff --git a/R/epi_keys.R b/R/epi_keys.R deleted file mode 100644 index 08e4595c3..000000000 --- a/R/epi_keys.R +++ /dev/null @@ -1,56 +0,0 @@ -#' Grab any keys associated to an epi_df -#' -#' @param x a data.frame, tibble, or epi_df -#' @param ... additional arguments passed on to methods -#' -#' @return If an `epi_df`, this returns all "keys". Otherwise `NULL` -#' @keywords internal -#' @export -epi_keys <- function(x, ...) { - UseMethod("epi_keys") -} - -#' @export -epi_keys.default <- function(x, ...) { - character(0L) -} - -#' @export -epi_keys.data.frame <- function(x, other_keys = character(0L), ...) { - arg_is_chr(other_keys, allow_empty = TRUE) - nm <- c("time_value", "geo_value", other_keys) - intersect(nm, names(x)) -} - -#' @export -epi_keys.epi_df <- function(x, ...) { - c("time_value", "geo_value", attr(x, "metadata")$other_keys) -} - -#' @export -epi_keys.recipe <- function(x, ...) { - x$var_info$variable[x$var_info$role %in% c("time_value", "geo_value", "key")] -} - -#' @export -epi_keys.epi_workflow <- function(x, ...) { - epi_keys_mold(hardhat::extract_mold(x)) -} - -# a mold is a list extracted from a fitted workflow, gives info about -# training data. But it doesn't have a class -epi_keys_mold <- function(mold) { - keys <- c("time_value", "geo_value", "key") - molded_names <- names(mold$extras$roles) - mold_keys <- map(mold$extras$roles[molded_names %in% keys], names) - unname(unlist(mold_keys)) -} - -kill_time_value <- function(v) { - arg_is_chr(v) - v[v != "time_value"] -} - -epi_keys_only <- function(x, ...) { - kill_time_value(epi_keys(x, ...)) -} diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 6d01d718f..311decf62 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -90,7 +90,7 @@ epi_recipe.epi_df <- rlang::abort("1 or more elements of `vars` are not in the data") } - keys <- epi_keys(x) # we know x is an epi_df + keys <- key_colnames(x) # we know x is an epi_df var_info <- tibble(variable = vars) key_roles <- c("time_value", "geo_value", rep("key", length(keys) - 2)) @@ -186,7 +186,7 @@ epi_form2args <- function(formula, data, ...) { ## use rlang to get both sides of the formula outcomes <- recipes:::get_lhs_vars(formula, data) predictors <- recipes:::get_rhs_vars(formula, data, no_lhs = TRUE) - keys <- epi_keys(data) + keys <- key_colnames(data) ## if . was used on the rhs, subtract out the outcomes predictors <- predictors[!(predictors %in% outcomes)] @@ -444,9 +444,9 @@ prep.epi_recipe <- function( } training <- recipes:::check_training_set(training, x, fresh) training <- epi_check_training_set(training, x) - training <- dplyr::relocate(training, tidyselect::all_of(epi_keys(training))) + training <- dplyr::relocate(training, dplyr::all_of(key_colnames(training))) tr_data <- recipes:::train_info(training) - keys <- epi_keys(x) + keys <- key_colnames(x) orig_lvls <- lapply(training, recipes:::get_levels) orig_lvls <- kill_levels(orig_lvls, keys) @@ -498,10 +498,10 @@ prep.epi_recipe <- function( # tidymodels killed our class # for now, we only allow step_epi_* to alter the metadata training <- dplyr::dplyr_reconstruct( - epiprocess::as_epi_df(training), before_template + as_epi_df(training), before_template ) } - training <- dplyr::relocate(training, tidyselect::all_of(epi_keys(training))) + training <- dplyr::relocate(training, dplyr::all_of(key_colnames(training))) x$term_info <- recipes:::merge_term_info(get_types(training), x$term_info) if (!is.na(x$steps[[i]]$role)) { new_vars <- setdiff(x$term_info$variable, running_info$variable) diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 0bdeece4f..3660b87e1 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -184,8 +184,8 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), . #' @export augment.epi_workflow <- function(x, new_data, ...) { predictions <- predict(x, new_data, ...) - if (epiprocess::is_epi_df(predictions)) { - join_by <- epi_keys(predictions) + if (is_epi_df(predictions)) { + join_by <- key_colnames(predictions) } else { rlang::abort( c( diff --git a/R/epipredict-package.R b/R/epipredict-package.R index 6ca349570..d3c7a8a4a 100644 --- a/R/epipredict-package.R +++ b/R/epipredict-package.R @@ -3,6 +3,8 @@ #' @importFrom rlang := !! %||% as_function global_env set_names !!! #' @importFrom rlang is_logical is_true inject enquo enquos expr #' @importFrom stats poly predict lm residuals quantile +#' @importFrom dplyr arrange across all_of any_of bind_rows group_by summarise +#' filter mutate select left_join #' @importFrom cli cli_abort #' @importFrom checkmate assert assert_character assert_int assert_scalar #' assert_logical assert_numeric assert_number assert_integer diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index dac87e5c4..42970c569 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -36,7 +36,7 @@ flatline_forecaster <- function( if (!inherits(args_list, c("flat_fcast", "alist"))) { cli_stop("args_list was not created using `flatline_args_list().") } - keys <- epi_keys(epi_data) + keys <- key_colnames(epi_data) ek <- kill_time_value(keys) outcome <- rlang::sym(outcome) diff --git a/R/get_test_data.R b/R/get_test_data.R index 0a7d0dc2a..88ecf4054 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -92,7 +92,7 @@ get_test_data <- function( } x <- arrange(x, time_value) - groups <- kill_time_value(epi_keys(recipe)) + groups <- kill_time_value(key_colnames(recipe)) # If we skip NA completion, we remove undesirably early time values # Happens globally, over all groups @@ -102,7 +102,7 @@ get_test_data <- function( # Pad with explicit missing values up to and including the forecast_date # x is grouped here x <- pad_to_end(x, groups, forecast_date) %>% - epiprocess::group_by(dplyr::across(dplyr::all_of(groups))) + group_by(dplyr::across(dplyr::all_of(groups))) # If all(lags > 0), then we get rid of recent data if (min_lags > 0 && min_lags < Inf) { @@ -116,14 +116,14 @@ get_test_data <- function( dplyr::mutate(fillers = forecast_date - time_value > min_required) %>% dplyr::summarize( dplyr::across( - -tidyselect::any_of(epi_keys(recipe)), + -dplyr::any_of(key_colnames(recipe)), ~ all(is.na(.x[fillers])) & is.na(head(.x[!fillers], 1)) ), .groups = "drop" ) %>% dplyr::select(-fillers) %>% dplyr::summarise(dplyr::across( - -tidyselect::any_of(epi_keys(recipe)), ~ any(.x) + -dplyr::any_of(key_colnames(recipe)), ~ any(.x) )) %>% unlist() if (any(cannot_be_used)) { @@ -142,13 +142,13 @@ get_test_data <- function( } dplyr::filter(x, forecast_date - time_value <= min_required) %>% - epiprocess::ungroup() + ungroup() } pad_to_end <- function(x, groups, end_date) { - itval <- epiprocess:::guess_period(c(x$time_value, end_date), "time_value") + itval <- guess_period(c(x$time_value, end_date), "time_value") completed_time_values <- x %>% - dplyr::group_by(dplyr::across(tidyselect::all_of(groups))) %>% + dplyr::group_by(dplyr::across(dplyr::all_of(groups))) %>% dplyr::summarise( time_value = rlang::list2( time_value = Seq(max(time_value) + itval, end_date, itval) @@ -158,7 +158,7 @@ pad_to_end <- function(x, groups, end_date) { mutate(time_value = vctrs::vec_cast(time_value, x$time_value)) dplyr::bind_rows(x, completed_time_values) %>% - dplyr::arrange(dplyr::across(tidyselect::all_of(c("time_value", groups)))) + dplyr::arrange(dplyr::across(dplyr::all_of(c("time_value", groups)))) } Seq <- function(from, to, by) { diff --git a/R/key_colnames.R b/R/key_colnames.R new file mode 100644 index 000000000..e16ce1f42 --- /dev/null +++ b/R/key_colnames.R @@ -0,0 +1,28 @@ +#' @export +key_colnames.recipe <- function(x, ...) { + x$var_info$variable[x$var_info$role %in% c("time_value", "geo_value", "key")] +} + +#' @export +key_colnames.epi_workflow <- function(x, ...) { + NextMethod(hardhat::extract_mold(x)) +} + +# a mold is a list extracted from a fitted workflow, gives info about +# training data. But it doesn't have a class +#' @export +key_colnames.list <- function(x, ...) { + keys <- c("time_value", "geo_value", "key") + molded_names <- names(x$extras$roles) + mold_keys <- map(x$extras$roles[molded_names %in% keys], names) + unname(unlist(mold_keys)) %||% character(0L) +} + +kill_time_value <- function(v) { + arg_is_chr(v) + v[v != "time_value"] +} + +epi_keys_only <- function(x, ...) { + kill_time_value(key_colnames(x, ...)) +} diff --git a/R/layer_cdc_flatline_quantiles.R b/R/layer_cdc_flatline_quantiles.R index f54c1da78..db1440b03 100644 --- a/R/layer_cdc_flatline_quantiles.R +++ b/R/layer_cdc_flatline_quantiles.R @@ -170,7 +170,7 @@ slather.layer_cdc_flatline_quantiles <- ) } p <- components$predictions - ek <- kill_time_value(epi_keys_mold(components$mold)) + ek <- kill_time_value(key_colnames(components$mold)) r <- grab_residuals(the_fit, components) avail_grps <- character(0L) diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 33183198d..3829eec5a 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -136,7 +136,7 @@ slather.layer_population_scaling <- if (is.null(object$by)) { object$by <- intersect( - kill_time_value(epi_keys(components$predictions)), + kill_time_value(key_colnames(components$predictions)), colnames(dplyr::select(object$df, !object$df_pop_col)) ) } diff --git a/R/step_epi_shift.R b/R/step_epi_shift.R index 52f51de16..f45f5d8f4 100644 --- a/R/step_epi_shift.R +++ b/R/step_epi_shift.R @@ -89,7 +89,7 @@ step_epi_lag <- lag = as.integer(lag), prefix = prefix, default = default, - keys = epi_keys(recipe), + keys = key_colnames(recipe), columns = columns, skip = skip, id = id @@ -133,13 +133,13 @@ step_epi_ahead <- add_step( recipe, step_epi_ahead_new( - terms = dplyr::enquos(...), + terms = enquos(...), role = role, trained = trained, ahead = as.integer(ahead), prefix = prefix, default = default, - keys = epi_keys(recipe), + keys = key_colnames(recipe), columns = columns, skip = skip, id = id diff --git a/R/step_epi_slide.R b/R/step_epi_slide.R index 637d31a54..a8ad66c85 100644 --- a/R/step_epi_slide.R +++ b/R/step_epi_slide.R @@ -72,7 +72,7 @@ step_epi_slide <- role = role, trained = FALSE, prefix = prefix, - keys = epi_keys(recipe), + keys = key_colnames(recipe), columns = NULL, skip = skip, id = id diff --git a/R/step_growth_rate.R b/R/step_growth_rate.R index e5edb18d4..48d8b4394 100644 --- a/R/step_growth_rate.R +++ b/R/step_growth_rate.R @@ -87,7 +87,7 @@ step_growth_rate <- add_step( recipe, step_growth_rate_new( - terms = dplyr::enquos(...), + terms = enquos(...), role = role, trained = trained, horizon = horizon, @@ -95,7 +95,7 @@ step_growth_rate <- log_scale = log_scale, replace_Inf = replace_Inf, prefix = prefix, - keys = epi_keys(recipe), + keys = key_colnames(recipe), columns = columns, skip = skip, id = id, diff --git a/R/step_lag_difference.R b/R/step_lag_difference.R index e954bd9a0..87852be2d 100644 --- a/R/step_lag_difference.R +++ b/R/step_lag_difference.R @@ -52,12 +52,12 @@ step_lag_difference <- add_step( recipe, step_lag_difference_new( - terms = dplyr::enquos(...), + terms = enquos(...), role = role, trained = trained, horizon = horizon, prefix = prefix, - keys = epi_keys(recipe), + keys = key_colnames(recipe), columns = columns, skip = skip, id = id diff --git a/R/step_training_window.R b/R/step_training_window.R index 7102d29d8..90de468ce 100644 --- a/R/step_training_window.R +++ b/R/step_training_window.R @@ -87,7 +87,7 @@ step_training_window_new <- #' @export prep.step_training_window <- function(x, training, info = NULL, ...) { - ekt <- kill_time_value(epi_keys(training)) + ekt <- kill_time_value(key_colnames(training)) ek <- x$epi_keys %||% ekt %||% character(0L) hardhat::validate_column_names(training, ek) diff --git a/R/utils-misc.R b/R/utils-misc.R index 18f6380df..231a8f60f 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -39,9 +39,9 @@ grab_forged_keys <- function(forged, mold, new_data) { # 1. these are the keys in the test data after prep/bake new_keys <- names(extras) # 2. these are the keys in the training data - old_keys <- epi_keys_mold(mold) + old_keys <- key_colnames(mold) # 3. these are the keys in the test data as input - new_df_keys <- epi_keys(new_data, extra_keys = setdiff(new_keys, keys[1:2])) + new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, keys[1:2])) if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) { cli::cli_warn(c( "Not all epi keys that were present in the training data are available", diff --git a/man/epi_keys.Rd b/man/epi_keys.Rd deleted file mode 100644 index 8026fc140..000000000 --- a/man/epi_keys.Rd +++ /dev/null @@ -1,20 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/epi_keys.R -\name{epi_keys} -\alias{epi_keys} -\title{Grab any keys associated to an epi_df} -\usage{ -epi_keys(x, ...) -} -\arguments{ -\item{x}{a data.frame, tibble, or epi_df} - -\item{...}{additional arguments passed on to methods} -} -\value{ -If an \code{epi_df}, this returns all "keys". Otherwise \code{NULL} -} -\description{ -Grab any keys associated to an epi_df -} -\keyword{internal} diff --git a/tests/testthat/test-epi_recipe.R b/tests/testthat/test-epi_recipe.R index 75726652d..ed27d88c0 100644 --- a/tests/testthat/test-epi_recipe.R +++ b/tests/testthat/test-epi_recipe.R @@ -128,7 +128,7 @@ test_that("add/update/adjust/remove epi_recipe works as intended", { wf <- epi_workflow() %>% add_epi_recipe(r) - steps <- extract_preprocessor(wf)$steps + steps <- workflows::extract_preprocessor(wf)$steps expect_equal(length(steps), 3) expect_equal(class(steps[[1]]), c("step_epi_lag", "step")) expect_equal(steps[[1]]$lag, c(0, 7, 14)) @@ -143,7 +143,7 @@ test_that("add/update/adjust/remove epi_recipe works as intended", { wf <- update_epi_recipe(wf, r2) - steps <- extract_preprocessor(wf)$steps + steps <- workflows::extract_preprocessor(wf)$steps expect_equal(length(steps), 2) expect_equal(class(steps[[1]]), c("step_epi_lag", "step")) expect_equal(steps[[1]]$lag, c(0, 1)) @@ -152,7 +152,7 @@ test_that("add/update/adjust/remove epi_recipe works as intended", { # adjust_epi_recipe using step number wf <- adjust_epi_recipe(wf, which_step = 2, ahead = 7) - steps <- extract_preprocessor(wf)$steps + steps <- workflows::extract_preprocessor(wf)$steps expect_equal(length(steps), 2) expect_equal(class(steps[[1]]), c("step_epi_lag", "step")) expect_equal(steps[[1]]$lag, c(0, 1)) @@ -161,7 +161,7 @@ test_that("add/update/adjust/remove epi_recipe works as intended", { # adjust_epi_recipe using step name wf <- adjust_epi_recipe(wf, which_step = "step_epi_ahead", ahead = 8) - steps <- extract_preprocessor(wf)$steps + steps <- workflows::extract_preprocessor(wf)$steps expect_equal(length(steps), 2) expect_equal(class(steps[[1]]), c("step_epi_lag", "step")) expect_equal(steps[[1]]$lag, c(0, 1)) @@ -170,6 +170,6 @@ test_that("add/update/adjust/remove epi_recipe works as intended", { wf <- remove_epi_recipe(wf) - expect_error(extract_preprocessor(wf)$steps) + expect_error(workflows::extract_preprocessor(wf)$steps) expect_equal(wf$pre$actions$recipe$recipe, NULL) }) diff --git a/tests/testthat/test-epi_shift.R b/tests/testthat/test-epi_shift.R index b0ab3a21f..245a39c0d 100644 --- a/tests/testthat/test-epi_shift.R +++ b/tests/testthat/test-epi_shift.R @@ -24,7 +24,7 @@ test_that("epi shift single works, renames", { time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5), geo_value = "ca" ) %>% epiprocess::as_epi_df() - ess <- epi_shift_single(tib, "x", 1, "test", epi_keys(tib)) + ess <- epi_shift_single(tib, "x", 1, "test", key_colnames(tib)) expect_named(ess, c("time_value", "geo_value", "test")) expect_equal(ess$time_value, tib$time_value + 1) }) diff --git a/tests/testthat/test-epi_keys.R b/tests/testthat/test-key_colnames.R similarity index 63% rename from tests/testthat/test-epi_keys.R rename to tests/testthat/test-key_colnames.R index 3e794542e..3fecd9e44 100644 --- a/tests/testthat/test-epi_keys.R +++ b/tests/testthat/test-key_colnames.R @@ -1,25 +1,5 @@ -library(parsnip) -library(workflows) -library(dplyr) - -test_that("epi_keys returns empty for an object that isn't an epi_df", { - expect_identical(epi_keys(data.frame(x = 1:3, y = 2:4)), character(0L)) -}) - -test_that("epi_keys returns possible keys if they exist", { - expect_identical( - epi_keys(data.frame(time_value = 1:3, geo_value = 2:4)), - c("time_value", "geo_value") - ) -}) - - -test_that("Extracts keys from an epi_df", { - expect_equal(epi_keys(case_death_rate_subset), c("time_value", "geo_value")) -}) - test_that("Extracts keys from a recipe; roles are NA, giving an empty vector", { - expect_equal(epi_keys(recipe(case_death_rate_subset)), character(0L)) + expect_equal(key_colnames(recipe(case_death_rate_subset)), character(0L)) }) test_that("epi_keys_mold extracts time_value and geo_value, but not raw", { @@ -35,12 +15,12 @@ test_that("epi_keys_mold extracts time_value and geo_value, but not raw", { fit(data = case_death_rate_subset) expect_setequal( - epi_keys_mold(my_workflow$pre$mold), + key_colnames(my_workflow$pre$mold), c("time_value", "geo_value") ) }) -test_that("epi_keys_mold extracts additional keys when they are present", { +test_that("key_colnames extracts additional keys when they are present", { my_data <- tibble::tibble( geo_value = rep(c("ca", "fl", "pa"), each = 3), time_value = rep(seq(as.Date("2020-06-01"), as.Date("2020-06-03"), @@ -50,7 +30,7 @@ test_that("epi_keys_mold extracts additional keys when they are present", { state = rep(c("ca", "fl", "pa"), each = 3), # extra key value = 1:length(geo_value) + 0.01 * rnorm(length(geo_value)) ) %>% - epiprocess::as_epi_df( + as_epi_df( additional_metadata = list(other_keys = c("state", "pol")) ) @@ -61,7 +41,7 @@ test_that("epi_keys_mold extracts additional keys when they are present", { my_workflow <- epi_workflow(my_recipe, linear_reg()) %>% fit(my_data) expect_setequal( - epi_keys_mold(my_workflow$pre$mold), + key_colnames(my_workflow$pre$mold), c("time_value", "geo_value", "state", "pol") ) }) diff --git a/vignettes/epipredict.Rmd b/vignettes/epipredict.Rmd index af83dc321..7e24b04c6 100644 --- a/vignettes/epipredict.Rmd +++ b/vignettes/epipredict.Rmd @@ -448,7 +448,7 @@ To illustrate everything above, here is (roughly) the code for the r <- epi_recipe(jhu) %>% step_epi_ahead(case_rate, ahead = 7, skip = TRUE) %>% update_role(case_rate, new_role = "predictor") %>% - add_role(all_of(epi_keys(jhu)), new_role = "predictor") + add_role(all_of(key_colnames(jhu)), new_role = "predictor") f <- frosting() %>% layer_predict() %>% From c531007c577ce94ac2d6a40dc6f64fbdd18c7767 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 09:44:29 -0700 Subject: [PATCH 02/16] bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 4da92afe6..6d1217587 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.19 +Version: 0.0.20 Authors@R: c( person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), From a0d09f9c8f0db963bc590bb56ef2f5d949203aae Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 11:00:47 -0700 Subject: [PATCH 03/16] remove lots of extraneous prefixing --- NAMESPACE | 14 ++++++-- R/arx_classifier.R | 51 +++++++++++++------------- R/arx_forecaster.R | 53 ++++++++++++++------------- R/autoplot.R | 68 ++++++++++++++++++----------------- R/canned-epipred.R | 12 +++---- R/cdc_baseline_forecaster.R | 14 ++++---- R/check_enough_train_data.R | 30 +++++++--------- R/epi_check_training_set.R | 8 ++--- R/epi_recipe.R | 12 ++++--- R/epi_shift.R | 10 +++--- R/epi_workflow.R | 24 ++++++------- R/epipredict-package.R | 12 ++++--- R/flatline.R | 23 ++++++------ R/flatline_forecaster.R | 8 ++--- R/flusight_hub_formatter.R | 61 ++++++++++++++++--------------- R/frosting.R | 40 +++++++++++---------- R/step_epi_naomit.R | 2 +- man/add_frosting.Rd | 5 +-- man/adjust_frosting.Rd | 3 +- man/arx_class_epi_workflow.Rd | 6 ++-- man/arx_classifier.Rd | 2 +- man/arx_fcast_epi_workflow.Rd | 5 +-- man/arx_forecaster.Rd | 2 +- man/autoplot-epipred.Rd | 14 ++++---- man/epi_recipe.Rd | 10 +++--- man/flusight_hub_formatter.Rd | 39 ++++++++++---------- man/frosting.Rd | 4 +-- 27 files changed, 270 insertions(+), 262 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index b07321768..d1bd615b4 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -226,6 +226,7 @@ importFrom(checkmate,assert_number) importFrom(checkmate,assert_numeric) importFrom(checkmate,assert_scalar) importFrom(cli,cli_abort) +importFrom(cli,cli_warn) importFrom(dplyr,across) importFrom(dplyr,all_of) importFrom(dplyr,any_of) @@ -233,10 +234,12 @@ importFrom(dplyr,arrange) importFrom(dplyr,bind_cols) importFrom(dplyr,bind_rows) importFrom(dplyr,filter) +importFrom(dplyr,full_join) importFrom(dplyr,group_by) importFrom(dplyr,left_join) importFrom(dplyr,mutate) -importFrom(dplyr,n) +importFrom(dplyr,relocate) +importFrom(dplyr,rename) importFrom(dplyr,select) importFrom(dplyr,summarise) importFrom(dplyr,ungroup) @@ -245,7 +248,12 @@ importFrom(epiprocess,growth_rate) importFrom(generics,augment) importFrom(generics,fit) importFrom(generics,forecast) +importFrom(ggplot2,aes) importFrom(ggplot2,autoplot) +importFrom(ggplot2,geom_line) +importFrom(ggplot2,geom_linerange) +importFrom(ggplot2,geom_point) +importFrom(ggplot2,geom_ribbon) importFrom(hardhat,refresh_blueprint) importFrom(hardhat,run_mold) importFrom(magrittr,"%>%") @@ -257,6 +265,7 @@ importFrom(rlang,"%@%") importFrom(rlang,"%||%") importFrom(rlang,":=") importFrom(rlang,abort) +importFrom(rlang,arg_match) importFrom(rlang,as_function) importFrom(rlang,caller_env) importFrom(rlang,enquo) @@ -268,6 +277,7 @@ importFrom(rlang,is_logical) importFrom(rlang,is_null) importFrom(rlang,is_true) importFrom(rlang,set_names) +importFrom(rlang,sym) importFrom(stats,as.formula) importFrom(stats,family) importFrom(stats,lm) @@ -278,9 +288,9 @@ importFrom(stats,predict) importFrom(stats,qnorm) importFrom(stats,quantile) importFrom(stats,residuals) +importFrom(tibble,as_tibble) importFrom(tibble,tibble) importFrom(tidyr,crossing) -importFrom(tidyr,drop_na) importFrom(vctrs,as_list_of) importFrom(vctrs,field) importFrom(vctrs,new_rcrd) diff --git a/R/arx_classifier.R b/R/arx_classifier.R index 44acb9b30..d5f5bf05d 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -45,14 +45,14 @@ arx_classifier <- function( epi_data, outcome, predictors, - trainer = parsnip::logistic_reg(), + trainer = logistic_reg(), args_list = arx_class_args_list()) { if (!is_classification(trainer)) { - cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.") + cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.") } wf <- arx_class_epi_workflow(epi_data, outcome, predictors, trainer, args_list) - wf <- generics::fit(wf, epi_data) + wf <- fit(wf, epi_data) preds <- forecast( wf, @@ -60,8 +60,8 @@ arx_classifier <- function( n_recent = args_list$nafill_buffer, forecast_date = args_list$forecast_date %||% max(epi_data$time_value) ) %>% - tibble::as_tibble() %>% - dplyr::select(-time_value) + as_tibble() %>% + select(-time_value) structure( list( @@ -95,9 +95,9 @@ arx_classifier <- function( #' @export #' @seealso [arx_classifier()] #' @examples -#' +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value >= as.Date("2021-11-01")) +#' filter(time_value >= as.Date("2021-11-01")) #' #' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate")) #' @@ -105,7 +105,7 @@ arx_classifier <- function( #' jhu, #' "death_rate", #' c("case_rate", "death_rate"), -#' trainer = parsnip::multinom_reg(), +#' trainer = multinom_reg(), #' args_list = arx_class_args_list( #' breaks = c(-.05, .1), ahead = 14, #' horizon = 14, method = "linear_reg" @@ -119,10 +119,10 @@ arx_class_epi_workflow <- function( args_list = arx_class_args_list()) { validate_forecaster_inputs(epi_data, outcome, predictors) if (!inherits(args_list, c("arx_class", "alist"))) { - rlang::abort("args_list was not created using `arx_class_args_list().") + cli_abort("`args_list` was not created using `arx_class_args_list()`.") } if (!(is.null(trainer) || is_classification(trainer))) { - rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.") + cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.") } lags <- arx_lags_validator(predictors, args_list$lags) @@ -130,7 +130,7 @@ arx_class_epi_workflow <- function( # ------- predictors r <- epi_recipe(epi_data) %>% step_growth_rate( - tidyselect::all_of(predictors), + all_of(predictors), role = "grp", horizon = args_list$horizon, method = args_list$method, @@ -173,26 +173,23 @@ arx_class_epi_workflow <- function( o2 <- rlang::sym(paste0("ahead_", args_list$ahead, "_", o)) r <- r %>% step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome") %>% - step_mutate( + recipes::step_mutate( outcome_class = cut(!!o2, breaks = args_list$breaks), role = "outcome" ) %>% step_epi_naomit() %>% - step_training_window(n_recent = args_list$n_training) %>% - { - if (!is.null(args_list$check_enough_data_n)) { - check_enough_train_data( - ., - all_predictors(), - !!outcome, - n = args_list$check_enough_data_n, - epi_keys = args_list$check_enough_data_epi_keys, - drop_na = FALSE - ) - } else { - . - } - } + step_training_window(n_recent = args_list$n_training) + + if (!is.null(args_list$check_enough_data_n)) { + r <- check_enough_train_data( + r, + all_predictors(), + !!outcome, + n = args_list$check_enough_data_n, + epi_keys = args_list$check_enough_data_epi_keys, + drop_na = FALSE + ) + } forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 1b9e3d503..37c9aae86 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -42,14 +42,14 @@ arx_forecaster <- function( epi_data, outcome, predictors = outcome, - trainer = parsnip::linear_reg(), + trainer = linear_reg(), args_list = arx_args_list()) { if (!is_regression(trainer)) { - cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") + cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") } wf <- arx_fcast_epi_workflow(epi_data, outcome, predictors, trainer, args_list) - wf <- generics::fit(wf, epi_data) + wf <- fit(wf, epi_data) preds <- forecast( wf, @@ -57,8 +57,8 @@ arx_forecaster <- function( n_recent = args_list$nafill_buffer, forecast_date = args_list$forecast_date %||% max(epi_data$time_value) ) %>% - tibble::as_tibble() %>% - dplyr::select(-time_value) + as_tibble() %>% + select(-time_value) structure( list( @@ -91,8 +91,9 @@ arx_forecaster <- function( #' @seealso [arx_forecaster()] #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value >= as.Date("2021-12-01")) +#' filter(time_value >= as.Date("2021-12-01")) #' #' arx_fcast_epi_workflow( #' jhu, "death_rate", @@ -108,15 +109,15 @@ arx_fcast_epi_workflow <- function( epi_data, outcome, predictors = outcome, - trainer = parsnip::linear_reg(), + trainer = linear_reg(), args_list = arx_args_list()) { # --- validation validate_forecaster_inputs(epi_data, outcome, predictors) if (!inherits(args_list, c("arx_fcast", "alist"))) { - cli::cli_abort("args_list was not created using `arx_args_list().") + cli_abort("`args_list` was not created using `arx_args_list()`.") } if (!(is.null(trainer) || is_regression(trainer))) { - cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.") + cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") } lags <- arx_lags_validator(predictors, args_list$lags) @@ -129,21 +130,19 @@ arx_fcast_epi_workflow <- function( r <- r %>% step_epi_ahead(!!outcome, ahead = args_list$ahead) %>% step_epi_naomit() %>% - step_training_window(n_recent = args_list$n_training) %>% - { - if (!is.null(args_list$check_enough_data_n)) { - check_enough_train_data( - ., - all_predictors(), - !!outcome, - n = args_list$check_enough_data_n, - epi_keys = args_list$check_enough_data_epi_keys, - drop_na = FALSE - ) - } else { - . - } - } + step_training_window(n_recent = args_list$n_training) + + if (!is.null(args_list$check_enough_data_n)) { + r <- check_enough_train_data( + r, + all_predictors(), + !!outcome, + n = args_list$check_enough_data_n, + epi_keys = args_list$check_enough_data_epi_keys, + drop_na = FALSE + ) + } + forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) @@ -157,7 +156,7 @@ arx_fcast_epi_workflow <- function( rlang::eval_tidy(trainer$args$quantile_levels) )) args_list$quantile_levels <- quantile_levels - trainer$args$quantile_levels <- rlang::enquo(quantile_levels) + trainer$args$quantile_levels <- enquo(quantile_levels) f <- layer_quantile_distn(f, quantile_levels = quantile_levels) %>% layer_point_from_distn() } else { @@ -265,7 +264,7 @@ arx_args_list <- function( if (!is.null(forecast_date) && !is.null(target_date)) { if (forecast_date + ahead != target_date) { - cli::cli_warn(c( + cli_warn(c( "`forecast_date` + `ahead` must equal `target_date`.", i = "{.val {forecast_date}} + {.val {ahead}} != {.val {target_date}}." )) @@ -316,7 +315,7 @@ compare_quantile_args <- function(alist, tlist) { if (setequal(alist, tlist)) { return(sort(unique(alist))) } - rlang::abort(c( + cli_abort(c( "You have specified different, non-default, quantiles in the trainier and `arx_args` options.", i = "Please only specify quantiles in one location." )) diff --git a/R/autoplot.R b/R/autoplot.R index 0d0e48e35..1be30dd29 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -1,4 +1,4 @@ -#' @importFrom ggplot2 autoplot +#' @importFrom ggplot2 autoplot aes geom_point geom_line geom_ribbon geom_linerange #' @export ggplot2::autoplot @@ -28,6 +28,7 @@ ggplot2::autoplot #' #' @name autoplot-epipred #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% #' filter(time_value >= as.Date("2021-11-01")) #' @@ -41,26 +42,26 @@ ggplot2::autoplot #' layer_residual_quantiles( #' quantile_levels = c(.025, .1, .25, .75, .9, .975) #' ) %>% -#' layer_threshold(dplyr::starts_with(".pred")) %>% +#' layer_threshold(starts_with(".pred")) %>% #' layer_add_target_date() #' -#' wf <- epi_workflow(r, parsnip::linear_reg(), f) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg(), f) %>% fit(jhu) #' #' autoplot(wf) #' -#' latest <- jhu %>% dplyr::filter(time_value >= max(time_value) - 14) +#' latest <- jhu %>% filter(time_value >= max(time_value) - 14) #' preds <- predict(wf, latest) #' autoplot(wf, preds, .max_facets = 4) #' #' # ------- Show multiple horizons #' -#' p <- lapply(c(7, 14, 21, 28), \(h) { +#' p <- lapply(c(7, 14, 21, 28), function(h) { #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = h) %>% #' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% #' step_epi_naomit() -#' ewf <- epi_workflow(r, parsnip::linear_reg(), f) %>% fit(jhu) +#' ewf <- epi_workflow(r, linear_reg(), f) %>% fit(jhu) #' forecast(ewf) #' }) #' @@ -69,7 +70,8 @@ ggplot2::autoplot #' #' # ------- Plotting canned forecaster output #' -#' jhu <- case_death_rate_subset %>% filter(time_value >= as.Date("2021-11-01")) +#' jhu <- case_death_rate_subset %>% +#' filter(time_value >= as.Date("2021-11-01")) #' flat <- flatline_forecaster(jhu, "death_rate") #' autoplot(flat, .max_facets = 4) #' @@ -95,7 +97,7 @@ autoplot.epi_workflow <- function( rlang::arg_match(.facet_by) if (!workflows::is_trained_workflow(object)) { - cli::cli_abort(c( + cli_abort(c( "Can't plot an untrained {.cls epi_workflow}.", i = "Do you need to call `fit()`?" )) @@ -105,25 +107,25 @@ autoplot.epi_workflow <- function( y <- mold$outcomes if (ncol(y) > 1) { y <- y[, 1] - cli::cli_warn("Multiple outcome variables were detected. Displaying only 1.") + cli_warn("Multiple outcome variables were detected. Displaying only 1.") } keys <- c("time_value", "geo_value", "key") mold_roles <- names(mold$extras$roles) - edf <- dplyr::bind_cols(mold$extras$roles[mold_roles %in% keys], y) + edf <- bind_cols(mold$extras$roles[mold_roles %in% keys], y) if (starts_with_impl("ahead_", names(y))) { old_name_y <- unlist(strsplit(names(y), "_")) shift <- as.numeric(old_name_y[2]) new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_") - edf <- dplyr::rename(edf, !!new_name_y := !!names(y)) + edf <- rename(edf, !!new_name_y := !!names(y)) } else if (starts_with_impl("lag_", names(y))) { old_name_y <- unlist(strsplit(names(y), "_")) shift <- -as.numeric(old_name_y[2]) new_name_y <- paste(old_name_y[-c(1:2)], collapse = "_") - edf <- dplyr::rename(edf, !!new_name_y := !!names(y)) + edf <- rename(edf, !!new_name_y := !!names(y)) } if (!is.null(shift)) { - edf <- dplyr::mutate(edf, time_value = time_value + shift) + edf <- mutate(edf, time_value = time_value + shift) } extra_keys <- setdiff(key_colnames(mold), c("time_value", "geo_value")) if (length(extra_keys) == 0L) extra_keys <- NULL @@ -141,13 +143,13 @@ autoplot.epi_workflow <- function( if ("target_date" %in% names(predictions)) { if ("time_value" %in% names(predictions)) { - predictions <- dplyr::select(predictions, -time_value) + predictions <- select(predictions, -time_value) } - predictions <- dplyr::rename(predictions, time_value = target_date) + predictions <- rename(predictions, time_value = target_date) } pred_cols_ok <- hardhat::check_column_names(predictions, key_colnames(edf)) if (!pred_cols_ok$ok) { - cli::cli_warn(c( + cli_warn(c( "`predictions` is missing required variables: {.var {pred_cols_ok$missing_names}}.", i = "Plotting the original data." )) @@ -167,13 +169,13 @@ autoplot.epi_workflow <- function( # Now, prepare matching facets in the predictions ek <- kill_time_value(key_colnames(edf)) predictions <- predictions %>% - dplyr::mutate( + mutate( .facets = interaction(!!!rlang::syms(as.list(ek)), sep = "/"), ) if (.max_facets < Inf) { top_n <- levels(as.factor(bp$data$.facets))[seq_len(.max_facets)] - predictions <- dplyr::filter(predictions, .facets %in% top_n) %>% - dplyr::mutate(.facets = droplevels(.facets)) + predictions <- filter(predictions, .facets %in% top_n) %>% + mutate(.facets = droplevels(.facets)) } @@ -182,17 +184,17 @@ autoplot.epi_workflow <- function( } if (".pred" %in% names(predictions)) { - ntarget_dates <- dplyr::n_distinct(predictions$time_value) + ntarget_dates <- n_distinct(predictions$time_value) if (ntarget_dates > 1L) { bp <- bp + - ggplot2::geom_line( - data = predictions, ggplot2::aes(y = .data$.pred), + geom_line( + data = predictions, aes(y = .data$.pred), color = .point_pred_color ) } else { bp <- bp + - ggplot2::geom_point( - data = predictions, ggplot2::aes(y = .data$.pred), + geom_point( + data = predictions, aes(y = .data$.pred), color = .point_pred_color ) } @@ -243,7 +245,7 @@ plot_bands <- function( ntarget_dates <- dplyr::n_distinct(predictions$time_value) predictions <- predictions %>% - dplyr::mutate(.pred_distn = dist_quantiles(quantile(.pred_distn, l), l)) %>% + mutate(.pred_distn = dist_quantiles(quantile(.pred_distn, l), l)) %>% pivot_quantiles_wider(.pred_distn) qnames <- setdiff(names(predictions), innames) @@ -253,32 +255,32 @@ plot_bands <- function( if (i == 1) { if (ntarget_dates > 1L) { base_plot <- base_plot + - ggplot2::geom_ribbon( + geom_ribbon( data = predictions, - ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]), + aes(ymin = .data[[bottom]], ymax = .data[[top]]), alpha = 0.2, linewidth = linewidth, fill = fill ) } else { base_plot <- base_plot + - ggplot2::geom_linerange( + geom_linerange( data = predictions, - ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]), + aes(ymin = .data[[bottom]], ymax = .data[[top]]), alpha = 0.2, linewidth = 2, color = fill ) } } else { if (ntarget_dates > 1L) { base_plot <- base_plot + - ggplot2::geom_ribbon( + geom_ribbon( data = predictions, - ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]), + aes(ymin = .data[[bottom]], ymax = .data[[top]]), fill = fill, alpha = alpha ) } else { base_plot <- base_plot + - ggplot2::geom_linerange( + geom_linerange( data = predictions, - ggplot2::aes(ymin = .data[[bottom]], ymax = .data[[top]]), + aes(ymin = .data[[bottom]], ymax = .data[[top]]), color = fill, alpha = alpha, linewidth = 2 ) } diff --git a/R/canned-epipred.R b/R/canned-epipred.R index 5a87a0d2e..0adc0536a 100644 --- a/R/canned-epipred.R +++ b/R/canned-epipred.R @@ -1,6 +1,6 @@ validate_forecaster_inputs <- function(epi_data, outcome, predictors) { - if (!epiprocess::is_epi_df(epi_data)) { - cli::cli_abort(c( + if (!is_epi_df(epi_data)) { + cli_abort(c( "`epi_data` must be an {.cls epi_df}.", "!" = "This one is a {.cls {class(epi_data)}}." )) @@ -8,11 +8,11 @@ validate_forecaster_inputs <- function(epi_data, outcome, predictors) { arg_is_chr_scalar(outcome) arg_is_chr(predictors) if (!outcome %in% names(epi_data)) { - cli::cli_abort("{.var {outcome}} was not found in the training data.") + cli_abort("{.var {outcome}} was not found in the training data.") } check <- hardhat::check_column_names(epi_data, predictors) if (!check$ok) { - cli::cli_abort(c( + cli_abort(c( "At least one predictor was not found in the training data.", "!" = "The following required columns are missing: {.val {check$missing_names}}." )) @@ -29,7 +29,7 @@ arx_lags_validator <- function(predictors, lags) { if (l == 1) { lags <- rep(lags, p) } else if (length(lags) != p) { - cli::cli_abort(c( + cli_abort(c( "You have requested {p} predictor(s) but {l} different lags.", i = "Lags must be a vector or a list with length == number of predictors." )) @@ -39,7 +39,7 @@ arx_lags_validator <- function(predictors, lags) { lags <- lags[order(match(names(lags), predictors))] } else { predictors_miss <- setdiff(predictors, names(lags)) - cli::cli_abort(c( + cli_abort(c( "If lags is a named list, then all predictors must be present.", i = "The predictors are {.var {predictors}}.", i = "So lags is missing {.var {predictors_miss}}'." diff --git a/R/cdc_baseline_forecaster.R b/R/cdc_baseline_forecaster.R index 31194daae..21b5e8ece 100644 --- a/R/cdc_baseline_forecaster.R +++ b/R/cdc_baseline_forecaster.R @@ -98,14 +98,14 @@ cdc_baseline_forecaster <- function( # layer_add_target_date(target_date = target_date) if (args_list$nonneg) f <- layer_threshold(f, ".pred") - eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") + eng <- linear_reg(engine = "flatline") wf <- epi_workflow(r, eng, f) - wf <- generics::fit(wf, epi_data) + wf <- fit(wf, epi_data) preds <- suppressWarnings(predict(wf, new_data = latest)) %>% - tibble::as_tibble() %>% - dplyr::select(-time_value) %>% - dplyr::mutate(target_date = forecast_date + ahead * args_list$data_frequency) + as_tibble() %>% + select(-time_value) %>% + mutate(target_date = forecast_date + ahead * args_list$data_frequency) structure( list( @@ -218,11 +218,11 @@ parse_period <- function(x) { mult <- switch(mult, day = 1L, wee = 7L, - cli::cli_abort("incompatible timespan in `aheads`.") + cli_abort("incompatible timespan in `aheads`.") ) x <- as.numeric(x[1]) * mult } - if (length(x) > 2L) cli::cli_abort("incompatible timespan in `aheads`.") + if (length(x) > 2L) cli_abort("incompatible timespan in `aheads`.") } stopifnot(rlang::is_integerish(x)) as.integer(x) diff --git a/R/check_enough_train_data.R b/R/check_enough_train_data.R index af2183d15..1279a3712 100644 --- a/R/check_enough_train_data.R +++ b/R/check_enough_train_data.R @@ -49,13 +49,13 @@ check_enough_train_data <- columns = NULL, skip = TRUE, id = rand_id("enough_train_data")) { - add_check( + recipes::add_check( recipe, check_enough_train_data_new( n = n, epi_keys = epi_keys, drop_na = drop_na, - terms = rlang::enquos(...), + terms = enquos(...), role = role, trained = trained, columns = columns, @@ -67,7 +67,7 @@ check_enough_train_data <- check_enough_train_data_new <- function(n, epi_keys, drop_na, terms, role, trained, columns, skip, id) { - check( + recipes::check( subclass = "enough_train_data", prefix = "check_", n = n, @@ -83,30 +83,24 @@ check_enough_train_data_new <- } #' @export -#' @importFrom dplyr group_by summarise ungroup across all_of n -#' @importFrom tidyr drop_na prep.check_enough_train_data <- function(x, training, info = NULL, ...) { - col_names <- recipes_eval_select(x$terms, training, info) + col_names <- recipes::recipes_eval_select(x$terms, training, info) if (is.null(x$n)) { x$n <- length(col_names) } + if (x$drop_na) { + training <- tidyr::drop_na(training) + } cols_not_enough_data <- training %>% - { - if (x$drop_na) { - drop_na(.) - } else { - . - } - } %>% group_by(across(all_of(.env$x$epi_keys))) %>% - summarise(across(all_of(.env$col_names), ~ n() < .env$x$n), .groups = "drop") %>% + summarise(across(all_of(.env$col_names), ~ dplyr::n() < .env$x$n), .groups = "drop") %>% summarise(across(all_of(.env$col_names), any), .groups = "drop") %>% unlist() %>% names(.)[.] if (length(cols_not_enough_data) > 0) { - cli::cli_abort( + cli_abort( "The following columns don't have enough data to predict: {cols_not_enough_data}." ) } @@ -132,16 +126,16 @@ bake.check_enough_train_data <- function(object, new_data, ...) { #' @export print.check_enough_train_data <- function(x, width = max(20, options()$width - 30), ...) { title <- paste0("Check enough data (n = ", x$n, ") for ") - print_step(x$columns, x$terms, x$trained, title, width) + recipes::print_step(x$columns, x$terms, x$trained, title, width) invisible(x) } #' @export tidy.check_enough_train_data <- function(x, ...) { - if (is_trained(x)) { + if (recipes::is_trained(x)) { res <- tibble(terms = unname(x$columns)) } else { - res <- tibble(terms = sel2char(x$terms)) + res <- tibble(terms = recipes::sel2char(x$terms)) } res$id <- x$id res$n <- x$n diff --git a/R/epi_check_training_set.R b/R/epi_check_training_set.R index 0c7dc9036..596e99887 100644 --- a/R/epi_check_training_set.R +++ b/R/epi_check_training_set.R @@ -16,7 +16,7 @@ epi_check_training_set <- function(x, rec) { if (!is.null(old_ok)) { if (all(old_ok %in% colnames(x))) { # case 1 if (!all(old_ok %in% new_ok)) { - cli::cli_warn(c( + cli_warn(c( "The recipe specifies additional keys. Because these are available,", "they are being added to the metadata of the training data." )) @@ -25,7 +25,7 @@ epi_check_training_set <- function(x, rec) { } missing_ok <- setdiff(old_ok, colnames(x)) if (length(missing_ok) > 0) { # case 2 - cli::cli_abort(c( + cli_abort(c( "The recipe specifies keys which are not in the training data.", i = "The training set is missing columns for {missing_ok}." )) @@ -45,8 +45,8 @@ validate_meta_match <- function(x, template, meta, warn_or_abort = "warn") { ) if (new_meta != old_meta) { switch(warn_or_abort, - warn = cli::cli_warn(msg), - abort = cli::cli_abort(msg) + warn = cli_warn(msg), + abort = cli_abort(msg) ) } } diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 311decf62..c40fda019 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -46,17 +46,19 @@ epi_recipe.default <- function(x, ...) { #' #' @export #' @examples +#' library(dplyr) +#' library(recipes) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-08-01") %>% -#' dplyr::arrange(geo_value, time_value) +#' filter(time_value > "2021-08-01") %>% +#' arrange(geo_value, time_value) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% -#' recipes::step_naomit(recipes::all_predictors()) %>% +#' step_naomit(all_predictors()) %>% #' # below, `skip` means we don't do this at predict time -#' recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) +#' step_naomit(all_outcomes(), skip = TRUE) #' #' r epi_recipe.epi_df <- @@ -501,7 +503,7 @@ prep.epi_recipe <- function( as_epi_df(training), before_template ) } - training <- dplyr::relocate(training, dplyr::all_of(key_colnames(training))) + training <- dplyr::relocate(training, all_of(key_colnames(training))) x$term_info <- recipes:::merge_term_info(get_types(training), x$term_info) if (!is.na(x$steps[[i]]$role)) { new_vars <- setdiff(x$term_info$variable, running_info$variable) diff --git a/R/epi_shift.R b/R/epi_shift.R index b40b36ecc..eb534f1ea 100644 --- a/R/epi_shift.R +++ b/R/epi_shift.R @@ -17,9 +17,9 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") { if (!is.data.frame(x)) x <- data.frame(x) if (is.null(keys)) keys <- rep("empty", nrow(x)) p_in <- ncol(x) - out_list <- tibble::tibble(i = 1:p_in, shift = shifts) %>% + out_list <- tibble(i = 1:p_in, shift = shifts) %>% tidyr::unchop(shift) %>% # what is chop - dplyr::mutate(name = paste0(out_name, 1:nrow(.))) %>% + mutate(name = paste0(out_name, 1:nrow(.))) %>% # One list element for each shifted feature pmap(function(i, shift, name) { tibble(keys, @@ -38,7 +38,7 @@ epi_shift <- function(x, shifts, time_value, keys = NULL, out_name = "x") { epi_shift_single <- function(x, col, shift_val, newname, key_cols) { x %>% - dplyr::select(tidyselect::all_of(c(key_cols, col))) %>% - dplyr::mutate(time_value = time_value + shift_val) %>% - dplyr::rename(!!newname := {{ col }}) + select(all_of(c(key_cols, col))) %>% + mutate(time_value = time_value + shift_val) %>% + rename(!!newname := {{ col }}) } diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 3660b87e1..f715dc9b0 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -187,24 +187,20 @@ augment.epi_workflow <- function(x, new_data, ...) { if (is_epi_df(predictions)) { join_by <- key_colnames(predictions) } else { - rlang::abort( - c( - "Cannot determine how to join new_data with the predictions.", - "Try converting new_data to an epi_df with `as_epi_df(new_data)`." - ) - ) + cli_abort(c( + "Cannot determine how to join new_data with the predictions.", + "Try converting new_data to an epi_df with `as_epi_df(new_data)`." + )) } complete_overlap <- intersect(names(new_data), join_by) if (length(complete_overlap) < length(join_by)) { - rlang::warn( - glue::glue( - "Your original training data had keys {join_by}, but", - "`new_data` only has {complete_overlap}. The output", - "may be strange." - ) - ) + rlang::warn(glue::glue( + "Your original training data had keys {join_by}, but", + "`new_data` only has {complete_overlap}. The output", + "may be strange." + )) } - dplyr::full_join(predictions, new_data, by = join_by) + full_join(predictions, new_data, by = join_by) } new_epi_workflow <- function( diff --git a/R/epipredict-package.R b/R/epipredict-package.R index d3c7a8a4a..785c183b7 100644 --- a/R/epipredict-package.R +++ b/R/epipredict-package.R @@ -1,14 +1,16 @@ ## usethis namespace: start -#' @importFrom tibble tibble +#' @importFrom tibble tibble as_tibble #' @importFrom rlang := !! %||% as_function global_env set_names !!! -#' @importFrom rlang is_logical is_true inject enquo enquos expr +#' is_logical is_true inject enquo enquos expr sym arg_match #' @importFrom stats poly predict lm residuals quantile -#' @importFrom dplyr arrange across all_of any_of bind_rows group_by summarise -#' filter mutate select left_join -#' @importFrom cli cli_abort +#' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by +#' summarise filter mutate select left_join rename ungroup full_join +#' relocate +#' @importFrom cli cli_abort cli_warn #' @importFrom checkmate assert assert_character assert_int assert_scalar #' assert_logical assert_numeric assert_number assert_integer #' assert_integerish assert_date assert_function assert_class +#' @importFrom generics fit #' @import epiprocess parsnip ## usethis namespace: end NULL diff --git a/R/flatline.R b/R/flatline.R index 0f98b0e2b..fb60c920d 100644 --- a/R/flatline.R +++ b/R/flatline.R @@ -43,26 +43,26 @@ flatline <- function(formula, data) { observed <- rhs[n] # DANGER!! ek <- rhs[-n] if (length(response) > 1) { - cli_stop("flatline forecaster can accept only 1 observed time series.") + cli_abort("flatline forecaster can accept only 1 observed time series.") } keys <- kill_time_value(ek) preds <- data %>% - dplyr::mutate( + mutate( .pred = !!rlang::sym(observed), .resid = !!rlang::sym(response) - .pred ) .pred <- preds %>% - dplyr::filter(!is.na(.pred)) %>% - dplyr::group_by(!!!rlang::syms(keys)) %>% - dplyr::arrange(time_value) %>% + filter(!is.na(.pred)) %>% + group_by(!!!rlang::syms(keys)) %>% + arrange(time_value) %>% dplyr::slice_tail(n = 1L) %>% - dplyr::ungroup() %>% - dplyr::select(tidyselect::all_of(c(keys, ".pred"))) + ungroup() %>% + select(all_of(c(keys, ".pred"))) structure( list( - residuals = dplyr::select(preds, dplyr::all_of(c(keys, ".resid"))), + residuals = select(preds, all_of(c(keys, ".resid"))), .pred = .pred ), class = "flatline" @@ -80,14 +80,13 @@ predict.flatline <- function(object, newdata, ...) { metadata <- names(object)[names(object) != ".pred"] ek <- names(newdata) if (!all(metadata %in% ek)) { - cli_stop( + cli_abort(c( "`newdata` has different metadata than was used", "to fit the flatline forecaster" - ) + )) } - dplyr::left_join(newdata, object, by = metadata) %>% - dplyr::pull(.pred) + left_join(newdata, object, by = metadata)$.pred } #' @export diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index 42970c569..50cccc908 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -61,18 +61,18 @@ flatline_forecaster <- function( layer_add_target_date(target_date = target_date) if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred")) - eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") + eng <- linear_reg(engine = "flatline") wf <- epi_workflow(r, eng, f) - wf <- generics::fit(wf, epi_data) + wf <- fit(wf, epi_data) preds <- suppressWarnings(forecast( wf, fill_locf = TRUE, n_recent = args_list$nafill_buffer, forecast_date = forecast_date )) %>% - tibble::as_tibble() %>% - dplyr::select(-time_value) + as_tibble() %>% + select(-time_value) structure( list( diff --git a/R/flusight_hub_formatter.R b/R/flusight_hub_formatter.R index e45f88189..c91f738ae 100644 --- a/R/flusight_hub_formatter.R +++ b/R/flusight_hub_formatter.R @@ -56,27 +56,26 @@ abbr_to_location <- function(abbr) { #' @export #' #' @examples -#' if (require(dplyr)) { -#' weekly_deaths <- case_death_rate_subset %>% -#' filter( -#' time_value >= as.Date("2021-09-01"), -#' geo_value %in% c("ca", "ny", "dc", "ga", "vt") -#' ) %>% -#' select(geo_value, time_value, death_rate) %>% -#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>% -#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>% -#' select(-pop, -death_rate) %>% -#' group_by(geo_value) %>% -#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>% -#' ungroup() %>% -#' filter(weekdays(time_value) == "Saturday") +#' library(dplyr) +#' weekly_deaths <- case_death_rate_subset %>% +#' filter( +#' time_value >= as.Date("2021-09-01"), +#' geo_value %in% c("ca", "ny", "dc", "ga", "vt") +#' ) %>% +#' select(geo_value, time_value, death_rate) %>% +#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>% +#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>% +#' select(-pop, -death_rate) %>% +#' group_by(geo_value) %>% +#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>% +#' ungroup() %>% +#' filter(weekdays(time_value) == "Saturday") #' -#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths") -#' flusight_hub_formatter(cdc) -#' flusight_hub_formatter(cdc, target = "wk inc covid deaths") -#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths")) -#' flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile") -#' } +#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths") +#' flusight_hub_formatter(cdc) +#' flusight_hub_formatter(cdc, target = "wk inc covid deaths") +#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths")) +#' flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile") flusight_hub_formatter <- function( object, ..., .fcast_period = c("daily", "weekly")) { @@ -98,7 +97,7 @@ flusight_hub_formatter.data.frame <- function( optional_names <- c("ahead", "target_date") hardhat::validate_column_names(object, required_names) if (!any(optional_names %in% names(object))) { - cli::cli_abort("At least one of {.val {optional_names}} must be present.") + cli_abort("At least one of {.val {optional_names}} must be present.") } dots <- enquos(..., .named = TRUE) @@ -106,38 +105,38 @@ flusight_hub_formatter.data.frame <- function( object <- object %>% # combine the predictions and the distribution - dplyr::mutate(.pred_distn = nested_quantiles(.pred_distn)) %>% + mutate(.pred_distn = nested_quantiles(.pred_distn)) %>% tidyr::unnest(.pred_distn) %>% # now we create the correct column names - dplyr::rename( + rename( value = values, output_type_id = quantile_levels, reference_date = forecast_date ) %>% # convert to fips codes, and add any constant cols passed in ... - dplyr::mutate(location = abbr_to_location(tolower(geo_value)), geo_value = NULL) + mutate(location = abbr_to_location(tolower(geo_value)), geo_value = NULL) # create target_end_date / horizon, depending on what is available pp <- ifelse(match.arg(.fcast_period) == "daily", 1L, 7L) has_ahead <- charmatch("ahead", names(object)) if ("target_date" %in% names(object) && !is.na(has_ahead)) { object <- object %>% - dplyr::rename( + rename( target_end_date = target_date, horizon = !!names(object)[has_ahead] ) } else if (!is.na(has_ahead)) { # ahead present, not target date object <- object %>% - dplyr::rename(horizon = !!names(object)[has_ahead]) %>% - dplyr::mutate(target_end_date = horizon * pp + reference_date) + rename(horizon = !!names(object)[has_ahead]) %>% + mutate(target_end_date = horizon * pp + reference_date) } else { # target_date present, not ahead object <- object %>% - dplyr::rename(target_end_date = target_date) %>% - dplyr::mutate(horizon = as.integer((target_end_date - reference_date)) / pp) + rename(target_end_date = target_date) %>% + mutate(horizon = as.integer((target_end_date - reference_date)) / pp) } object %>% - dplyr::relocate( + relocate( reference_date, horizon, target_end_date, location, output_type_id, value ) %>% - dplyr::mutate(!!!dots) + mutate(!!!dots) } diff --git a/R/frosting.R b/R/frosting.R index 4fc0caec3..8474edbdf 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -8,15 +8,16 @@ #' @export #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% #' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' latest <- jhu %>% -#' dplyr::filter(time_value >= max(time_value) - 14) +#' filter(time_value >= max(time_value) - 14) #' #' # Add frosting to a workflow and predict #' f <- frosting() %>% @@ -84,7 +85,8 @@ validate_has_postprocessor <- function(x, ..., call = caller_env()) { rlang::check_dots_empty() has_postprocessor <- has_postprocessor_frosting(x) if (!has_postprocessor) { - message <- c("The workflow must have a frosting postprocessor.", + message <- c( + "The workflow must have a frosting postprocessor.", i = "Provide one with `add_frosting()`." ) rlang::abort(message, call = call) @@ -125,6 +127,7 @@ update_frosting <- function(x, frosting, ...) { #' #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% #' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% @@ -132,7 +135,7 @@ update_frosting <- function(x, frosting, ...) { #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' # in the frosting from the workflow #' f1 <- frosting() %>% @@ -177,11 +180,10 @@ adjust_frosting.epi_workflow <- function( adjust_frosting.frosting <- function( x, which_layer, ...) { if (!(is.numeric(which_layer) || is.character(which_layer))) { - cli::cli_abort( - c("`which_layer` must be a number or a character.", - i = "`which_layer` has class {.cls {class(which_layer)[1]}}." - ) - ) + cli_abort(c( + "`which_layer` must be a number or a character.", + i = "`which_layer` has class {.cls {class(which_layer)[1]}}." + )) } else if (is.numeric(which_layer)) { x$layers[[which_layer]] <- update(x$layers[[which_layer]], ...) } else { @@ -190,7 +192,7 @@ adjust_frosting.frosting <- function( if (!starts_with_layer) which_layer <- paste0("layer_", which_layer) if (!(which_layer %in% layer_names)) { - cli::cli_abort(c( + cli_abort(c( "`which_layer` does not appear in the available `frosting` layer names. ", i = "The layer names are {.val {layer_names}}." )) @@ -199,7 +201,7 @@ adjust_frosting.frosting <- function( if (length(which_layer_idx) == 1) { x$layers[[which_layer_idx]] <- update(x$layers[[which_layer_idx]], ...) } else { - cli::cli_abort(c( + cli_abort(c( "`which_layer` is not unique. Matches layers: {.val {which_layer_idx}}.", i = "Please use the layer number instead for precise alterations." )) @@ -216,7 +218,7 @@ add_postprocessor <- function(x, postprocessor, ..., call = caller_env()) { if (is_frosting(postprocessor)) { return(add_frosting(x, postprocessor)) } - cli::cli_abort("`postprocessor` must be a frosting object.", call = call) + cli_abort("`postprocessor` must be a frosting object.", call = call) } is_frosting <- function(x) { @@ -227,7 +229,7 @@ is_frosting <- function(x) { validate_frosting <- function(x, ..., arg = "`x`", call = caller_env()) { rlang::check_dots_empty() if (!is_frosting(x)) { - cli::cli_abort( + cli_abort( "{arg} must be a frosting postprocessor, not a {.cls {class(x)[[1]]}}.", .call = call ) @@ -260,14 +262,14 @@ new_frosting <- function() { #' @export #' #' @examples -#' +#' library(dplyr) #' # Toy example to show that frosting can be created and added for postprocessing #' f <- frosting() #' wf <- epi_workflow() %>% add_frosting(f) #' #' # A more realistic example #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% @@ -307,7 +309,7 @@ extract_frosting <- function(x, ...) { #' @export extract_frosting.default <- function(x, ...) { - cli::cli_abort(c( + cli_abort(c( "Frosting is only available for epi_workflows currently.", i = "Can you use `epi_workflow()` instead of `workflow()`?" )) @@ -319,7 +321,7 @@ extract_frosting.epi_workflow <- function(x, ...) { if (has_postprocessor_frosting(x)) { return(x$post$actions$frosting$frosting) } else { - cli_stop("The epi_workflow does not have a postprocessor.") + cli_abort("The epi_workflow does not have a postprocessor.") } } @@ -342,7 +344,7 @@ apply_frosting <- function(workflow, ...) { #' @export apply_frosting.default <- function(workflow, components, ...) { if (has_postprocessor(workflow)) { - cli::cli_abort(c( + cli_abort(c( "Postprocessing is only available for epi_workflows currently.", i = "Can you use `epi_workflow()` instead of `workflow()`?" )) @@ -373,7 +375,7 @@ apply_frosting.epi_workflow <- } if (!has_postprocessor_frosting(workflow)) { - cli::cli_warn(c( + cli_warn(c( "Only postprocessors of class {.cls frosting} are allowed.", "Returning unpostprocessed predictions." )) diff --git a/R/step_epi_naomit.R b/R/step_epi_naomit.R index 1cbc9c5d9..d81ba398d 100644 --- a/R/step_epi_naomit.R +++ b/R/step_epi_naomit.R @@ -22,6 +22,6 @@ step_epi_naomit <- function(recipe) { print.step_naomit <- # not exported from recipes package function(x, width = max(20, options()$width - 30), ...) { title <- "Removing rows with NA values in " - print_step(x$columns, x$terms, x$trained, title, width) + recipes::print_step(x$columns, x$terms, x$trained, title, width) invisible(x) } diff --git a/man/add_frosting.Rd b/man/add_frosting.Rd index 161a540e2..94812cbe2 100644 --- a/man/add_frosting.Rd +++ b/man/add_frosting.Rd @@ -26,15 +26,16 @@ update_frosting(x, frosting, ...) Add frosting to a workflow } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) latest <- jhu \%>\% - dplyr::filter(time_value >= max(time_value) - 14) + filter(time_value >= max(time_value) - 14) # Add frosting to a workflow and predict f <- frosting() \%>\% diff --git a/man/adjust_frosting.Rd b/man/adjust_frosting.Rd index 6cdc13b30..c089b3443 100644 --- a/man/adjust_frosting.Rd +++ b/man/adjust_frosting.Rd @@ -35,6 +35,7 @@ must be inputted as \code{...}. See the examples below for brief illustrations of the different types of updates. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% @@ -42,7 +43,7 @@ r <- epi_recipe(jhu) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) # in the frosting from the workflow f1 <- frosting() \%>\% diff --git a/man/arx_class_epi_workflow.Rd b/man/arx_class_epi_workflow.Rd index bfce7cdaa..713365f17 100644 --- a/man/arx_class_epi_workflow.Rd +++ b/man/arx_class_epi_workflow.Rd @@ -47,9 +47,9 @@ before fitting and predicting. Supplying a trainer to the function may alter the returned \code{epi_workflow} object but can be omitted. } \examples{ - +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value >= as.Date("2021-11-01")) + filter(time_value >= as.Date("2021-11-01")) arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate")) @@ -57,7 +57,7 @@ arx_class_epi_workflow( jhu, "death_rate", c("case_rate", "death_rate"), - trainer = parsnip::multinom_reg(), + trainer = multinom_reg(), args_list = arx_class_args_list( breaks = c(-.05, .1), ahead = 14, horizon = 14, method = "linear_reg" diff --git a/man/arx_classifier.Rd b/man/arx_classifier.Rd index 350352ae9..36297b00c 100644 --- a/man/arx_classifier.Rd +++ b/man/arx_classifier.Rd @@ -8,7 +8,7 @@ arx_classifier( epi_data, outcome, predictors, - trainer = parsnip::logistic_reg(), + trainer = logistic_reg(), args_list = arx_class_args_list() ) } diff --git a/man/arx_fcast_epi_workflow.Rd b/man/arx_fcast_epi_workflow.Rd index 4ed279351..4070a3337 100644 --- a/man/arx_fcast_epi_workflow.Rd +++ b/man/arx_fcast_epi_workflow.Rd @@ -8,7 +8,7 @@ arx_fcast_epi_workflow( epi_data, outcome, predictors = outcome, - trainer = parsnip::linear_reg(), + trainer = linear_reg(), args_list = arx_args_list() ) } @@ -42,8 +42,9 @@ may alter the returned \code{epi_workflow} object (e.g., if you intend to use \code{\link[=quantile_reg]{quantile_reg()}}) but can be omitted. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value >= as.Date("2021-12-01")) + filter(time_value >= as.Date("2021-12-01")) arx_fcast_epi_workflow( jhu, "death_rate", diff --git a/man/arx_forecaster.Rd b/man/arx_forecaster.Rd index af05c0682..d8c7671dc 100644 --- a/man/arx_forecaster.Rd +++ b/man/arx_forecaster.Rd @@ -8,7 +8,7 @@ arx_forecaster( epi_data, outcome, predictors = outcome, - trainer = parsnip::linear_reg(), + trainer = linear_reg(), args_list = arx_args_list() ) } diff --git a/man/autoplot-epipred.Rd b/man/autoplot-epipred.Rd index dd6b37dcd..27bfdf5f7 100644 --- a/man/autoplot-epipred.Rd +++ b/man/autoplot-epipred.Rd @@ -70,6 +70,7 @@ will be shown as well. Unfit workflows will result in an error, (you can simply call \code{autoplot()} on the original \code{epi_df}). } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% filter(time_value >= as.Date("2021-11-01")) @@ -83,26 +84,26 @@ f <- frosting() \%>\% layer_residual_quantiles( quantile_levels = c(.025, .1, .25, .75, .9, .975) ) \%>\% - layer_threshold(dplyr::starts_with(".pred")) \%>\% + layer_threshold(starts_with(".pred")) \%>\% layer_add_target_date() -wf <- epi_workflow(r, parsnip::linear_reg(), f) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg(), f) \%>\% fit(jhu) autoplot(wf) -latest <- jhu \%>\% dplyr::filter(time_value >= max(time_value) - 14) +latest <- jhu \%>\% filter(time_value >= max(time_value) - 14) preds <- predict(wf, latest) autoplot(wf, preds, .max_facets = 4) # ------- Show multiple horizons -p <- lapply(c(7, 14, 21, 28), \(h) { +p <- lapply(c(7, 14, 21, 28), function(h) { r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = h) \%>\% step_epi_lag(case_rate, lag = c(0, 7, 14)) \%>\% step_epi_naomit() - ewf <- epi_workflow(r, parsnip::linear_reg(), f) \%>\% fit(jhu) + ewf <- epi_workflow(r, linear_reg(), f) \%>\% fit(jhu) forecast(ewf) }) @@ -111,7 +112,8 @@ autoplot(wf, p, .max_facets = 4) # ------- Plotting canned forecaster output -jhu <- case_death_rate_subset \%>\% filter(time_value >= as.Date("2021-11-01")) +jhu <- case_death_rate_subset \%>\% + filter(time_value >= as.Date("2021-11-01")) flat <- flatline_forecaster(jhu, "death_rate") autoplot(flat, .max_facets = 4) diff --git a/man/epi_recipe.Rd b/man/epi_recipe.Rd index 1c9048a36..d0105d1ec 100644 --- a/man/epi_recipe.Rd +++ b/man/epi_recipe.Rd @@ -57,17 +57,19 @@ around \code{\link[recipes:recipe]{recipes::recipe()}} to properly handle the ad columns present in an \code{epi_df} } \examples{ +library(dplyr) +library(recipes) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-08-01") \%>\% - dplyr::arrange(geo_value, time_value) + filter(time_value > "2021-08-01") \%>\% + arrange(geo_value, time_value) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_lag(case_rate, lag = c(0, 7, 14)) \%>\% - recipes::step_naomit(recipes::all_predictors()) \%>\% + step_naomit(all_predictors()) \%>\% # below, `skip` means we don't do this at predict time - recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) + step_naomit(all_outcomes(), skip = TRUE) r } diff --git a/man/flusight_hub_formatter.Rd b/man/flusight_hub_formatter.Rd index 87cfc7e4f..b43bc0ac2 100644 --- a/man/flusight_hub_formatter.Rd +++ b/man/flusight_hub_formatter.Rd @@ -41,25 +41,24 @@ be done via the \code{...} argument. See the examples below. The specific requir format for this forecast task is \href{https://github.com/cdcepi/FluSight-forecast-hub/blob/main/model-output/README.md}{here}. } \examples{ -if (require(dplyr)) { - weekly_deaths <- case_death_rate_subset \%>\% - filter( - time_value >= as.Date("2021-09-01"), - geo_value \%in\% c("ca", "ny", "dc", "ga", "vt") - ) \%>\% - select(geo_value, time_value, death_rate) \%>\% - left_join(state_census \%>\% select(pop, abbr), by = c("geo_value" = "abbr")) \%>\% - mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) \%>\% - select(-pop, -death_rate) \%>\% - group_by(geo_value) \%>\% - epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") \%>\% - ungroup() \%>\% - filter(weekdays(time_value) == "Saturday") +library(dplyr) +weekly_deaths <- case_death_rate_subset \%>\% + filter( + time_value >= as.Date("2021-09-01"), + geo_value \%in\% c("ca", "ny", "dc", "ga", "vt") + ) \%>\% + select(geo_value, time_value, death_rate) \%>\% + left_join(state_census \%>\% select(pop, abbr), by = c("geo_value" = "abbr")) \%>\% + mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) \%>\% + select(-pop, -death_rate) \%>\% + group_by(geo_value) \%>\% + epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") \%>\% + ungroup() \%>\% + filter(weekdays(time_value) == "Saturday") - cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths") - flusight_hub_formatter(cdc) - flusight_hub_formatter(cdc, target = "wk inc covid deaths") - flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths")) - flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile") -} +cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths") +flusight_hub_formatter(cdc) +flusight_hub_formatter(cdc, target = "wk inc covid deaths") +flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths")) +flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile") } diff --git a/man/frosting.Rd b/man/frosting.Rd index 367d132ec..a75f21b61 100644 --- a/man/frosting.Rd +++ b/man/frosting.Rd @@ -22,14 +22,14 @@ to hold steps for postprocessing predictions. The arguments are currently placeholders and must be NULL } \examples{ - +library(dplyr) # Toy example to show that frosting can be created and added for postprocessing f <- frosting() wf <- epi_workflow() \%>\% add_frosting(f) # A more realistic example jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% From cf3e89a0ce497693f81b71f989f14d571d1d0808 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 11:05:46 -0700 Subject: [PATCH 04/16] remove unused function --- R/grab_names.R | 23 ----------------------- tests/testthat/test-grab_names.R | 8 -------- 2 files changed, 31 deletions(-) delete mode 100644 R/grab_names.R delete mode 100644 tests/testthat/test-grab_names.R diff --git a/R/grab_names.R b/R/grab_names.R deleted file mode 100644 index 7ff3ac77e..000000000 --- a/R/grab_names.R +++ /dev/null @@ -1,23 +0,0 @@ -#' Get the names from a data frame via tidy select -#' -#' Given a data.frame, use `` syntax to choose -#' some variables. Return the names of those variables -#' -#' As this is an internal function, no checks are performed. -#' -#' @param dat a data.frame -#' @param ... <[`tidy-select`][dplyr::dplyr_tidy_select]> One or more unquoted -#' expressions separated by commas. Variable names can be used as if they -#' were positions in the data frame, so expressions like `x:y` can -#' be used to select a range of variables. -#' -#' @export -#' @keywords internal -#' @return a character vector -#' @examples -#' df <- data.frame(a = 1, b = 2, cc = rep(NA, 3)) -#' grab_names(df, dplyr::starts_with("c")) -grab_names <- function(dat, ...) { - x <- rlang::expr(c(...)) - names(tidyselect::eval_select(x, dat)) -} diff --git a/tests/testthat/test-grab_names.R b/tests/testthat/test-grab_names.R deleted file mode 100644 index 6e0376f5a..000000000 --- a/tests/testthat/test-grab_names.R +++ /dev/null @@ -1,8 +0,0 @@ -df <- data.frame(b = 1, c = 2, ca = 3, cat = 4) - -test_that("Names are grabbed properly", { - expect_identical( - grab_names(df, dplyr::starts_with("ca")), - subset(names(df), startsWith(names(df), "ca")) - ) -}) From 78cd65e2d27ea284a682f271ba36db7b1f2fedfe Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 14:13:09 -0700 Subject: [PATCH 05/16] remove unused funs --- R/utils-cli.R | 23 ----------------------- man/grab_names.Rd | 31 ------------------------------- 2 files changed, 54 deletions(-) delete mode 100644 R/utils-cli.R delete mode 100644 man/grab_names.Rd diff --git a/R/utils-cli.R b/R/utils-cli.R deleted file mode 100644 index 3b1555941..000000000 --- a/R/utils-cli.R +++ /dev/null @@ -1,23 +0,0 @@ -# Modeled after / copied from rundel/ghclass -cli_glue <- function(..., .envir = parent.frame()) { - txt <- cli::cli_format_method(cli::cli_text(..., .envir = .envir)) - - # cli_format_method does wrapping which we dont want at this stage - # so glue things back together. - paste(txt, collapse = " ") -} - -cli_stop <- function(..., .envir = parent.frame()) { - text <- cli_glue(..., .envir = .envir) - stop(paste(text, collapse = "\n"), call. = FALSE) -} - -cli_warn <- function(..., .envir = parent.frame()) { - text <- cli_glue(..., .envir = .envir) - warning(paste(text, collapse = "\n"), call. = FALSE) -} - -#' @importFrom rlang caller_env -cat_line <- function(...) { - cat(paste0(..., collapse = "\n"), "\n", sep = "") -} diff --git a/man/grab_names.Rd b/man/grab_names.Rd deleted file mode 100644 index cee6b19dc..000000000 --- a/man/grab_names.Rd +++ /dev/null @@ -1,31 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/grab_names.R -\name{grab_names} -\alias{grab_names} -\title{Get the names from a data frame via tidy select} -\usage{ -grab_names(dat, ...) -} -\arguments{ -\item{dat}{a data.frame} - -\item{...}{<\code{\link[dplyr:dplyr_tidy_select]{tidy-select}}> One or more unquoted -expressions separated by commas. Variable names can be used as if they -were positions in the data frame, so expressions like \code{x:y} can -be used to select a range of variables.} -} -\value{ -a character vector -} -\description{ -Given a data.frame, use \verb{} syntax to choose -some variables. Return the names of those variables -} -\details{ -As this is an internal function, no checks are performed. -} -\examples{ -df <- data.frame(a = 1, b = 2, cc = rep(NA, 3)) -grab_names(df, dplyr::starts_with("c")) -} -\keyword{internal} From 8d24aa17a1666b93638e8ee972bff7a5196bc5a1 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 14:13:33 -0700 Subject: [PATCH 06/16] remove prefixed funs, simplify some redundancies --- NAMESPACE | 8 ++- R/cdc_baseline_forecaster.R | 2 +- R/epipredict-package.R | 7 +- R/flatline_forecaster.R | 2 +- R/get_test_data.R | 45 ++++++------ R/layer_add_forecast_date.R | 15 ++-- R/layer_add_target_date.R | 17 +++-- R/layer_cdc_flatline_quantiles.R | 39 +++++----- R/layer_naomit.R | 9 +-- R/layer_point_from_distn.R | 10 +-- R/layer_population_scaling.R | 41 ++++------- R/layer_predict.R | 7 +- R/layer_predictive_distn.R | 9 +-- R/layer_quantile_distn.R | 7 +- R/layer_residual_quantiles.R | 31 ++++---- R/layer_threshold_preds.R | 16 ++--- R/layer_unnest.R | 2 +- R/layers.R | 8 +-- R/model-methods.R | 8 +-- R/pivot_quantiles.R | 49 ++++++------- R/reexports-tidymodels.R | 12 ++++ R/step_epi_shift.R | 108 +++++++++++----------------- R/step_epi_slide.R | 11 ++- R/step_growth_rate.R | 85 +++++++++------------- R/step_lag_difference.R | 59 +++++++-------- R/step_population_scaling.R | 88 ++++++++--------------- R/step_training_window.R | 33 +++------ R/tidy.R | 25 +++---- R/time_types.R | 2 +- R/utils-misc.R | 15 ++-- man/Add_model.Rd | 8 +-- man/layer_add_forecast_date.Rd | 7 +- man/layer_add_target_date.Rd | 5 +- man/layer_cdc_flatline_quantiles.Rd | 9 +-- man/layer_naomit.Rd | 5 +- man/layer_point_from_distn.Rd | 6 +- man/layer_population_scaling.Rd | 9 +-- man/layer_predict.Rd | 3 +- man/layer_predictive_distn.Rd | 5 +- man/layer_quantile_distn.Rd | 3 +- man/layer_residual_quantiles.Rd | 15 ++-- man/layer_threshold.Rd | 6 +- man/nested_quantiles.Rd | 6 +- man/pivot_quantiles_longer.Rd | 4 +- man/pivot_quantiles_wider.Rd | 2 +- man/reexports.Rd | 9 ++- man/step_epi_shift.Rd | 12 +--- man/step_growth_rate.Rd | 22 ++---- man/step_lag_difference.Rd | 9 +-- man/step_population_scaling.Rd | 32 +++------ man/step_training_window.Rd | 18 ++--- man/tidy.frosting.Rd | 3 +- man/update.layer.Rd | 8 +-- 53 files changed, 429 insertions(+), 547 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index d1bd615b4..8d08ea12e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -167,7 +167,6 @@ export(flusight_hub_formatter) export(forecast) export(frosting) export(get_test_data) -export(grab_names) export(is_epi_recipe) export(is_epi_workflow) export(is_layer) @@ -191,6 +190,7 @@ export(pivot_quantiles_longer) export(pivot_quantiles_wider) export(prep) export(quantile_reg) +export(rand_id) export(remove_epi_recipe) export(remove_frosting) export(remove_model) @@ -204,6 +204,8 @@ export(step_growth_rate) export(step_lag_difference) export(step_population_scaling) export(step_training_window) +export(tibble) +export(tidy) export(update_epi_recipe) export(update_frosting) export(update_model) @@ -233,6 +235,7 @@ importFrom(dplyr,any_of) importFrom(dplyr,arrange) importFrom(dplyr,bind_cols) importFrom(dplyr,bind_rows) +importFrom(dplyr,everything) importFrom(dplyr,filter) importFrom(dplyr,full_join) importFrom(dplyr,group_by) @@ -242,12 +245,14 @@ importFrom(dplyr,relocate) importFrom(dplyr,rename) importFrom(dplyr,select) importFrom(dplyr,summarise) +importFrom(dplyr,summarize) importFrom(dplyr,ungroup) importFrom(epiprocess,epi_slide) importFrom(epiprocess,growth_rate) importFrom(generics,augment) importFrom(generics,fit) importFrom(generics,forecast) +importFrom(generics,tidy) importFrom(ggplot2,aes) importFrom(ggplot2,autoplot) importFrom(ggplot2,geom_line) @@ -259,6 +264,7 @@ importFrom(hardhat,run_mold) importFrom(magrittr,"%>%") importFrom(recipes,bake) importFrom(recipes,prep) +importFrom(recipes,rand_id) importFrom(rlang,"!!!") importFrom(rlang,"!!") importFrom(rlang,"%@%") diff --git a/R/cdc_baseline_forecaster.R b/R/cdc_baseline_forecaster.R index 21b5e8ece..74af5e443 100644 --- a/R/cdc_baseline_forecaster.R +++ b/R/cdc_baseline_forecaster.R @@ -61,7 +61,7 @@ cdc_baseline_forecaster <- function( args_list = cdc_baseline_args_list()) { validate_forecaster_inputs(epi_data, outcome, "time_value") if (!inherits(args_list, c("cdc_flat_fcast", "alist"))) { - cli_stop("args_list was not created using `cdc_baseline_args_list().") + cli_abort("`args_list` was not created using `cdc_baseline_args_list().") } keys <- key_colnames(epi_data) ek <- kill_time_value(keys) diff --git a/R/epipredict-package.R b/R/epipredict-package.R index 785c183b7..6460b65e4 100644 --- a/R/epipredict-package.R +++ b/R/epipredict-package.R @@ -1,16 +1,15 @@ ## usethis namespace: start -#' @importFrom tibble tibble as_tibble +#' @importFrom tibble as_tibble #' @importFrom rlang := !! %||% as_function global_env set_names !!! #' is_logical is_true inject enquo enquos expr sym arg_match #' @importFrom stats poly predict lm residuals quantile #' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by -#' summarise filter mutate select left_join rename ungroup full_join -#' relocate +#' summarize filter mutate select left_join rename ungroup full_join +#' relocate summarise everything #' @importFrom cli cli_abort cli_warn #' @importFrom checkmate assert assert_character assert_int assert_scalar #' assert_logical assert_numeric assert_number assert_integer #' assert_integerish assert_date assert_function assert_class -#' @importFrom generics fit #' @import epiprocess parsnip ## usethis namespace: end NULL diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index 50cccc908..55808b803 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -34,7 +34,7 @@ flatline_forecaster <- function( args_list = flatline_args_list()) { validate_forecaster_inputs(epi_data, outcome, "time_value") if (!inherits(args_list, c("flat_fcast", "alist"))) { - cli_stop("args_list was not created using `flatline_args_list().") + cli_abort("`args_list` was not created using `flatline_args_list().") } keys <- key_colnames(epi_data) ek <- kill_time_value(keys) diff --git a/R/get_test_data.R b/R/get_test_data.R index 88ecf4054..ff3a146ef 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -42,14 +42,13 @@ #' get_test_data(recipe = rec, x = case_death_rate_subset) #' @importFrom rlang %@% #' @export - get_test_data <- function( recipe, x, fill_locf = FALSE, n_recent = NULL, forecast_date = max(x$time_value)) { - if (!is_epi_df(x)) cli::cli_abort("`x` must be an `epi_df`.") + if (!is_epi_df(x)) cli_abort("`x` must be an `epi_df`.") arg_is_lgl(fill_locf) arg_is_scalar(fill_locf) arg_is_scalar(n_recent, allow_null = TRUE) @@ -60,16 +59,16 @@ get_test_data <- function( check <- hardhat::check_column_names(x, colnames(recipe$template)) if (!check$ok) { - cli::cli_abort(c( + cli_abort(c( "Some variables used for training are not available in {.arg x}.", i = "The following required columns are missing: {check$missing_names}" )) } if (class(forecast_date) != class(x$time_value)) { - cli::cli_abort("`forecast_date` must be the same class as `x$time_value`.") + cli_abort("`forecast_date` must be the same class as `x$time_value`.") } if (forecast_date < max(x$time_value)) { - cli::cli_abort("`forecast_date` must be no earlier than `max(x$time_value)`") + cli_abort("`forecast_date` must be no earlier than `max(x$time_value)`") } min_lags <- min(map_dbl(recipe$steps, ~ min(.x$lag %||% Inf)), Inf) @@ -84,7 +83,7 @@ get_test_data <- function( # Probably needs a fix based on the time_type of the epi_df avail_recent <- diff(range(x$time_value)) if (avail_recent < min_required) { - cli::cli_abort(c( + cli_abort(c( "You supplied insufficient recent data for this recipe. ", "!" = "You need at least {min_required} days of data,", "!" = "but `x` contains only {avail_recent}." @@ -97,39 +96,37 @@ get_test_data <- function( # If we skip NA completion, we remove undesirably early time values # Happens globally, over all groups keep <- max(n_recent, min_required + 1) - x <- dplyr::filter(x, forecast_date - time_value <= keep) + x <- filter(x, forecast_date - time_value <= keep) # Pad with explicit missing values up to and including the forecast_date # x is grouped here x <- pad_to_end(x, groups, forecast_date) %>% - group_by(dplyr::across(dplyr::all_of(groups))) + group_by(across(all_of(groups))) # If all(lags > 0), then we get rid of recent data if (min_lags > 0 && min_lags < Inf) { - x <- dplyr::filter(x, forecast_date - time_value >= min_lags) + x <- filter(x, forecast_date - time_value >= min_lags) } # Now, fill forward missing data if requested if (fill_locf) { cannot_be_used <- x %>% - dplyr::filter(forecast_date - time_value <= n_recent) %>% - dplyr::mutate(fillers = forecast_date - time_value > min_required) %>% - dplyr::summarize( - dplyr::across( - -dplyr::any_of(key_colnames(recipe)), + filter(forecast_date - time_value <= n_recent) %>% + mutate(fillers = forecast_date - time_value > min_required) %>% + summarize( + across( + -any_of(key_colnames(recipe)), ~ all(is.na(.x[fillers])) & is.na(head(.x[!fillers], 1)) ), .groups = "drop" ) %>% - dplyr::select(-fillers) %>% - dplyr::summarise(dplyr::across( - -dplyr::any_of(key_colnames(recipe)), ~ any(.x) - )) %>% + select(-fillers) %>% + summarise(across(-any_of(key_colnames(recipe)), ~ any(.x))) %>% unlist() if (any(cannot_be_used)) { bad_vars <- names(cannot_be_used)[cannot_be_used] if (recipes::is_trained(recipe)) { - cli::cli_abort(c( + cli_abort(c( "The variables {.var {bad_vars}} have too many recent missing", `!` = "values to be filled automatically. ", i = "You should either choose `n_recent` larger than its current ", @@ -141,15 +138,15 @@ get_test_data <- function( x <- tidyr::fill(x, !time_value) } - dplyr::filter(x, forecast_date - time_value <= min_required) %>% + filter(x, forecast_date - time_value <= min_required) %>% ungroup() } pad_to_end <- function(x, groups, end_date) { itval <- guess_period(c(x$time_value, end_date), "time_value") completed_time_values <- x %>% - dplyr::group_by(dplyr::across(dplyr::all_of(groups))) %>% - dplyr::summarise( + group_by(across(all_of(groups))) %>% + summarise( time_value = rlang::list2( time_value = Seq(max(time_value) + itval, end_date, itval) ) @@ -157,8 +154,8 @@ pad_to_end <- function(x, groups, end_date) { unnest("time_value") %>% mutate(time_value = vctrs::vec_cast(time_value, x$time_value)) - dplyr::bind_rows(x, completed_time_values) %>% - dplyr::arrange(dplyr::across(dplyr::all_of(c("time_value", groups)))) + bind_rows(x, completed_time_values) %>% + arrange(across(all_of(c("time_value", groups)))) } Seq <- function(from, to, by) { diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index c4bb7d483..02395f960 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -19,15 +19,16 @@ #' #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' latest <- jhu %>% -#' dplyr::filter(time_value >= max(time_value) - 14) +#' filter(time_value >= max(time_value) - 14) #' #' # Don't specify `forecast_date` (by default, this should be last date in latest) #' f <- frosting() %>% @@ -85,7 +86,8 @@ layer_add_forecast_date_new <- function(forecast_date, id) { } #' @export -slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { +slather.layer_add_forecast_date <- function(object, components, workflow, + new_data, ...) { rlang::check_dots_empty() if (is.null(object$forecast_date)) { max_time_value <- as.Date(max( @@ -102,12 +104,13 @@ slather.layer_add_forecast_date <- function(object, components, workflow, new_da workflows::extract_preprocessor(workflow)$template, "metadata" )$time_type if (expected_time_type == "week") expected_time_type <- "day" - validate_date(forecast_date, expected_time_type, + validate_date( + forecast_date, expected_time_type, call = rlang::expr(layer_add_forecast_date()) ) forecast_date <- coerce_time_type(forecast_date, expected_time_type) object$forecast_date <- forecast_date - components$predictions <- dplyr::bind_cols( + components$predictions <- bind_cols( components$predictions, forecast_date = forecast_date ) diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index 23aeb4091..834deb82b 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -20,14 +20,15 @@ #' #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' dfilter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' # Use ahead + forecast date #' f <- frosting() %>% @@ -79,7 +80,8 @@ layer_add_target_date_new <- function(id = id, target_date = target_date) { } #' @export -slather.layer_add_target_date <- function(object, components, workflow, new_data, ...) { +slather.layer_add_target_date <- function(object, components, workflow, + new_data, ...) { rlang::check_dots_empty() the_recipe <- workflows::extract_recipe(workflow) the_frosting <- extract_frosting(workflow) @@ -91,7 +93,8 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data if (!is.null(object$target_date)) { target_date <- object$target_date - validate_date(target_date, expected_time_type, + validate_date( + target_date, expected_time_type, call = expr(layer_add_target_date()) ) target_date <- coerce_time_type(target_date, expected_time_type) @@ -100,7 +103,8 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data !is.null(forecast_date <- extract_argument( the_frosting, "layer_add_forecast_date", "forecast_date" ))) { - validate_date(forecast_date, expected_time_type, + validate_date( + forecast_date, expected_time_type, call = rlang::expr(layer_add_forecast_date()) ) forecast_date <- coerce_time_type(forecast_date, expected_time_type) @@ -117,7 +121,8 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data } object$target_date <- target_date - components$predictions <- dplyr::bind_cols(components$predictions, + components$predictions <- bind_cols( + components$predictions, target_date = target_date ) components diff --git a/R/layer_cdc_flatline_quantiles.R b/R/layer_cdc_flatline_quantiles.R index db1440b03..9166d3469 100644 --- a/R/layer_cdc_flatline_quantiles.R +++ b/R/layer_cdc_flatline_quantiles.R @@ -55,6 +55,7 @@ #' @export #' #' @examples +#' library(dplyr) #' r <- epi_recipe(case_death_rate_subset) %>% #' # data is "daily", so we fit this to 1 ahead, the result will contain #' # 1 day ahead residuals @@ -68,16 +69,16 @@ #' layer_predict() %>% #' layer_cdc_flatline_quantiles(aheads = c(7, 14, 21, 28), symmetrize = TRUE) #' -#' eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline") +#' eng <- linear_reg(engine = "flatline") #' #' wf <- epi_workflow(r, eng, f) %>% fit(case_death_rate_subset) #' preds <- forecast(wf) %>% -#' dplyr::select(-time_value) %>% -#' dplyr::mutate(forecast_date = forecast_date) +#' select(-time_value) %>% +#' mutate(forecast_date = forecast_date) #' preds #' #' preds <- preds %>% -#' unnest(.pred_distn_all) %>% +#' tidyr::unnest(.pred_distn_all) %>% #' pivot_quantiles_wider(.pred_distn) %>% #' mutate(target_date = forecast_date + ahead) #' @@ -162,12 +163,10 @@ slather.layer_cdc_flatline_quantiles <- } the_fit <- workflows::extract_fit_parsnip(workflow) if (!inherits(the_fit, "_flatline")) { - cli::cli_warn( - c( - "Predictions for this workflow were not produced by the {.cls flatline}", - "{.pkg parsnip} engine. Results may be unexpected. See {.fn epipredict::flatline}." - ) - ) + cli::cli_warn(c( + "Predictions for this workflow were not produced by the {.cls flatline}", + "{.pkg parsnip} engine. Results may be unexpected. See {.fn epipredict::flatline}." + )) } p <- components$predictions ek <- kill_time_value(key_colnames(components$mold)) @@ -196,7 +195,7 @@ slather.layer_cdc_flatline_quantiles <- c(cols_in_preds$missing_names, cols_in_resids$missing_names) )) } else { # not flatline, but we'll try - key_cols <- dplyr::bind_cols( + key_cols <- bind_cols( geo_value = components$mold$extras$roles$geo_value, components$mold$extras$roles$key ) @@ -211,26 +210,26 @@ slather.layer_cdc_flatline_quantiles <- object$by_key, c(cols_in_preds$missing_names, cols_in_resids$missing_names) )) - r <- dplyr::bind_cols(key_cols, r) + r <- bind_cols(key_cols, r) } } r <- r %>% - dplyr::select(tidyselect::all_of(c(avail_grps, ".resid"))) %>% - dplyr::group_by(!!!rlang::syms(avail_grps)) %>% - dplyr::summarise(.resid = list(.resid), .groups = "drop") + select(all_of(c(avail_grps, ".resid"))) %>% + group_by(!!!rlang::syms(avail_grps)) %>% + summarise(.resid = list(.resid), .groups = "drop") - res <- dplyr::left_join(p, r, by = avail_grps) %>% + res <- left_join(p, r, by = avail_grps) %>% dplyr::rowwise() %>% - dplyr::mutate( + mutate( .pred_distn_all = propagate_samples( .resid, .pred, object$quantile_levels, object$aheads, object$nsim, object$symmetrize, object$nonneg ) ) %>% - dplyr::select(tidyselect::all_of(c(avail_grps, ".pred_distn_all"))) + select(all_of(c(avail_grps, ".pred_distn_all"))) # res <- check_pname(res, components$predictions, object) - components$predictions <- dplyr::left_join( + components$predictions <- left_join( components$predictions, res, by = avail_grps @@ -267,7 +266,7 @@ propagate_samples <- function( } } res <- res[aheads] - list(tibble::tibble( + list(tibble( ahead = aheads, .pred_distn = map_vec( res, ~ dist_quantiles(quantile(.x, quantile_levels), quantile_levels) diff --git a/R/layer_naomit.R b/R/layer_naomit.R index 85842bfdf..209a663b4 100644 --- a/R/layer_naomit.R +++ b/R/layer_naomit.R @@ -11,14 +11,15 @@ #' @return an updated `frosting` postprocessor #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' f <- frosting() %>% #' layer_predict() %>% @@ -33,7 +34,7 @@ layer_naomit <- function(frosting, ..., id = rand_id("naomit")) { add_layer( frosting, layer_naomit_new( - terms = dplyr::enquos(...), + terms = enquos(...), id = id ) ) @@ -50,7 +51,7 @@ slather.layer_naomit <- function(object, components, workflow, new_data, ...) { pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) components$predictions <- components$predictions %>% - dplyr::filter(dplyr::if_any(dplyr::all_of(col_names), ~ !is.na(.x))) + filter(dplyr::if_any(all_of(col_names), ~ !is.na(.x))) components } diff --git a/R/layer_point_from_distn.R b/R/layer_point_from_distn.R index f415e7bd4..f14008748 100644 --- a/R/layer_point_from_distn.R +++ b/R/layer_point_from_distn.R @@ -16,15 +16,17 @@ #' @export #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) %>% fit(jhu) +#' wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) %>% +#' fit(jhu) #' #' f1 <- frosting() %>% #' layer_predict() %>% @@ -91,9 +93,9 @@ slather.layer_point_from_distn <- if (is.null(object$name)) { components$predictions$.pred <- dstn } else { - dstn <- tibble::tibble(dstn = dstn) + dstn <- tibble(dstn = dstn) dstn <- check_pname(dstn, components$predictions, object) - components$predictions <- dplyr::mutate(components$predictions, !!!dstn) + components$predictions <- mutate(components$predictions, !!!dstn) } components } diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 3829eec5a..1b940b804 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -47,9 +47,10 @@ #' @return an updated `frosting` postprocessor #' @export #' @examples -#' jhu <- epiprocess::jhu_csse_daily_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% -#' dplyr::select(geo_value, time_value, cases) +#' library(dplyr) +#' jhu <- jhu_csse_daily_subset %>% +#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% +#' select(geo_value, time_value, cases) #' #' pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000)) #' @@ -74,7 +75,7 @@ #' df_pop_col = "value" #' ) #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% +#' wf <- epi_workflow(r, linear_reg()) %>% #' fit(jhu) %>% #' add_frosting(f) #' @@ -93,7 +94,7 @@ layer_population_scaling <- function(frosting, arg_is_chr(df_pop_col, suffix, id) arg_is_chr(by, allow_null = TRUE) if (rate_rescaling <= 0) { - cli_stop("`rate_rescaling` should be a positive number") + cli_abort("`rate_rescaling` must be a positive number.") } add_layer( @@ -134,24 +135,12 @@ slather.layer_population_scaling <- ) rlang::check_dots_empty() - if (is.null(object$by)) { - object$by <- intersect( - kill_time_value(key_colnames(components$predictions)), - colnames(dplyr::select(object$df, !object$df_pop_col)) - ) - } - try_join <- try( - dplyr::left_join(components$predictions, object$df, - by = object$by - ), - silent = TRUE + object$by <- object$by %||% intersect( + kill_time_value(key_colnames(components$predictions)), + colnames(select(object$df, !object$df_pop_col)) ) - if (any(grepl("Join columns must be present in data", unlist(try_join)))) { - cli_stop(c( - "columns in `by` selectors of `layer_population_scaling` ", - "must be present in data and match" - )) - } + hardhat::validate_column_names(components$predictions, object$by) + hardhat::validate_column_names(object$df, object$by) # object$df <- object$df %>% # dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower)) @@ -162,18 +151,18 @@ slather.layer_population_scaling <- suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(components$predictions)) - components$predictions <- dplyr::left_join( + components$predictions <- left_join( components$predictions, object$df, by = object$by, suffix = c("", ".df") ) %>% - dplyr::mutate(dplyr::across( - dplyr::all_of(col_names), + mutate(across( + all_of(col_names), ~ .x * !!pop_col / object$rate_rescaling, .names = "{.col}{suffix}" )) %>% - dplyr::select(-dplyr::any_of(col_to_remove)) + select(-any_of(col_to_remove)) components } diff --git a/R/layer_predict.R b/R/layer_predict.R index 46d81be18..6ca17ac24 100644 --- a/R/layer_predict.R +++ b/R/layer_predict.R @@ -16,6 +16,7 @@ #' @export #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% #' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' @@ -24,7 +25,7 @@ #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' latest <- jhu %>% filter(time_value >= max(time_value) - 14) #' #' # Predict layer alone @@ -90,9 +91,7 @@ slather.layer_predict <- function(object, components, workflow, new_data, type = opts = c(object$opts, opts), !!!object$dots_list, ... )) - components$predictions <- dplyr::bind_cols( - components$keys, components$predictions - ) + components$predictions <- bind_cols(components$keys, components$predictions) components } diff --git a/R/layer_predictive_distn.R b/R/layer_predictive_distn.R index 9b1a160e1..b28e0c765 100644 --- a/R/layer_predictive_distn.R +++ b/R/layer_predictive_distn.R @@ -20,15 +20,16 @@ #' @export #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' f <- frosting() %>% #' layer_predict() %>% @@ -91,9 +92,9 @@ slather.layer_predictive_distn <- if (!all(is.infinite(truncate))) { dstn <- distributional::dist_truncated(dstn, truncate[1], truncate[2]) } - dstn <- tibble::tibble(dstn = dstn) + dstn <- tibble(dstn = dstn) dstn <- check_pname(dstn, components$predictions, object) - components$predictions <- dplyr::mutate(components$predictions, !!!dstn) + components$predictions <- mutate(components$predictions, !!!dstn) components } diff --git a/R/layer_quantile_distn.R b/R/layer_quantile_distn.R index ea96969da..5f87ded29 100644 --- a/R/layer_quantile_distn.R +++ b/R/layer_quantile_distn.R @@ -22,8 +22,9 @@ #' @export #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% @@ -95,9 +96,9 @@ slather.layer_quantile_distn <- if (!all(is.infinite(truncate))) { dstn <- snap(dstn, truncate[1], truncate[2]) } - dstn <- tibble::tibble(dstn = dstn) + dstn <- tibble(dstn = dstn) dstn <- check_pname(dstn, components$predictions, object) - components$predictions <- dplyr::mutate(components$predictions, !!!dstn) + components$predictions <- mutate(components$predictions, !!!dstn) components } diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index 85c1c6ed0..dc53f0f60 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -14,19 +14,23 @@ #' residual quantiles added to the prediction #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' f <- frosting() %>% #' layer_predict() %>% -#' layer_residual_quantiles(quantile_levels = c(0.0275, 0.975), symmetrize = FALSE) %>% +#' layer_residual_quantiles( +#' quantile_levels = c(0.0275, 0.975), +#' symmetrize = FALSE +#' ) %>% #' layer_naomit(.pred) #' wf1 <- wf %>% add_frosting(f) #' @@ -34,7 +38,10 @@ #' #' f2 <- frosting() %>% #' layer_predict() %>% -#' layer_residual_quantiles(quantile_levels = c(0.3, 0.7), by_key = "geo_value") %>% +#' layer_residual_quantiles( +#' quantile_levels = c(0.3, 0.7), +#' by_key = "geo_value" +#' ) %>% #' layer_naomit(.pred) #' wf2 <- wf %>% add_frosting(f2) #' @@ -88,7 +95,7 @@ slather.layer_residual_quantiles <- ## Handle any grouping requests if (length(object$by_key) > 0L) { - key_cols <- dplyr::bind_cols( + key_cols <- bind_cols( geo_value = components$mold$extras$roles$geo_value, components$mold$extras$roles$key ) @@ -101,23 +108,23 @@ slather.layer_residual_quantiles <- )) } if (length(common) > 0L) { - r <- r %>% dplyr::select(tidyselect::any_of(c(common, ".resid"))) + r <- r %>% select(any_of(c(common, ".resid"))) common_in_r <- common[common %in% names(r)] if (length(common_in_r) == length(common)) { - r <- dplyr::left_join(key_cols, r, by = common_in_r) + r <- left_join(key_cols, r, by = common_in_r) } else { cli::cli_warn(c( "Some grouping keys are not in data.frame returned by the", "`residuals()` method. Groupings may not be correct." )) - r <- dplyr::bind_cols(key_cols, r %>% dplyr::select(.resid)) %>% - dplyr::group_by(!!!rlang::syms(common)) + r <- bind_cols(key_cols, select(r, .resid)) %>% + group_by(!!!rlang::syms(common)) } } } r <- r %>% - dplyr::summarize( + summarize( dstn = list(quantile( c(.resid, s * .resid), probs = object$quantile_levels, na.rm = TRUE @@ -132,11 +139,11 @@ slather.layer_residual_quantiles <- } estimate <- components$predictions$.pred - res <- tibble::tibble( + res <- tibble( .pred_distn = dist_quantiles(map2(estimate, r$dstn, "+"), object$quantile_levels) ) res <- check_pname(res, components$predictions, object) - components$predictions <- dplyr::mutate(components$predictions, !!!res) + components$predictions <- mutate(components$predictions, !!!res) components } diff --git a/R/layer_threshold_preds.R b/R/layer_threshold_preds.R index 8b2b56d1e..56f8059ab 100644 --- a/R/layer_threshold_preds.R +++ b/R/layer_threshold_preds.R @@ -22,15 +22,14 @@ #' @return an updated `frosting` postprocessor #' @export #' @examples - +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value < "2021-03-08", -#' geo_value %in% c("ak", "ca", "ar")) +#' filter(time_value < "2021-03-08", geo_value %in% c("ak", "ca", "ar")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) #' #' f <- frosting() %>% #' layer_predict() %>% @@ -46,7 +45,7 @@ layer_threshold <- add_layer( frosting, layer_threshold_new( - terms = dplyr::enquos(...), + terms = enquos(...), lower = lower, upper = upper, id = id @@ -103,12 +102,7 @@ slather.layer_threshold <- pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) components$predictions <- components$predictions %>% - dplyr::mutate( - dplyr::across( - dplyr::all_of(col_names), - ~ snap(.x, object$lower, object$upper) - ) - ) + mutate(across(all_of(col_names), ~ snap(.x, object$lower, object$upper))) components } diff --git a/R/layer_unnest.R b/R/layer_unnest.R index dfc391942..a6fc9f0af 100644 --- a/R/layer_unnest.R +++ b/R/layer_unnest.R @@ -15,7 +15,7 @@ layer_unnest <- function(frosting, ..., id = rand_id("unnest")) { add_layer( frosting, layer_unnest_new( - terms = dplyr::enquos(...), + terms = enquos(...), id = id ) ) diff --git a/R/layers.R b/R/layers.R index b59e95cdd..aa515a917 100644 --- a/R/layers.R +++ b/R/layers.R @@ -41,15 +41,15 @@ layer <- function(subclass, ..., .prefix = "layer_") { #' in the layer, and the values are the new values to update the layer with. #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% #' step_epi_naomit() -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) -#' latest <- jhu %>% -#' dplyr::filter(time_value >= max(time_value) - 14) +#' wf <- epi_workflow(r, linear_reg()) %>% fit(jhu) +#' latest <- jhu %>% filter(time_value >= max(time_value) - 14) #' #' # Specify a `forecast_date` that is greater than or equal to `as_of` date #' f <- frosting() %>% diff --git a/R/model-methods.R b/R/model-methods.R index 607b04234..f3b374879 100644 --- a/R/model-methods.R +++ b/R/model-methods.R @@ -32,11 +32,9 @@ #' #' @export #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter( -#' time_value > "2021-11-01", -#' geo_value %in% c("ak", "ca", "ny") -#' ) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% @@ -49,7 +47,7 @@ #' wf <- wf %>% Add_model(rf_model) #' wf #' -#' lm_model <- parsnip::linear_reg() +#' lm_model <- linear_reg() #' #' wf <- Update_model(wf, lm_model) #' wf diff --git a/R/pivot_quantiles.R b/R/pivot_quantiles.R index 70e51da8e..c8601b4f6 100644 --- a/R/pivot_quantiles.R +++ b/R/pivot_quantiles.R @@ -6,16 +6,18 @@ #' @export #' #' @examples +#' library(dplyr) +#' library(tidyr) #' edf <- case_death_rate_subset[1:3, ] #' edf$q <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11)) #' -#' edf_nested <- edf %>% dplyr::mutate(q = nested_quantiles(q)) -#' edf_nested %>% tidyr::unnest(q) +#' edf_nested <- edf %>% mutate(q = nested_quantiles(q)) +#' edf_nested %>% unnest(q) nested_quantiles <- function(x) { stopifnot(is_dist_quantiles(x)) distributional:::dist_apply(x, .f = function(z) { - tibble::as_tibble(vec_data(z)) %>% - dplyr::mutate(dplyr::across(tidyselect::everything(), as.double)) %>% + as_tibble(vec_data(z)) %>% + mutate(across(everything(), as.double)) %>% vctrs::list_of() }) } @@ -47,31 +49,26 @@ nested_quantiles <- function(x) { #' @examples #' d1 <- c(dist_quantiles(1:3, 1:3 / 4), dist_quantiles(2:4, 1:3 / 4)) #' d2 <- c(dist_quantiles(2:4, 2:4 / 5), dist_quantiles(3:5, 2:4 / 5)) -#' tib <- tibble::tibble(g = c("a", "b"), d1 = d1, d2 = d2) +#' tib <- tibble(g = c("a", "b"), d1 = d1, d2 = d2) #' #' pivot_quantiles_longer(tib, "d1") -#' pivot_quantiles_longer(tib, tidyselect::ends_with("1")) +#' pivot_quantiles_longer(tib, dplyr::ends_with("1")) #' pivot_quantiles_longer(tib, d1, d2) pivot_quantiles_longer <- function(.data, ..., .ignore_length_check = FALSE) { cols <- validate_pivot_quantiles(.data, ...) - .data <- .data %>% - dplyr::mutate(dplyr::across(tidyselect::all_of(cols), nested_quantiles)) + .data <- .data %>% mutate(across(all_of(cols), nested_quantiles)) if (length(cols) > 1L) { lengths_check <- .data %>% - dplyr::transmute(dplyr::across( - tidyselect::all_of(cols), - ~ map_int(.x, vctrs::vec_size) - )) %>% + dplyr::transmute(across(all_of(cols), ~ map_int(.x, vctrs::vec_size))) %>% as.matrix() %>% apply(1, function(x) dplyr::n_distinct(x) == 1L) %>% all() if (lengths_check) { - .data <- tidyr::unnest(.data, tidyselect::all_of(cols), names_sep = "_") + .data <- tidyr::unnest(.data, all_of(cols), names_sep = "_") } else { if (.ignore_length_check) { for (col in cols) { - .data <- .data %>% - tidyr::unnest(tidyselect::all_of(col), names_sep = "_") + .data <- .data %>% tidyr::unnest(all_of(col), names_sep = "_") } } else { cli::cli_abort(c( @@ -82,7 +79,7 @@ pivot_quantiles_longer <- function(.data, ..., .ignore_length_check = FALSE) { } } } else { - .data <- .data %>% tidyr::unnest(tidyselect::all_of(cols)) + .data <- .data %>% tidyr::unnest(all_of(cols)) } .data } @@ -110,20 +107,18 @@ pivot_quantiles_longer <- function(.data, ..., .ignore_length_check = FALSE) { #' tib <- tibble::tibble(g = c("a", "b"), d1 = d1, d2 = d2) #' #' pivot_quantiles_wider(tib, c("d1", "d2")) -#' pivot_quantiles_wider(tib, tidyselect::starts_with("d")) +#' pivot_quantiles_wider(tib, dplyr::starts_with("d")) #' pivot_quantiles_wider(tib, d2) pivot_quantiles_wider <- function(.data, ...) { cols <- validate_pivot_quantiles(.data, ...) - .data <- .data %>% - dplyr::mutate(dplyr::across(tidyselect::all_of(cols), nested_quantiles)) + .data <- .data %>% mutate(across(all_of(cols), nested_quantiles)) checks <- map_lgl(cols, ~ diff(range(vctrs::list_sizes(.data[[.x]]))) == 0L) if (!all(checks)) { nms <- cols[!checks] - cli::cli_abort( - c("Quantiles must be the same length and have the same set of taus.", - i = "Check failed for variables(s) {.var {nms}}." - ) - ) + cli::cli_abort(c( + "Quantiles must be the same length and have the same set of taus.", + i = "Check failed for variables(s) {.var {nms}}." + )) } # tidyr::pivot_wider can crash if there are duplicates, this generally won't @@ -134,7 +129,7 @@ pivot_quantiles_wider <- function(.data, ...) { if (length(cols) > 1L) { for (col in cols) { .data <- .data %>% - tidyr::unnest(tidyselect::all_of(col)) %>% + tidyr::unnest(all_of(col)) %>% tidyr::pivot_wider( names_from = "quantile_levels", values_from = "values", names_prefix = paste0(col, "_") @@ -142,10 +137,10 @@ pivot_quantiles_wider <- function(.data, ...) { } } else { .data <- .data %>% - tidyr::unnest(tidyselect::all_of(cols)) %>% + tidyr::unnest(all_of(cols)) %>% tidyr::pivot_wider(names_from = "quantile_levels", values_from = "values") } - dplyr::select(.data, -.hidden_index) + select(.data, -.hidden_index) } pivot_quantiles <- function(.data, ...) { diff --git a/R/reexports-tidymodels.R b/R/reexports-tidymodels.R index 2c69139a2..3b28ac5c5 100644 --- a/R/reexports-tidymodels.R +++ b/R/reexports-tidymodels.R @@ -13,3 +13,15 @@ recipes::prep #' @importFrom recipes bake #' @export recipes::bake + +#' @importFrom recipes rand_id +#' @export +recipes::rand_id + +#' @importFrom tibble tibble +#' @export +tibble::tibble + +#' @importFrom generics tidy +#' @export +generics::tidy diff --git a/R/step_epi_shift.R b/R/step_epi_shift.R index f45f5d8f4..616d98f03 100644 --- a/R/step_epi_shift.R +++ b/R/step_epi_shift.R @@ -15,16 +15,12 @@ #' for this step. See [recipes::selections()] for more details. #' @param role For model terms created by this step, what analysis role should #' they be assigned? `lag` is default a predictor while `ahead` is an outcome. -#' @param trained A logical to indicate if the quantities for -#' preprocessing have been estimated. #' @param lag,ahead A vector of integers. Each specified column will #' be the lag or lead for each value in the vector. Lag integers must be #' nonnegative, while ahead integers must be positive. -#' @param prefix A prefix to indicate what type of variable this is +#' @param prefix A character string that will be prefixed to the new column. #' @param default Determines what fills empty rows #' left by leading/lagging (defaults to NA). -#' @param columns A character string of variable names that will -#' be populated (eventually) by the `terms` argument. #' @param skip A logical. Should the step be skipped when the #' recipe is baked by [bake()]? While all operations are baked #' when [prep()] is run, some operations may not be able to be @@ -55,42 +51,34 @@ step_epi_lag <- ..., lag, role = "predictor", - trained = FALSE, prefix = "lag_", default = NA, - columns = NULL, skip = FALSE, id = rand_id("epi_lag")) { if (!is_epi_recipe(recipe)) { - cli::cli_abort("This step can only operate on an `epi_recipe`.") + cli_abort("This step can only operate on an `epi_recipe`.") } if (missing(lag)) { - cli::cli_abort( - c("The `lag` argument must not be empty.", - i = "Did you perhaps pass an integer in `...` accidentally?" - ) - ) + cli_abort(c( + "The `lag` argument must not be empty.", + i = "Did you perhaps pass an integer in `...` accidentally?" + )) } arg_is_nonneg_int(lag) arg_is_chr_scalar(prefix, id) - if (!is.null(columns)) { - cli::cli_abort(c( - "The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to lag." - )) - } - add_step( + + recipes::add_step( recipe, step_epi_lag_new( - terms = dplyr::enquos(...), + terms = enquos(...), role = role, - trained = trained, + trained = FALSE, lag = as.integer(lag), prefix = prefix, default = default, keys = key_colnames(recipe), - columns = columns, + columns = NULL, skip = skip, id = id ) @@ -107,40 +95,34 @@ step_epi_ahead <- ..., ahead, role = "outcome", - trained = FALSE, prefix = "ahead_", default = NA, - columns = NULL, skip = FALSE, id = rand_id("epi_ahead")) { if (!is_epi_recipe(recipe)) { - cli::cli_abort("This step can only operate on an `epi_recipe`.") + cli_abort("This step can only operate on an `epi_recipe`.") } if (missing(ahead)) { - cli::cli_abort(c( + cli_abort(c( "The `ahead` argument must not be empty.", i = "Did you perhaps pass an integer in `...` accidentally?" )) } arg_is_nonneg_int(ahead) arg_is_chr_scalar(prefix, id) - if (!is.null(columns)) { - rlang::abort(c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to lead." - )) - } - add_step( + + recipes::add_step( recipe, step_epi_ahead_new( terms = enquos(...), role = role, - trained = trained, + trained = FALSE, ahead = as.integer(ahead), prefix = prefix, default = default, keys = key_colnames(recipe), - columns = columns, + columns = NULL, skip = skip, id = id ) @@ -151,7 +133,7 @@ step_epi_ahead <- step_epi_lag_new <- function(terms, role, trained, lag, prefix, default, keys, columns, skip, id) { - step( + recipes::step( subclass = "epi_lag", terms = terms, role = role, @@ -169,7 +151,7 @@ step_epi_lag_new <- step_epi_ahead_new <- function(terms, role, trained, ahead, prefix, default, keys, columns, skip, id) { - step( + recipes::step( subclass = "epi_ahead", terms = terms, role = role, @@ -196,7 +178,7 @@ prep.step_epi_lag <- function(x, training, info = NULL, ...) { prefix = x$prefix, default = x$default, keys = x$keys, - columns = recipes_eval_select(x$terms, training, info), + columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id ) @@ -212,7 +194,7 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) { prefix = x$prefix, default = x$default, keys = x$keys, - columns = recipes_eval_select(x$terms, training, info), + columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id ) @@ -223,7 +205,7 @@ prep.step_epi_ahead <- function(x, training, info = NULL, ...) { #' @export bake.step_epi_lag <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, lag = object$lag) %>% - dplyr::mutate( + mutate( newname = glue::glue("{object$prefix}{lag}_{col}"), shift_val = lag, lag = NULL @@ -233,32 +215,28 @@ bake.step_epi_lag <- function(object, new_data, ...) { new_data_names <- colnames(new_data) intersection <- new_data_names %in% grid$newname if (any(intersection)) { - rlang::abort( - paste0( - "Name collision occured in `", class(object)[1], - "`. The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - "." - ) - ) + cli_abort(c( + "Name collision occured in {.cls {class(object)[1]}}", + "The following variable name{?s} already exist{?s/}: {.val {new_data_names[intersection]}}." + )) } ok <- object$keys shifted <- reduce( pmap(grid, epi_shift_single, x = new_data, key_cols = ok), - dplyr::full_join, + full_join, by = ok ) - dplyr::full_join(new_data, shifted, by = ok) %>% - dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>% - dplyr::arrange(time_value) %>% - dplyr::ungroup() + full_join(new_data, shifted, by = ok) %>% + group_by(across(all_of(ok[-1]))) %>% + arrange(time_value) %>% + ungroup() } #' @export bake.step_epi_ahead <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, ahead = object$ahead) %>% - dplyr::mutate( + mutate( newname = glue::glue("{object$prefix}{ahead}_{col}"), shift_val = -ahead, ahead = NULL @@ -268,26 +246,22 @@ bake.step_epi_ahead <- function(object, new_data, ...) { new_data_names <- colnames(new_data) intersection <- new_data_names %in% grid$newname if (any(intersection)) { - rlang::abort( - paste0( - "Name collision occured in `", class(object)[1], - "`. The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - "." - ) - ) + cli_abort(c( + "Name collision occured in {.cls {class(object)[1]}}", + "The following variable name{?s} already exist{?s/}: {.val {new_data_names[intersection]}}." + )) } ok <- object$keys shifted <- reduce( pmap(grid, epi_shift_single, x = new_data, key_cols = ok), - dplyr::full_join, + full_join, by = ok ) - dplyr::full_join(new_data, shifted, by = ok) %>% - dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>% - dplyr::arrange(time_value) %>% - dplyr::ungroup() + full_join(new_data, shifted, by = ok) %>% + group_by(across(all_of(ok[-1]))) %>% + arrange(time_value) %>% + ungroup() } diff --git a/R/step_epi_slide.R b/R/step_epi_slide.R index a8ad66c85..180be3d51 100644 --- a/R/step_epi_slide.R +++ b/R/step_epi_slide.R @@ -22,7 +22,6 @@ #' @param before,after the size of the sliding window on the left and the right #' of the center. Usually non-negative integers for data indexed by date, but #' more restrictive in other cases (see [epiprocess::epi_slide()] for details). -#' @param prefix A character string that will be prefixed to the new column. #' @param f_name a character string of at most 20 characters that describes #' the function. This will be combined with `prefix` and the columns in `...` #' to name the result using `{prefix}{f_name}_{column}`. By default it will be determined @@ -53,7 +52,7 @@ step_epi_slide <- skip = FALSE, id = rand_id("epi_slide")) { if (!is_epi_recipe(recipe)) { - rlang::abort("This recipe step can only operate on an `epi_recipe`.") + cli_abort("This recipe step can only operate on an {.cls epi_recipe}.") } .f <- validate_slide_fun(.f) epiprocess:::validate_slide_window_arg(before, attributes(recipe$template)$metadata$time_type) @@ -61,7 +60,7 @@ step_epi_slide <- arg_is_chr_scalar(role, prefix, id) arg_is_lgl_scalar(skip) - add_step( + recipes::add_step( recipe, step_epi_slide_new( terms = enquos(...), @@ -94,7 +93,7 @@ step_epi_slide_new <- columns, skip, id) { - step( + recipes::step( subclass = "epi_slide", terms = terms, before = before, @@ -116,7 +115,7 @@ step_epi_slide_new <- prep.step_epi_slide <- function(x, training, info = NULL, ...) { col_names <- recipes::recipes_eval_select(x$terms, data = training, info = info) - check_type(training[, col_names], types = c("double", "integer")) + recipes::check_type(training[, col_names], types = c("double", "integer")) step_epi_slide_new( terms = x$terms, @@ -147,7 +146,7 @@ bake.step_epi_slide <- function(object, new_data, ...) { if (any(intersection)) { nms <- new_data_names[intersection] cli_abort( - c("In `step_epi_slide()` a name collision occurred. The following variable names already exist:", + c("In `step_epi_slide()` a name collision occurred. The following variable name{?s} already exist{?/s}:", `*` = "{.var {nms}}" ), call = caller_env(), diff --git a/R/step_growth_rate.R b/R/step_growth_rate.R index 48d8b4394..e1950d208 100644 --- a/R/step_growth_rate.R +++ b/R/step_growth_rate.R @@ -8,13 +8,11 @@ #' @param horizon Bandwidth for the sliding window, when `method` is #' "rel_change" or "linear_reg". See [epiprocess::growth_rate()] for more #' details. -#' @param method Either "rel_change", "linear_reg", "smooth_spline", or -#' "trend_filter", indicating the method to use for the growth rate -#' calculation. The first two are local methods: they are run in a sliding +#' @param method Either "rel_change" or "linear_reg", +#' indicating the method to use for the growth rate +#' calculation. These are local methods: they are run in a sliding #' fashion over the sequence (in order to estimate derivatives and hence -#' growth rates); the latter two are global methods: they are run once over -#' the entire sequence. See [epiprocess::growth_rate()] for more -#' details. +#' growth rates). See [epiprocess::growth_rate()] for more details. #' @param log_scale Should growth rates be estimated using the parameterization #' on the log scale? See details for an explanation. Default is `FALSE`. #' @param replace_Inf Sometimes, the growth rate calculation can result in @@ -45,58 +43,49 @@ step_growth_rate <- function(recipe, ..., role = "predictor", - trained = FALSE, horizon = 7, - method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), + method = c("rel_change", "linear_reg"), log_scale = FALSE, replace_Inf = NA, prefix = "gr_", - columns = NULL, skip = FALSE, id = rand_id("growth_rate"), additional_gr_args_list = list()) { if (!is_epi_recipe(recipe)) { - rlang::abort("This recipe step can only operate on an `epi_recipe`.") + cli_abort("This recipe step can only operate on an {.cls epi_recipe}.") } - method <- match.arg(method) + method <- rlang::arg_match(method) arg_is_pos_int(horizon) arg_is_scalar(horizon) if (!is.null(replace_Inf)) { - if (length(replace_Inf) != 1L) rlang::abort("replace_Inf must be a scalar.") + if (length(replace_Inf) != 1L) cli_abort("replace_Inf must be a scalar.") if (!is.na(replace_Inf)) arg_is_numeric(replace_Inf) } arg_is_chr(role) arg_is_chr_scalar(prefix, id) - arg_is_lgl_scalar(trained, log_scale, skip) + arg_is_lgl_scalar(log_scale, skip) if (!is.list(additional_gr_args_list)) { - rlang::abort( - c("`additional_gr_args_list` must be a list.", - i = "See `?epiprocess::growth_rate` for available options." - ) - ) - } - - if (!is.null(columns)) { - rlang::abort(c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to use." + cli_abort(c( + "`additional_gr_args_list` must be a {.cls list}.", + i = "See `?epiprocess::growth_rate` for available options." )) } - add_step( + recipes::add_step( recipe, step_growth_rate_new( terms = enquos(...), role = role, - trained = trained, + trained = FALSE, horizon = horizon, method = method, log_scale = log_scale, replace_Inf = replace_Inf, prefix = prefix, keys = key_colnames(recipe), - columns = columns, + columns = NULL, skip = skip, id = id, additional_gr_args_list = additional_gr_args_list @@ -119,7 +108,7 @@ step_growth_rate_new <- skip, id, additional_gr_args_list) { - step( + recipes::step( subclass = "growth_rate", terms = terms, role = role, @@ -151,7 +140,7 @@ prep.step_growth_rate <- function(x, training, info = NULL, ...) { replace_Inf = x$replace_Inf, prefix = x$prefix, keys = x$keys, - columns = recipes_eval_select(x$terms, training, info), + columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id, additional_gr_args_list = x$additional_gr_args_list @@ -170,24 +159,23 @@ bake.step_growth_rate <- function(object, new_data, ...) { new_data_names <- colnames(new_data) intersection <- new_data_names %in% newnames if (any(intersection)) { - rlang::abort( - c(paste0("Name collision occured in `", class(object)[1], "`."), - i = paste( - "The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - "." - ) - ) + nms <- new_data_names[intersection] + cli_abort( + c("In `step_growth_rate()` a name collision occurred. The following variable name{?s} already exist{?/s}:", + `*` = "{.var {nms}}" + ), + call = caller_env(), + class = "epipredict__step__name_collision_error" ) } ok <- object$keys gr <- new_data %>% - dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>% + group_by(across(all_of(ok[-1]))) %>% dplyr::transmute( time_value = time_value, - dplyr::across( - dplyr::all_of(object$columns), + across( + all_of(object$columns), ~ epiprocess::growth_rate( time_value, .x, method = object$method, @@ -197,23 +185,18 @@ bake.step_growth_rate <- function(object, new_data, ...) { .names = "{object$prefix}{object$horizon}_{object$method}_{.col}" ) ) %>% - dplyr::ungroup() %>% - dplyr::mutate(time_value = time_value + object$horizon) # shift x0 right + ungroup() %>% + mutate(time_value = time_value + object$horizon) # shift x0 right if (!is.null(object$replace_Inf)) { gr <- gr %>% - dplyr::mutate( - dplyr::across( - !dplyr::all_of(ok), - ~ vec_replace_inf(.x, object$replace_Inf) - ) - ) + mutate(across(all_of(ok), ~ vec_replace_inf(.x, object$replace_Inf))) } - dplyr::left_join(new_data, gr, by = ok) %>% - dplyr::group_by(dplyr::across(dplyr::all_of(ok[-1]))) %>% - dplyr::arrange(time_value) %>% - dplyr::ungroup() + left_join(new_data, gr, by = ok) %>% + group_by(across(all_of(ok[-1]))) %>% + arrange(time_value) %>% + ungroup() } diff --git a/R/step_lag_difference.R b/R/step_lag_difference.R index 87852be2d..009ebe4f5 100644 --- a/R/step_lag_difference.R +++ b/R/step_lag_difference.R @@ -30,35 +30,27 @@ step_lag_difference <- trained = FALSE, horizon = 7, prefix = "lag_diff_", - columns = NULL, skip = FALSE, id = rand_id("lag_diff")) { if (!is_epi_recipe(recipe)) { - rlang::abort("This recipe step can only operate on an `epi_recipe`.") + cli_abort("This recipe step can only operate on an {.cls epi_recipe}.") } arg_is_pos_int(horizon) arg_is_chr(role) arg_is_chr_scalar(prefix, id) - arg_is_lgl_scalar(trained, skip) + arg_is_lgl_scalar(skip) - if (!is.null(columns)) { - rlang::abort( - c("The `columns` argument must be `NULL.", - i = "Use `tidyselect` methods to choose columns to use." - ) - ) - } - add_step( + recipes::add_step( recipe, step_lag_difference_new( terms = enquos(...), role = role, - trained = trained, + trained = FALSE, horizon = horizon, prefix = prefix, keys = key_colnames(recipe), - columns = columns, + columns = NULL, skip = skip, id = id ) @@ -76,7 +68,7 @@ step_lag_difference_new <- columns, skip, id) { - step( + recipes::step( subclass = "lag_difference", terms = terms, role = role, @@ -101,7 +93,7 @@ prep.step_lag_difference <- function(x, training, info = NULL, ...) { horizon = x$horizon, prefix = x$prefix, keys = x$keys, - columns = recipes_eval_select(x$terms, training, info), + columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id ) @@ -109,47 +101,46 @@ prep.step_lag_difference <- function(x, training, info = NULL, ...) { epi_shift_single_diff <- function(x, col, horizon, newname, key_cols) { - x <- x %>% dplyr::select(tidyselect::all_of(c(key_cols, col))) + x <- x %>% select(all_of(c(key_cols, col))) y <- x %>% - dplyr::mutate(time_value = time_value + horizon) %>% - dplyr::rename(!!newname := {{ col }}) - x <- dplyr::left_join(x, y, by = key_cols) + mutate(time_value = time_value + horizon) %>% + rename(!!newname := {{ col }}) + x <- left_join(x, y, by = key_cols) x[, newname] <- x[, col] - x[, newname] - x %>% dplyr::select(tidyselect::all_of(c(key_cols, newname))) + x %>% select(all_of(c(key_cols, newname))) } #' @export bake.step_lag_difference <- function(object, new_data, ...) { grid <- tidyr::expand_grid(col = object$columns, horizon = object$horizon) %>% - dplyr::mutate(newname = glue::glue("{object$prefix}{horizon}_{col}")) + mutate(newname = glue::glue("{object$prefix}{horizon}_{col}")) ## ensure no name clashes new_data_names <- colnames(new_data) intersection <- new_data_names %in% grid$newname if (any(intersection)) { - rlang::abort( - c(paste0("Name collision occured in `", class(object)[1], "`."), - i = paste( - "The following variable names already exists: ", - paste0(new_data_names[intersection], collapse = ", "), - "." - ) - ) + nms <- new_data_names[intersection] + cli_abort( + c("In `step_lag_difference()` a name collision occurred. The following variable name{?s} already exist{?/s}:", + `*` = "{.var {nms}}" + ), + call = caller_env(), + class = "epipredict__step__name_collision_error" ) } ok <- object$keys shifted <- reduce( pmap(grid, epi_shift_single_diff, x = new_data, key_cols = ok), - dplyr::full_join, + full_join, by = ok ) - dplyr::left_join(new_data, shifted, by = ok) %>% - dplyr::group_by(dplyr::across(tidyselect::all_of(ok[-1]))) %>% - dplyr::arrange(time_value) %>% - dplyr::ungroup() + left_join(new_data, shifted, by = ok) %>% + group_by(across(all_of(ok[-1]))) %>% + arrange(time_value) %>% + ungroup() } diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index 7f2d44ab9..e5baa837b 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -11,17 +11,7 @@ #' passed will *divide* the selected variables while the `rate_rescaling` #' argument is a common *multiplier* of the selected variables. #' -#' @param recipe A recipe object. The step will be added to the sequence of -#' operations for this recipe. The recipe should contain information about the -#' `epi_df` such as column names. -#' @param ... One or more selector functions to scale variables -#' for this step. See [recipes::selections()] for more details. -#' @param role For model terms created by this step, what analysis role should -#' they be assigned? By default, the new columns created by this step from the -#' original variables will be used as predictors in a model. Other options can -#' be ard are not limited to "outcome". -#' @param trained A logical to indicate if the quantities for preprocessing -#' have been estimated. +#' @inheritParams step_epi_lag #' @param df a data frame that contains the population data to be used for #' inverting the existing scaling. #' @param by A (possibly named) character vector of variables to join by. @@ -49,25 +39,15 @@ #' @param create_new TRUE to create a new column and keep the original column #' in the `epi_df` #' @param suffix a character. The suffix added to the column name if -#' `crete_new = TRUE`. Default to "_scaled". -#' @param columns A character string of variable names that will -#' be populated (eventually) by the `terms` argument. -#' @param skip A logical. Should the step be skipped when the -#' recipe is baked by [bake()]? While all operations are baked -#' when [prep()] is run, some operations may not be able to be -#' conducted on new data (e.g. processing the outcome variable(s)). -#' Care should be taken when using `skip = TRUE` as it may affect -#' the computations for subsequent operations. -#' @param id A unique identifier for the step +#' `create_new = TRUE`. Default to "_scaled". #' #' @return Scales raw data by the population #' @export #' @examples -#' library(epiprocess) -#' library(epipredict) -#' jhu <- epiprocess::jhu_csse_daily_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% -#' dplyr::select(geo_value, time_value, cases) +#' library(dplyr) +#' jhu <- jhu_csse_daily_subset %>% +#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% +#' select(geo_value, time_value, cases) #' #' pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000)) #' @@ -92,7 +72,7 @@ #' df_pop_col = "value" #' ) #' -#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% +#' wf <- epi_workflow(r, linear_reg()) %>% #' fit(jhu) %>% #' add_frosting(f) #' @@ -101,37 +81,35 @@ step_population_scaling <- function(recipe, ..., role = "raw", - trained = FALSE, df, by = NULL, df_pop_col, rate_rescaling = 1, create_new = TRUE, suffix = "_scaled", - columns = NULL, skip = FALSE, id = rand_id("population_scaling")) { - arg_is_scalar(role, trained, df_pop_col, rate_rescaling, create_new, suffix, id) + arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, id) arg_is_lgl(create_new, skip) arg_is_chr(df_pop_col, suffix, id) - arg_is_chr(by, columns, allow_null = TRUE) + arg_is_chr(by, allow_null = TRUE) if (rate_rescaling <= 0) { - cli_stop("`rate_rescaling` should be a positive number") + cli_abort("`rate_rescaling` must be a positive number.") } - add_step( + recipes::add_step( recipe, step_population_scaling_new( - terms = dplyr::enquos(...), + terms = enquos(...), role = role, - trained = trained, + trained = FALSE, df = df, by = by, df_pop_col = df_pop_col, rate_rescaling = rate_rescaling, create_new = create_new, suffix = suffix, - columns = columns, + columns = NULL, skip = skip, id = id ) @@ -141,7 +119,7 @@ step_population_scaling <- step_population_scaling_new <- function(role, trained, df, by, df_pop_col, rate_rescaling, terms, create_new, suffix, columns, skip, id) { - step( + recipes::step( subclass = "population_scaling", terms = terms, role = role, @@ -170,7 +148,7 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) { rate_rescaling = x$rate_rescaling, create_new = x$create_new, suffix = x$suffix, - columns = recipes_eval_select(x$terms, training, info), + columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, id = x$id ) @@ -185,15 +163,10 @@ bake.step_population_scaling <- function(object, length(object$df_pop_col) == 1 ) - try_join <- try(dplyr::left_join(new_data, object$df, by = object$by), - silent = TRUE - ) - if (any(grepl("Join columns must be present in data", unlist(try_join)))) { - cli_stop(c( - "columns in `by` selectors of `step_population_scaling` ", - "must be present in data and match" - )) - } + + hardhat::validate_column_names(new_data, object$by) + hardhat::validate_column_names(object$df, object$by) + if (object$suffix != "_scaled" && object$create_new == FALSE) { cli::cli_warn(c( @@ -202,23 +175,22 @@ bake.step_population_scaling <- function(object, )) } - object$df <- object$df %>% - dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower)) + object$df <- mutate(object$df, across(dplyr::where(is.character), tolower)) pop_col <- rlang::sym(object$df_pop_col) suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(new_data)) - dplyr::left_join(new_data, - object$df, - by = object$by, suffix = c("", ".df") - ) %>% - dplyr::mutate(dplyr::across(dplyr::all_of(object$columns), - ~ .x * object$rate_rescaling / !!pop_col, - .names = "{.col}{suffix}" - )) %>% + left_join(new_data, object$df, by = object$by, suffix = c("", ".df")) %>% + mutate( + across( + all_of(object$columns), + ~ .x * object$rate_rescaling / !!pop_col, + .names = "{.col}{suffix}" + ) + ) %>% # removed so the models do not use the population column - dplyr::select(-dplyr::any_of(col_to_remove)) + select(any_of(col_to_remove)) } #' @export diff --git a/R/step_training_window.R b/R/step_training_window.R index 90de468ce..e66b08365 100644 --- a/R/step_training_window.R +++ b/R/step_training_window.R @@ -5,18 +5,13 @@ #' observations in `time_value` per group, where the groups are formed #' based on the remaining `epi_keys`. #' -#' @param recipe A recipe object. The step will be added to the -#' sequence of operations for this recipe. -#' @param role Not used by this step since no new variables are created. -#' @param trained A logical to indicate if the quantities for -#' preprocessing have been estimated. #' @param n_recent An integer value that represents the number of most recent #' observations that are to be kept in the training window per key #' The default value is 50. #' @param epi_keys An optional character vector for specifying "key" variables #' to group on. The default, `NULL`, ensures that every key combination is #' limited. -#' @param id A character string that is unique to this step to identify it. +#' @inheritParams step_epi_lag #' @template step-return #' #' @details Note that `step_epi_lead()` and `step_epi_lag()` should come @@ -25,13 +20,10 @@ #' @export #' #' @examples -#' tib <- tibble::tibble( +#' tib <- tibble( #' x = 1:10, #' y = 1:10, -#' time_value = rep(seq(as.Date("2020-01-01"), -#' by = 1, -#' length.out = 5 -#' ), times = 2), +#' time_value = rep(seq(as.Date("2020-01-01"), by = 1, length.out = 5), 2), #' geo_value = rep(c("ca", "hi"), each = 5) #' ) %>% #' as_epi_df() @@ -42,18 +34,16 @@ #' bake(new_data = NULL) #' #' epi_recipe(y ~ x, data = tib) %>% -#' recipes::step_naomit() %>% +#' step_epi_naomit() %>% #' step_training_window(n_recent = 3) %>% #' prep(tib) %>% #' bake(new_data = NULL) step_training_window <- function(recipe, role = NA, - trained = FALSE, n_recent = 50, epi_keys = NULL, id = rand_id("training_window")) { - arg_is_lgl_scalar(trained) arg_is_scalar(n_recent, id) arg_is_pos(n_recent) if (is.finite(n_recent)) arg_is_pos_int(n_recent) @@ -63,7 +53,7 @@ step_training_window <- recipe, step_training_window_new( role = role, - trained = trained, + trained = FALSE, n_recent = n_recent, epi_keys = epi_keys, skip = TRUE, @@ -108,10 +98,10 @@ bake.step_training_window <- function(object, new_data, ...) { if (object$n_recent < Inf) { new_data <- new_data %>% - dplyr::group_by(dplyr::across(dplyr::all_of(object$epi_keys))) %>% - dplyr::arrange(time_value) %>% + group_by(across(all_of(object$epi_keys))) %>% + arrange(time_value) %>% dplyr::slice_tail(n = object$n_recent) %>% - dplyr::ungroup() + ungroup() } new_data @@ -122,10 +112,7 @@ print.step_training_window <- function(x, width = max(20, options()$width - 30), ...) { title <- "# of recent observations per key limited to:" n_recent <- x$n_recent - tr_obj <- format_selectors(rlang::enquos(n_recent), width) - recipes::print_step( - tr_obj, rlang::enquos(n_recent), - x$trained, title, width - ) + tr_obj <- recipes::format_selectors(rlang::enquos(n_recent), width) + recipes::print_step(tr_obj, rlang::enquos(n_recent), x$trained, title, width) invisible(x) } diff --git a/R/tidy.R b/R/tidy.R index 06835eff0..a239a8121 100644 --- a/R/tidy.R +++ b/R/tidy.R @@ -26,8 +26,9 @@ #' `type` (the method, e.g. "predict", "naomit"), and a character column `id`. #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% @@ -52,21 +53,17 @@ tidy.frosting <- function(x, number = NA, id = NA, ...) { num_oper <- length(x$layers) pattern <- "^layer_" - if (length(id) != 1L) { - rlang::abort("If `id` is provided, it must be a length 1 character vector.") - } - - if (length(number) != 1L) { - rlang::abort("If `number` is provided, it must be a length 1 integer vector.") - } + arg_is_chr_scalar(id, allow_na = TRUE) + arg_is_scalar(number, allow_na = TRUE) + arg_is_int(number, allow_na = TRUE) if (!is.na(id)) { if (!is.na(number)) { - rlang::abort("You may specify `number` or `id`, but not both.") + cli_abort("You may specify `number` or `id`, but not both.") } layer_ids <- vapply(x$layers, function(x) x$id, character(1)) if (!(id %in% layer_ids)) { - rlang::abort("Supplied `id` not found in the frosting.") + cli_abort("Supplied `id` not found in the frosting.") } number <- which(id == layer_ids) } @@ -89,13 +86,7 @@ tidy.frosting <- function(x, number = NA, id = NA, ...) { ) } else { if (number > num_oper || length(number) > 1) { - rlang::abort( - paste0( - "`number` should be a single value between 1 and ", - num_oper, - "." - ) - ) + cli_abort("`number` should be a single value between 1 and {num_oper}.") } res <- tidy(x$layers[[number]], ...) diff --git a/R/time_types.R b/R/time_types.R index 7fe3e47b4..f33974833 100644 --- a/R/time_types.R +++ b/R/time_types.R @@ -64,7 +64,7 @@ validate_date <- function(x, expected, arg = rlang::caller_arg(x), time_type_x <- guess_time_type(x) ok <- time_type_x == expected if (!ok) { - cli::cli_abort(c( + cli_abort(c( "The {.arg {arg}} was given as a {.val {time_type_x}} while the", `!` = "`time_type` of the training data was {.val {expected}}.", i = "See {.topic epiprocess::epi_df} for descriptions of these are determined." diff --git a/R/utils-misc.R b/R/utils-misc.R index 231a8f60f..c59afc19a 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -48,13 +48,13 @@ grab_forged_keys <- function(forged, mold, new_data) { "in `new_data`. Predictions will have only the available keys." )) } - if (epiprocess::is_epi_df(new_data)) { - extras <- epiprocess::as_epi_df(extras) + if (is_epi_df(new_data)) { + extras <- as_epi_df(extras) attr(extras, "metadata") <- attr(new_data, "metadata") } else if (all(keys[1:2] %in% new_keys)) { l <- list() if (length(new_keys) > 2) l <- list(other_keys = new_keys[-c(1:2)]) - extras <- epiprocess::as_epi_df(extras, additional_metadata = l) + extras <- as_epi_df(extras, additional_metadata = l) } extras } @@ -64,11 +64,10 @@ get_parsnip_mode <- function(trainer) { return(trainer$mode) } cc <- class(trainer) - cli::cli_abort( - c("`trainer` must be a `parsnip` model.", - i = "This trainer has class(s) {.cls {cc}}." - ) - ) + cli_abort(c( + "`trainer` must be a `parsnip` model.", + i = "This trainer has class{?s}: {.cls {cc}}." + )) } is_classification <- function(trainer) { diff --git a/man/Add_model.Rd b/man/Add_model.Rd index 6bf6b6b02..17b65793c 100644 --- a/man/Add_model.Rd +++ b/man/Add_model.Rd @@ -71,11 +71,9 @@ aliases with the lower-case names. However, in the event that properly. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter( - time_value > "2021-11-01", - geo_value \%in\% c("ak", "ca", "ny") - ) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% @@ -88,7 +86,7 @@ wf <- epi_workflow(r) wf <- wf \%>\% Add_model(rf_model) wf -lm_model <- parsnip::linear_reg() +lm_model <- linear_reg() wf <- Update_model(wf, lm_model) wf diff --git a/man/layer_add_forecast_date.Rd b/man/layer_add_forecast_date.Rd index 4e173d662..e27f2bacd 100644 --- a/man/layer_add_forecast_date.Rd +++ b/man/layer_add_forecast_date.Rd @@ -36,15 +36,16 @@ less than the maximum \code{as_of} value (from the data used pre-processing, model fitting, and postprocessing), an appropriate warning will be thrown. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) latest <- jhu \%>\% - dplyr::filter(time_value >= max(time_value) - 14) + filter(time_value >= max(time_value) - 14) # Don't specify `forecast_date` (by default, this should be last date in latest) f <- frosting() \%>\% diff --git a/man/layer_add_target_date.Rd b/man/layer_add_target_date.Rd index 5b32002d1..9dc6abbdd 100644 --- a/man/layer_add_target_date.Rd +++ b/man/layer_add_target_date.Rd @@ -37,14 +37,15 @@ has been specified in a preprocessing step (most likely in in the test data to get the target date. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + dfilter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) # Use ahead + forecast date f <- frosting() \%>\% diff --git a/man/layer_cdc_flatline_quantiles.Rd b/man/layer_cdc_flatline_quantiles.Rd index 5653f9691..c3bc4f257 100644 --- a/man/layer_cdc_flatline_quantiles.Rd +++ b/man/layer_cdc_flatline_quantiles.Rd @@ -84,6 +84,7 @@ the future. This version continues to use the same set of residuals, and adds them on to produce wider intervals as \code{ahead} increases. } \examples{ +library(dplyr) r <- epi_recipe(case_death_rate_subset) \%>\% # data is "daily", so we fit this to 1 ahead, the result will contain # 1 day ahead residuals @@ -97,16 +98,16 @@ f <- frosting() \%>\% layer_predict() \%>\% layer_cdc_flatline_quantiles(aheads = c(7, 14, 21, 28), symmetrize = TRUE) -eng <- parsnip::linear_reg() \%>\% parsnip::set_engine("flatline") +eng <- linear_reg(engine = "flatline") wf <- epi_workflow(r, eng, f) \%>\% fit(case_death_rate_subset) preds <- forecast(wf) \%>\% - dplyr::select(-time_value) \%>\% - dplyr::mutate(forecast_date = forecast_date) + select(-time_value) \%>\% + mutate(forecast_date = forecast_date) preds preds <- preds \%>\% - unnest(.pred_distn_all) \%>\% + tidyr::unnest(.pred_distn_all) \%>\% pivot_quantiles_wider(.pred_distn) \%>\% mutate(target_date = forecast_date + ahead) diff --git a/man/layer_naomit.Rd b/man/layer_naomit.Rd index e3325fe7c..d77112f95 100644 --- a/man/layer_naomit.Rd +++ b/man/layer_naomit.Rd @@ -24,14 +24,15 @@ an updated \code{frosting} postprocessor Omit \code{NA}s from predictions or other columns } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) f <- frosting() \%>\% layer_predict() \%>\% diff --git a/man/layer_point_from_distn.Rd b/man/layer_point_from_distn.Rd index 58d8add8b..276f7cb17 100644 --- a/man/layer_point_from_distn.Rd +++ b/man/layer_point_from_distn.Rd @@ -34,15 +34,17 @@ information, so one should usually call this AFTER \code{layer_quantile_distn()} or set the \code{name} argument to something specific. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) \%>\% fit(jhu) +wf <- epi_workflow(r, quantile_reg(quantile_levels = c(.25, .5, .75))) \%>\% + fit(jhu) f1 <- frosting() \%>\% layer_predict() \%>\% diff --git a/man/layer_population_scaling.Rd b/man/layer_population_scaling.Rd index cf8dfcc1a..5a105f208 100644 --- a/man/layer_population_scaling.Rd +++ b/man/layer_population_scaling.Rd @@ -74,9 +74,10 @@ passed will \emph{multiply} the selected variables while the \code{rate_rescalin argument is a common \emph{divisor} of the selected variables. } \examples{ -jhu <- epiprocess::jhu_csse_daily_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ca", "ny")) \%>\% - dplyr::select(geo_value, time_value, cases) +library(dplyr) +jhu <- jhu_csse_daily_subset \%>\% + filter(time_value > "2021-11-01", geo_value \%in\% c("ca", "ny")) \%>\% + select(geo_value, time_value, cases) pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000)) @@ -101,7 +102,7 @@ f <- frosting() \%>\% df_pop_col = "value" ) -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) \%>\% add_frosting(f) diff --git a/man/layer_predict.Rd b/man/layer_predict.Rd index 03473053f..8ae92f4c8 100644 --- a/man/layer_predict.Rd +++ b/man/layer_predict.Rd @@ -58,6 +58,7 @@ postprocessing. This would typically be the first layer in a \code{frosting} postprocessor. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) @@ -66,7 +67,7 @@ r <- epi_recipe(jhu) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) latest <- jhu \%>\% filter(time_value >= max(time_value) - 14) # Predict layer alone diff --git a/man/layer_predictive_distn.Rd b/man/layer_predictive_distn.Rd index 7cd4e4efc..240db5f5b 100644 --- a/man/layer_predictive_distn.Rd +++ b/man/layer_predictive_distn.Rd @@ -39,15 +39,16 @@ should be reasonably accurate for models fit using \code{lm} when the new point \verb{x*} isn't too far from the bulk of the data. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) f <- frosting() \%>\% layer_predict() \%>\% diff --git a/man/layer_quantile_distn.Rd b/man/layer_quantile_distn.Rd index f5de4aa19..68192deee 100644 --- a/man/layer_quantile_distn.Rd +++ b/man/layer_quantile_distn.Rd @@ -45,8 +45,9 @@ If these engines were used, then this layer will grab out estimated (or extrapolated) quantiles at the requested quantile values. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% diff --git a/man/layer_residual_quantiles.Rd b/man/layer_residual_quantiles.Rd index dd576aa5e..d300241d3 100644 --- a/man/layer_residual_quantiles.Rd +++ b/man/layer_residual_quantiles.Rd @@ -39,19 +39,23 @@ residual quantiles added to the prediction Creates predictions based on residual quantiles } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) f <- frosting() \%>\% layer_predict() \%>\% - layer_residual_quantiles(quantile_levels = c(0.0275, 0.975), symmetrize = FALSE) \%>\% + layer_residual_quantiles( + quantile_levels = c(0.0275, 0.975), + symmetrize = FALSE + ) \%>\% layer_naomit(.pred) wf1 <- wf \%>\% add_frosting(f) @@ -59,7 +63,10 @@ p <- forecast(wf1) f2 <- frosting() \%>\% layer_predict() \%>\% - layer_residual_quantiles(quantile_levels = c(0.3, 0.7), by_key = "geo_value") \%>\% + layer_residual_quantiles( + quantile_levels = c(0.3, 0.7), + by_key = "geo_value" + ) \%>\% layer_naomit(.pred) wf2 <- wf \%>\% add_frosting(f2) diff --git a/man/layer_threshold.Rd b/man/layer_threshold.Rd index dbd7e6669..0f4b1dfb7 100644 --- a/man/layer_threshold.Rd +++ b/man/layer_threshold.Rd @@ -40,14 +40,14 @@ smaller than the lower threshold or higher than the upper threshold equal to the threshold values. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value < "2021-03-08", - geo_value \%in\% c("ak", "ca", "ar")) + filter(time_value < "2021-03-08", geo_value \%in\% c("ak", "ca", "ar")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) f <- frosting() \%>\% layer_predict() \%>\% diff --git a/man/nested_quantiles.Rd b/man/nested_quantiles.Rd index 143532650..b34b718ca 100644 --- a/man/nested_quantiles.Rd +++ b/man/nested_quantiles.Rd @@ -16,9 +16,11 @@ a list-col Turn a vector of quantile distributions into a list-col } \examples{ +library(dplyr) +library(tidyr) edf <- case_death_rate_subset[1:3, ] edf$q <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11)) -edf_nested <- edf \%>\% dplyr::mutate(q = nested_quantiles(q)) -edf_nested \%>\% tidyr::unnest(q) +edf_nested <- edf \%>\% mutate(q = nested_quantiles(q)) +edf_nested \%>\% unnest(q) } diff --git a/man/pivot_quantiles_longer.Rd b/man/pivot_quantiles_longer.Rd index f73e6deaf..9879d5d07 100644 --- a/man/pivot_quantiles_longer.Rd +++ b/man/pivot_quantiles_longer.Rd @@ -34,9 +34,9 @@ multiple columns are selected, these will be prefixed with the column name. \examples{ d1 <- c(dist_quantiles(1:3, 1:3 / 4), dist_quantiles(2:4, 1:3 / 4)) d2 <- c(dist_quantiles(2:4, 2:4 / 5), dist_quantiles(3:5, 2:4 / 5)) -tib <- tibble::tibble(g = c("a", "b"), d1 = d1, d2 = d2) +tib <- tibble(g = c("a", "b"), d1 = d1, d2 = d2) pivot_quantiles_longer(tib, "d1") -pivot_quantiles_longer(tib, tidyselect::ends_with("1")) +pivot_quantiles_longer(tib, dplyr::ends_with("1")) pivot_quantiles_longer(tib, d1, d2) } diff --git a/man/pivot_quantiles_wider.Rd b/man/pivot_quantiles_wider.Rd index 02a33bb2f..e477777ca 100644 --- a/man/pivot_quantiles_wider.Rd +++ b/man/pivot_quantiles_wider.Rd @@ -30,6 +30,6 @@ d2 <- c(dist_quantiles(2:4, 2:4 / 5), dist_quantiles(3:5, 2:4 / 5)) tib <- tibble::tibble(g = c("a", "b"), d1 = d1, d2 = d2) pivot_quantiles_wider(tib, c("d1", "d2")) -pivot_quantiles_wider(tib, tidyselect::starts_with("d")) +pivot_quantiles_wider(tib, dplyr::starts_with("d")) pivot_quantiles_wider(tib, d2) } diff --git a/man/reexports.Rd b/man/reexports.Rd index 1ac328b2c..f6849a53c 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -8,6 +8,9 @@ \alias{forecast} \alias{prep} \alias{bake} +\alias{rand_id} +\alias{tibble} +\alias{tidy} \title{Objects exported from other packages} \keyword{internal} \description{ @@ -15,10 +18,12 @@ These objects are imported from other packages. Follow the links below to see their documentation. \describe{ - \item{generics}{\code{\link[generics]{fit}}, \code{\link[generics]{forecast}}} + \item{generics}{\code{\link[generics]{fit}}, \code{\link[generics]{forecast}}, \code{\link[generics]{tidy}}} \item{ggplot2}{\code{\link[ggplot2]{autoplot}}} - \item{recipes}{\code{\link[recipes]{bake}}, \code{\link[recipes]{prep}}} + \item{recipes}{\code{\link[recipes]{bake}}, \code{\link[recipes]{prep}}, \code{\link[recipes]{rand_id}}} + + \item{tibble}{\code{\link[tibble]{tibble}}} }} diff --git a/man/step_epi_shift.Rd b/man/step_epi_shift.Rd index f4419b831..2bf22c15d 100644 --- a/man/step_epi_shift.Rd +++ b/man/step_epi_shift.Rd @@ -10,10 +10,8 @@ step_epi_lag( ..., lag, role = "predictor", - trained = FALSE, prefix = "lag_", default = NA, - columns = NULL, skip = FALSE, id = rand_id("epi_lag") ) @@ -23,10 +21,8 @@ step_epi_ahead( ..., ahead, role = "outcome", - trained = FALSE, prefix = "ahead_", default = NA, - columns = NULL, skip = FALSE, id = rand_id("epi_ahead") ) @@ -45,17 +41,11 @@ nonnegative, while ahead integers must be positive.} \item{role}{For model terms created by this step, what analysis role should they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} -\item{trained}{A logical to indicate if the quantities for -preprocessing have been estimated.} - -\item{prefix}{A prefix to indicate what type of variable this is} +\item{prefix}{A character string that will be prefixed to the new column.} \item{default}{Determines what fills empty rows left by leading/lagging (defaults to NA).} -\item{columns}{A character string of variable names that will -be populated (eventually) by the \code{terms} argument.} - \item{skip}{A logical. Should the step be skipped when the recipe is baked by \code{\link[=bake]{bake()}}? While all operations are baked when \code{\link[=prep]{prep()}} is run, some operations may not be able to be diff --git a/man/step_growth_rate.Rd b/man/step_growth_rate.Rd index 46d8b92f6..bc6da0bef 100644 --- a/man/step_growth_rate.Rd +++ b/man/step_growth_rate.Rd @@ -8,13 +8,11 @@ step_growth_rate( recipe, ..., role = "predictor", - trained = FALSE, horizon = 7, - method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), + method = c("rel_change", "linear_reg"), log_scale = FALSE, replace_Inf = NA, prefix = "gr_", - columns = NULL, skip = FALSE, id = rand_id("growth_rate"), additional_gr_args_list = list() @@ -30,20 +28,15 @@ for this step. See \code{\link[recipes:selections]{recipes::selections()}} for m \item{role}{For model terms created by this step, what analysis role should they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} -\item{trained}{A logical to indicate if the quantities for -preprocessing have been estimated.} - \item{horizon}{Bandwidth for the sliding window, when \code{method} is "rel_change" or "linear_reg". See \code{\link[epiprocess:growth_rate]{epiprocess::growth_rate()}} for more details.} -\item{method}{Either "rel_change", "linear_reg", "smooth_spline", or -"trend_filter", indicating the method to use for the growth rate -calculation. The first two are local methods: they are run in a sliding +\item{method}{Either "rel_change" or "linear_reg", +indicating the method to use for the growth rate +calculation. These are local methods: they are run in a sliding fashion over the sequence (in order to estimate derivatives and hence -growth rates); the latter two are global methods: they are run once over -the entire sequence. See \code{\link[epiprocess:growth_rate]{epiprocess::growth_rate()}} for more -details.} +growth rates). See \code{\link[epiprocess:growth_rate]{epiprocess::growth_rate()}} for more details.} \item{log_scale}{Should growth rates be estimated using the parameterization on the log scale? See details for an explanation. Default is \code{FALSE}.} @@ -56,10 +49,7 @@ being removed from the data. Alternatively, you could specify arbitrary large values, or perhaps zero. Setting this argument to \code{NULL} will result in no replacement.} -\item{prefix}{A prefix to indicate what type of variable this is} - -\item{columns}{A character string of variable names that will -be populated (eventually) by the \code{terms} argument.} +\item{prefix}{A character string that will be prefixed to the new column.} \item{skip}{A logical. Should the step be skipped when the recipe is baked by \code{\link[=bake]{bake()}}? While all operations are baked diff --git a/man/step_lag_difference.Rd b/man/step_lag_difference.Rd index 123265ea6..d6bafc4c7 100644 --- a/man/step_lag_difference.Rd +++ b/man/step_lag_difference.Rd @@ -11,7 +11,6 @@ step_lag_difference( trained = FALSE, horizon = 7, prefix = "lag_diff_", - columns = NULL, skip = FALSE, id = rand_id("lag_diff") ) @@ -26,16 +25,10 @@ for this step. See \code{\link[recipes:selections]{recipes::selections()}} for m \item{role}{For model terms created by this step, what analysis role should they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} -\item{trained}{A logical to indicate if the quantities for -preprocessing have been estimated.} - \item{horizon}{Scalar or vector. Time period(s) over which to calculate differences.} -\item{prefix}{A prefix to indicate what type of variable this is} - -\item{columns}{A character string of variable names that will -be populated (eventually) by the \code{terms} argument.} +\item{prefix}{A character string that will be prefixed to the new column.} \item{skip}{A logical. Should the step be skipped when the recipe is baked by \code{\link[=bake]{bake()}}? While all operations are baked diff --git a/man/step_population_scaling.Rd b/man/step_population_scaling.Rd index 2af3c245b..294f27f61 100644 --- a/man/step_population_scaling.Rd +++ b/man/step_population_scaling.Rd @@ -8,33 +8,25 @@ step_population_scaling( recipe, ..., role = "raw", - trained = FALSE, df, by = NULL, df_pop_col, rate_rescaling = 1, create_new = TRUE, suffix = "_scaled", - columns = NULL, skip = FALSE, id = rand_id("population_scaling") ) } \arguments{ -\item{recipe}{A recipe object. The step will be added to the sequence of -operations for this recipe. The recipe should contain information about the -\code{epi_df} such as column names.} +\item{recipe}{A recipe object. The step will be added to the +sequence of operations for this recipe.} -\item{...}{One or more selector functions to scale variables +\item{...}{One or more selector functions to choose variables for this step. See \code{\link[recipes:selections]{recipes::selections()}} for more details.} \item{role}{For model terms created by this step, what analysis role should -they be assigned? By default, the new columns created by this step from the -original variables will be used as predictors in a model. Other options can -be ard are not limited to "outcome".} - -\item{trained}{A logical to indicate if the quantities for preprocessing -have been estimated.} +they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} \item{df}{a data frame that contains the population data to be used for inverting the existing scaling.} @@ -68,10 +60,7 @@ scale is "per 100K", then set \code{rate_rescaling = 1e5} to get rates.} in the \code{epi_df}} \item{suffix}{a character. The suffix added to the column name if -\code{crete_new = TRUE}. Default to "_scaled".} - -\item{columns}{A character string of variable names that will -be populated (eventually) by the \code{terms} argument.} +\code{create_new = TRUE}. Default to "_scaled".} \item{skip}{A logical. Should the step be skipped when the recipe is baked by \code{\link[=bake]{bake()}}? While all operations are baked @@ -98,11 +87,10 @@ passed will \emph{divide} the selected variables while the \code{rate_rescaling} argument is a common \emph{multiplier} of the selected variables. } \examples{ -library(epiprocess) -library(epipredict) -jhu <- epiprocess::jhu_csse_daily_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ca", "ny")) \%>\% - dplyr::select(geo_value, time_value, cases) +library(dplyr) +jhu <- jhu_csse_daily_subset \%>\% + filter(time_value > "2021-11-01", geo_value \%in\% c("ca", "ny")) \%>\% + select(geo_value, time_value, cases) pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000)) @@ -127,7 +115,7 @@ f <- frosting() \%>\% df_pop_col = "value" ) -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) \%>\% add_frosting(f) diff --git a/man/step_training_window.Rd b/man/step_training_window.Rd index ce7c0fc74..42f6b9a95 100644 --- a/man/step_training_window.Rd +++ b/man/step_training_window.Rd @@ -7,7 +7,6 @@ step_training_window( recipe, role = NA, - trained = FALSE, n_recent = 50, epi_keys = NULL, id = rand_id("training_window") @@ -17,10 +16,8 @@ step_training_window( \item{recipe}{A recipe object. The step will be added to the sequence of operations for this recipe.} -\item{role}{Not used by this step since no new variables are created.} - -\item{trained}{A logical to indicate if the quantities for -preprocessing have been estimated.} +\item{role}{For model terms created by this step, what analysis role should +they be assigned? \code{lag} is default a predictor while \code{ahead} is an outcome.} \item{n_recent}{An integer value that represents the number of most recent observations that are to be kept in the training window per key @@ -30,7 +27,7 @@ The default value is 50.} to group on. The default, \code{NULL}, ensures that every key combination is limited.} -\item{id}{A character string that is unique to this step to identify it.} +\item{id}{A unique identifier for the step} } \value{ An updated version of \code{recipe} with the new step added to the @@ -47,13 +44,10 @@ Note that \code{step_epi_lead()} and \code{step_epi_lag()} should come after any filtering step. } \examples{ -tib <- tibble::tibble( +tib <- tibble( x = 1:10, y = 1:10, - time_value = rep(seq(as.Date("2020-01-01"), - by = 1, - length.out = 5 - ), times = 2), + time_value = rep(seq(as.Date("2020-01-01"), by = 1, length.out = 5), 2), geo_value = rep(c("ca", "hi"), each = 5) ) \%>\% as_epi_df() @@ -64,7 +58,7 @@ epi_recipe(y ~ x, data = tib) \%>\% bake(new_data = NULL) epi_recipe(y ~ x, data = tib) \%>\% - recipes::step_naomit() \%>\% + step_epi_naomit() \%>\% step_training_window(n_recent = 3) \%>\% prep(tib) \%>\% bake(new_data = NULL) diff --git a/man/tidy.frosting.Rd b/man/tidy.frosting.Rd index 6b28461b4..ba3c0f3d5 100644 --- a/man/tidy.frosting.Rd +++ b/man/tidy.frosting.Rd @@ -37,8 +37,9 @@ method for the operation exists). Note that this is a modified version of the \code{tidy} method for a recipe. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% diff --git a/man/update.layer.Rd b/man/update.layer.Rd index 0f1fe9c22..9604992e1 100644 --- a/man/update.layer.Rd +++ b/man/update.layer.Rd @@ -18,15 +18,15 @@ will replace the elements of the same name in the actual post-processing layer. Analogous to \code{update.step()} from the \code{recipes} package. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_naomit() -wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) -latest <- jhu \%>\% - dplyr::filter(time_value >= max(time_value) - 14) +wf <- epi_workflow(r, linear_reg()) \%>\% fit(jhu) +latest <- jhu \%>\% filter(time_value >= max(time_value) - 14) # Specify a `forecast_date` that is greater than or equal to `as_of` date f <- frosting() \%>\% From d6dff90684a386a807088d72a9e986532e548839 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 17:20:31 -0700 Subject: [PATCH 07/16] all checks pass --- R/arx_classifier.R | 27 ++++++++++++----------- R/layer_add_target_date.R | 2 +- R/layer_population_scaling.R | 5 +++-- R/step_growth_rate.R | 2 +- R/step_lag_difference.R | 1 - R/step_population_scaling.R | 21 ++++++++---------- R/tidy.R | 2 +- man/arx_class_args_list.Rd | 2 +- man/arx_classifier.Rd | 3 ++- man/layer_add_target_date.Rd | 2 +- man/step_lag_difference.Rd | 1 - tests/testthat/test-population_scaling.R | 9 ++++---- tests/testthat/test-step_growth_rate.R | 2 -- tests/testthat/test-step_lag_difference.R | 2 -- 14 files changed, 37 insertions(+), 44 deletions(-) diff --git a/R/arx_classifier.R b/R/arx_classifier.R index d5f5bf05d..ca6a3537b 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -26,8 +26,9 @@ #' @seealso [arx_class_epi_workflow()], [arx_class_args_list()] #' #' @examples +#' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dplyr::filter(time_value >= as.Date("2021-11-01")) +#' filter(time_value >= as.Date("2021-11-01")) #' #' out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate")) #' @@ -130,7 +131,7 @@ arx_class_epi_workflow <- function( # ------- predictors r <- epi_recipe(epi_data) %>% step_growth_rate( - all_of(predictors), + dplyr::all_of(predictors), role = "grp", horizon = args_list$horizon, method = args_list$method, @@ -183,14 +184,15 @@ arx_class_epi_workflow <- function( if (!is.null(args_list$check_enough_data_n)) { r <- check_enough_train_data( r, - all_predictors(), - !!outcome, + recipes::all_predictors(), + recipes::all_outcomes(), n = args_list$check_enough_data_n, epi_keys = args_list$check_enough_data_epi_keys, drop_na = FALSE ) } + forecast_date <- args_list$forecast_date %||% max(epi_data$time_value) target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) @@ -261,7 +263,7 @@ arx_class_args_list <- function( outcome_transform = c("growth_rate", "lag_difference"), breaks = 0.25, horizon = 7L, - method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), + method = c("rel_change", "linear_reg"), log_scale = FALSE, additional_gr_args = list(), nafill_buffer = Inf, @@ -271,8 +273,8 @@ arx_class_args_list <- function( rlang::check_dots_empty() .lags <- lags if (is.list(lags)) lags <- unlist(lags) - method <- match.arg(method) - outcome_transform <- match.arg(outcome_transform) + method <- rlang::arg_match(method) + outcome_transform <- rlang::arg_match(outcome_transform) arg_is_scalar(ahead, n_training, horizon, log_scale) arg_is_scalar(forecast_date, target_date, allow_null = TRUE) @@ -284,12 +286,11 @@ arx_class_args_list <- function( if (is.finite(n_training)) arg_is_pos_int(n_training) if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE) if (!is.list(additional_gr_args)) { - cli::cli_abort( - c("`additional_gr_args` must be a {.cls list}.", - "!" = "This is a {.cls {class(additional_gr_args)}}.", - i = "See `?epiprocess::growth_rate` for available arguments." - ) - ) + cli_abort(c( + "`additional_gr_args` must be a {.cls list}.", + "!" = "This is a {.cls {class(additional_gr_args)}}.", + i = "See `?epiprocess::growth_rate` for available arguments." + )) } arg_is_pos(check_enough_data_n, allow_null = TRUE) arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE) diff --git a/R/layer_add_target_date.R b/R/layer_add_target_date.R index 834deb82b..9176fb593 100644 --- a/R/layer_add_target_date.R +++ b/R/layer_add_target_date.R @@ -22,7 +22,7 @@ #' @examples #' library(dplyr) #' jhu <- case_death_rate_subset %>% -#' dfilter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) +#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) #' r <- epi_recipe(jhu) %>% #' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% #' step_epi_ahead(death_rate, ahead = 7) %>% diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 1b940b804..a2c56233f 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -139,8 +139,9 @@ slather.layer_population_scaling <- kill_time_value(key_colnames(components$predictions)), colnames(select(object$df, !object$df_pop_col)) ) - hardhat::validate_column_names(components$predictions, object$by) - hardhat::validate_column_names(object$df, object$by) + joinby <- list(x = names(object$by) %||% object$by, y = object$by) + hardhat::validate_column_names(components$predictions, joinby$x) + hardhat::validate_column_names(object$df, joinby$y) # object$df <- object$df %>% # dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower)) diff --git a/R/step_growth_rate.R b/R/step_growth_rate.R index e1950d208..0a8c4118e 100644 --- a/R/step_growth_rate.R +++ b/R/step_growth_rate.R @@ -190,7 +190,7 @@ bake.step_growth_rate <- function(object, new_data, ...) { if (!is.null(object$replace_Inf)) { gr <- gr %>% - mutate(across(all_of(ok), ~ vec_replace_inf(.x, object$replace_Inf))) + mutate(across(!all_of(ok), ~ vec_replace_inf(.x, object$replace_Inf))) } left_join(new_data, gr, by = ok) %>% diff --git a/R/step_lag_difference.R b/R/step_lag_difference.R index 009ebe4f5..f3a470c7f 100644 --- a/R/step_lag_difference.R +++ b/R/step_lag_difference.R @@ -27,7 +27,6 @@ step_lag_difference <- function(recipe, ..., role = "predictor", - trained = FALSE, horizon = 7, prefix = "lag_diff_", skip = FALSE, diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index e5baa837b..f5373647e 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -155,18 +155,15 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) { } #' @export -bake.step_population_scaling <- function(object, - new_data, - ...) { - stopifnot( - "Only one population column allowed for scaling" = - length(object$df_pop_col) == 1 - ) - - - hardhat::validate_column_names(new_data, object$by) - hardhat::validate_column_names(object$df, object$by) +bake.step_population_scaling <- function(object, new_data, ...) { + object$by <- object$by %||% intersect( + kill_time_value(key_colnames(new_data)), + colnames(select(object$df, !object$df_pop_col)) + ) + joinby <- list(x = names(object$by) %||% object$by, y = object$by) + hardhat::validate_column_names(new_data, joinby$x) + hardhat::validate_column_names(object$df, joinby$y) if (object$suffix != "_scaled" && object$create_new == FALSE) { cli::cli_warn(c( @@ -190,7 +187,7 @@ bake.step_population_scaling <- function(object, ) ) %>% # removed so the models do not use the population column - select(any_of(col_to_remove)) + select(!any_of(col_to_remove)) } #' @export diff --git a/R/tidy.R b/R/tidy.R index a239a8121..61b298411 100644 --- a/R/tidy.R +++ b/R/tidy.R @@ -55,7 +55,7 @@ tidy.frosting <- function(x, number = NA, id = NA, ...) { arg_is_chr_scalar(id, allow_na = TRUE) arg_is_scalar(number, allow_na = TRUE) - arg_is_int(number, allow_na = TRUE) + if (!is.na(number)) arg_is_int(number) if (!is.na(id)) { if (!is.na(number)) { diff --git a/man/arx_class_args_list.Rd b/man/arx_class_args_list.Rd index a1205c71a..311950d62 100644 --- a/man/arx_class_args_list.Rd +++ b/man/arx_class_args_list.Rd @@ -13,7 +13,7 @@ arx_class_args_list( outcome_transform = c("growth_rate", "lag_difference"), breaks = 0.25, horizon = 7L, - method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), + method = c("rel_change", "linear_reg"), log_scale = FALSE, additional_gr_args = list(), nafill_buffer = Inf, diff --git a/man/arx_classifier.Rd b/man/arx_classifier.Rd index 36297b00c..c7c2cf059 100644 --- a/man/arx_classifier.Rd +++ b/man/arx_classifier.Rd @@ -48,8 +48,9 @@ This is an autoregressive classification model for that it estimates a class at a particular target horizon. } \examples{ +library(dplyr) jhu <- case_death_rate_subset \%>\% - dplyr::filter(time_value >= as.Date("2021-11-01")) + filter(time_value >= as.Date("2021-11-01")) out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate")) diff --git a/man/layer_add_target_date.Rd b/man/layer_add_target_date.Rd index 9dc6abbdd..dc0d2f190 100644 --- a/man/layer_add_target_date.Rd +++ b/man/layer_add_target_date.Rd @@ -39,7 +39,7 @@ in the test data to get the target date. \examples{ library(dplyr) jhu <- case_death_rate_subset \%>\% - dfilter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) + filter(time_value > "2021-11-01", geo_value \%in\% c("ak", "ca", "ny")) r <- epi_recipe(jhu) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% diff --git a/man/step_lag_difference.Rd b/man/step_lag_difference.Rd index d6bafc4c7..7969ea3a7 100644 --- a/man/step_lag_difference.Rd +++ b/man/step_lag_difference.Rd @@ -8,7 +8,6 @@ step_lag_difference( recipe, ..., role = "predictor", - trained = FALSE, horizon = 7, prefix = "lag_diff_", skip = FALSE, diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index b66bb08c3..d04a6bb43 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -186,7 +186,6 @@ test_that("Postprocessing to get cases from case rate", { test_that("test joining by default columns", { - skip() jhu <- case_death_rate_subset %>% dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, case_rate) @@ -208,9 +207,9 @@ test_that("test joining by default columns", { recipes::step_naomit(recipes::all_predictors()) %>% recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) - suppressMessages(prep <- prep(r, jhu)) + expect_silent(prep(r, jhu)) - suppressMessages(b <- bake(prep, jhu)) + expect_silent(bake(prep(r, jhu), new_data = NULL)) f <- frosting() %>% layer_predict() %>% @@ -222,13 +221,13 @@ test_that("test joining by default columns", { df_pop_col = "values" ) - suppressMessages( + expect_silent( wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% add_frosting(f) ) - suppressMessages(p <- forecast(wf)) + expect_silent(forecast(wf)) }) diff --git a/tests/testthat/test-step_growth_rate.R b/tests/testthat/test-step_growth_rate.R index 052141710..29a2fc2f5 100644 --- a/tests/testthat/test-step_growth_rate.R +++ b/tests/testthat/test-step_growth_rate.R @@ -15,11 +15,9 @@ test_that("step_growth_rate validates arguments", { expect_error(step_growth_rate(r, value, prefix = letters[1:2])) expect_error(step_growth_rate(r, value, prefix = 1)) expect_error(step_growth_rate(r, value, id = 1)) - expect_error(step_growth_rate(r, value, trained = 1)) expect_error(step_growth_rate(r, value, log_scale = 1)) expect_error(step_growth_rate(r, value, skip = 1)) expect_error(step_growth_rate(r, value, additional_gr_args_list = 1:5)) - expect_error(step_growth_rate(r, value, columns = letters[1:5])) expect_error(step_growth_rate(r, value, replace_Inf = "c")) expect_error(step_growth_rate(r, value, replace_Inf = c(1, 2))) expect_silent(step_growth_rate(r, value, replace_Inf = NULL)) diff --git a/tests/testthat/test-step_lag_difference.R b/tests/testthat/test-step_lag_difference.R index c0fd377e6..cd92da1fb 100644 --- a/tests/testthat/test-step_lag_difference.R +++ b/tests/testthat/test-step_lag_difference.R @@ -14,9 +14,7 @@ test_that("step_lag_difference validates arguments", { expect_error(step_lag_difference(r, value, prefix = letters[1:2])) expect_error(step_lag_difference(r, value, prefix = 1)) expect_error(step_lag_difference(r, value, id = 1)) - expect_error(step_lag_difference(r, value, trained = 1)) expect_error(step_lag_difference(r, value, skip = 1)) - expect_error(step_lag_difference(r, value, columns = letters[1:5])) }) From 2dbc21288887b40db7947d88e784a7ce5b89c4fd Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 17:29:26 -0700 Subject: [PATCH 08/16] styler --- R/layer_residual_quantiles.R | 2 +- R/step_population_scaling.R | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index dc53f0f60..eae151905 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -30,7 +30,7 @@ #' layer_residual_quantiles( #' quantile_levels = c(0.0275, 0.975), #' symmetrize = FALSE -#' ) %>% +#' ) %>% #' layer_naomit(.pred) #' wf1 <- wf %>% add_frosting(f) #' diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index f5373647e..5555bc65c 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -156,7 +156,6 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) { #' @export bake.step_population_scaling <- function(object, new_data, ...) { - object$by <- object$by %||% intersect( kill_time_value(key_colnames(new_data)), colnames(select(object$df, !object$df_pop_col)) From f394fde1ea3d9466f6680a78cc32e3ccc4fa320c Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 17:34:53 -0700 Subject: [PATCH 09/16] fix vignette naming error --- man/layer_residual_quantiles.Rd | 2 +- vignettes/preprocessing-and-models.Rmd | 2 +- vignettes/update.Rmd | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/man/layer_residual_quantiles.Rd b/man/layer_residual_quantiles.Rd index d300241d3..39e1ecfbe 100644 --- a/man/layer_residual_quantiles.Rd +++ b/man/layer_residual_quantiles.Rd @@ -55,7 +55,7 @@ f <- frosting() \%>\% layer_residual_quantiles( quantile_levels = c(0.0275, 0.975), symmetrize = FALSE - ) \%>\% + ) \%>\% layer_naomit(.pred) wf1 <- wf \%>\% add_frosting(f) diff --git a/vignettes/preprocessing-and-models.Rmd b/vignettes/preprocessing-and-models.Rmd index d557ed1f7..63a27bd55 100644 --- a/vignettes/preprocessing-and-models.Rmd +++ b/vignettes/preprocessing-and-models.Rmd @@ -478,7 +478,7 @@ r <- epi_recipe(jhu) %>% We will fit the multinomial regression and examine the predictions: ```{r, warning=FALSE} -wf <- epi_workflow(r, parsnip::multinom_reg()) %>% +wf <- epi_workflow(r, multinom_reg()) %>% fit(jhu) forecast(wf) %>% filter(!is.na(.pred_class)) diff --git a/vignettes/update.Rmd b/vignettes/update.Rmd index cb19ce192..fcd3653ca 100644 --- a/vignettes/update.Rmd +++ b/vignettes/update.Rmd @@ -2,7 +2,7 @@ title: "Using the add/update/remove and adjust functions" output: rmarkdown::html_vignette vignette: > - %\VignetteIndexEntry{Using the update and adjust functions} + %\VignetteIndexEntry{Using the add/update/remove and adjust functions} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- From 9a7d4336d03645750d585f6def22395f1c5ab738 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 17:54:53 -0700 Subject: [PATCH 10/16] safely remove time value from epi_keys --- R/step_epi_shift.R | 4 ++-- R/step_epi_slide.R | 4 ++-- R/step_growth_rate.R | 6 ++++-- R/step_lag_difference.R | 2 +- vignettes/arx-classifier.Rmd | 2 +- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/R/step_epi_shift.R b/R/step_epi_shift.R index 616d98f03..465d64e7f 100644 --- a/R/step_epi_shift.R +++ b/R/step_epi_shift.R @@ -228,7 +228,7 @@ bake.step_epi_lag <- function(object, new_data, ...) { ) full_join(new_data, shifted, by = ok) %>% - group_by(across(all_of(ok[-1]))) %>% + group_by(across(all_of(kill_time_value(ok)))) %>% arrange(time_value) %>% ungroup() } @@ -259,7 +259,7 @@ bake.step_epi_ahead <- function(object, new_data, ...) { ) full_join(new_data, shifted, by = ok) %>% - group_by(across(all_of(ok[-1]))) %>% + group_by(across(all_of(kill_time_value(ok)))) %>% arrange(time_value) %>% ungroup() } diff --git a/R/step_epi_slide.R b/R/step_epi_slide.R index 180be3d51..9714971fa 100644 --- a/R/step_epi_slide.R +++ b/R/step_epi_slide.R @@ -170,7 +170,7 @@ bake.step_epi_slide <- function(object, new_data, ...) { object$columns, c(object$.f), object$f_name, - object$keys[-1], + kill_time_value(object$keys), object$prefix ) } @@ -184,7 +184,7 @@ bake.step_epi_slide <- function(object, new_data, ...) { #' using roughly equivalent tidy select style. #' #' @param fns vector of functions, even if it's length 1. -#' @param group_keys the keys to group by. likely `epi_keys[-1]` (to remove time_value) +#' @param group_keys the keys to group by. likely `epi_keys` (without `time_value`) #' #' @importFrom tidyr crossing #' @importFrom dplyr bind_cols group_by ungroup diff --git a/R/step_growth_rate.R b/R/step_growth_rate.R index 0a8c4118e..c7798cee3 100644 --- a/R/step_growth_rate.R +++ b/R/step_growth_rate.R @@ -171,7 +171,7 @@ bake.step_growth_rate <- function(object, new_data, ...) { ok <- object$keys gr <- new_data %>% - group_by(across(all_of(ok[-1]))) %>% + group_by(across(all_of(kill_time_value(ok)))) %>% dplyr::transmute( time_value = time_value, across( @@ -188,13 +188,15 @@ bake.step_growth_rate <- function(object, new_data, ...) { ungroup() %>% mutate(time_value = time_value + object$horizon) # shift x0 right + if (!is.null(object$replace_Inf)) { + browser() gr <- gr %>% mutate(across(!all_of(ok), ~ vec_replace_inf(.x, object$replace_Inf))) } left_join(new_data, gr, by = ok) %>% - group_by(across(all_of(ok[-1]))) %>% + group_by(across(all_of(kill_time_value(ok)))) %>% arrange(time_value) %>% ungroup() } diff --git a/R/step_lag_difference.R b/R/step_lag_difference.R index f3a470c7f..39ae1ba59 100644 --- a/R/step_lag_difference.R +++ b/R/step_lag_difference.R @@ -137,7 +137,7 @@ bake.step_lag_difference <- function(object, new_data, ...) { ) left_join(new_data, shifted, by = ok) %>% - group_by(across(all_of(ok[-1]))) %>% + group_by(across(all_of(kill_time_value(ok)))) %>% arrange(time_value) %>% ungroup() } diff --git a/vignettes/arx-classifier.Rmd b/vignettes/arx-classifier.Rmd index 089f000df..ae1641cce 100644 --- a/vignettes/arx-classifier.Rmd +++ b/vignettes/arx-classifier.Rmd @@ -44,7 +44,7 @@ the 0, 7 and 14 day case rates: ```{r} jhu <- case_death_rate_subset %>% - dplyr::filter( + filter( time_value >= "2021-06-04", time_value <= "2021-12-31", geo_value %in% c("ca", "fl", "tx", "ny", "nj") From aba9fa481445b01fb79ea51fddcf964f2650beaa Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 17:55:38 -0700 Subject: [PATCH 11/16] redocument, remove browser() --- R/step_growth_rate.R | 1 - man/epi_slide_wrapper.Rd | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/R/step_growth_rate.R b/R/step_growth_rate.R index c7798cee3..06f8da4cf 100644 --- a/R/step_growth_rate.R +++ b/R/step_growth_rate.R @@ -190,7 +190,6 @@ bake.step_growth_rate <- function(object, new_data, ...) { if (!is.null(object$replace_Inf)) { - browser() gr <- gr %>% mutate(across(!all_of(ok), ~ vec_replace_inf(.x, object$replace_Inf))) } diff --git a/man/epi_slide_wrapper.Rd b/man/epi_slide_wrapper.Rd index 583d2eacf..0c05b7650 100644 --- a/man/epi_slide_wrapper.Rd +++ b/man/epi_slide_wrapper.Rd @@ -18,7 +18,7 @@ epi_slide_wrapper( \arguments{ \item{fns}{vector of functions, even if it's length 1.} -\item{group_keys}{the keys to group by. likely \code{epi_keys[-1]} (to remove time_value)} +\item{group_keys}{the keys to group by. likely \code{epi_keys} (without \code{time_value})} } \description{ This should simplify somewhat in the future when we can run \code{epi_slide} on From 0976e2778ffb95b599fc2ab7397c0a2ba4c38098 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 27 Aug 2024 18:07:14 -0700 Subject: [PATCH 12/16] geo then time, no more dropping the first element to remove time_value --- R/epi_recipe.R | 2 +- tests/testthat/test-epi_shift.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index c40fda019..88ba605cd 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -95,7 +95,7 @@ epi_recipe.epi_df <- keys <- key_colnames(x) # we know x is an epi_df var_info <- tibble(variable = vars) - key_roles <- c("time_value", "geo_value", rep("key", length(keys) - 2)) + key_roles <- c("geo_value", "time_value", rep("key", length(keys) - 2)) ## Check and add roles when available if (!is.null(roles)) { diff --git a/tests/testthat/test-epi_shift.R b/tests/testthat/test-epi_shift.R index 245a39c0d..78c9384f1 100644 --- a/tests/testthat/test-epi_shift.R +++ b/tests/testthat/test-epi_shift.R @@ -25,6 +25,6 @@ test_that("epi shift single works, renames", { geo_value = "ca" ) %>% epiprocess::as_epi_df() ess <- epi_shift_single(tib, "x", 1, "test", key_colnames(tib)) - expect_named(ess, c("time_value", "geo_value", "test")) + expect_named(ess, c("geo_value", "time_value", "test")) expect_equal(ess$time_value, tib$time_value + 1) }) From a99a5530ee5cf2ad8949c7a61218e65d07ce5321 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Thu, 29 Aug 2024 09:53:15 -0700 Subject: [PATCH 13/16] remove key_colnames.list() --- R/autoplot.R | 6 +++--- R/epi_workflow.R | 5 +---- R/get_test_data.R | 2 +- R/key_colnames.R | 17 ++++++----------- R/layer_cdc_flatline_quantiles.R | 2 +- R/layer_population_scaling.R | 2 +- R/step_population_scaling.R | 2 +- R/step_training_window.R | 2 +- R/utils-misc.R | 6 +++--- 9 files changed, 18 insertions(+), 26 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 1be30dd29..d35850fd6 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -109,7 +109,7 @@ autoplot.epi_workflow <- function( y <- y[, 1] cli_warn("Multiple outcome variables were detected. Displaying only 1.") } - keys <- c("time_value", "geo_value", "key") + keys <- c("geo_value", "time_value", "key") mold_roles <- names(mold$extras$roles) edf <- bind_cols(mold$extras$roles[mold_roles %in% keys], y) if (starts_with_impl("ahead_", names(y))) { @@ -127,7 +127,7 @@ autoplot.epi_workflow <- function( if (!is.null(shift)) { edf <- mutate(edf, time_value = time_value + shift) } - extra_keys <- setdiff(key_colnames(mold), c("time_value", "geo_value")) + extra_keys <- setdiff(key_colnames(object), c("geo_value", "time_value")) if (length(extra_keys) == 0L) extra_keys <- NULL edf <- as_epi_df(edf, as_of = object$fit$meta$as_of, @@ -167,7 +167,7 @@ autoplot.epi_workflow <- function( ) # Now, prepare matching facets in the predictions - ek <- kill_time_value(key_colnames(edf)) + ek <- epi_keys_only(edf) predictions <- predictions %>% mutate( .facets = interaction(!!!rlang::syms(as.list(ek)), sep = "/"), diff --git a/R/epi_workflow.R b/R/epi_workflow.R index f715dc9b0..b059a81d0 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -164,10 +164,7 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), . components$forged <- hardhat::forge(new_data, blueprint = components$mold$blueprint ) - components$keys <- grab_forged_keys( - components$forged, - components$mold, new_data - ) + components$keys <- grab_forged_keys(components$forged, object, new_data) components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...) components$predictions } diff --git a/R/get_test_data.R b/R/get_test_data.R index ff3a146ef..694e73b06 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -91,7 +91,7 @@ get_test_data <- function( } x <- arrange(x, time_value) - groups <- kill_time_value(key_colnames(recipe)) + groups <- epi_keys_only(recipe) # If we skip NA completion, we remove undesirably early time values # Happens globally, over all groups diff --git a/R/key_colnames.R b/R/key_colnames.R index e16ce1f42..71ab0fd05 100644 --- a/R/key_colnames.R +++ b/R/key_colnames.R @@ -1,20 +1,15 @@ #' @export key_colnames.recipe <- function(x, ...) { - x$var_info$variable[x$var_info$role %in% c("time_value", "geo_value", "key")] + x$var_info$variable[x$var_info$role %in% c("geo_value", "time_value", "key")] } #' @export key_colnames.epi_workflow <- function(x, ...) { - NextMethod(hardhat::extract_mold(x)) -} - -# a mold is a list extracted from a fitted workflow, gives info about -# training data. But it doesn't have a class -#' @export -key_colnames.list <- function(x, ...) { - keys <- c("time_value", "geo_value", "key") - molded_names <- names(x$extras$roles) - mold_keys <- map(x$extras$roles[molded_names %in% keys], names) + # safer to look at the mold than the preprocessor + mold <- hardhat::extract_mold(x) + keys <- c("geo_value", "time_value", "key") + molded_names <- names(mold$extras$roles) + mold_keys <- map(mold$extras$roles[molded_names %in% keys], names) unname(unlist(mold_keys)) %||% character(0L) } diff --git a/R/layer_cdc_flatline_quantiles.R b/R/layer_cdc_flatline_quantiles.R index 9166d3469..8d16ba32f 100644 --- a/R/layer_cdc_flatline_quantiles.R +++ b/R/layer_cdc_flatline_quantiles.R @@ -169,7 +169,7 @@ slather.layer_cdc_flatline_quantiles <- )) } p <- components$predictions - ek <- kill_time_value(key_colnames(components$mold)) + ek <- epi_keys_only(workflow) r <- grab_residuals(the_fit, components) avail_grps <- character(0L) diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index a2c56233f..9275d910c 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -136,7 +136,7 @@ slather.layer_population_scaling <- rlang::check_dots_empty() object$by <- object$by %||% intersect( - kill_time_value(key_colnames(components$predictions)), + epi_keys_only(components$predictions), colnames(select(object$df, !object$df_pop_col)) ) joinby <- list(x = names(object$by) %||% object$by, y = object$by) diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index 5555bc65c..4e4d3aa26 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -157,7 +157,7 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) { #' @export bake.step_population_scaling <- function(object, new_data, ...) { object$by <- object$by %||% intersect( - kill_time_value(key_colnames(new_data)), + epi_keys_only(new_data), colnames(select(object$df, !object$df_pop_col)) ) joinby <- list(x = names(object$by) %||% object$by, y = object$by) diff --git a/R/step_training_window.R b/R/step_training_window.R index e66b08365..eafc076c7 100644 --- a/R/step_training_window.R +++ b/R/step_training_window.R @@ -77,7 +77,7 @@ step_training_window_new <- #' @export prep.step_training_window <- function(x, training, info = NULL, ...) { - ekt <- kill_time_value(key_colnames(training)) + ekt <- epi_keys_only(training) ek <- x$epi_keys %||% ekt %||% character(0L) hardhat::validate_column_names(training, ek) diff --git a/R/utils-misc.R b/R/utils-misc.R index c59afc19a..af064b37c 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -32,14 +32,14 @@ check_pname <- function(res, preds, object, newname = NULL) { } -grab_forged_keys <- function(forged, mold, new_data) { - keys <- c("time_value", "geo_value", "key") +grab_forged_keys <- function(forged, workflow, new_data) { + keys <- c("geo_value", "time_value", "key") forged_roles <- names(forged$extras$roles) extras <- dplyr::bind_cols(forged$extras$roles[forged_roles %in% keys]) # 1. these are the keys in the test data after prep/bake new_keys <- names(extras) # 2. these are the keys in the training data - old_keys <- key_colnames(mold) + old_keys <- key_colnames(workflow) # 3. these are the keys in the test data as input new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, keys[1:2])) if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) { From dcf778cd4313ec1c22d94cbc57c7504a36256ef7 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Thu, 29 Aug 2024 10:19:49 -0700 Subject: [PATCH 14/16] update tests in re @dshemetov comments --- NAMESPACE | 1 - R/key_colnames.R | 11 +++++++---- tests/testthat/test-key_colnames.R | 25 ++++++++++++++++++------- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 8d08ea12e..23c5adeaf 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -51,7 +51,6 @@ S3method(format,dist_quantiles) S3method(is.na,dist_quantiles) S3method(is.na,distribution) S3method(key_colnames,epi_workflow) -S3method(key_colnames,list) S3method(key_colnames,recipe) S3method(mean,dist_quantiles) S3method(median,dist_quantiles) diff --git a/R/key_colnames.R b/R/key_colnames.R index 71ab0fd05..c69d1a628 100644 --- a/R/key_colnames.R +++ b/R/key_colnames.R @@ -1,16 +1,19 @@ #' @export key_colnames.recipe <- function(x, ...) { - x$var_info$variable[x$var_info$role %in% c("geo_value", "time_value", "key")] + possible_keys <- c("geo_value", "time_value", "key") + keys <- x$var_info$variable[x$var_info$role %in% possible_keys] + keys[order(match(keys, possible_keys))] %||% character(0L) } #' @export key_colnames.epi_workflow <- function(x, ...) { # safer to look at the mold than the preprocessor mold <- hardhat::extract_mold(x) - keys <- c("geo_value", "time_value", "key") + possible_keys <- c("geo_value", "time_value", "key") molded_names <- names(mold$extras$roles) - mold_keys <- map(mold$extras$roles[molded_names %in% keys], names) - unname(unlist(mold_keys)) %||% character(0L) + keys <- map(mold$extras$roles[molded_names %in% possible_keys], names) + keys <- unname(unlist(keys)) + keys[order(match(keys, possible_keys))] %||% character(0L) } kill_time_value <- function(v) { diff --git a/tests/testthat/test-key_colnames.R b/tests/testthat/test-key_colnames.R index 3fecd9e44..d55a515ca 100644 --- a/tests/testthat/test-key_colnames.R +++ b/tests/testthat/test-key_colnames.R @@ -2,22 +2,21 @@ test_that("Extracts keys from a recipe; roles are NA, giving an empty vector", { expect_equal(key_colnames(recipe(case_death_rate_subset)), character(0L)) }) -test_that("epi_keys_mold extracts time_value and geo_value, but not raw", { +test_that("key_colnames extracts time_value and geo_value, but not raw", { my_recipe <- epi_recipe(case_death_rate_subset) %>% step_epi_ahead(death_rate, ahead = 7) %>% step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% step_epi_naomit() + expect_identical(key_colnames(my_recipe), c("geo_value", "time_value")) + my_workflow <- epi_workflow() %>% add_epi_recipe(my_recipe) %>% add_model(linear_reg()) %>% fit(data = case_death_rate_subset) - expect_setequal( - key_colnames(my_workflow$pre$mold), - c("time_value", "geo_value") - ) + expect_identical(key_colnames(my_workflow), c("geo_value", "time_value")) }) test_that("key_colnames extracts additional keys when they are present", { @@ -34,14 +33,26 @@ test_that("key_colnames extracts additional keys when they are present", { additional_metadata = list(other_keys = c("state", "pol")) ) + expect_identical( + key_colnames(my_data), + c("geo_value", "time_value", "state", "pol") + ) + my_recipe <- epi_recipe(my_data) %>% step_epi_ahead(value, ahead = 7) %>% step_epi_naomit() + # order of the additional keys may be different + expect_setequal( + key_colnames(my_recipe), + c("geo_value", "time_value", "state", "pol") + ) + my_workflow <- epi_workflow(my_recipe, linear_reg()) %>% fit(my_data) + # order of the additional keys may be different expect_setequal( - key_colnames(my_workflow$pre$mold), - c("time_value", "geo_value", "state", "pol") + key_colnames(my_workflow), + c("geo_value", "time_value", "state", "pol") ) }) From bec78a280bb35eeefbf98ee1852bccf1281ff598 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Thu, 29 Aug 2024 10:35:58 -0700 Subject: [PATCH 15/16] remove / revise some silent tests --- tests/testthat/test-population_scaling.R | 53 +++++++++++++++--------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index d04a6bb43..9860ab08c 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -5,18 +5,25 @@ test_that("Column names can be passed with and without the tidy way", { value = c(1000, 2000, 3000, 4000, 5000, 6000) ) - newdata <- case_death_rate_subset %>% filter(geo_value %in% c("ak", "al", "ar", "as", "az", "ca")) + pop_data2 <- pop_data %>% dplyr::rename(geo_value = states) + + newdata <- case_death_rate_subset %>% + filter(geo_value %in% c("ak", "al", "ar", "as", "az", "ca")) r1 <- epi_recipe(newdata) %>% - step_population_scaling(c("case_rate", "death_rate"), + step_population_scaling( + case_rate, death_rate, df = pop_data, - df_pop_col = "value", by = c("geo_value" = "states") + df_pop_col = "value", + by = c("geo_value" = "states") ) r2 <- epi_recipe(newdata) %>% - step_population_scaling(case_rate, death_rate, - df = pop_data, - df_pop_col = "value", by = c("geo_value" = "states") + step_population_scaling( + case_rate, death_rate, + df = pop_data2, + df_pop_col = "value", + by = "geo_value" ) prep1 <- prep(r1, newdata) @@ -56,9 +63,9 @@ test_that("Number of columns and column names returned correctly, Upper and lowe suffix = "_rate" ) - prep <- prep(r, newdata) + p <- prep(r, newdata) - expect_silent(b <- bake(prep, newdata)) + b <- bake(p, newdata) expect_equal(ncol(b), 7L) expect_true("case_rate" %in% colnames(b)) expect_true("death_rate" %in% colnames(b)) @@ -75,15 +82,15 @@ test_that("Number of columns and column names returned correctly, Upper and lowe create_new = FALSE ) - expect_warning(prep <- prep(r, newdata)) + expect_warning(p <- prep(r, newdata)) - expect_warning(b <- bake(prep, newdata)) + expect_warning(b <- bake(p, newdata)) expect_equal(ncol(b), 5L) }) ## Postprocessing test_that("Postprocessing workflow works and values correct", { - jhu <- epiprocess::jhu_csse_daily_subset %>% + jhu <- jhu_csse_daily_subset %>% dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, cases) @@ -207,9 +214,17 @@ test_that("test joining by default columns", { recipes::step_naomit(recipes::all_predictors()) %>% recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) - expect_silent(prep(r, jhu)) + p <- prep(r, jhu) + b <- bake(p, new_data = NULL) + expect_named( + b, + c("geo_value", "time_value", "case_rate", "case_rate_scaled", + paste0("lag_", c(0,7,14), "_case_rate_scaled"), + "ahead_7_case_rate_scaled" + ) + ) + - expect_silent(bake(prep(r, jhu), new_data = NULL)) f <- frosting() %>% layer_predict() %>% @@ -221,13 +236,13 @@ test_that("test joining by default columns", { df_pop_col = "values" ) - expect_silent( - wf <- epi_workflow(r, parsnip::linear_reg()) %>% - fit(jhu) %>% - add_frosting(f) - ) + wf <- epi_workflow(r, parsnip::linear_reg()) %>% + fit(jhu) %>% + add_frosting(f) - expect_silent(forecast(wf)) + fc <- forecast(wf) + expect_named(fc, c("geo_value", "time_value", ".pred", ".pred_scaled")) + expect_equal(fc$.pred_scaled, fc$.pred * c(1 / 20000, 1 / 30000)) }) From 6d8edc03a70c223d3b58dc9343b9c8cedbf7788c Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Thu, 29 Aug 2024 12:50:09 -0700 Subject: [PATCH 16/16] styler --- tests/testthat/test-population_scaling.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index 9860ab08c..a94b40b82 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -218,8 +218,9 @@ test_that("test joining by default columns", { b <- bake(p, new_data = NULL) expect_named( b, - c("geo_value", "time_value", "case_rate", "case_rate_scaled", - paste0("lag_", c(0,7,14), "_case_rate_scaled"), + c( + "geo_value", "time_value", "case_rate", "case_rate_scaled", + paste0("lag_", c(0, 7, 14), "_case_rate_scaled"), "ahead_7_case_rate_scaled" ) )