Skip to content

Commit

Permalink
Merge pull request #377 from cmu-delphi/363-no-drop-epiwf-class
Browse files Browse the repository at this point in the history
fit() no longer drops the epi_workflow class
  • Loading branch information
dshemetov authored Sep 30, 2024
2 parents db2cfee + de0add1 commit 1c01028
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 42 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.22
Version: 0.0.23
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ importFrom(rlang,":=")
importFrom(rlang,abort)
importFrom(rlang,arg_match)
importFrom(rlang,as_function)
importFrom(rlang,caller_arg)
importFrom(rlang,caller_env)
importFrom(rlang,enquo)
importFrom(rlang,enquos)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- Add `step_epi_slide` to produce generic sliding computations over an `epi_df`
- Add quantile random forests (via `{grf}`) as a parsnip engine
- Replace `epi_keys()` with `epiprocess::key_colnames()`, #352
- Fix bug where `fit()` drops the `epi_workflow` class, #363
- Try to retain the `epi_df` class during baking to the extent possible, #376
24 changes: 9 additions & 15 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@ epi_recipe <- function(x, ...) {
#' @rdname epi_recipe
#' @export
epi_recipe.default <- function(x, ...) {
## if not a formula or an epi_df, we just pass to recipes::recipe
if (is.matrix(x) || is.data.frame(x) || tibble::is_tibble(x)) {
x <- x[1, , drop = FALSE]
}
cli_warn(
"epi_recipe has been called with a non-epi_df object, returning a regular recipe. Various
step_epi_* functions will not work."
)
recipes::recipe(x, ...)
cli_abort(paste(
"`x` must be an {.cls epi_df} or a {.cls formula},",
"not a {.cls {class(x)[[1]]}}."
))
}

#' @rdname epi_recipe
Expand Down Expand Up @@ -154,17 +149,16 @@ epi_recipe.formula <- function(formula, data, ...) {
data <- data[1, ]
# check for minus:
if (!epiprocess::is_epi_df(data)) {
cli_warn(
"epi_recipe has been called with a non-epi_df object, returning a regular recipe. Various
step_epi_* functions will not work."
)
return(recipes::recipe(formula, data, ...))
cli_abort(paste(
"`epi_recipe()` has been called with a non-{.cls epi_df} object.",
"Use `recipe()` instead."
))
}

attr(data, "decay_to_tibble") <- FALSE
f_funcs <- recipes:::fun_calls(formula, data)
if (any(f_funcs == "-")) {
abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.")
cli_abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.")
}

# Check for other in-line functions
Expand Down
4 changes: 3 additions & 1 deletion R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
)
object$original_data <- data

NextMethod()
res <- NextMethod()
class(res) <- c("epi_workflow", class(res))
res
}

#' Predict from an epi_workflow
Expand Down
13 changes: 7 additions & 6 deletions R/epipredict-package.R
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
## usethis namespace: start
#' @importFrom tibble as_tibble
#' @importFrom rlang := !! %||% as_function global_env set_names !!!
#' is_logical is_true inject enquo enquos expr sym arg_match
#' @importFrom rlang := !! %||% as_function global_env set_names !!! caller_arg
#' @importFrom rlang is_logical is_true inject enquo enquos expr sym arg_match
#' @importFrom stats poly predict lm residuals quantile
#' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by
#' summarize filter mutate select left_join rename ungroup full_join
#' relocate summarise everything
#' @importFrom dplyr summarize filter mutate select left_join rename ungroup
#' @importFrom dplyr full_join relocate summarise everything
#' @importFrom cli cli_abort cli_warn
#' @importFrom checkmate assert assert_character assert_int assert_scalar
#' assert_logical assert_numeric assert_number assert_integer
#' assert_integerish assert_date assert_function assert_class
#' @importFrom checkmate assert_logical assert_numeric assert_number
#' @importFrom checkmate assert_integer assert_integerish
#' @importFrom checkmate assert_date assert_function assert_class
#' @import epiprocess parsnip
## usethis namespace: end
NULL
24 changes: 24 additions & 0 deletions tests/testthat/_snaps/epi_recipe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# epi_recipe produces error if not an epi_df

Code
epi_recipe(tib)
Condition
Error in `epi_recipe()`:
! `x` must be an <epi_df> or a <formula>, not a <tbl_df>.

---

Code
epi_recipe(y ~ x, tib)
Condition
Error in `epi_recipe()`:
! `epi_recipe()` has been called with a non-<epi_df> object. Use `recipe()` instead.

---

Code
epi_recipe(m)
Condition
Error in `epi_recipe()`:
! `x` must be an <epi_df> or a <formula>, not a <matrix>.

16 changes: 16 additions & 0 deletions tests/testthat/_snaps/epi_workflow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# fit method does not silently drop the class

Code
epi_recipe(y ~ x, data = tbl)
Condition
Error in `epi_recipe()`:
! `epi_recipe()` has been called with a non-<epi_df> object. Use `recipe()` instead.

---

Code
ewf_erec_edf %>% fit(tbl)
Condition
Error in `if (new_meta != old_meta) ...`:
! argument is of length zero

23 changes: 4 additions & 19 deletions tests/testthat/test-epi_recipe.R
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
test_that("epi_recipe produces default recipe", {
# these all call recipes::recipe(), but the template will always have 1 row
test_that("epi_recipe produces error if not an epi_df", {
tib <- tibble(
x = 1:5, y = 1:5,
time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5)
)
expected_rec <- recipes::recipe(tib)
expected_rec$template <- expected_rec$template[1, ]
expect_warning(rec <- epi_recipe(tib), regexp = "epi_recipe has been called with a non-epi_df object")
expect_identical(expected_rec, rec)
expect_equal(nrow(rec$template), 1L)

expected_rec <- recipes::recipe(y ~ x, tib)
expected_rec$template <- expected_rec$template[1, ]
expect_warning(rec <- epi_recipe(y ~ x, tib), regexp = "epi_recipe has been called with a non-epi_df object")
expect_identical(expected_rec, rec)
expect_equal(nrow(rec$template), 1L)

expect_snapshot(error = TRUE, epi_recipe(tib))
expect_snapshot(error = TRUE, epi_recipe(y ~ x, tib))
m <- as.matrix(tib)
expected_rec <- recipes::recipe(m)
expected_rec$template <- expected_rec$template[1, ]
expect_warning(rec <- epi_recipe(m), regexp = "epi_recipe has been called with a non-epi_df object")
expect_identical(expected_rec, rec)
expect_equal(nrow(rec$template), 1L)
expect_snapshot(error = TRUE, epi_recipe(m))
})

test_that("epi_recipe formula works", {
Expand Down
37 changes: 37 additions & 0 deletions tests/testthat/test-epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,40 @@ test_that("forecast method errors when workflow not fit", {

expect_error(forecast(wf))
})

test_that("fit method does not silently drop the class", {
# This is issue #363

library(recipes)
tbl <- tibble::tibble(
geo_value = 1,
time_value = 1:100,
x = 1:100,
y = x + rnorm(100L)
)
edf <- as_epi_df(tbl)

rec_tbl <- recipe(y ~ x, data = tbl)
rec_edf <- recipe(y ~ x, data = edf)
expect_snapshot(error = TRUE, epi_recipe(y ~ x, data = tbl))
erec_edf <- epi_recipe(y ~ x, data = edf)

ewf_rec_tbl <- epi_workflow(rec_tbl, linear_reg())
ewf_rec_edf <- epi_workflow(rec_edf, linear_reg())
ewf_erec_edf <- epi_workflow(erec_edf, linear_reg())

# above are all epi_workflows:

expect_s3_class(ewf_rec_tbl, "epi_workflow")
expect_s3_class(ewf_rec_edf, "epi_workflow")
expect_s3_class(ewf_erec_edf, "epi_workflow")

# but fitting drops the class or generates errors in many cases:

expect_s3_class(ewf_rec_tbl %>% fit(tbl), "epi_workflow")
expect_s3_class(ewf_rec_tbl %>% fit(edf), "epi_workflow")
expect_s3_class(ewf_rec_edf %>% fit(tbl), "epi_workflow")
expect_s3_class(ewf_rec_edf %>% fit(edf), "epi_workflow")
expect_snapshot(ewf_erec_edf %>% fit(tbl), error = TRUE)
expect_s3_class(ewf_erec_edf %>% fit(edf), "epi_workflow")
})

0 comments on commit 1c01028

Please sign in to comment.