diff --git a/DESCRIPTION b/DESCRIPTION index f70ebde14..77118287f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.22 +Version: 0.0.23 Authors@R: c( person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), diff --git a/NAMESPACE b/NAMESPACE index 23c5adeaf..5dea128ac 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -272,6 +272,7 @@ importFrom(rlang,":=") importFrom(rlang,abort) importFrom(rlang,arg_match) importFrom(rlang,as_function) +importFrom(rlang,caller_arg) importFrom(rlang,caller_env) importFrom(rlang,enquo) importFrom(rlang,enquos) diff --git a/NEWS.md b/NEWS.md index 5d1082c2b..7b28de6e5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -57,4 +57,5 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - Add `step_epi_slide` to produce generic sliding computations over an `epi_df` - Add quantile random forests (via `{grf}`) as a parsnip engine - Replace `epi_keys()` with `epiprocess::key_colnames()`, #352 +- Fix bug where `fit()` drops the `epi_workflow` class, #363 - Try to retain the `epi_df` class during baking to the extent possible, #376 diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 2c1ffffa1..f8216c2af 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -16,15 +16,10 @@ epi_recipe <- function(x, ...) { #' @rdname epi_recipe #' @export epi_recipe.default <- function(x, ...) { - ## if not a formula or an epi_df, we just pass to recipes::recipe - if (is.matrix(x) || is.data.frame(x) || tibble::is_tibble(x)) { - x <- x[1, , drop = FALSE] - } - cli_warn( - "epi_recipe has been called with a non-epi_df object, returning a regular recipe. Various - step_epi_* functions will not work." - ) - recipes::recipe(x, ...) + cli_abort(paste( + "`x` must be an {.cls epi_df} or a {.cls formula},", + "not a {.cls {class(x)[[1]]}}." + )) } #' @rdname epi_recipe @@ -154,17 +149,16 @@ epi_recipe.formula <- function(formula, data, ...) { data <- data[1, ] # check for minus: if (!epiprocess::is_epi_df(data)) { - cli_warn( - "epi_recipe has been called with a non-epi_df object, returning a regular recipe. Various - step_epi_* functions will not work." - ) - return(recipes::recipe(formula, data, ...)) + cli_abort(paste( + "`epi_recipe()` has been called with a non-{.cls epi_df} object.", + "Use `recipe()` instead." + )) } attr(data, "decay_to_tibble") <- FALSE f_funcs <- recipes:::fun_calls(formula, data) if (any(f_funcs == "-")) { - abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.") + cli_abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.") } # Check for other in-line functions diff --git a/R/epi_workflow.R b/R/epi_workflow.R index f448f4aff..af4555303 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -103,7 +103,9 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor ) object$original_data <- data - NextMethod() + res <- NextMethod() + class(res) <- c("epi_workflow", class(res)) + res } #' Predict from an epi_workflow diff --git a/R/epipredict-package.R b/R/epipredict-package.R index 6460b65e4..733ab9755 100644 --- a/R/epipredict-package.R +++ b/R/epipredict-package.R @@ -1,15 +1,16 @@ ## usethis namespace: start #' @importFrom tibble as_tibble -#' @importFrom rlang := !! %||% as_function global_env set_names !!! -#' is_logical is_true inject enquo enquos expr sym arg_match +#' @importFrom rlang := !! %||% as_function global_env set_names !!! caller_arg +#' @importFrom rlang is_logical is_true inject enquo enquos expr sym arg_match #' @importFrom stats poly predict lm residuals quantile #' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by -#' summarize filter mutate select left_join rename ungroup full_join -#' relocate summarise everything +#' @importFrom dplyr summarize filter mutate select left_join rename ungroup +#' @importFrom dplyr full_join relocate summarise everything #' @importFrom cli cli_abort cli_warn #' @importFrom checkmate assert assert_character assert_int assert_scalar -#' assert_logical assert_numeric assert_number assert_integer -#' assert_integerish assert_date assert_function assert_class +#' @importFrom checkmate assert_logical assert_numeric assert_number +#' @importFrom checkmate assert_integer assert_integerish +#' @importFrom checkmate assert_date assert_function assert_class #' @import epiprocess parsnip ## usethis namespace: end NULL diff --git a/tests/testthat/_snaps/epi_recipe.md b/tests/testthat/_snaps/epi_recipe.md new file mode 100644 index 000000000..3d797461d --- /dev/null +++ b/tests/testthat/_snaps/epi_recipe.md @@ -0,0 +1,24 @@ +# epi_recipe produces error if not an epi_df + + Code + epi_recipe(tib) + Condition + Error in `epi_recipe()`: + ! `x` must be an or a , not a . + +--- + + Code + epi_recipe(y ~ x, tib) + Condition + Error in `epi_recipe()`: + ! `epi_recipe()` has been called with a non- object. Use `recipe()` instead. + +--- + + Code + epi_recipe(m) + Condition + Error in `epi_recipe()`: + ! `x` must be an or a , not a . + diff --git a/tests/testthat/_snaps/epi_workflow.md b/tests/testthat/_snaps/epi_workflow.md new file mode 100644 index 000000000..d46dad6c1 --- /dev/null +++ b/tests/testthat/_snaps/epi_workflow.md @@ -0,0 +1,16 @@ +# fit method does not silently drop the class + + Code + epi_recipe(y ~ x, data = tbl) + Condition + Error in `epi_recipe()`: + ! `epi_recipe()` has been called with a non- object. Use `recipe()` instead. + +--- + + Code + ewf_erec_edf %>% fit(tbl) + Condition + Error in `if (new_meta != old_meta) ...`: + ! argument is of length zero + diff --git a/tests/testthat/test-epi_recipe.R b/tests/testthat/test-epi_recipe.R index a4cbb00b4..f8933b018 100644 --- a/tests/testthat/test-epi_recipe.R +++ b/tests/testthat/test-epi_recipe.R @@ -1,27 +1,12 @@ -test_that("epi_recipe produces default recipe", { - # these all call recipes::recipe(), but the template will always have 1 row +test_that("epi_recipe produces error if not an epi_df", { tib <- tibble( x = 1:5, y = 1:5, time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5) ) - expected_rec <- recipes::recipe(tib) - expected_rec$template <- expected_rec$template[1, ] - expect_warning(rec <- epi_recipe(tib), regexp = "epi_recipe has been called with a non-epi_df object") - expect_identical(expected_rec, rec) - expect_equal(nrow(rec$template), 1L) - - expected_rec <- recipes::recipe(y ~ x, tib) - expected_rec$template <- expected_rec$template[1, ] - expect_warning(rec <- epi_recipe(y ~ x, tib), regexp = "epi_recipe has been called with a non-epi_df object") - expect_identical(expected_rec, rec) - expect_equal(nrow(rec$template), 1L) - + expect_snapshot(error = TRUE, epi_recipe(tib)) + expect_snapshot(error = TRUE, epi_recipe(y ~ x, tib)) m <- as.matrix(tib) - expected_rec <- recipes::recipe(m) - expected_rec$template <- expected_rec$template[1, ] - expect_warning(rec <- epi_recipe(m), regexp = "epi_recipe has been called with a non-epi_df object") - expect_identical(expected_rec, rec) - expect_equal(nrow(rec$template), 1L) + expect_snapshot(error = TRUE, epi_recipe(m)) }) test_that("epi_recipe formula works", { diff --git a/tests/testthat/test-epi_workflow.R b/tests/testthat/test-epi_workflow.R index 09dd6fe82..01eff4209 100644 --- a/tests/testthat/test-epi_workflow.R +++ b/tests/testthat/test-epi_workflow.R @@ -105,3 +105,40 @@ test_that("forecast method errors when workflow not fit", { expect_error(forecast(wf)) }) + +test_that("fit method does not silently drop the class", { + # This is issue #363 + + library(recipes) + tbl <- tibble::tibble( + geo_value = 1, + time_value = 1:100, + x = 1:100, + y = x + rnorm(100L) + ) + edf <- as_epi_df(tbl) + + rec_tbl <- recipe(y ~ x, data = tbl) + rec_edf <- recipe(y ~ x, data = edf) + expect_snapshot(error = TRUE, epi_recipe(y ~ x, data = tbl)) + erec_edf <- epi_recipe(y ~ x, data = edf) + + ewf_rec_tbl <- epi_workflow(rec_tbl, linear_reg()) + ewf_rec_edf <- epi_workflow(rec_edf, linear_reg()) + ewf_erec_edf <- epi_workflow(erec_edf, linear_reg()) + + # above are all epi_workflows: + + expect_s3_class(ewf_rec_tbl, "epi_workflow") + expect_s3_class(ewf_rec_edf, "epi_workflow") + expect_s3_class(ewf_erec_edf, "epi_workflow") + + # but fitting drops the class or generates errors in many cases: + + expect_s3_class(ewf_rec_tbl %>% fit(tbl), "epi_workflow") + expect_s3_class(ewf_rec_tbl %>% fit(edf), "epi_workflow") + expect_s3_class(ewf_rec_edf %>% fit(tbl), "epi_workflow") + expect_s3_class(ewf_rec_edf %>% fit(edf), "epi_workflow") + expect_snapshot(ewf_erec_edf %>% fit(tbl), error = TRUE) + expect_s3_class(ewf_erec_edf %>% fit(edf), "epi_workflow") +})