Skip to content

Commit

Permalink
Merge pull request #241 from cmu-delphi/240-quantile-pivot
Browse files Browse the repository at this point in the history
240 quantile pivot
  • Loading branch information
dajmcdon authored Oct 5, 2023
2 parents 015b0ea + 8d1e47d commit b2d1e11
Show file tree
Hide file tree
Showing 58 changed files with 1,451 additions and 223 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
^musings$
^data-raw$
^vignettes/articles$
^.git-blame-ignore-revs$
4 changes: 2 additions & 2 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# Created with usethis + edited to use API key.
on:
push:
branches: [main, master]
branches: [main, master, v0.0.6]
pull_request:
branches: [main, master]
branches: [main, master, v0.0.6]

name: R-CMD-check

Expand Down
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Imports:
generics,
glue,
hardhat (>= 1.3.0),
lifecycle,
magrittr,
methods,
quantreg,
Expand Down
10 changes: 9 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ S3method(print,alist)
S3method(print,arx_class)
S3method(print,arx_fcast)
S3method(print,canned_epipred)
S3method(print,cdc_baseline_fcast)
S3method(print,epi_workflow)
S3method(print,flat_fcast)
S3method(print,flatline)
Expand Down Expand Up @@ -79,6 +80,7 @@ S3method(residuals,flatline)
S3method(run_mold,default_epi_recipe_blueprint)
S3method(slather,layer_add_forecast_date)
S3method(slather,layer_add_target_date)
S3method(slather,layer_cdc_flatline_quantiles)
S3method(slather,layer_naomit)
S3method(slather,layer_point_from_distn)
S3method(slather,layer_population_scaling)
Expand Down Expand Up @@ -106,6 +108,8 @@ export(arx_classifier)
export(arx_fcast_epi_workflow)
export(arx_forecaster)
export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
export(create_layer)
export(default_epi_recipe_blueprint)
export(detect_layer)
Expand All @@ -131,6 +135,7 @@ export(is_layer)
export(layer)
export(layer_add_forecast_date)
export(layer_add_target_date)
export(layer_cdc_flatline_quantiles)
export(layer_naomit)
export(layer_point_from_distn)
export(layer_population_scaling)
Expand All @@ -143,7 +148,8 @@ export(layer_unnest)
export(nested_quantiles)
export(new_default_epi_recipe_blueprint)
export(new_epi_recipe_blueprint)
export(pivot_quantiles)
export(pivot_quantiles_longer)
export(pivot_quantiles_wider)
export(prep)
export(quantile_reg)
export(remove_frosting)
Expand All @@ -167,6 +173,7 @@ importFrom(generics,augment)
importFrom(generics,fit)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(lifecycle,deprecated)
importFrom(magrittr,"%>%")
importFrom(methods,is)
importFrom(quantreg,rq)
Expand All @@ -181,6 +188,7 @@ importFrom(rlang,caller_env)
importFrom(rlang,is_empty)
importFrom(rlang,is_null)
importFrom(rlang,quos)
importFrom(smoothqr,smooth_qr)
importFrom(stats,as.formula)
importFrom(stats,family)
importFrom(stats,lm)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
* canned forecasters get a class
* fixed quantile bug in `flatline_forecaster()`
* add functionality to output the unfit workflow from the canned forecasters
* add `pivot_quantiles()` for easier plotting
* add `pivot_quantiles_wider()` for easier plotting
* add complement `pivot_quantiles_longer()`


# epipredict 0.0.4
Expand Down
228 changes: 228 additions & 0 deletions R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
#' Predict the future with the most recent value
#'
#' This is a simple forecasting model for
#' [epiprocess::epi_df] data. It uses the most recent observation as the
#' forecast for any future date, and produces intervals by shuffling the quantiles
#' of the residuals of such a "flatline" forecast and incrementing these
#' forward over all available training data.
#'
#' By default, the predictive intervals are computed separately for each
#' combination of `geo_value` in the `epi_data` argument.
#'
#' This forecaster is meant to produce exactly the CDC Baseline used for
#' [COVID19ForecastHub](https://covid19forecasthub.org)
#'
#' @param epi_data An [epiprocess::epi_df]
#' @param outcome A scalar character for the column name we wish to predict.
#' @param args_list A list of additional arguments as created by the
#' [cdc_baseline_args_list()] constructor function.
#'
#' @return A data frame of point and interval forecasts at for all
#' aheads (unique horizons) for each unique combination of `key_vars`.
#' @export
#'
#' @examples
#' library(dplyr)
#' weekly_deaths <- case_death_rate_subset %>%
#' select(geo_value, time_value, death_rate) %>%
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
#' mutate(deaths = pmax(death_rate / 1e5 * pop, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
#' ungroup() %>%
#' filter(weekdays(time_value) == "Saturday")
#'
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn)
#'
#' if (require(ggplot2)) {
#' forecast_date <- unique(preds$forecast_date)
#' four_states <- c("ca", "pa", "wa", "ny")
#' preds %>%
#' filter(geo_value %in% four_states) %>%
#' ggplot(aes(target_date)) +
#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) +
#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) +
#' geom_line(aes(y = .pred), color = "orange") +
#' geom_line(
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
#' aes(x = time_value, y = deaths)
#' ) +
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
#' labs(x = "Date", y = "Weekly deaths") +
#' facet_wrap(~geo_value, scales = "free_y") +
#' theme_bw() +
#' geom_vline(xintercept = forecast_date)
#' }
cdc_baseline_forecaster <- function(
epi_data,
outcome,
args_list = cdc_baseline_args_list()) {
validate_forecaster_inputs(epi_data, outcome, "time_value")
if (!inherits(args_list, c("cdc_flat_fcast", "alist"))) {
cli_stop("args_list was not created using `cdc_baseline_args_list().")
}
keys <- epi_keys(epi_data)
ek <- kill_time_value(keys)
outcome <- rlang::sym(outcome)


r <- epi_recipe(epi_data) %>%
step_epi_ahead(!!outcome, ahead = args_list$data_frequency, skip = TRUE) %>%
recipes::update_role(!!outcome, new_role = "predictor") %>%
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>%
step_training_window(n_recent = args_list$n_training)

forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
# target_date <- args_list$target_date %||% forecast_date + args_list$ahead


latest <- get_test_data(
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
forecast_date
)

f <- frosting() %>%
layer_predict() %>%
layer_cdc_flatline_quantiles(
aheads = args_list$aheads,
quantile_levels = args_list$quantile_levels,
nsims = args_list$nsims,
by_key = args_list$quantile_by_key,
symmetrize = args_list$symmetrize,
nonneg = args_list$nonneg
) %>%
layer_add_forecast_date(forecast_date = forecast_date) %>%
layer_unnest(.pred_distn_all)
# layer_add_target_date(target_date = target_date)
if (args_list$nonneg) f <- layer_threshold(f, ".pred")

eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")

wf <- epi_workflow(r, eng, f)
wf <- generics::fit(wf, epi_data)
preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
tibble::as_tibble() %>%
dplyr::select(-time_value) %>%
dplyr::mutate(target_date = forecast_date + ahead * args_list$data_frequency)

structure(
list(
predictions = preds,
epi_workflow = wf,
metadata = list(
training = attr(epi_data, "metadata"),
forecast_created = Sys.time()
)
),
class = c("cdc_baseline_fcast", "canned_epipred")
)
}



#' CDC baseline forecaster argument constructor
#'
#' Constructs a list of arguments for [cdc_baseline_forecaster()].
#'
#' @inheritParams arx_args_list
#' @param data_frequency Integer or string. This describes the frequency of the
#' input `epi_df`. For typical FluSight forecasts, this would be `"1 week"`.
#' Allowable arguments are integers (taken to mean numbers of days) or a
#' string like `"7 days"` or `"2 weeks"`. Currently, all other periods
#' (other than days or weeks) result in an error.
#' @param aheads Integer vector. Unlike [arx_forecaster()], this doesn't have
#' any effect on the predicted values.
#' Predictions are always the most recent observation. This determines the
#' set of prediction horizons for [layer_cdc_flatline_quantiles()]`. It interacts
#' with the `data_frequency` argument. So, for example, if the data is daily
#' and you want forecasts for 1:4 days ahead, then you would use `1:4`. However,
#' if you want one-week predictions, you would set this as `c(7, 14, 21, 28)`.
#' But if `data_frequency` is `"1 week"`, then you would set it as `1:4`.
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
#' prediction intervals. These are created by computing the quantiles of
#' training residuals. A `NULL` value will result in point forecasts only.
#' @param nsims Positive integer. The number of draws from the empirical CDF.
#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting
#' in linear interpolation on the X scale. This is achieved with
#' [stats::quantile()] Type 7 (the default for that function).
#' @param nonneg Logical. Force all predictive intervals be non-negative.
#' Because non-negativity is forced _before_ propagating forward, this
#' has slightly different behaviour than would occur if using
#' [layer_threshold()].
#'
#' @return A list containing updated parameter choices with class `cdc_flat_fcast`.
#' @export
#'
#' @examples
#' cdc_baseline_args_list()
#' cdc_baseline_args_list(symmetrize = FALSE)
#' cdc_baseline_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
cdc_baseline_args_list <- function(
data_frequency = "1 week",
aheads = 1:4,
n_training = Inf,
forecast_date = NULL,
quantile_levels = c(.01, .025, 1:19 / 20, .975, .99),
nsims = 1e3L,
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = "geo_value",
nafill_buffer = Inf) {
arg_is_scalar(n_training, nsims, data_frequency)
data_frequency <- parse_period(data_frequency)
arg_is_pos_int(data_frequency)
arg_is_chr(quantile_by_key, allow_empty = TRUE)
arg_is_scalar(forecast_date, allow_null = TRUE)
arg_is_date(forecast_date, allow_null = TRUE)
arg_is_nonneg_int(aheads, nsims)
arg_is_lgl(symmetrize, nonneg)
arg_is_probabilities(quantile_levels, allow_null = TRUE)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)

structure(
enlist(
data_frequency,
aheads,
n_training,
forecast_date,
quantile_levels,
nsims,
symmetrize,
nonneg,
quantile_by_key,
nafill_buffer
),
class = c("cdc_baseline_fcast", "alist")
)
}

#' @export
print.cdc_baseline_fcast <- function(x, ...) {
name <- "CDC Baseline"
NextMethod(name = name, ...)
}

parse_period <- function(x) {
arg_is_scalar(x)
if (is.character(x)) {
x <- unlist(strsplit(x, " "))
if (length(x) == 1L) x <- as.numeric(x)
if (length(x) == 2L) {
mult <- substr(x[2], 1, 3)
mult <- switch(
mult,
day = 1L,
wee = 7L,
cli::cli_abort("incompatible timespan in `aheads`.")
)
x <- as.numeric(x[1]) * mult
}
if (length(x) > 2L) cli::cli_abort("incompatible timespan in `aheads`.")
}
stopifnot(rlang::is_integerish(x))
as.integer(x)
}
5 changes: 5 additions & 0 deletions R/compat-purrr.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ map_chr <- function(.x, .f, ...) {
.rlang_purrr_map_mold(.x, .f, character(1), ...)
}

map_vec <- function(.x, .f, ...) {
out <- map(.x, .f, ...)
vctrs::list_unchop(out)
}

map_dfr <- function(.x, .f, ..., .id = NULL) {
.f <- rlang::as_function(.f, env = rlang::global_env())
res <- map(.x, .f, ...)
Expand Down
Loading

0 comments on commit b2d1e11

Please sign in to comment.