Skip to content

Commit

Permalink
refactor: epi_recipe warns when given non-epi_df
Browse files Browse the repository at this point in the history
  • Loading branch information
dshemetov committed Jun 28, 2024
1 parent 14e3708 commit 7c888f1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
20 changes: 11 additions & 9 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ epi_recipe.default <- function(x, ...) {
if (is.matrix(x) || is.data.frame(x) || tibble::is_tibble(x)) {
x <- x[1, , drop = FALSE]
}
cli_warn(
"epi_recipe has been called with a non-epi_df object, returning a regular recipe. Various
step_epi_* functions will not work."
)
recipes::recipe(x, ...)
}

Expand Down Expand Up @@ -147,6 +151,10 @@ epi_recipe.formula <- function(formula, data, ...) {
data <- data[1, ]
# check for minus:
if (!epiprocess::is_epi_df(data)) {
cli_warn(
"epi_recipe has been called with a non-epi_df object, returning a regular recipe. Various
step_epi_* functions will not work."
)
return(recipes::recipe(formula, data, ...))
}

Expand Down Expand Up @@ -333,15 +341,11 @@ update_epi_recipe <- function(x, recipe, ..., blueprint = default_epi_recipe_blu
#' illustrations of the different types of updates.
#'
#' @param x A `epi_workflow` or `epi_recipe` object
#'
#' @param which_step the number or name of the step to adjust
#'
#' @param ... Used to input a parameter adjustment
#'
#' @param blueprint A hardhat blueprint used for fine tuning the preprocessing.
#'
#' @return
#' `x`, updated with the adjustment to the specified `epi_recipe` step.
#' @return `x`, updated with the adjustment to the specified `epi_recipe` step.
#'
#' @export
#' @examples
Expand Down Expand Up @@ -383,17 +387,15 @@ adjust_epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_workflow <- function(
x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
adjust_epi_recipe.epi_workflow <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
recipe <- adjust_epi_recipe(workflows::extract_preprocessor(x), which_step, ...)

update_epi_recipe(x, recipe, blueprint = blueprint)
}

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_recipe <- function(
x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
adjust_epi_recipe.epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
if (!(is.numeric(which_step) || is.character(which_step))) {
cli::cli_abort(
c("`which_step` must be a number or a character.",
Expand Down
21 changes: 12 additions & 9 deletions tests/testthat/test-epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@ test_that("epi_recipe produces default recipe", {
x = 1:5, y = 1:5,
time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5)
)
rec <- recipes::recipe(tib)
rec$template <- rec$template[1, ]
expect_identical(rec, epi_recipe(tib))
expected_rec <- recipes::recipe(tib)
expected_rec$template <- expected_rec$template[1, ]
expect_warning(rec <- epi_recipe(tib), regexp = "epi_recipe has been called with a non-epi_df object")
expect_identical(expected_rec, rec)
expect_equal(nrow(rec$template), 1L)

rec <- recipes::recipe(y ~ x, tib)
rec$template <- rec$template[1, ]
expect_identical(rec, epi_recipe(y ~ x, tib))
expected_rec <- recipes::recipe(y ~ x, tib)
expected_rec$template <- expected_rec$template[1, ]
expect_warning(rec <- epi_recipe(y ~ x, tib), regexp = "epi_recipe has been called with a non-epi_df object")
expect_identical(expected_rec, rec)
expect_equal(nrow(rec$template), 1L)

m <- as.matrix(tib)
rec <- recipes::recipe(m)
rec$template <- rec$template[1, ]
expect_identical(rec, epi_recipe(m))
expected_rec <- recipes::recipe(m)
expected_rec$template <- expected_rec$template[1, ]
expect_warning(rec <- epi_recipe(m), regexp = "epi_recipe has been called with a non-epi_df object")
expect_identical(expected_rec, rec)
expect_equal(nrow(rec$template), 1L)
})

Expand Down

0 comments on commit 7c888f1

Please sign in to comment.