Skip to content

Commit

Permalink
Merge pull request #386 from cmu-delphi/ds/epiprocess-0.9.0
Browse files Browse the repository at this point in the history
fix: update for compatibility with epiprocess==0.9.0
  • Loading branch information
dshemetov authored Sep 27, 2024
2 parents cd12775 + 374cb2f commit 34cb6ed
Show file tree
Hide file tree
Showing 29 changed files with 200 additions and 191 deletions.
3 changes: 2 additions & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
^DEVELOPMENT\.md$
^doc$
^Meta$
^.lintr$
^.lintr$
^.venv$
5 changes: 2 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.20
Version: 0.0.21
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand All @@ -23,8 +23,7 @@ URL: https://github.com/cmu-delphi/epipredict/,
https://cmu-delphi.github.io/epipredict
BugReports: https://github.com/cmu-delphi/epipredict/issues/
Depends:
epiprocess (>= 0.8.0),
epiprocess (< 0.9.0),
epiprocess (>= 0.9.0),
parsnip (>= 1.0.0),
R (>= 3.5.0)
Imports:
Expand Down
2 changes: 1 addition & 1 deletion R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ autoplot.epi_workflow <- function(
if (length(extra_keys) == 0L) extra_keys <- NULL
edf <- as_epi_df(edf,
as_of = object$fit$meta$as_of,
additional_metadata = list(other_keys = extra_keys)
other_keys = extra_keys %||% character()
)
if (is.null(predictions)) {
return(autoplot(
Expand Down
6 changes: 3 additions & 3 deletions R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
#' epi_slide(~ sum(.$deaths), .window_size = 7, .new_col_name = "deaths_7dsum") %>%
#' ungroup() %>%
#' filter(weekdays(time_value) == "Saturday")
#'
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths_7dsum")
#' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn)
#'
#' if (require(ggplot2)) {
Expand All @@ -47,7 +47,7 @@
#' geom_line(aes(y = .pred), color = "orange") +
#' geom_line(
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
#' aes(x = time_value, y = deaths)
#' aes(x = time_value, y = deaths_7dsum)
#' ) +
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
#' labs(x = "Date", y = "Weekly deaths") +
Expand Down
12 changes: 7 additions & 5 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ epi_recipe.epi_df <-
keys <- key_colnames(x) # we know x is an epi_df

var_info <- tibble(variable = vars)
key_roles <- c("geo_value", "time_value", rep("key", length(keys) - 2))
key_roles <- c("geo_value", rep("key", length(keys) - 2), "time_value")

## Check and add roles when available
if (!is.null(roles)) {
Expand Down Expand Up @@ -499,8 +499,11 @@ prep.epi_recipe <- function(
if (!is_epi_df(training)) {
# tidymodels killed our class
# for now, we only allow step_epi_* to alter the metadata
training <- dplyr::dplyr_reconstruct(
as_epi_df(training), before_template
metadata <- attr(before_template, "metadata")
training <- as_epi_df(
training,
as_of = metadata$as_of,
other_keys = metadata$other_keys %||% character()
)
}
training <- dplyr::relocate(training, all_of(key_colnames(training)))
Expand Down Expand Up @@ -579,8 +582,7 @@ bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") {
new_data <- as_epi_df(
new_data,
as_of = meta$as_of,
# avoid NULL if meta is from saved older epi_df:
additional_metadata = meta$additional_metadata %||% list()
other_keys = meta$other_keys %||% character()
)
}
new_data
Expand Down
3 changes: 2 additions & 1 deletion R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ is_epi_workflow <- function(x) {
fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()) {
object$fit$meta <- list(
max_time_value = max(data$time_value),
as_of = attributes(data)$metadata$as_of
as_of = attr(data, "metadata")$as_of,
other_keys = attr(data, "metadata")$other_keys
)
object$original_data <- data

Expand Down
4 changes: 2 additions & 2 deletions R/flusight_hub_formatter.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ abbr_to_location <- function(abbr) {
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
#' epi_slide(~ sum(.$deaths), .window_size = 7, .new_col_name = "deaths_7dsum") %>%
#' ungroup() %>%
#' filter(weekdays(time_value) == "Saturday")
#'
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths_7dsum")
#' flusight_hub_formatter(cdc)
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths")
#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths"))
Expand Down
15 changes: 8 additions & 7 deletions R/key_colnames.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
#' @export
key_colnames.recipe <- function(x, ...) {
possible_keys <- c("geo_value", "time_value", "key")
keys <- x$var_info$variable[x$var_info$role %in% possible_keys]
keys[order(match(keys, possible_keys))] %||% character(0L)
geo_key <- x$var_info$variable[x$var_info$role %in% "geo_value"]
time_key <- x$var_info$variable[x$var_info$role %in% "time_value"]
keys <- x$var_info$variable[x$var_info$role %in% "key"]
c(geo_key, keys, time_key) %||% character(0L)
}

#' @export
key_colnames.epi_workflow <- function(x, ...) {
# safer to look at the mold than the preprocessor
mold <- hardhat::extract_mold(x)
possible_keys <- c("geo_value", "time_value", "key")
molded_names <- names(mold$extras$roles)
keys <- map(mold$extras$roles[molded_names %in% possible_keys], names)
keys <- unname(unlist(keys))
keys[order(match(keys, possible_keys))] %||% character(0L)
geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value)
time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value)
keys <- names(mold$extras$roles[molded_names %in% "key"]$key)
c(geo_key, keys, time_key) %||% character(0L)
}

kill_time_value <- function(v) {
Expand Down
1 change: 1 addition & 0 deletions R/layer_add_forecast_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ slather.layer_add_forecast_date <- function(object, components, workflow,
workflows::extract_preprocessor(workflow)$template, "metadata"
)$time_type
if (expected_time_type == "week") expected_time_type <- "day"
if (expected_time_type == "integer") expected_time_type <- "year"
validate_date(
forecast_date, expected_time_type,
call = rlang::expr(layer_add_forecast_date())
Expand Down
1 change: 1 addition & 0 deletions R/layer_add_target_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ slather.layer_add_target_date <- function(object, components, workflow,
workflows::extract_preprocessor(workflow)$template, "metadata"
)$time_type
if (expected_time_type == "week") expected_time_type <- "day"
if (expected_time_type == "integer") expected_time_type <- "year"

if (!is.null(object$target_date)) {
target_date <- object$target_date
Expand Down
121 changes: 64 additions & 57 deletions R/step_epi_slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@
#' argument must be named `.x`. A common, though very difficult to debug
#' error is using something like `function(x) mean`. This will not work
#' because it returns the function mean, rather than `mean(x)`
#' @param before,after the size of the sliding window on the left and the right
#' of the center. Usually non-negative integers for data indexed by date, but
#' more restrictive in other cases (see [epiprocess::epi_slide()] for details).
#' @param f_name a character string of at most 20 characters that describes
#' the function. This will be combined with `prefix` and the columns in `...`
#' to name the result using `{prefix}{f_name}_{column}`. By default it will be determined
#' automatically using `clean_f_name()`.
#' @param .window_size the size of the sliding window, required. Usually a
#' non-negative integer will suffice (e.g. for data indexed by date, but more
#' restrictive in other time_type cases (see [epiprocess::epi_slide()] for
#' details). For example, set to 7 for a 7-day window.
#' @param .align a character string indicating how the window should be aligned.
#' By default, this is "right", meaning the slide_window will be anchored with
#' its right end point on the reference date. (see [epiprocess::epi_slide()]
#' for details).
#' @param f_name a character string of at most 20 characters that describes the
#' function. This will be combined with `prefix` and the columns in `...` to
#' name the result using `{prefix}{f_name}_{column}`. By default it will be
#' determined automatically using `clean_f_name()`.
#'
#' @template step-return
#'
Expand All @@ -37,53 +42,55 @@
#' rec <- epi_recipe(jhu) %>%
#' step_epi_slide(case_rate, death_rate,
#' .f = \(x) mean(x, na.rm = TRUE),
#' before = 6L
#' .window_size = 7L
#' )
#' bake(prep(rec, jhu), new_data = NULL)
step_epi_slide <-
function(recipe,
...,
.f,
before = 0L,
after = 0L,
role = "predictor",
prefix = "epi_slide_",
f_name = clean_f_name(.f),
skip = FALSE,
id = rand_id("epi_slide")) {
if (!is_epi_recipe(recipe)) {
cli_abort("This recipe step can only operate on an {.cls epi_recipe}.")
}
.f <- validate_slide_fun(.f)
epiprocess:::validate_slide_window_arg(before, attributes(recipe$template)$metadata$time_type)
epiprocess:::validate_slide_window_arg(after, attributes(recipe$template)$metadata$time_type)
arg_is_chr_scalar(role, prefix, id)
arg_is_lgl_scalar(skip)
step_epi_slide <- function(recipe,
...,
.f,
.window_size = NULL,
.align = c("right", "center", "left"),
role = "predictor",
prefix = "epi_slide_",
f_name = clean_f_name(.f),
skip = FALSE,
id = rand_id("epi_slide")) {
if (!is_epi_recipe(recipe)) {
cli_abort("This recipe step can only operate on an {.cls epi_recipe}.")
}
.f <- validate_slide_fun(.f)
if (is.null(.window_size)) {
cli_abort("step_epi_slide: `.window_size` must be specified.")
}
epiprocess:::validate_slide_window_arg(.window_size, attributes(recipe$template)$metadata$time_type)
.align <- rlang::arg_match(.align)
arg_is_chr_scalar(role, prefix, id)
arg_is_lgl_scalar(skip)

recipes::add_step(
recipe,
step_epi_slide_new(
terms = enquos(...),
before = before,
after = after,
.f = .f,
f_name = f_name,
role = role,
trained = FALSE,
prefix = prefix,
keys = key_colnames(recipe),
columns = NULL,
skip = skip,
id = id
)
recipes::add_step(
recipe,
step_epi_slide_new(
terms = enquos(...),
.window_size = .window_size,
.align = .align,
.f = .f,
f_name = f_name,
role = role,
trained = FALSE,
prefix = prefix,
keys = key_colnames(recipe),
columns = NULL,
skip = skip,
id = id
)
}
)
}


step_epi_slide_new <-
function(terms,
before,
after,
.window_size,
.align,
.f,
f_name,
role,
Expand All @@ -96,8 +103,8 @@ step_epi_slide_new <-
recipes::step(
subclass = "epi_slide",
terms = terms,
before = before,
after = after,
.window_size = .window_size,
.align = .align,
.f = .f,
f_name = f_name,
role = role,
Expand All @@ -119,8 +126,8 @@ prep.step_epi_slide <- function(x, training, info = NULL, ...) {

step_epi_slide_new(
terms = x$terms,
before = x$before,
after = x$after,
.window_size = x$.window_size,
.align = x$.align,
.f = x$.f,
f_name = x$f_name,
role = x$role,
Expand Down Expand Up @@ -165,8 +172,8 @@ bake.step_epi_slide <- function(object, new_data, ...) {
# }
epi_slide_wrapper(
new_data,
object$before,
object$after,
object$.window_size,
object$.align,
object$columns,
c(object$.f),
object$f_name,
Expand All @@ -190,7 +197,7 @@ bake.step_epi_slide <- function(object, new_data, ...) {
#' @importFrom dplyr bind_cols group_by ungroup
#' @importFrom epiprocess epi_slide
#' @keywords internal
epi_slide_wrapper <- function(new_data, before, after, columns, fns, fn_names, group_keys, name_prefix) {
epi_slide_wrapper <- function(new_data, .window_size, .align, columns, fns, fn_names, group_keys, name_prefix) {
cols_fns <- tidyr::crossing(col_name = columns, fn_name = fn_names, fn = fns)
# Iterate over the rows of cols_fns. For each row number, we will output a
# transformed column. The first result returns all the original columns along
Expand All @@ -204,10 +211,10 @@ epi_slide_wrapper <- function(new_data, before, after, columns, fns, fn_names, g
result <- new_data %>%
group_by(across(all_of(group_keys))) %>%
epi_slide(
before = before,
after = after,
new_col_name = result_name,
f = function(slice, geo_key, ref_time_value) {
.window_size = .window_size,
.align = .align,
.new_col_name = result_name,
.f = function(slice, geo_key, ref_time_value) {
fn(slice[[col_name]])
}
) %>%
Expand Down
16 changes: 7 additions & 9 deletions R/utils-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,26 @@ check_pname <- function(res, preds, object, newname = NULL) {


grab_forged_keys <- function(forged, workflow, new_data) {
keys <- c("geo_value", "time_value", "key")
forged_roles <- names(forged$extras$roles)
extras <- dplyr::bind_cols(forged$extras$roles[forged_roles %in% keys])
extras <- dplyr::bind_cols(forged$extras$roles[forged_roles %in% c("geo_value", "time_value", "key")])
# 1. these are the keys in the test data after prep/bake
new_keys <- names(extras)
# 2. these are the keys in the training data
old_keys <- key_colnames(workflow)
# 3. these are the keys in the test data as input
new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, keys[1:2]))
new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, c("geo_value", "time_value")))
if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) {
cli::cli_warn(c(
"Not all epi keys that were present in the training data are available",
"in `new_data`. Predictions will have only the available keys."
))
}
if (is_epi_df(new_data)) {
extras <- as_epi_df(extras)
attr(extras, "metadata") <- attr(new_data, "metadata")
} else if (all(keys[1:2] %in% new_keys)) {
l <- list()
if (length(new_keys) > 2) l <- list(other_keys = new_keys[-c(1:2)])
extras <- as_epi_df(extras, additional_metadata = l)
meta <- attr(new_data, "metadata")
extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys %||% character())
} else if (all(c("geo_value", "time_value") %in% new_keys)) {
if (length(new_keys) > 2) other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")]
extras <- as_epi_df(extras, other_keys = other_keys %||% character())
}
extras
}
Expand Down
2 changes: 1 addition & 1 deletion data-raw/grad_employ_subset.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,6 @@ ncol(gemploy)
grad_employ_subset <- gemploy %>%
as_epi_df(
as_of = "2022-07-19",
additional_metadata = list(other_keys = c("age_group", "edu_qual"))
other_keys = c("age_group", "edu_qual")
)
usethis::use_data(grad_employ_subset, overwrite = TRUE)
Binary file modified data/grad_employ_subset.rda
Binary file not shown.
2 changes: 2 additions & 0 deletions man/autoplot-epipred.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 34cb6ed

Please sign in to comment.