Skip to content

Commit

Permalink
Sometimes allow passing type, opts, ... via predict.epi_workflow()
Browse files Browse the repository at this point in the history
  • Loading branch information
brookslogan committed Jul 18, 2024
1 parent 58a3674 commit 1c9b308
Show file tree
Hide file tree
Showing 19 changed files with 127 additions and 26 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ import(parsnip)
import(recipes)
importFrom(checkmate,assert)
importFrom(checkmate,assert_character)
importFrom(checkmate,assert_class)
importFrom(checkmate,assert_date)
importFrom(checkmate,assert_function)
importFrom(checkmate,assert_int)
Expand Down
4 changes: 2 additions & 2 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
#'
#' preds <- predict(wf, latest)
#' preds
predict.epi_workflow <- function(object, new_data, ...) {
predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), ...) {
if (!workflows::is_trained_workflow(object)) {
cli::cli_abort(c(
"Can't predict on an untrained epi_workflow.",
Expand All @@ -168,7 +168,7 @@ predict.epi_workflow <- function(object, new_data, ...) {
components$forged,
components$mold, new_data
)
components <- apply_frosting(object, components, new_data, ...)
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
components$predictions
}

Expand Down
2 changes: 1 addition & 1 deletion R/epipredict-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @importFrom cli cli_abort
#' @importFrom checkmate assert assert_character assert_int assert_scalar
#' assert_logical assert_numeric assert_number assert_integer
#' assert_integerish assert_date assert_function
#' assert_integerish assert_date assert_function assert_class
#' @import epiprocess parsnip
## usethis namespace: end
NULL
22 changes: 20 additions & 2 deletions R/frosting.R
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ apply_frosting.default <- function(workflow, components, ...) {
#' @importFrom rlang abort
#' @export
apply_frosting.epi_workflow <-
function(workflow, components, new_data, ...) {
function(workflow, components, new_data, type = NULL, opts = list(), ...) {
the_fit <- workflows::extract_fit_parsnip(workflow)

if (!has_postprocessor(workflow)) {
Expand Down Expand Up @@ -397,10 +397,28 @@ apply_frosting.epi_workflow <-
layers
)
}
if (length(layers) > 1L &&
(!is.null(type) || !identical(opts, list()) || rlang::dots_n(...) > 0L)) {
cli_abort("
Passing `type`, `opts`, or `...` into `predict.epi_workflow()` is not
supported if you have frosting layers other than `layer_predict`. Please
provide these arguments earlier (i.e. while constructing the frosting
object) by passing them into an explicit call to `layer_predict(), and
adjust the remaining layers to account for resulting differences in
output format from these settings.
", class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers")
}

for (l in seq_along(layers)) {
la <- layers[[l]]
components <- slather(la, components, workflow, new_data)
if (inherits(la, "layer_predict")) {
components <- slather(la, components, workflow, new_data, type = type, opts = opts, ...)
} else {
# The check above should ensure we have default `type` and `opts` and
# empty `...`; don't forward these default `type` and `opts`, to avoid
# upsetting some slather method validation.
components <- slather(la, components, workflow, new_data)
}
}

return(components)
Expand Down
1 change: 1 addition & 0 deletions R/layer_add_forecast_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ layer_add_forecast_date_new <- function(forecast_date, id) {

#' @export
slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
if (is.null(object$forecast_date)) {
max_time_value <- as.Date(max(
workflows::extract_preprocessor(workflow)$max_time_value,
Expand Down
1 change: 1 addition & 0 deletions R/layer_naomit.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ layer_naomit_new <- function(terms, id) {

#' @export
slather.layer_naomit <- function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
exprs <- rlang::expr(c(!!!object$terms))
pos <- tidyselect::eval_select(exprs, components$predictions)
col_names <- names(pos)
Expand Down
4 changes: 2 additions & 2 deletions R/layer_point_from_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ layer_point_from_distn_new <- function(type, name, id) {
#' @export
slather.layer_point_from_distn <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
dstn <- components$predictions$.pred
if (!inherits(dstn, "distribution")) {
rlang::warn(
c("`layer_point_from_distn` requires distributional predictions.",
i = "These are of class {class(dstn)}. Ignoring this layer."
)
)
)
return(components)
}
rlang::check_dots_empty()

dstn <- match.fun(object$type)(dstn)
if (is.null(object$name)) {
Expand Down
2 changes: 1 addition & 1 deletion R/layer_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ layer_population_scaling_new <-
#' @export
slather.layer_population_scaling <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
stopifnot(
"Only one population column allowed for scaling" =
length(object$df_pop_col) == 1
)
rlang::check_dots_empty()

if (is.null(object$by)) {
object$by <- intersect(
Expand Down
15 changes: 9 additions & 6 deletions R/layer_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ layer_predict <-
id = rand_id("predict_default")) {
arg_is_chr_scalar(id)
arg_is_chr_scalar(type, allow_null = TRUE)
assert_class(opts, "list")
dots_list <- rlang::dots_list(..., .homonyms = "error", .check_assign = TRUE)
if (any(rlang::names2(dots_list) == "")) {
cli_abort("All `...` arguments must be named.",
class = "epipredict__layer_predict__unnamed_dot"
)
class = "epipredict__layer_predict__unnamed_dot"
)
}
add_layer(
frosting,
Expand All @@ -68,16 +69,18 @@ layer_predict_new <- function(type, opts, dots_list, id) {
}

#' @export
slather.layer_predict <- function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
slather.layer_predict <- function(object, components, workflow, new_data, type = NULL, opts = list(), ...) {
arg_is_chr_scalar(type, allow_null = TRUE)
assert_class(opts, "list")

the_fit <- workflows::extract_fit_parsnip(workflow)

components$predictions <- rlang::inject(predict(
the_fit,
components$forged$predictors,
type = object$type, opts = object$opts,
!!!object$dots_list
type = object$type %||% type,
opts = c(object$opts, opts),
!!!object$dots_list, ...
))
components$predictions <- dplyr::bind_cols(
components$keys, components$predictions
Expand Down
1 change: 1 addition & 0 deletions R/layer_predictive_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ layer_predictive_distn_new <- function(dist_type, truncate, name, id) {
slather.layer_predictive_distn <-
function(object, components, workflow, new_data, ...) {
the_fit <- workflows::extract_fit_parsnip(workflow)
rlang::check_dots_empty()

m <- components$predictions$.pred
r <- grab_residuals(the_fit, components)
Expand Down
2 changes: 2 additions & 0 deletions R/layer_quantile_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ slather.layer_quantile_distn <-
"These are of class {.cls {class(dstn)}}."
))
}
rlang::check_dots_empty()

dstn <- dist_quantiles(
quantile(dstn, object$quantile_levels),
object$quantile_levels
Expand Down
2 changes: 2 additions & 0 deletions R/layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ layer_residual_quantiles_new <- function(
#' @export
slather.layer_residual_quantiles <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()

the_fit <- workflows::extract_fit_parsnip(workflow)

if (is.null(object$quantile_levels)) {
Expand Down
1 change: 1 addition & 0 deletions R/layer_threshold_preds.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ snap.dist_quantiles <- function(x, lower, upper, ...) {
#' @export
slather.layer_threshold <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
exprs <- rlang::expr(c(!!!object$terms))
pos <- tidyselect::eval_select(exprs, components$predictions)
col_names <- names(pos)
Expand Down
1 change: 1 addition & 0 deletions R/layer_unnest.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ layer_unnest_new <- function(terms, id) {
#' @export
slather.layer_unnest <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()
exprs <- rlang::expr(c(!!!object$terms))
pos <- tidyselect::eval_select(exprs, components$predictions)
col_names <- names(pos)
Expand Down
1 change: 1 addition & 0 deletions inst/templates/layer.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ layer_{{{ name }}}_new <- function(terms, args, more_args, id) {
#' @export
slather.layer_{{{ name }}} <-
function(object, components, workflow, new_data, ...) {
rlang::check_dots_empty()

# if layer_ used ... in tidyselect, we need to evaluate it now
exprs <- rlang::expr(c(!!!object$terms))
Expand Down
2 changes: 1 addition & 1 deletion man/apply_frosting.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 11 additions & 1 deletion man/predict-epi_workflow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions tests/testthat/test-frosting.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,46 @@ test_that("layer_predict is added by default if missing", {

expect_equal(forecast(wf1), forecast(wf2))
})


test_that("parsnip settings can be passed through predict.epi_workflow", {
jhu <- case_death_rate_subset %>%
dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))

r <- epi_recipe(jhu) %>%
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
step_epi_ahead(death_rate, ahead = 7) %>%
step_epi_naomit()

wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)

latest <- get_test_data(r, jhu)

f1 <- frosting() %>% layer_predict()
f2 <- frosting() %>% layer_predict(type = "pred_int")
f3 <- frosting() %>% layer_predict(type = "pred_int", level = 0.6)

pred2 <- wf %>% add_frosting(f2) %>% predict(latest)
pred3 <- wf %>% add_frosting(f3) %>% predict(latest)

pred2_re <- wf %>% add_frosting(f1) %>% predict(latest, type = "pred_int")
pred3_re <- wf %>% add_frosting(f1) %>% predict(latest, type = "pred_int", level = 0.6)

expect_identical(pred2, pred2_re)
expect_identical(pred3, pred3_re)

f4 <- frosting() %>%
layer_predict() %>%
layer_threshold(.pred, lower = 0)

expect_error(wf %>% add_frosting(f4) %>% predict(latest, type = "pred_int"),
class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers")

# We also refuse to continue when just passing the level, which might not be ideal:
f5 <- frosting() %>%
layer_predict(type = "pred_int") %>%
layer_threshold(.pred_lower, .pred_upper, lower = 0)

expect_error(wf %>% add_frosting(f5) %>% predict(latest, level = 0.6),
class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers")
})
36 changes: 26 additions & 10 deletions tests/testthat/test-layer_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,33 @@ test_that("layer_predict dots validation", {
})

test_that("layer_predict dots are forwarded", {
f_lm_int_level <- frosting() %>%
f_lm_int_level_95 <- frosting() %>%
layer_predict(type = "pred_int")
f_lm_int_level_80 <- frosting() %>%
layer_predict(type = "pred_int", level = 0.8)
wf_lm_int_level <- wf %>% add_frosting(f_lm_int_level)
wf_lm_int_level_95 <- wf %>% add_frosting(f_lm_int_level_95)
wf_lm_int_level_80 <- wf %>% add_frosting(f_lm_int_level_80)
p <- predict(wf, latest)
p_lm_int_level <- predict(wf_lm_int_level, latest)
expect_contains(names(p_lm_int_level), c(".pred_lower", ".pred_upper"))
expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level)))
expect_true(cbind(p, p_lm_int_level[c(".pred_lower", ".pred_upper")]) %>%
na.omit() %>%
mutate(sandwiched = .pred_lower <= .pred & .pred <= .pred_upper) %>%
`[[`("sandwiched") %>%
all())
p_lm_int_level_95 <- predict(wf_lm_int_level_95, latest)
p_lm_int_level_80 <- predict(wf_lm_int_level_80, latest)
expect_contains(names(p_lm_int_level_95), c(".pred_lower", ".pred_upper"))
expect_contains(names(p_lm_int_level_80), c(".pred_lower", ".pred_upper"))
expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level_95)))
expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level_80)))
expect_true(
cbind(
p,
p_lm_int_level_95 %>% dplyr::select(.pred_lower_95 = .pred_lower, .pred_upper_95 = .pred_upper),
p_lm_int_level_80 %>% dplyr::select(.pred_lower_80 = .pred_lower, .pred_upper_80 = .pred_upper)
) %>%
na.omit() %>%
mutate(sandwiched =
.pred_lower_95 <= .pred_lower_80 &
.pred_lower_80 <= .pred &
.pred <= .pred_upper_80 &
.pred_upper_80 <= .pred_upper_95) %>%
`[[`("sandwiched") %>%
all()
)
# There are many possible other valid configurations that aren't tested here.
})

0 comments on commit 1c9b308

Please sign in to comment.