Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

352 remove all instances of epi keys #373

Merged
merged 16 commits into from
Aug 29, 2024
Merged
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.19
Version: 0.0.20
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down
37 changes: 28 additions & 9 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ S3method(bake,step_population_scaling)
S3method(bake,step_training_window)
S3method(detect_layer,frosting)
S3method(detect_layer,workflow)
S3method(epi_keys,data.frame)
S3method(epi_keys,default)
S3method(epi_keys,epi_df)
S3method(epi_keys,epi_workflow)
S3method(epi_keys,recipe)
S3method(epi_recipe,default)
S3method(epi_recipe,epi_df)
S3method(epi_recipe,formula)
Expand All @@ -55,6 +50,8 @@ S3method(forecast,epi_workflow)
S3method(format,dist_quantiles)
S3method(is.na,dist_quantiles)
S3method(is.na,distribution)
S3method(key_colnames,epi_workflow)
S3method(key_colnames,recipe)
S3method(mean,dist_quantiles)
S3method(median,dist_quantiles)
S3method(predict,epi_workflow)
Expand Down Expand Up @@ -154,7 +151,6 @@ export(clean_f_name)
export(default_epi_recipe_blueprint)
export(detect_layer)
export(dist_quantiles)
export(epi_keys)
export(epi_recipe)
export(epi_recipe_blueprint)
export(epi_workflow)
Expand All @@ -170,7 +166,6 @@ export(flusight_hub_formatter)
export(forecast)
export(frosting)
export(get_test_data)
export(grab_names)
export(is_epi_recipe)
export(is_epi_workflow)
export(is_layer)
Expand All @@ -194,6 +189,7 @@ export(pivot_quantiles_longer)
export(pivot_quantiles_wider)
export(prep)
export(quantile_reg)
export(rand_id)
export(remove_epi_recipe)
export(remove_frosting)
export(remove_model)
Expand All @@ -207,6 +203,8 @@ export(step_growth_rate)
export(step_lag_difference)
export(step_population_scaling)
export(step_training_window)
export(tibble)
export(tidy)
export(update_epi_recipe)
export(update_frosting)
export(update_model)
Expand All @@ -229,30 +227,50 @@ importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_scalar)
importFrom(cli,cli_abort)
importFrom(cli,cli_warn)
importFrom(dplyr,across)
importFrom(dplyr,all_of)
importFrom(dplyr,any_of)
importFrom(dplyr,arrange)
importFrom(dplyr,bind_cols)
importFrom(dplyr,bind_rows)
importFrom(dplyr,everything)
importFrom(dplyr,filter)
importFrom(dplyr,full_join)
importFrom(dplyr,group_by)
importFrom(dplyr,n)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
importFrom(dplyr,relocate)
importFrom(dplyr,rename)
importFrom(dplyr,select)
importFrom(dplyr,summarise)
importFrom(dplyr,summarize)
importFrom(dplyr,ungroup)
importFrom(epiprocess,epi_slide)
importFrom(epiprocess,growth_rate)
importFrom(generics,augment)
importFrom(generics,fit)
importFrom(generics,forecast)
importFrom(generics,tidy)
importFrom(ggplot2,aes)
importFrom(ggplot2,autoplot)
importFrom(ggplot2,geom_line)
importFrom(ggplot2,geom_linerange)
importFrom(ggplot2,geom_point)
importFrom(ggplot2,geom_ribbon)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(magrittr,"%>%")
importFrom(recipes,bake)
importFrom(recipes,prep)
importFrom(recipes,rand_id)
importFrom(rlang,"!!!")
importFrom(rlang,"!!")
importFrom(rlang,"%@%")
importFrom(rlang,"%||%")
importFrom(rlang,":=")
importFrom(rlang,abort)
importFrom(rlang,arg_match)
importFrom(rlang,as_function)
importFrom(rlang,caller_env)
importFrom(rlang,enquo)
Expand All @@ -264,6 +282,7 @@ importFrom(rlang,is_logical)
importFrom(rlang,is_null)
importFrom(rlang,is_true)
importFrom(rlang,set_names)
importFrom(rlang,sym)
importFrom(stats,as.formula)
importFrom(stats,family)
importFrom(stats,lm)
Expand All @@ -274,9 +293,9 @@ importFrom(stats,predict)
importFrom(stats,qnorm)
importFrom(stats,quantile)
importFrom(stats,residuals)
importFrom(tibble,as_tibble)
importFrom(tibble,tibble)
importFrom(tidyr,crossing)
importFrom(tidyr,drop_na)
importFrom(vctrs,as_list_of)
importFrom(vctrs,field)
importFrom(vctrs,new_rcrd)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- add functionality to calculate weighted interval scores for `dist_quantiles()`
- Add `step_epi_slide` to produce generic sliding computations over an `epi_df`
- Add quantile random forests (via `{grf}`) as a parsnip engine
- Replace `epi_keys()` with `epiprocess::key_colnames()`, #352
72 changes: 35 additions & 37 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
#' @seealso [arx_class_epi_workflow()], [arx_class_args_list()]
#'
#' @examples
#' library(dplyr)
#' jhu <- case_death_rate_subset %>%
#' dplyr::filter(time_value >= as.Date("2021-11-01"))
#' filter(time_value >= as.Date("2021-11-01"))
#'
#' out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate"))
#'
Expand All @@ -45,23 +46,23 @@ arx_classifier <- function(
epi_data,
outcome,
predictors,
trainer = parsnip::logistic_reg(),
trainer = logistic_reg(),
args_list = arx_class_args_list()) {
if (!is_classification(trainer)) {
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
}

wf <- arx_class_epi_workflow(epi_data, outcome, predictors, trainer, args_list)
wf <- generics::fit(wf, epi_data)
wf <- fit(wf, epi_data)

preds <- forecast(
wf,
fill_locf = TRUE,
n_recent = args_list$nafill_buffer,
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
) %>%
tibble::as_tibble() %>%
dplyr::select(-time_value)
as_tibble() %>%
select(-time_value)

structure(
list(
Expand Down Expand Up @@ -95,17 +96,17 @@ arx_classifier <- function(
#' @export
#' @seealso [arx_classifier()]
#' @examples
#'
#' library(dplyr)
#' jhu <- case_death_rate_subset %>%
#' dplyr::filter(time_value >= as.Date("2021-11-01"))
#' filter(time_value >= as.Date("2021-11-01"))
#'
#' arx_class_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate"))
#'
#' arx_class_epi_workflow(
#' jhu,
#' "death_rate",
#' c("case_rate", "death_rate"),
#' trainer = parsnip::multinom_reg(),
#' trainer = multinom_reg(),
#' args_list = arx_class_args_list(
#' breaks = c(-.05, .1), ahead = 14,
#' horizon = 14, method = "linear_reg"
Expand All @@ -119,18 +120,18 @@ arx_class_epi_workflow <- function(
args_list = arx_class_args_list()) {
validate_forecaster_inputs(epi_data, outcome, predictors)
if (!inherits(args_list, c("arx_class", "alist"))) {
rlang::abort("args_list was not created using `arx_class_args_list().")
cli_abort("`args_list` was not created using `arx_class_args_list()`.")
}
if (!(is.null(trainer) || is_classification(trainer))) {
rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.")
cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
}
lags <- arx_lags_validator(predictors, args_list$lags)

# --- preprocessor
# ------- predictors
r <- epi_recipe(epi_data) %>%
step_growth_rate(
tidyselect::all_of(predictors),
dplyr::all_of(predictors),
role = "grp",
horizon = args_list$horizon,
method = args_list$method,
Expand Down Expand Up @@ -173,26 +174,24 @@ arx_class_epi_workflow <- function(
o2 <- rlang::sym(paste0("ahead_", args_list$ahead, "_", o))
r <- r %>%
step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome") %>%
step_mutate(
recipes::step_mutate(
outcome_class = cut(!!o2, breaks = args_list$breaks),
role = "outcome"
) %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training) %>%
{
if (!is.null(args_list$check_enough_data_n)) {
check_enough_train_data(
.,
all_predictors(),
!!outcome,
n = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
} else {
.
}
}
step_training_window(n_recent = args_list$n_training)

if (!is.null(args_list$check_enough_data_n)) {
r <- check_enough_train_data(
r,
recipes::all_predictors(),
recipes::all_outcomes(),
n = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
}


forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
Expand Down Expand Up @@ -264,7 +263,7 @@ arx_class_args_list <- function(
outcome_transform = c("growth_rate", "lag_difference"),
breaks = 0.25,
horizon = 7L,
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
method = c("rel_change", "linear_reg"),
log_scale = FALSE,
additional_gr_args = list(),
nafill_buffer = Inf,
Expand All @@ -274,8 +273,8 @@ arx_class_args_list <- function(
rlang::check_dots_empty()
.lags <- lags
if (is.list(lags)) lags <- unlist(lags)
method <- match.arg(method)
outcome_transform <- match.arg(outcome_transform)
method <- rlang::arg_match(method)
outcome_transform <- rlang::arg_match(outcome_transform)

arg_is_scalar(ahead, n_training, horizon, log_scale)
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
Expand All @@ -287,12 +286,11 @@ arx_class_args_list <- function(
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
if (!is.list(additional_gr_args)) {
cli::cli_abort(
c("`additional_gr_args` must be a {.cls list}.",
"!" = "This is a {.cls {class(additional_gr_args)}}.",
i = "See `?epiprocess::growth_rate` for available arguments."
)
)
cli_abort(c(
"`additional_gr_args` must be a {.cls list}.",
"!" = "This is a {.cls {class(additional_gr_args)}}.",
i = "See `?epiprocess::growth_rate` for available arguments."
))
}
arg_is_pos(check_enough_data_n, allow_null = TRUE)
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)
Expand Down
Loading
Loading