Skip to content

Commit

Permalink
refactor: epi_recipe only accepts epi_df
Browse files Browse the repository at this point in the history
  • Loading branch information
dshemetov committed Jun 28, 2024
1 parent 14e3708 commit a5bf7b6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 30 deletions.
20 changes: 5 additions & 15 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@ epi_recipe <- function(x, ...) {
#' @rdname epi_recipe
#' @export
epi_recipe.default <- function(x, ...) {
## if not a formula or an epi_df, we just pass to recipes::recipe
if (is.matrix(x) || is.data.frame(x) || tibble::is_tibble(x)) {
x <- x[1, , drop = FALSE]
}
recipes::recipe(x, ...)
cli_abort("epi_recipe requires an epi_df")
}

#' @rdname epi_recipe
Expand Down Expand Up @@ -147,7 +143,7 @@ epi_recipe.formula <- function(formula, data, ...) {
data <- data[1, ]
# check for minus:
if (!epiprocess::is_epi_df(data)) {
return(recipes::recipe(formula, data, ...))
cli_abort("epi_recipe requires an epi_df")
}

f_funcs <- recipes:::fun_calls(formula)
Expand Down Expand Up @@ -333,15 +329,11 @@ update_epi_recipe <- function(x, recipe, ..., blueprint = default_epi_recipe_blu
#' illustrations of the different types of updates.
#'
#' @param x A `epi_workflow` or `epi_recipe` object
#'
#' @param which_step the number or name of the step to adjust
#'
#' @param ... Used to input a parameter adjustment
#'
#' @param blueprint A hardhat blueprint used for fine tuning the preprocessing.
#'
#' @return
#' `x`, updated with the adjustment to the specified `epi_recipe` step.
#' @return `x`, updated with the adjustment to the specified `epi_recipe` step.
#'
#' @export
#' @examples
Expand Down Expand Up @@ -383,17 +375,15 @@ adjust_epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_workflow <- function(
x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
adjust_epi_recipe.epi_workflow <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
recipe <- adjust_epi_recipe(workflows::extract_preprocessor(x), which_step, ...)

update_epi_recipe(x, recipe, blueprint = blueprint)
}

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_recipe <- function(
x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
adjust_epi_recipe.epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
if (!(is.numeric(which_step) || is.character(which_step))) {
cli::cli_abort(
c("`which_step` must be a number or a character.",
Expand Down
19 changes: 4 additions & 15 deletions tests/testthat/test-epi_recipe.R
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
test_that("epi_recipe produces default recipe", {
test_that("epi_recipe errors when given non-epi_df", {
# these all call recipes::recipe(), but the template will always have 1 row
tib <- tibble(
x = 1:5, y = 1:5,
time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5)
)
rec <- recipes::recipe(tib)
rec$template <- rec$template[1, ]
expect_identical(rec, epi_recipe(tib))
expect_equal(nrow(rec$template), 1L)

rec <- recipes::recipe(y ~ x, tib)
rec$template <- rec$template[1, ]
expect_identical(rec, epi_recipe(y ~ x, tib))
expect_equal(nrow(rec$template), 1L)

expect_error(epi_recipe(tib), regexp = "epi_recipe requires an epi_df")
expect_error(epi_recipe(y ~ x, tib), regexp = "epi_recipe requires an epi_df")
m <- as.matrix(tib)
rec <- recipes::recipe(m)
rec$template <- rec$template[1, ]
expect_identical(rec, epi_recipe(m))
expect_equal(nrow(rec$template), 1L)
expect_error(epi_recipe(m), regexp = "epi_recipe requires an epi_df")
})

test_that("epi_recipe formula works", {
Expand Down

0 comments on commit a5bf7b6

Please sign in to comment.