diff --git a/DESCRIPTION b/DESCRIPTION index aafeb2e42..12d86602b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.14 +Version: 0.0.15 Authors@R: c( person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), @@ -38,7 +38,7 @@ Imports: magrittr, quantreg, recipes (>= 1.0.4), - rlang, + rlang (>= 1.0.0), smoothqr, stats, tibble, diff --git a/NAMESPACE b/NAMESPACE index 106dfe5b0..abc7e99d6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -240,10 +240,11 @@ importFrom(rlang,"%@%") importFrom(rlang,"%||%") importFrom(rlang,":=") importFrom(rlang,abort) +importFrom(rlang,as_function) importFrom(rlang,caller_env) -importFrom(rlang,is_empty) +importFrom(rlang,global_env) importFrom(rlang,is_null) -importFrom(rlang,quos) +importFrom(rlang,set_names) importFrom(smoothqr,smooth_qr) importFrom(stats,as.formula) importFrom(stats,family) @@ -255,8 +256,6 @@ importFrom(stats,predict) importFrom(stats,qnorm) importFrom(stats,quantile) importFrom(stats,residuals) -importFrom(tibble,as_tibble) -importFrom(tibble,is_tibble) importFrom(tibble,tibble) importFrom(tidyr,drop_na) importFrom(vctrs,as_list_of) diff --git a/NEWS.md b/NEWS.md index 8cf0b028a..c756cb34d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -43,3 +43,6 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - `arx_fcast_epi_workflow()` and `arx_class_epi_workflow()` now default to `trainer = parsnip::logistic_reg()` to match their more canned versions. - add a `forecast()` method simplify generating forecasts +- refactor `bake.epi_recipe()` and remove `epi_juice()`. +- Revise `compat-purrr` to use the r-lang `standalone-*` version (via + `{usethis}`) diff --git a/R/bake.epi_recipe.R b/R/bake.epi_recipe.R deleted file mode 100644 index 4a083ee19..000000000 --- a/R/bake.epi_recipe.R +++ /dev/null @@ -1,104 +0,0 @@ -#' Bake an epi_recipe -#' -#' @param object A trained object such as a [recipe()] with at least -#' one preprocessing operation. -#' @param new_data An `epi_df`, data frame or tibble for whom the -#' preprocessing will be applied. If `NULL` is given to `new_data`, -#' the pre-processed _training data_ will be returned. -#' @param ... One or more selector functions to choose which variables will be -#' returned by the function. See [recipes::selections()] for -#' more details. If no selectors are given, the default is to -#' use [tidyselect::everything()]. -#' @return An `epi_df` that may have different columns than the -#' original columns in `new_data`. -#' @importFrom rlang is_empty quos -#' @importFrom tibble is_tibble as_tibble -#' @rdname bake -#' @export -bake.epi_recipe <- function(object, new_data, ...) { - if (rlang::is_missing(new_data)) { - rlang::abort("'new_data' must be either an epi_df or NULL. No value is not allowed.") - } - - if (is.null(new_data)) { - return(epi_juice(object, ...)) - } - - if (!fully_trained(object)) { - rlang::abort("At least one step has not been trained. Please run `prep`.") - } - - terms <- quos(...) - if (is_empty(terms)) { - terms <- quos(tidyselect::everything()) - } - - # In case someone used the deprecated `newdata`: - if (is.null(new_data) || is.null(ncol(new_data))) { - if (any(names(terms) == "newdata")) { - rlang::abort("Please use `new_data` instead of `newdata` with `bake`.") - } else { - rlang::abort("Please pass a data set to `new_data`.") - } - } - - if (!is_tibble(new_data)) { - new_data <- as_tibble(new_data) - } - - recipes:::check_role_requirements(object, new_data) - - recipes:::check_nominal_type(new_data, object$orig_lvls) - - # Drop completely new columns from `new_data` and reorder columns that do - # still exist to match the ordering used when training - original_names <- names(new_data) - original_training_names <- unique(object$var_info$variable) - bakeable_names <- intersect(original_training_names, original_names) - new_data <- new_data[, bakeable_names] - - n_steps <- length(object$steps) - - for (i in seq_len(n_steps)) { - step <- object$steps[[i]] - - if (recipes:::is_skipable(step)) { - next - } - - new_data <- bake(step, new_data = new_data) - - if (!is_tibble(new_data)) { - abort("bake() methods should always return tibbles") - } - } - - # Use `last_term_info`, which maintains info on all columns that got added - # and removed from the training data. This is important for skipped steps - # which might have resulted in columns not being added/removed in the test - # set. - info <- object$last_term_info - - # Now reduce to only user selected columns - out_names <- recipes_eval_select(terms, new_data, info, - check_case_weights = FALSE - ) - new_data <- new_data[, out_names] - - # The levels are not null when no nominal data are present or - # if strings_as_factors = FALSE in `prep` - if (!is.null(object$levels)) { - var_levels <- object$levels - var_levels <- var_levels[out_names] - check_values <- - vapply(var_levels, function(x) { - (!all(is.na(x))) - }, c(all = TRUE)) - var_levels <- var_levels[check_values] - if (length(var_levels) > 0) { - new_data <- recipes:::strings2factors(new_data, var_levels) - } - } - - new_data -} diff --git a/R/compat-purrr.R b/R/compat-purrr.R index 712926f73..e06038e44 100644 --- a/R/compat-purrr.R +++ b/R/compat-purrr.R @@ -1,37 +1,8 @@ -# See https://github.com/r-lib/rlang/blob/main/R/compat-purrr.R - - -map <- function(.x, .f, ...) { - .f <- rlang::as_function(.f, env = rlang::global_env()) - lapply(.x, .f, ...) -} - -walk <- function(.x, .f, ...) { - map(.x, .f, ...) - invisible(.x) -} - walk2 <- function(.x, .y, .f, ...) { map2(.x, .y, .f, ...) invisible(.x) } -map_lgl <- function(.x, .f, ...) { - .rlang_purrr_map_mold(.x, .f, logical(1), ...) -} - -map_int <- function(.x, .f, ...) { - .rlang_purrr_map_mold(.x, .f, integer(1), ...) -} - -map_dbl <- function(.x, .f, ...) { - .rlang_purrr_map_mold(.x, .f, double(1), ...) -} - -map_chr <- function(.x, .f, ...) { - .rlang_purrr_map_mold(.x, .f, character(1), ...) -} - map_vec <- function(.x, .f, ...) { out <- map(.x, .f, ...) vctrs::list_unchop(out) @@ -48,61 +19,3 @@ map2_dfr <- function(.x, .y, .f, ..., .id = NULL) { res <- map2(.x, .y, .f, ...) dplyr::bind_rows(res, .id = .id) } - -.rlang_purrr_map_mold <- function(.x, .f, .mold, ...) { - .f <- rlang::as_function(.f, env = rlang::global_env()) - out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE) - names(out) <- names(.x) - out -} - -.rlang_purrr_args_recycle <- function(args) { - lengths <- map_int(args, length) - n <- max(lengths) - - stopifnot(all(lengths == 1L | lengths == n)) - to_recycle <- lengths == 1L - args[to_recycle] <- map(args[to_recycle], function(x) rep.int(x, n)) - - args -} - -map2 <- function(.x, .y, .f, ...) { - .f <- rlang::as_function(.f, env = rlang::global_env()) - out <- mapply(.f, .x, .y, MoreArgs = list(...), SIMPLIFY = FALSE) - if (length(out) == length(.x)) { - rlang::set_names(out, names(.x)) - } else { - rlang::set_names(out, NULL) - } -} -map2_lgl <- function(.x, .y, .f, ...) { - as.vector(map2(.x, .y, .f, ...), "logical") -} -map2_int <- function(.x, .y, .f, ...) { - as.vector(map2(.x, .y, .f, ...), "integer") -} -map2_dbl <- function(.x, .y, .f, ...) { - as.vector(map2(.x, .y, .f, ...), "double") -} -map2_chr <- function(.x, .y, .f, ...) { - as.vector(map2(.x, .y, .f, ...), "character") -} -imap <- function(.x, .f, ...) { - map2(.x, names(.x) %||% seq_along(.x), .f, ...) -} - -pmap <- function(.l, .f, ...) { - .f <- as.function(.f) - args <- .rlang_purrr_args_recycle(.l) - do.call("mapply", c( - FUN = list(quote(.f)), - args, MoreArgs = quote(list(...)), - SIMPLIFY = FALSE, USE.NAMES = FALSE - )) -} - -reduce <- function(.x, .f, ..., .init) { - f <- function(x, y) .f(x, y, ...) - Reduce(f, .x, init = .init) -} diff --git a/R/epi_juice.R b/R/epi_juice.R deleted file mode 100644 index d9d23df97..000000000 --- a/R/epi_juice.R +++ /dev/null @@ -1,43 +0,0 @@ -#' Extract transformed training set -#' -#' @inheritParams bake.epi_recipe -epi_juice <- function(object, ...) { - if (!fully_trained(object)) { - rlang::abort("At least one step has not been trained. Please run `prep()`.") - } - - if (!isTRUE(object$retained)) { - rlang::abort(paste0( - "Use `retain = TRUE` in `prep()` to be able ", - "to extract the training set" - )) - } - - terms <- quos(...) - if (is_empty(terms)) { - terms <- quos(dplyr::everything()) - } - - # Get user requested columns - new_data <- object$template - out_names <- recipes_eval_select(terms, new_data, object$term_info, - check_case_weights = FALSE - ) - new_data <- new_data[, out_names] - - # Since most models require factors, do the conversion from character - if (!is.null(object$levels)) { - var_levels <- object$levels - var_levels <- var_levels[out_names] - check_values <- - vapply(var_levels, function(x) { - (!all(is.na(x))) - }, c(all = TRUE)) - var_levels <- var_levels[check_values] - if (length(var_levels) > 0) { - new_data <- recipes:::strings2factors(new_data, var_levels) - } - } - - new_data -} diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 3e5607dbb..e5182b99b 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -557,6 +557,28 @@ prep.epi_recipe <- function( x } +#' @export +bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") { + meta <- NULL + if (composition == "epi_df") { + if (is_epi_df(new_data)) { + meta <- attr(new_data, "metadata") + } else if (is_epi_df(object$template)) { + meta <- attr(object$template, "metadata") + } + composition <- "tibble" + } + new_data <- NextMethod("bake") + if (!is.null(meta)) { + new_data <- as_epi_df( + new_data, meta$geo_type, meta$time_type, meta$as_of, + meta$additional_metadata %||% list() + ) + } + new_data +} + + kill_levels <- function(x, keys) { for (i in which(names(x) %in% keys)) x[[i]] <- list(values = NA, ordered = NA) x diff --git a/R/epipredict-package.R b/R/epipredict-package.R index a3c6a208a..b6681f982 100644 --- a/R/epipredict-package.R +++ b/R/epipredict-package.R @@ -1,6 +1,6 @@ ## usethis namespace: start #' @importFrom tibble tibble -#' @importFrom rlang := !! %||% +#' @importFrom rlang := !! %||% as_function global_env set_names #' @importFrom stats poly predict lm residuals quantile #' @importFrom cli cli_abort #' @importFrom checkmate assert assert_character assert_int assert_scalar diff --git a/R/standalone-lifecycle.R b/R/import-standalone-lifecycle.R similarity index 96% rename from R/standalone-lifecycle.R rename to R/import-standalone-lifecycle.R index e1b812a6d..a1be17134 100644 --- a/R/standalone-lifecycle.R +++ b/R/import-standalone-lifecycle.R @@ -1,3 +1,7 @@ +# Standalone file: do not edit by hand +# Source: +# ---------------------------------------------------------------------- +# # --- # repo: r-lib/rlang # file: standalone-lifecycle.R @@ -94,7 +98,8 @@ deprecate_soft <- function(msg, id <- paste(id, collapse = "\n") verbosity <- .rlang_lifecycle_verbosity() - invisible(switch(verbosity, + invisible(switch( + verbosity, quiet = NULL, warning = , default = @@ -121,7 +126,8 @@ deprecate_warn <- function(msg, id <- paste(id, collapse = "\n") verbosity <- .rlang_lifecycle_verbosity() - invisible(switch(verbosity, + invisible(switch( + verbosity, quiet = NULL, warning = , default = { diff --git a/R/import-standalone-purrr.R b/R/import-standalone-purrr.R new file mode 100644 index 000000000..623142a0e --- /dev/null +++ b/R/import-standalone-purrr.R @@ -0,0 +1,240 @@ +# Standalone file: do not edit by hand +# Source: +# ---------------------------------------------------------------------- +# +# --- +# repo: r-lib/rlang +# file: standalone-purrr.R +# last-updated: 2023-02-23 +# license: https://unlicense.org +# imports: rlang +# --- +# +# This file provides a minimal shim to provide a purrr-like API on top of +# base R functions. They are not drop-in replacements but allow a similar style +# of programming. +# +# ## Changelog +# +# 2023-02-23: +# * Added `list_c()` +# +# 2022-06-07: +# * `transpose()` is now more consistent with purrr when inner names +# are not congruent (#1346). +# +# 2021-12-15: +# * `transpose()` now supports empty lists. +# +# 2021-05-21: +# * Fixed "object `x` not found" error in `imap()` (@mgirlich) +# +# 2020-04-14: +# * Removed `pluck*()` functions +# * Removed `*_cpl()` functions +# * Used `as_function()` to allow use of `~` +# * Used `.` prefix for helpers +# +# nocov start + +map <- function(.x, .f, ...) { + .f <- as_function(.f, env = global_env()) + lapply(.x, .f, ...) +} +walk <- function(.x, .f, ...) { + map(.x, .f, ...) + invisible(.x) +} + +map_lgl <- function(.x, .f, ...) { + .rlang_purrr_map_mold(.x, .f, logical(1), ...) +} +map_int <- function(.x, .f, ...) { + .rlang_purrr_map_mold(.x, .f, integer(1), ...) +} +map_dbl <- function(.x, .f, ...) { + .rlang_purrr_map_mold(.x, .f, double(1), ...) +} +map_chr <- function(.x, .f, ...) { + .rlang_purrr_map_mold(.x, .f, character(1), ...) +} +.rlang_purrr_map_mold <- function(.x, .f, .mold, ...) { + .f <- as_function(.f, env = global_env()) + out <- vapply(.x, .f, .mold, ..., USE.NAMES = FALSE) + names(out) <- names(.x) + out +} + +map2 <- function(.x, .y, .f, ...) { + .f <- as_function(.f, env = global_env()) + out <- mapply(.f, .x, .y, MoreArgs = list(...), SIMPLIFY = FALSE) + if (length(out) == length(.x)) { + set_names(out, names(.x)) + } else { + set_names(out, NULL) + } +} +map2_lgl <- function(.x, .y, .f, ...) { + as.vector(map2(.x, .y, .f, ...), "logical") +} +map2_int <- function(.x, .y, .f, ...) { + as.vector(map2(.x, .y, .f, ...), "integer") +} +map2_dbl <- function(.x, .y, .f, ...) { + as.vector(map2(.x, .y, .f, ...), "double") +} +map2_chr <- function(.x, .y, .f, ...) { + as.vector(map2(.x, .y, .f, ...), "character") +} +imap <- function(.x, .f, ...) { + map2(.x, names(.x) %||% seq_along(.x), .f, ...) +} + +pmap <- function(.l, .f, ...) { + .f <- as.function(.f) + args <- .rlang_purrr_args_recycle(.l) + do.call("mapply", c( + FUN = list(quote(.f)), + args, MoreArgs = quote(list(...)), + SIMPLIFY = FALSE, USE.NAMES = FALSE + )) +} +.rlang_purrr_args_recycle <- function(args) { + lengths <- map_int(args, length) + n <- max(lengths) + + stopifnot(all(lengths == 1L | lengths == n)) + to_recycle <- lengths == 1L + args[to_recycle] <- map(args[to_recycle], function(x) rep.int(x, n)) + + args +} + +keep <- function(.x, .f, ...) { + .x[.rlang_purrr_probe(.x, .f, ...)] +} +discard <- function(.x, .p, ...) { + sel <- .rlang_purrr_probe(.x, .p, ...) + .x[is.na(sel) | !sel] +} +map_if <- function(.x, .p, .f, ...) { + matches <- .rlang_purrr_probe(.x, .p) + .x[matches] <- map(.x[matches], .f, ...) + .x +} +.rlang_purrr_probe <- function(.x, .p, ...) { + if (is_logical(.p)) { + stopifnot(length(.p) == length(.x)) + .p + } else { + .p <- as_function(.p, env = global_env()) + map_lgl(.x, .p, ...) + } +} + +compact <- function(.x) { + Filter(length, .x) +} + +transpose <- function(.l) { + if (!length(.l)) { + return(.l) + } + + inner_names <- names(.l[[1]]) + + if (is.null(inner_names)) { + fields <- seq_along(.l[[1]]) + } else { + fields <- set_names(inner_names) + .l <- map(.l, function(x) { + if (is.null(names(x))) { + set_names(x, inner_names) + } else { + x + } + }) + } + + # This way missing fields are subsetted as `NULL` instead of causing + # an error + .l <- map(.l, as.list) + + map(fields, function(i) { + map(.l, .subset2, i) + }) +} + +every <- function(.x, .p, ...) { + .p <- as_function(.p, env = global_env()) + + for (i in seq_along(.x)) { + if (!rlang::is_true(.p(.x[[i]], ...))) return(FALSE) + } + TRUE +} +some <- function(.x, .p, ...) { + .p <- as_function(.p, env = global_env()) + + for (i in seq_along(.x)) { + if (rlang::is_true(.p(.x[[i]], ...))) return(TRUE) + } + FALSE +} +negate <- function(.p) { + .p <- as_function(.p, env = global_env()) + function(...) !.p(...) +} + +reduce <- function(.x, .f, ..., .init) { + f <- function(x, y) .f(x, y, ...) + Reduce(f, .x, init = .init) +} +reduce_right <- function(.x, .f, ..., .init) { + f <- function(x, y) .f(y, x, ...) + Reduce(f, .x, init = .init, right = TRUE) +} +accumulate <- function(.x, .f, ..., .init) { + f <- function(x, y) .f(x, y, ...) + Reduce(f, .x, init = .init, accumulate = TRUE) +} +accumulate_right <- function(.x, .f, ..., .init) { + f <- function(x, y) .f(y, x, ...) + Reduce(f, .x, init = .init, right = TRUE, accumulate = TRUE) +} + +detect <- function(.x, .f, ..., .right = FALSE, .p = is_true) { + .p <- as_function(.p, env = global_env()) + .f <- as_function(.f, env = global_env()) + + for (i in .rlang_purrr_index(.x, .right)) { + if (.p(.f(.x[[i]], ...))) { + return(.x[[i]]) + } + } + NULL +} +detect_index <- function(.x, .f, ..., .right = FALSE, .p = is_true) { + .p <- as_function(.p, env = global_env()) + .f <- as_function(.f, env = global_env()) + + for (i in .rlang_purrr_index(.x, .right)) { + if (.p(.f(.x[[i]], ...))) { + return(i) + } + } + 0L +} +.rlang_purrr_index <- function(x, right = FALSE) { + idx <- seq_along(x) + if (right) { + idx <- rev(idx) + } + idx +} + +list_c <- function(x) { + inject(c(!!!x)) +} + +# nocov end diff --git a/R/layers.R b/R/layers.R index c93423d32..b59e95cdd 100644 --- a/R/layers.R +++ b/R/layers.R @@ -144,11 +144,12 @@ pull_layer_name <- function(x) { #' @export #' @rdname layer-processors -validate_layer <- function(x, ..., arg = "`x`", call = caller_env()) { +validate_layer <- function(x, ..., arg = rlang::caller_arg(x), + call = caller_env()) { rlang::check_dots_empty() if (!is_layer(x)) { - glubort( - "{arg} must be a frosting layer, not a {class(x)[[1]]}.", + cli::cli_abort( + "{arg} must be a frosting layer, not a {.cls {class(x)[[1]]}}.", .call = call ) } diff --git a/R/utils-cli.R b/R/utils-cli.R index ad43c95eb..3b1555941 100644 --- a/R/utils-cli.R +++ b/R/utils-cli.R @@ -18,11 +18,6 @@ cli_warn <- function(..., .envir = parent.frame()) { } #' @importFrom rlang caller_env -glubort <- - function(..., .sep = "", .envir = caller_env(), .call = .envir) { - rlang::abort(glue::glue(..., .sep = .sep, .envir = .envir), call = .call) - } - cat_line <- function(...) { cat(paste0(..., collapse = "\n"), "\n", sep = "") } diff --git a/_pkgdown.yml b/_pkgdown.yml index a0edb663e..62f1cc25d 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -92,7 +92,6 @@ reference: contents: - starts_with("step_") - contains("bake") - - contains("juice") - title: Epi recipe verification checks contents: - check_enough_train_data diff --git a/man/bake.Rd b/man/bake.Rd deleted file mode 100644 index c1c0137c5..000000000 --- a/man/bake.Rd +++ /dev/null @@ -1,28 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/bake.epi_recipe.R -\name{bake.epi_recipe} -\alias{bake.epi_recipe} -\title{Bake an epi_recipe} -\usage{ -\method{bake}{epi_recipe}(object, new_data, ...) -} -\arguments{ -\item{object}{A trained object such as a \code{\link[=recipe]{recipe()}} with at least -one preprocessing operation.} - -\item{new_data}{An \code{epi_df}, data frame or tibble for whom the -preprocessing will be applied. If \code{NULL} is given to \code{new_data}, -the pre-processed \emph{training data} will be returned.} - -\item{...}{One or more selector functions to choose which variables will be -returned by the function. See \code{\link[recipes:selections]{recipes::selections()}} for -more details. If no selectors are given, the default is to -use \code{\link[tidyselect:everything]{tidyselect::everything()}}.} -} -\value{ -An \code{epi_df} that may have different columns than the -original columns in \code{new_data}. -} -\description{ -Bake an epi_recipe -} diff --git a/man/epi_juice.Rd b/man/epi_juice.Rd deleted file mode 100644 index 38eccb9a9..000000000 --- a/man/epi_juice.Rd +++ /dev/null @@ -1,20 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/epi_juice.R -\name{epi_juice} -\alias{epi_juice} -\title{Extract transformed training set} -\usage{ -epi_juice(object, ...) -} -\arguments{ -\item{object}{A trained object such as a \code{\link[=recipe]{recipe()}} with at least -one preprocessing operation.} - -\item{...}{One or more selector functions to choose which variables will be -returned by the function. See \code{\link[recipes:selections]{recipes::selections()}} for -more details. If no selectors are given, the default is to -use \code{\link[tidyselect:everything]{tidyselect::everything()}}.} -} -\description{ -Extract transformed training set -} diff --git a/man/layer-processors.Rd b/man/layer-processors.Rd index 0c6df8c5c..76e230a7b 100644 --- a/man/layer-processors.Rd +++ b/man/layer-processors.Rd @@ -20,7 +20,7 @@ extract_layers(x, ...) is_layer(x) -validate_layer(x, ..., arg = "`x`", call = caller_env()) +validate_layer(x, ..., arg = rlang::caller_arg(x), call = caller_env()) detect_layer(x, name, ...) diff --git a/tests/testthat/test-bake-method.R b/tests/testthat/test-bake-method.R new file mode 100644 index 000000000..0e2746cf2 --- /dev/null +++ b/tests/testthat/test-bake-method.R @@ -0,0 +1,29 @@ +test_that("bake method works in all cases", { + edf <- case_death_rate_subset %>% + filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) + r <- epi_recipe(edf) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_ahead(death_rate, ahead = 7) + + r2 <- epi_recipe(edf) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_naomit() + + b_null <- bake(prep(r, edf), NULL) + b_train <- bake(prep(r, edf), edf) + expect_s3_class(b_null, "epi_df") + expect_identical(b_null, b_train) + + b_baked <- bake(prep(r2, edf), edf) # leaves rows with NA in the response + # doesn't (because we "juice", so skip doesn't apply) + b_juiced <- bake(prep(r2, edf), NULL) + expect_equal(nrow(b_juiced), sum(complete.cases(b_train))) + expect_equal(nrow(b_baked), sum(complete.cases(b_train)) + 3 * 7) + + # check that the {recipes} behaves + expect_s3_class(bake(prep(r, edf), NULL, composition = "tibble"), "tbl_df") + expect_s3_class(bake(prep(r, edf), NULL, composition = "data.frame"), "data.frame") + # can't be a matrix because time_value/geo_value aren't numeric + expect_error(bake(prep(r, edf), NULL, composition = "matrix")) +}) diff --git a/tests/testthat/test-check_enough_train_data.R b/tests/testthat/test-check_enough_train_data.R index 5eae01bb2..02b9e35e4 100644 --- a/tests/testthat/test-check_enough_train_data.R +++ b/tests/testthat/test-check_enough_train_data.R @@ -75,7 +75,7 @@ test_that("check_enough_train_data outputs the correct recipe values", { expect_equal(nrow(p), 2 * n) expect_equal(ncol(p), 4L) expect_s3_class(p, "epi_df") - expect_named(p, c("time_value", "geo_value", "x", "y")) + expect_named(p, c("geo_value", "time_value", "x", "y")) # order in epiprocess::new_epi_df expect_equal( p$time_value, rep(seq(as.Date("2020-01-01"), by = 1, length.out = n), times = 2) diff --git a/tests/testthat/test-step_training_window.R b/tests/testthat/test-step_training_window.R index c8a17f43f..9ec1e5982 100644 --- a/tests/testthat/test-step_training_window.R +++ b/tests/testthat/test-step_training_window.R @@ -17,7 +17,7 @@ test_that("step_training_window works with default n_recent", { expect_equal(nrow(p), 100L) expect_equal(ncol(p), 4L) expect_s3_class(p, "epi_df") - expect_named(p, c("time_value", "geo_value", "x", "y")) + expect_named(p, c("geo_value", "time_value", "x", "y")) # order in epiprocess::new_epi_df expect_equal( p$time_value, rep(seq(as.Date("2020-02-20"), as.Date("2020-04-09"), by = 1), times = 2) @@ -34,7 +34,7 @@ test_that("step_training_window works with specified n_recent", { expect_equal(nrow(p2), 10L) expect_equal(ncol(p2), 4L) expect_s3_class(p2, "epi_df") - expect_named(p2, c("time_value", "geo_value", "x", "y")) + expect_named(p2, c("geo_value", "time_value", "x", "y")) # order in epiprocess::new_epi_df expect_equal( p2$time_value, rep(seq(as.Date("2020-04-05"), as.Date("2020-04-09"), by = 1), times = 2) @@ -55,7 +55,7 @@ test_that("step_training_window does not proceed with specified new_data", { expect_equal(ncol(p3), 4L) expect_s3_class(p3, "epi_df") # cols will be predictors, outcomes, time_value, geo_value - expect_named(p3, c("x", "y", "time_value", "geo_value")) + expect_named(p3, c("geo_value", "time_value", "x", "y")) # order in epiprocess::new_epi_df expect_equal( p3$time_value, rep(seq(as.Date("2020-01-01"), as.Date("2020-01-10"), by = 1), times = 1) @@ -84,7 +84,7 @@ test_that("step_training_window works with multiple keys", { expect_equal(nrow(p4), 12L) expect_equal(ncol(p4), 5L) expect_s3_class(p4, "epi_df") - expect_named(p4, c("time_value", "geo_value", "additional_key", "x", "y")) + expect_named(p4, c("geo_value", "time_value", "additional_key", "x", "y")) expect_equal( p4$time_value, rep(c(