From 7c888f13828d79af6b961e2923706089cb33c7f1 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Thu, 27 Jun 2024 17:12:34 -0700 Subject: [PATCH] refactor: epi_recipe warns when given non-epi_df --- R/epi_recipe.R | 20 +++++++++++--------- tests/testthat/test-epi_recipe.R | 21 ++++++++++++--------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/R/epi_recipe.R b/R/epi_recipe.R index e5182b99b..b40d7e510 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -20,6 +20,10 @@ epi_recipe.default <- function(x, ...) { if (is.matrix(x) || is.data.frame(x) || tibble::is_tibble(x)) { x <- x[1, , drop = FALSE] } + cli_warn( + "epi_recipe has been called with a non-epi_df object, returning a regular recipe. Various + step_epi_* functions will not work." + ) recipes::recipe(x, ...) } @@ -147,6 +151,10 @@ epi_recipe.formula <- function(formula, data, ...) { data <- data[1, ] # check for minus: if (!epiprocess::is_epi_df(data)) { + cli_warn( + "epi_recipe has been called with a non-epi_df object, returning a regular recipe. Various + step_epi_* functions will not work." + ) return(recipes::recipe(formula, data, ...)) } @@ -333,15 +341,11 @@ update_epi_recipe <- function(x, recipe, ..., blueprint = default_epi_recipe_blu #' illustrations of the different types of updates. #' #' @param x A `epi_workflow` or `epi_recipe` object -#' #' @param which_step the number or name of the step to adjust -#' #' @param ... Used to input a parameter adjustment -#' #' @param blueprint A hardhat blueprint used for fine tuning the preprocessing. #' -#' @return -#' `x`, updated with the adjustment to the specified `epi_recipe` step. +#' @return `x`, updated with the adjustment to the specified `epi_recipe` step. #' #' @export #' @examples @@ -383,8 +387,7 @@ adjust_epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe #' @rdname adjust_epi_recipe #' @export -adjust_epi_recipe.epi_workflow <- function( - x, which_step, ..., blueprint = default_epi_recipe_blueprint()) { +adjust_epi_recipe.epi_workflow <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) { recipe <- adjust_epi_recipe(workflows::extract_preprocessor(x), which_step, ...) update_epi_recipe(x, recipe, blueprint = blueprint) @@ -392,8 +395,7 @@ adjust_epi_recipe.epi_workflow <- function( #' @rdname adjust_epi_recipe #' @export -adjust_epi_recipe.epi_recipe <- function( - x, which_step, ..., blueprint = default_epi_recipe_blueprint()) { +adjust_epi_recipe.epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) { if (!(is.numeric(which_step) || is.character(which_step))) { cli::cli_abort( c("`which_step` must be a number or a character.", diff --git a/tests/testthat/test-epi_recipe.R b/tests/testthat/test-epi_recipe.R index d288ec058..75726652d 100644 --- a/tests/testthat/test-epi_recipe.R +++ b/tests/testthat/test-epi_recipe.R @@ -4,20 +4,23 @@ test_that("epi_recipe produces default recipe", { x = 1:5, y = 1:5, time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5) ) - rec <- recipes::recipe(tib) - rec$template <- rec$template[1, ] - expect_identical(rec, epi_recipe(tib)) + expected_rec <- recipes::recipe(tib) + expected_rec$template <- expected_rec$template[1, ] + expect_warning(rec <- epi_recipe(tib), regexp = "epi_recipe has been called with a non-epi_df object") + expect_identical(expected_rec, rec) expect_equal(nrow(rec$template), 1L) - rec <- recipes::recipe(y ~ x, tib) - rec$template <- rec$template[1, ] - expect_identical(rec, epi_recipe(y ~ x, tib)) + expected_rec <- recipes::recipe(y ~ x, tib) + expected_rec$template <- expected_rec$template[1, ] + expect_warning(rec <- epi_recipe(y ~ x, tib), regexp = "epi_recipe has been called with a non-epi_df object") + expect_identical(expected_rec, rec) expect_equal(nrow(rec$template), 1L) m <- as.matrix(tib) - rec <- recipes::recipe(m) - rec$template <- rec$template[1, ] - expect_identical(rec, epi_recipe(m)) + expected_rec <- recipes::recipe(m) + expected_rec$template <- expected_rec$template[1, ] + expect_warning(rec <- epi_recipe(m), regexp = "epi_recipe has been called with a non-epi_df object") + expect_identical(expected_rec, rec) expect_equal(nrow(rec$template), 1L) })