Skip to content

Commit

Permalink
Merge pull request #408 from cmu-delphi/prodFixes
Browse files Browse the repository at this point in the history
Prod fixes
  • Loading branch information
dsweber2 authored Oct 9, 2024
2 parents 4b9fc72 + fb7d6ba commit a415100
Show file tree
Hide file tree
Showing 17 changed files with 162 additions and 45 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.1
Version: 0.1.2
Authors@R: c(
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat

## features
- Add `step_adjust_latency`, which give several methods to adjust the forecast if the `forecast_date` is after the last day of data.
- (temporary) ahead negative is allowed for `step_epi_ahead` until we have `step_epi_shift`

## bugfixes
- shifting no columns results in no error for either `step_epi_ahead` and `step_epi_lag`

# epipredict 0.1

Expand Down
4 changes: 4 additions & 0 deletions R/epi_shift.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ get_sign <- function(object) {
add_shifted_columns <- function(new_data, object) {
grid <- object$shift_grid

if (nrow(object$shift_grid) == 0) {
# we're not shifting any rows, so this is a no-op
return(new_data)
}
## ensure no name clashes
new_data_names <- colnames(new_data)
intersection <- new_data_names %in% grid$newname
Expand Down
1 change: 1 addition & 0 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), .
components$forged <- hardhat::forge(new_data,
blueprint = components$mold$blueprint
)

components$keys <- grab_forged_keys(components$forged, object, new_data)
components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...)
components$predictions
Expand Down
18 changes: 13 additions & 5 deletions R/step_adjust_latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@
#'
#' Note that this is a separate concern from different latencies across
#' different *data columns*, which is only handled by the choice of `method`.
#' @param keys_to_ignore a list of character vectors. Set this to avoid using
#' specific key values in the `epi_keys_checked` to set latency. For example,
#' say you have two locations `pr` and `gu` which have useful training data,
#' but have stopped providing up-to-date information, and so are no longer
#' part of the test set. Setting `keys_to_ignore = list(geo_value = c("pr",
#' "gu"))` will exclude them from the latency calculation.
#' @param fixed_latency either a positive integer, or a labeled positive integer
#' vector. Cannot be set at the same time as `fixed_forecast_date`. If
#' non-`NULL`, the amount to offset the ahead or lag by. If a single integer,
Expand Down Expand Up @@ -203,6 +209,7 @@ step_adjust_latency <-
"extend_lags"
),
epi_keys_checked = NULL,
keys_to_ignore = c(),
fixed_latency = NULL,
fixed_forecast_date = NULL,
check_latency_length = TRUE,
Expand All @@ -228,6 +235,7 @@ step_adjust_latency <-
metadata = NULL,
method = method,
epi_keys_checked = epi_keys_checked,
keys_to_ignore = keys_to_ignore,
check_latency_length = check_latency_length,
columns = NULL,
skip = FALSE,
Expand All @@ -239,7 +247,7 @@ step_adjust_latency <-
step_adjust_latency_new <-
function(terms, role, trained, fixed_forecast_date, forecast_date, latency,
latency_table, latency_sign, metadata, method, epi_keys_checked,
check_latency_length, columns, skip, id) {
keys_to_ignore, check_latency_length, columns, skip, id) {
step(
subclass = "adjust_latency",
terms = terms,
Expand All @@ -253,6 +261,7 @@ step_adjust_latency_new <-
metadata = metadata,
method = method,
epi_keys_checked = epi_keys_checked,
keys_to_ignore = keys_to_ignore,
check_latency_length = check_latency_length,
columns = columns,
skip = skip,
Expand All @@ -271,7 +280,7 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {

latency_table <- get_latency_table(
training, NULL, forecast_date, latency,
get_sign(x), x$epi_keys_checked, info, x$terms
get_sign(x), x$epi_keys_checked, x$keys_to_ignore, info, x$terms
)
# get the columns used, even if it's all of them
terms_used <- x$terms
Expand All @@ -293,6 +302,7 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
metadata = attributes(training)$metadata,
method = x$method,
epi_keys_checked = x$epi_keys_checked,
keys_to_ignore = x$keys_to_ignore,
check_latency_length = x$check_latency_length,
columns = recipes_eval_select(latency_table$col_name, training, info),
skip = x$skip,
Expand All @@ -305,10 +315,8 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
#' @export
bake.step_adjust_latency <- function(object, new_data, ...) {
if (!inherits(new_data, "epi_df") || is.null(attributes(new_data)$metadata$as_of)) {
new_data <- as_epi_df(new_data)
new_data <- as_epi_df(new_data, as_of = object$forecast_date, other_keys = object$metadata$other_keys %||% character())
attributes(new_data)$metadata <- object$metadata
attributes(new_data)$metadata$as_of <- object$forecast_date
} else {
compare_bake_prep_latencies(object, new_data)
}
if (object$method == "locf") {
Expand Down
1 change: 0 additions & 1 deletion R/step_epi_shift.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ step_epi_ahead <-
i = "Did you perhaps pass an integer in `...` accidentally?"
))
}
arg_is_nonneg_int(ahead)
arg_is_chr_scalar(prefix, id)

recipes::add_step(
Expand Down
87 changes: 62 additions & 25 deletions R/utils-latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,18 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns
)
}
}
max_time <- get_max_time(new_data, epi_keys_checked, columns)
# the source data determines the actual time_values
# these are the non-na time_values;
# get the minimum value across the checked epi_keys' maximum time values
max_time <- new_data %>%
select(all_of(columns)) %>%
drop_na()
# null and "" don't work in `group_by`
if (!is.null(epi_keys_checked) && (epi_keys_checked != "")) {
max_time <- max_time %>% group_by(get(epi_keys_checked))
}
max_time <- max_time %>%
summarise(time_value = max(time_value)) %>%
pull(time_value) %>%
min()
if (is.null(latency)) {
forecast_date <- attributes(new_data)$metadata$as_of
} else {
if (is.null(max_time)) {
cli_abort("max_time is null. This likely means there is one of {columns} that is all `NA`")
}
forecast_date <- max_time + latency
}
# make sure the as_of is sane
if (!inherits(forecast_date, class(max_time)) & !inherits(forecast_date, "POSIXt")) {
if (!inherits(forecast_date, class(new_data$time_value)) & !inherits(forecast_date, "POSIXt")) {
cli_abort(
paste(
"the data matrix `forecast_date` value is {forecast_date}, ",
Expand All @@ -84,13 +75,13 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns
if (is.null(forecast_date) || is.na(forecast_date)) {
cli_warn(
paste(
"epi_data's `forecast_date` was {forecast_date}, setting to ",
"the latest time value, {max_time}."
"epi_data's `forecast_date` was `NA`, setting to ",
"the latest non-`NA` time value for these columns, {max_time}."
),
class = "epipredict__get_forecast_date__max_time_warning"
)
forecast_date <- max_time
} else if (forecast_date < max_time) {
} else if (!is.null(max_time) && (forecast_date < max_time)) {
cli_abort(
paste(
"`forecast_date` ({(forecast_date)}) is before the most ",
Expand All @@ -101,22 +92,49 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns
)
}
# TODO cover the rest of the possible types for as_of and max_time...
if (inherits(max_time, "Date")) {
if (inherits(new_data$time_value, "Date")) {
forecast_date <- as.Date(forecast_date)
}
return(forecast_date)
}

get_max_time <- function(new_data, epi_keys_checked, columns) {
# these are the non-na time_values;
# get the minimum value across the checked epi_keys' maximum time values
max_time <- new_data %>%
select(all_of(columns)) %>%
drop_na()
if (nrow(max_time) == 0) {
return(NULL)
}
# null and "" don't work in `group_by`
if (!is.null(epi_keys_checked) && all(epi_keys_checked != "")) {
max_time <- max_time %>% group_by(across(all_of(epi_keys_checked)))
}
max_time <- max_time %>%
summarise(time_value = max(time_value)) %>%
pull(time_value) %>%
min()
return(max_time)
}



#' the latency is also the amount the shift is off by
#' @param sign_shift integer. 1 if lag and -1 if ahead. These represent how you
#' need to shift the data to bring the 3 day lagged value to today.
#' @keywords internal
get_latency <- function(new_data, forecast_date, column, sign_shift, epi_keys_checked) {
shift_max_date <- new_data %>%
drop_na(all_of(column))
if (nrow(shift_max_date) == 0) {
# if everything is an NA, there's infinite latency, but shifting by that is
# untenable. May as well not shift at all
return(0)
}
# null and "" don't work in `group_by`
if (!is.null(epi_keys_checked) && epi_keys_checked != "") {
shift_max_date <- shift_max_date %>% group_by(get(epi_keys_checked))
if (!is.null(epi_keys_checked) && all(epi_keys_checked != "")) {
shift_max_date <- shift_max_date %>% group_by(across(all_of(epi_keys_checked)))
}
shift_max_date <- shift_max_date %>%
summarise(time_value = max(time_value)) %>%
Expand Down Expand Up @@ -290,7 +308,8 @@ check_interminable_latency <- function(dataset, latency_table, target_columns, f
#' @keywords internal
#' @importFrom dplyr rowwise
get_latency_table <- function(training, columns, forecast_date, latency,
sign_shift, epi_keys_checked, info, terms) {
sign_shift, epi_keys_checked, keys_to_ignore,
info, terms) {
if (is.null(columns)) {
columns <- recipes_eval_select(terms, training, info)
}
Expand All @@ -300,12 +319,17 @@ get_latency_table <- function(training, columns, forecast_date, latency,
if (length(columns) > 0) {
latency_table <- latency_table %>% filter(col_name %in% columns)
}

training_dropped <- training %>%
drop_ignored_keys(keys_to_ignore)
if (is.null(latency)) {
latency_table <- latency_table %>%
rowwise() %>%
mutate(latency = get_latency(
training, forecast_date, col_name, sign_shift, epi_keys_checked
training_dropped,
forecast_date,
col_name,
sign_shift,
epi_keys_checked
))
} else if (length(latency) > 1) {
# if latency has a length, it must also have named elements.
Expand All @@ -319,7 +343,7 @@ get_latency_table <- function(training, columns, forecast_date, latency,
latency_table <- latency_table %>%
rowwise() %>%
mutate(latency = get_latency(
training, forecast_date, col_name, sign_shift, epi_keys_checked
training %>% drop_ignored_keys(keys_to_ignore), forecast_date, col_name, sign_shift, epi_keys_checked
))
if (latency) {
latency_table <- latency_table %>% mutate(latency = latency)
Expand All @@ -328,6 +352,19 @@ get_latency_table <- function(training, columns, forecast_date, latency,
return(latency_table %>% ungroup())
}

#' given a list named by key columns, remove any matching key values
#' keys_to_ignore should have the form list(col_name = c("value_to_ignore", "other_value_to_ignore"))
#' @keywords internal
drop_ignored_keys <- function(training, keys_to_ignore) {
# note that the extra parenthesis black magic is described here: https://github.com/tidyverse/dplyr/issues/6194
# and is needed to bypass an incomplete port of `across` functions to `if_any`
training %>%
filter((dplyr::if_all(
names(keys_to_ignore),
~ . %nin% keys_to_ignore[[cur_column()]]
)))
}


#' checks: the recipe type, whether a previous step is the relevant epi_shift,
#' that either `fixed_latency` or `fixed_forecast_date` is non-null, and that
Expand Down Expand Up @@ -394,7 +431,7 @@ compare_bake_prep_latencies <- function(object, new_data, call = caller_env()) {
)
local_latency_table <- get_latency_table(
new_data, object$columns, current_forecast_date, latency,
get_sign(object), object$epi_keys_checked, NULL, NULL
get_sign(object), object$epi_keys_checked, object$keys_to_ignore, NULL, NULL
)
comparison_table <- local_latency_table %>%
ungroup() %>%
Expand Down
14 changes: 14 additions & 0 deletions man/drop_ignored_keys.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/get_latency_table.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions man/step_adjust_latency.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions tests/testthat/_snaps/step_epi_shift.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
r1 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = 3.6) %>%
step_epi_lag(death_rate, lag = 1.9)
Condition
Error in `step_epi_ahead()`:
! `ahead` must be a non-negative integer.
Error in `step_epi_lag()`:
! `lag` must be a non-negative integer.

# A negative lag value should should throw an error

Expand All @@ -21,9 +21,6 @@
Code
r3 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = -7) %>% step_epi_lag(
death_rate, lag = 7)
Condition
Error in `step_epi_ahead()`:
! `ahead` must be a non-negative integer.

# Values for ahead and lag cannot be duplicates

Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-dist_quantiles.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
library(distributional)
suppressPackageStartupMessages(library(distributional))

test_that("constructor returns reasonable quantiles", {
expect_snapshot(error = TRUE, new_quantiles(rnorm(5), c(-2, -1, 0, 1, 2)))
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ test_that("forecast method errors when workflow not fit", {
test_that("fit method does not silently drop the class", {
# This is issue #363

library(recipes)
suppressPackageStartupMessages(library(recipes))
tbl <- tibble::tibble(
geo_value = 1,
time_value = 1:100,
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-grf_quantiles.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
set.seed(12345)
library(grf)
suppressPackageStartupMessages(library(grf))
tib <- tibble(
y = rnorm(100), x = rnorm(100), z = rnorm(100),
f = factor(sample(letters[1:3], 100, replace = TRUE))
Expand Down
4 changes: 0 additions & 4 deletions tests/testthat/test-step_adjust_latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,6 @@ test_that("epi_adjust_latency correctly extends the lags when there are differen
names(fit5$pre$mold$outcomes),
glue::glue("ahead_{ahead}_death_rate")
)
latest <- get_test_data(r5, x)
pred <- predict(fit5, latest)
actual_solutions <- pred %>% filter(!is.na(.pred))
expect_equal(actual_solutions$time_value, testing_as_of + 1)

# should have four predictors, including the intercept
expect_equal(length(fit5$fit$fit$fit$coefficients), 6)
Expand Down
Loading

0 comments on commit a415100

Please sign in to comment.