Skip to content

Commit

Permalink
Merge pull request #245 from cmu-delphi/cdc-baseline
Browse files Browse the repository at this point in the history
Cdc baseline
  • Loading branch information
dajmcdon authored Oct 6, 2023
2 parents dacf6e7 + bdbd3ee commit 3ef79b9
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 31 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ S3method(extrapolate_quantiles,dist_default)
S3method(extrapolate_quantiles,dist_quantiles)
S3method(extrapolate_quantiles,distribution)
S3method(fit,epi_workflow)
S3method(flusight_hub_formatter,canned_epipred)
S3method(flusight_hub_formatter,data.frame)
S3method(format,dist_quantiles)
S3method(is.na,dist_quantiles)
S3method(is.na,distribution)
Expand Down Expand Up @@ -126,6 +128,7 @@ export(fit)
export(flatline)
export(flatline_args_list)
export(flatline_forecaster)
export(flusight_hub_formatter)
export(frosting)
export(get_test_data)
export(grab_names)
Expand Down
8 changes: 4 additions & 4 deletions R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@
#' This forecaster is meant to produce exactly the CDC Baseline used for
#' [COVID19ForecastHub](https://covid19forecasthub.org)
#'
#' @param epi_data An [epiprocess::epi_df]
#' @param epi_data An [`epiprocess::epi_df`]
#' @param outcome A scalar character for the column name we wish to predict.
#' @param args_list A list of additional arguments as created by the
#' [cdc_baseline_args_list()] constructor function.
#'
#' @return A data frame of point and interval forecasts at for all
#' aheads (unique horizons) for each unique combination of `key_vars`.
#' @return A data frame of point and interval forecasts for all aheads (unique
#' horizons) for each unique combination of `key_vars`.
#' @export
#'
#' @examples
#' library(dplyr)
#' weekly_deaths <- case_death_rate_subset %>%
#' select(geo_value, time_value, death_rate) %>%
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
#' mutate(deaths = pmax(death_rate / 1e5 * pop, 0)) %>%
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
Expand Down
131 changes: 131 additions & 0 deletions R/flusight_hub_formatter.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
abbr_to_fips <- function(abbr) {
fi <- dplyr::left_join(
tibble::tibble(abbr = tolower(abbr)),
state_census, by = "abbr") %>%
dplyr::mutate(fips = as.character(fips), fips = case_when(
fips == "0" ~ "US",
nchar(fips) < 2L ~ paste0("0", fips),
TRUE ~ fips
)) %>%
pull(.data$fips)
names(fi) <- NULL
fi
}

#' Format predictions for submission to FluSight forecast Hub
#'
#' This function converts predictions from any of the included forecasters into
#' a format (nearly) ready for submission to the 2023-24
#' [FluSight-forecast-hub](https://github.com/cdcepi/FluSight-forecast-hub).
#' See there for documentation of the required columns. Currently, only
#' "quantile" forcasts are supported, but the intention is to support both
#' "quantile" and "pmf". For this reason, adding the `output_type` column should
#' be done via the `...` argument. See the examples below. The specific required
#' format for this forecast task is [here](https://github.com/cdcepi/FluSight-forecast-hub/blob/main/model-output/README.md).
#'
#' @param object a data.frame of predictions or an object of class
#' `canned_epipred` as created by, e.g., [arx_forecaster()]
#' @param ... <[`dynamic-dots`][rlang::dyn-dots]> Name = value pairs of constant
#' columns (or mutations) to perform to the results. See examples.
#' @param .fcast_period Control whether the `horizon` should represent days or
#' weeks. Depending on whether the forecaster output has target dates
#' from [layer_add_target_date()] or not, we may need to compute the horizon
#' and/or the `target_end_date` from the other available columns in the predictions.
#' When both `ahead` and `target_date` are available, this is ignored. If only
#' `ahead` or `aheads` exists, then the target date may need to be multiplied
#' if the `ahead` represents weekly forecasts. Alternatively, if only, the
#' `target_date` is available, then the `horizon` will be in days, unless
#' this argument is `"weekly"`. Note that these can be adjusted later by the
#' `...` argument.
#'
#' @return A [tibble::tibble]. If `...` is empty, the result will contain the
#' columns `reference_date`, `horizon`, `target_end_date`, `location`,
#' `output_type_id`, and `value`. The `...` can perform mutations on any of
#' these.
#' @export
#'
#' @examples
#' library(dplyr)
#' weekly_deaths <- case_death_rate_subset %>%
#' select(geo_value, time_value, death_rate) %>%
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
#' ungroup() %>%
#' filter(weekdays(time_value) == "Saturday")
#'
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' flusight_hub_formatter(cdc)
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths")
#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths"))
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile")
flusight_hub_formatter <- function(
object, ...,
.fcast_period = c("daily", "weekly")) {
UseMethod("flusight_hub_formatter")
}

#' @export
flusight_hub_formatter.canned_epipred <- function(
object, ...,
.fcast_period = c("daily", "weekly")) {
flusight_hub_formatter(object$predictions, ..., .fcast_period = .fcast_period)
}

#' @export
flusight_hub_formatter.data.frame <- function(
object, ...,
.fcast_period = c("daily", "weekly")) {
required_names <- c(".pred", ".pred_distn", "forecast_date", "geo_value")
optional_names <- c("ahead", "target_date")
hardhat::validate_column_names(object, required_names)
if (!any(optional_names %in% names(object))) {
cli::cli_abort("At least one of {.val {optional_names}} must be present.")
}

dots <- enquos(..., .named = TRUE)
names <- names(dots)

object <- object %>%
# combine the predictions and the distribution
dplyr::mutate(.pred_distn = nested_quantiles(.pred_distn)) %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred_distn = list(add_row(.pred_distn, q = .pred, tau = NA)),
.pred = NULL
) %>%
tidyr::unnest(.pred_distn) %>%
# now we create the correct column names
dplyr::rename(
value = q,
output_type_id = tau,
reference_date = forecast_date
) %>%
# convert to fips codes, and add any constant cols passed in ...
dplyr::mutate(location = abbr_to_fips(tolower(geo_value)), geo_value = NULL)

# create target_end_date / horizon, depending on what is available
pp <- ifelse(match.arg(.fcast_period) == "daily", 1L, 7L)
has_ahead <- charmatch("ahead", names(object))
if ("target_date" %in% names(object) && !is.na(has_ahead)) {
object <- object %>%
dplyr::rename(
target_end_date = target_date,
horizon = !!names(object)[has_ahead]
)
} else if (!is.na(has_ahead)) { # ahead present, not target date
object <- object %>%
dplyr::rename(horizon = !!names(object)[has_ahead]) %>%
dplyr::mutate(target_end_date = horizon * pp + reference_date)
} else { # target_date present, not ahead
object <- object %>%
dplyr::rename(target_end_date = target_date) %>%
dplyr::mutate(horizon = as.integer((target_end_date - reference_date)) / pp)
}
object %>% dplyr::relocate(
reference_date, horizon, target_end_date, location, output_type_id, value
) %>%
dplyr::mutate(!!!dots)
}
38 changes: 26 additions & 12 deletions R/layer_cdc_flatline_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,30 @@
#' @inheritParams layer_residual_quantiles
#' @param aheads Numeric vector of desired forecast horizons. These should be
#' given in the "units of the training data". So, for example, for data
#' typically observed daily (possibly with missing values), but
#' with weekly forecast targets, you would use `c(7, 14, 21, 28)`. But with
#' weekly data, you would use `1:4`.
#' typically observed daily (possibly with missing values), but with weekly
#' forecast targets, you would use `c(7, 14, 21, 28)`. But with weekly data,
#' you would use `1:4`.
#' @param quantile_levels Numeric vector of probabilities with values in (0,1)
#' referring to the desired predictive intervals. The default is the standard
#' set for the COVID Forecast Hub.
#' @param nsims Positive integer. The number of draws from the empirical CDF.
#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting
#' in linear interpolation on the X scale. This is achieved with
#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting in
#' linear interpolation on the X scale. This is achieved with
#' [stats::quantile()] Type 7 (the default for that function).
#' @param nonneg Logical. Force all predictive intervals be non-negative.
#' Because non-negativity is forced _before_ propagating forward, this
#' has slightly different behaviour than would occur if using
#' [layer_threshold()].
#' @param symmetrize Scalar logical. If `TRUE`, does two things: (i) forces the
#' "empirical" CDF of residuals to be symmetric by pretending that for every
#' actually-observed residual X we also observed another residual -X, and (ii)
#' at each ahead, forces the median simulated value to be equal to the point
#' prediction by adding or subtracting the same amount to every simulated
#' value. Adjustments in (ii) take place before propagating forward and
#' simulating the next ahead. This forces any 1-ahead predictive intervals to
#' be symmetric about the point prediction, and encourages larger aheads to be
#' more symmetric.
#' @param nonneg Scalar logical. Force all predictive intervals be non-negative.
#' Because non-negativity is forced _before_ propagating forward, this has
#' slightly different behaviour than would occur if using [layer_threshold()].
#' Thresholding at each ahead takes place after any shifting from
#' `symmetrize`.
#'
#' @return an updated `frosting` postprocessor. Calling [predict()] will result
#' in an additional `<list-col>` named `.pred_distn_all` containing 2-column
Expand Down Expand Up @@ -213,7 +223,7 @@ slather.layer_cdc_flatline_quantiles <-
res <- dplyr::left_join(p, r, by = avail_grps) %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred_distn_all = propogate_samples(
.pred_distn_all = propagate_samples(
.resid, .pred, object$quantile_levels,
object$aheads, object$nsim, object$symmetrize, object$nonneg
)
Expand All @@ -229,10 +239,14 @@ slather.layer_cdc_flatline_quantiles <-
components
}

propogate_samples <- function(
propagate_samples <- function(
r, p, quantile_levels, aheads, nsim, symmetrize, nonneg) {
max_ahead <- max(aheads)
samp <- quantile(r, probs = c(0, seq_len(nsim - 1)) / (nsim - 1), na.rm = TRUE)
if (symmetrize) {
r <- c(r, -r)
}
samp <- quantile(r, probs = c(0, seq_len(nsim - 1)) / (nsim - 1),
na.rm = TRUE, names = FALSE)
res <- list()

raw <- samp + p
Expand Down
8 changes: 4 additions & 4 deletions man/cdc_baseline_forecaster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 60 additions & 0 deletions man/flusight_hub_formatter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 19 additions & 10 deletions man/layer_cdc_flatline_quantiles.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
test_that("propogate_samples", {
test_that("propagate_samples", {
r <- -30:50
p <- 40
quantiles <- 1:9 / 10
Expand Down

0 comments on commit 3ef79b9

Please sign in to comment.