Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate epi_recipe() in favour of recipe() #370

Merged
merged 16 commits into from
Sep 20, 2024
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.20
Version: 0.1.0
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand All @@ -23,7 +23,7 @@ URL: https://github.com/cmu-delphi/epipredict/,
https://cmu-delphi.github.io/epipredict
BugReports: https://github.com/cmu-delphi/epipredict/issues/
Depends:
epiprocess (>= 0.7.5),
epiprocess (>= 0.7.12),
parsnip (>= 1.0.0),
R (>= 3.5.0)
Imports:
Expand Down
9 changes: 4 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ S3method(bake,step_training_window)
S3method(detect_layer,frosting)
S3method(detect_layer,workflow)
S3method(epi_recipe,default)
S3method(epi_recipe,epi_df)
S3method(epi_recipe,formula)
S3method(extract_argument,epi_workflow)
S3method(extract_argument,frosting)
S3method(extract_argument,layer)
Expand Down Expand Up @@ -96,6 +94,8 @@ S3method(print,step_naomit)
S3method(print,step_population_scaling)
S3method(print,step_training_window)
S3method(quantile,dist_quantiles)
S3method(recipe,epi_df)
S3method(recipes::recipe,formula)
S3method(refresh_blueprint,default_epi_recipe_blueprint)
S3method(residuals,flatline)
S3method(run_mold,default_epi_recipe_blueprint)
Expand Down Expand Up @@ -152,7 +152,6 @@ export(default_epi_recipe_blueprint)
export(detect_layer)
export(dist_quantiles)
export(epi_recipe)
export(epi_recipe_blueprint)
export(epi_workflow)
export(extract_argument)
export(extract_frosting)
Expand Down Expand Up @@ -183,13 +182,12 @@ export(layer_residual_quantiles)
export(layer_threshold)
export(layer_unnest)
export(nested_quantiles)
export(new_default_epi_recipe_blueprint)
export(new_epi_recipe_blueprint)
export(pivot_quantiles_longer)
export(pivot_quantiles_wider)
export(prep)
export(quantile_reg)
export(rand_id)
export(recipe)
export(remove_epi_recipe)
export(remove_frosting)
export(remove_model)
Expand Down Expand Up @@ -264,6 +262,7 @@ importFrom(magrittr,"%>%")
importFrom(recipes,bake)
importFrom(recipes,prep)
importFrom(recipes,rand_id)
importFrom(recipes,recipe)
importFrom(rlang,"!!!")
importFrom(rlang,"!!")
importFrom(rlang,"%@%")
Expand Down
4 changes: 2 additions & 2 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#' be real-valued. Conversion of this data to unordered classes is handled
#' internally based on the `breaks` argument to [arx_class_args_list()].
#' If discrete classes are already in the `epi_df`, it is recommended to
#' code up a classifier from scratch using [epi_recipe()].
#' code up a classifier from scratch using [recipe()].
#' @param trainer A `{parsnip}` model describing the type of estimation.
#' For now, we enforce `mode = "classification"`. Typical values are
#' [parsnip::logistic_reg()] or [parsnip::multinom_reg()]. More complicated
Expand Down Expand Up @@ -129,7 +129,7 @@ arx_class_epi_workflow <- function(

# --- preprocessor
# ------- predictors
r <- epi_recipe(epi_data) %>%
r <- recipe(epi_data) %>%
step_growth_rate(
dplyr::all_of(predictors),
role = "grp",
Expand Down
2 changes: 1 addition & 1 deletion R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ arx_fcast_epi_workflow <- function(
lags <- arx_lags_validator(predictors, args_list$lags)

# --- preprocessor
r <- epi_recipe(epi_data)
r <- recipe(epi_data)
for (l in seq_along(lags)) {
p <- predictors[l]
r <- step_epi_lag(r, !!p, lag = lags[[l]])
Expand Down
9 changes: 6 additions & 3 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ ggplot2::autoplot
#' jhu <- case_death_rate_subset %>%
#' filter(time_value >= as.Date("2021-11-01"))
#'
#' r <- epi_recipe(jhu) %>%
#' r <- recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = 7) %>%
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
Expand All @@ -56,7 +56,7 @@ ggplot2::autoplot
#' # ------- Show multiple horizons
#'
#' p <- lapply(c(7, 14, 21, 28), function(h) {
#' r <- epi_recipe(jhu) %>%
#' r <- recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = h) %>%
#' step_epi_lag(case_rate, lag = c(0, 7, 14)) %>%
Expand Down Expand Up @@ -184,7 +184,10 @@ autoplot.epi_workflow <- function(
}

if (".pred" %in% names(predictions)) {
ntarget_dates <- n_distinct(predictions$time_value)
ntarget_dates <- dplyr::n_distinct(predictions$time_value)
if (distributional::is_distribution(predictions$.pred)) {
predictions <- dplyr::mutate(predictions, .pred = median(.pred))
}
if (ntarget_dates > 1L) {
bp <- bp +
geom_line(
Expand Down
144 changes: 51 additions & 93 deletions R/blueprint-epi_recipe-default.R
Original file line number Diff line number Diff line change
@@ -1,111 +1,69 @@
#' Recipe blueprint that accounts for `epi_df` panel data
#'
#' Used for simplicity. See [hardhat::new_recipe_blueprint()] or
#' [hardhat::default_recipe_blueprint()] for more details.
#'
#' @inheritParams hardhat::new_recipe_blueprint
#' Default epi_recipe blueprint
#'
#' @details The `bake_dependent_roles` are automatically set to `epi_df` defaults.
#' @return A recipe blueprint.
#' Recipe blueprint that accounts for `epi_df` panel data
#' Used for simplicity. See [hardhat::default_recipe_blueprint()] for more
#' details. This subclass is nearly the same, except it ensures that
#' downstream processing doesn't drop the epi_df class from the data.
#'
#' @keywords internal
#' @inheritParams hardhat::default_recipe_blueprint
#' @return A `epi_recipe` blueprint.
#' @export
new_epi_recipe_blueprint <-
function(intercept = FALSE, allow_novel_levels = FALSE, fresh = TRUE,
composition = "tibble",
ptypes = NULL, recipe = NULL, ..., subclass = character()) {
hardhat::new_recipe_blueprint(
intercept = intercept,
allow_novel_levels = allow_novel_levels,
fresh = fresh,
composition = composition,
ptypes = ptypes,
recipe = recipe,
...,
subclass = c(subclass, "epi_recipe_blueprint")
)
}


#' @rdname new_epi_recipe_blueprint
#' @export
epi_recipe_blueprint <-
function(intercept = FALSE, allow_novel_levels = FALSE,
fresh = TRUE,
composition = "tibble") {
new_epi_recipe_blueprint(
intercept = intercept,
allow_novel_levels = allow_novel_levels,
fresh = fresh,
composition = composition
)
}
#' @keywords internal
default_epi_recipe_blueprint <- function(intercept = FALSE,
allow_novel_levels = FALSE,
fresh = TRUE,
strings_as_factors = FALSE,
composition = "tibble") {
new_default_epi_recipe_blueprint(
intercept = intercept,
allow_novel_levels = allow_novel_levels,
fresh = fresh,
strings_as_factors = strings_as_factors,
composition = composition
)
}

#' @rdname new_epi_recipe_blueprint
#' @export
default_epi_recipe_blueprint <-
function(intercept = FALSE, allow_novel_levels = FALSE, fresh = TRUE,
composition = "tibble") {
new_default_epi_recipe_blueprint(
intercept = intercept,
allow_novel_levels = allow_novel_levels,
fresh = fresh,
composition = composition
)
}
new_default_epi_recipe_blueprint <- function(intercept = FALSE,
allow_novel_levels = TRUE,
fresh = TRUE,
strings_as_factors = FALSE,
composition = "tibble",
ptypes = NULL,
recipe = NULL,
extra_role_ptypes = NULL,
...,
subclass = character()) {
hardhat::new_recipe_blueprint(
intercept = intercept,
allow_novel_levels = allow_novel_levels,
fresh = fresh,
strings_as_factors = strings_as_factors,
composition = composition,
ptypes = ptypes,
recipe = recipe,
extra_role_ptypes = extra_role_ptypes,
...,
subclass = c(subclass, "default_epi_recipe_blueprint", "default_recipe_blueprint")
)
}

#' @rdname new_epi_recipe_blueprint
#' @inheritParams hardhat::new_default_recipe_blueprint
#' @export
new_default_epi_recipe_blueprint <-
function(intercept = FALSE, allow_novel_levels = FALSE,
fresh = TRUE,
composition = "tibble", ptypes = NULL, recipe = NULL,
extra_role_ptypes = NULL, ..., subclass = character()) {
new_epi_recipe_blueprint(
intercept = intercept,
allow_novel_levels = allow_novel_levels,
fresh = fresh,
composition = composition,
ptypes = ptypes,
recipe = recipe,
extra_role_ptypes = extra_role_ptypes,
...,
subclass = c(subclass, "default_epi_recipe_blueprint", "default_recipe_blueprint")
)
}

#' @importFrom hardhat run_mold
#' @export
run_mold.default_epi_recipe_blueprint <- function(blueprint, ..., data) {
rlang::check_dots_empty0(...)
# blueprint <- hardhat:::patch_recipe_default_blueprint(blueprint)
cleaned <- mold_epi_recipe_default_clean(blueprint = blueprint, data = data)
blueprint <- cleaned$blueprint
data <- cleaned$data
# we don't do the "cleaning" in `hardhat:::run_mold.default_recipe_blueprint`
# That function drops the epi_df class without any recourse.
# The only way we should be here at all is if `data` is an epi_df, but just
# in case...
if (!is_epi_df(data)) {
cli_warn("`data` is not an {.cls epi_df}. It has class {.cls {class(data)}}.")
}
hardhat:::mold_recipe_default_process(blueprint = blueprint, data = data)
}

mold_epi_recipe_default_clean <- function(blueprint, data) {
hardhat:::check_data_frame_or_matrix(data)
if (!is_epi_df(data)) data <- hardhat:::coerce_to_tibble(data)
hardhat:::new_mold_clean(blueprint, data)
}

#' @importFrom hardhat refresh_blueprint
#' @export
refresh_blueprint.default_epi_recipe_blueprint <- function(blueprint) {
do.call(new_default_epi_recipe_blueprint, as.list(blueprint))
}


## removing this function?
# er_check_is_data_like <- function(.x, .x_nm) {
# if (rlang::is_missing(.x_nm)) {
# .x_nm <- rlang::as_label(rlang::enexpr(.x))
# }
# if (!hardhat:::is_new_data_like(.x)) {
# hardhat:::glubort("`{.x_nm}` must be a data.frame or a matrix, not a {class1(.x)}.")
# }
# .x
# }
42 changes: 21 additions & 21 deletions R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,25 @@
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn)
#'
#' if (require(ggplot2)) {
#' forecast_date <- unique(preds$forecast_date)
#' four_states <- c("ca", "pa", "wa", "ny")
#' preds %>%
#' filter(geo_value %in% four_states) %>%
#' ggplot(aes(target_date)) +
#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) +
#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) +
#' geom_line(aes(y = .pred), color = "orange") +
#' geom_line(
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
#' aes(x = time_value, y = deaths)
#' ) +
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
#' labs(x = "Date", y = "Weekly deaths") +
#' facet_wrap(~geo_value, scales = "free_y") +
#' theme_bw() +
#' geom_vline(xintercept = forecast_date)
#' }
#' library(ggplot2)
#' forecast_date <- unique(preds$forecast_date)
#' four_states <- c("ca", "pa", "wa", "ny")
#' preds %>%
#' filter(geo_value %in% four_states) %>%
#' ggplot(aes(target_date)) +
#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) +
#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) +
#' geom_line(aes(y = .pred), color = "orange") +
#' geom_line(
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
#' aes(x = time_value, y = deaths)
#' ) +
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
#' labs(x = "Date", y = "Weekly deaths") +
#' facet_wrap(~geo_value, scales = "free_y") +
#' theme_bw() +
#' geom_vline(xintercept = forecast_date)
#'
cdc_baseline_forecaster <- function(
epi_data,
outcome,
Expand All @@ -68,7 +68,7 @@ cdc_baseline_forecaster <- function(
outcome <- rlang::sym(outcome)


r <- epi_recipe(epi_data) %>%
r <- recipe(epi_data) %>%
step_epi_ahead(!!outcome, ahead = args_list$data_frequency, skip = TRUE) %>%
recipes::update_role(!!outcome, new_role = "predictor") %>%
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>%
Expand All @@ -79,7 +79,7 @@ cdc_baseline_forecaster <- function(


latest <- get_test_data(
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
forecast_date
)

Expand Down
Loading
Loading