diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index ad7e5253e..3dc54286e 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -99,7 +99,7 @@ arx_forecaster <- function(epi_data, #' arx_fcast_epi_workflow(jhu, "death_rate", #' c("case_rate", "death_rate"), #' trainer = quantile_reg(), -#' args_list = arx_args_list(quantile_level = 1:9 / 10) +#' args_list = arx_args_list(quantile_levels = 1:9 / 10) #' ) arx_fcast_epi_workflow <- function( epi_data, @@ -135,18 +135,18 @@ arx_fcast_epi_workflow <- function( f <- frosting() %>% layer_predict() # %>% layer_naomit() if (inherits(trainer, "quantile_reg")) { # add all quantile_level to the forecaster and update postprocessor - quantile_level <- sort(compare_quantile_args( - args_list$quantile_level, - rlang::eval_tidy(trainer$args$quantile_level) + quantile_levels <- sort(compare_quantile_args( + args_list$quantile_levels, + rlang::eval_tidy(trainer$args$quantile_levels) )) - args_list$quantile_level <- quantile_level - trainer$args$quantile_level <- rlang::enquo(quantile_level) - f <- layer_quantile_distn(f, quantile_level = quantile_level) %>% + args_list$quantile_levels <- quantile_levels + trainer$args$quantile_levels <- rlang::enquo(quantile_levels) + f <- layer_quantile_distn(f, quantile_levels = quantile_levels) %>% layer_point_from_distn() } else { f <- layer_residual_quantiles( f, - quantile_level = args_list$quantile_level, + quantile_levels = args_list$quantile_levels, symmetrize = args_list$symmetrize, by_key = args_list$quantile_by_key ) @@ -175,7 +175,7 @@ arx_fcast_epi_workflow <- function( #' The default `NULL` will attempt to determine this automatically. #' @param target_date Date. The date for which the forecast is intended. #' The default `NULL` will attempt to determine this automatically. -#' @param quantile_level Vector or `NULL`. A vector of probabilities to produce +#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce #' prediction intervals. These are created by computing the quantiles of #' training residuals. A `NULL` value will result in point forecasts only. #' @param symmetrize Logical. The default `TRUE` calculates @@ -208,14 +208,14 @@ arx_fcast_epi_workflow <- function( #' @examples #' arx_args_list() #' arx_args_list(symmetrize = FALSE) -#' arx_args_list(quantile_level = c(.1, .3, .7, .9), n_training = 120) +#' arx_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120) arx_args_list <- function( lags = c(0L, 7L, 14L), ahead = 7L, n_training = Inf, forecast_date = NULL, target_date = NULL, - quantile_level = c(0.05, 0.95), + quantile_levels = c(0.05, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), @@ -231,7 +231,7 @@ arx_args_list <- function( arg_is_date(forecast_date, target_date, allow_null = TRUE) arg_is_nonneg_int(ahead, lags) arg_is_lgl(symmetrize, nonneg) - arg_is_probabilities(quantile_level, allow_null = TRUE) + arg_is_probabilities(quantile_levels, allow_null = TRUE) arg_is_pos(n_training) if (is.finite(n_training)) arg_is_pos_int(n_training) if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE) @@ -242,7 +242,7 @@ arx_args_list <- function( lags = .lags, ahead, n_training, - quantile_level, + quantile_levels, forecast_date, target_date, symmetrize, diff --git a/R/dist_quantiles.R b/R/dist_quantiles.R index e94773f8e..bb8810902 100644 --- a/R/dist_quantiles.R +++ b/R/dist_quantiles.R @@ -16,7 +16,7 @@ new_quantiles <- function(values = double(), quantile_levels = double()) { } new_rcrd(list(values = values, quantile_levels = quantile_levels), - class = c("dist_quantiles", "dist_default") + class = c("dist_quantiles", "dist_default") ) } @@ -30,9 +30,8 @@ vec_ptype_full.dist_quantiles <- function(x, ...) "dist_quantiles" #' @export format.dist_quantiles <- function(x, digits = 2, ...) { - q <- field(x, "values") m <- suppressWarnings(median(x)) - paste0("quantiles(", round(m, digits), ")[", vctrs::vec_size(q), "]") + paste0("quantiles(", round(m, digits), ")[", vctrs::vec_size(x), "]") } @@ -78,11 +77,11 @@ validate_dist_quantiles <- function(values, quantile_levels) { i = "Mismatches found at position(s): {.val {which(length_diff)}}." )) } - tau_duplication <- map_lgl(quantile_levels, vctrs::vec_duplicate_any) - if (any(tau_duplication)) { + level_duplication <- map_lgl(quantile_levels, vctrs::vec_duplicate_any) + if (any(level_duplication)) { cli::cli_abort(c( "`quantile_levels` must not be duplicated.", - i = "Duplicates found at position(s): {.val {which(tau_duplication)}}." + i = "Duplicates found at position(s): {.val {which(level_duplication)}}." )) } } @@ -120,22 +119,25 @@ extrapolate_quantiles <- function(x, probs, ...) { #' @importFrom vctrs vec_data extrapolate_quantiles.distribution <- function(x, probs, ...) { arg_is_probabilities(probs) - dstn <- lapply(vec_data(x), extrapolate_quantiles, p = probs, ...) + dstn <- lapply(vec_data(x), extrapolate_quantiles, probs = probs, ...) new_vctr(dstn, vars = NULL, class = "distribution") } #' @export extrapolate_quantiles.dist_default <- function(x, probs, ...) { - q <- quantile(x, probs, ...) - new_quantiles(values = q, quantile_levels = probs) + values <- quantile(x, probs, ...) + new_quantiles(values = values, quantile_levels = probs) } #' @export extrapolate_quantiles.dist_quantiles <- function(x, probs, ...) { - q <- quantile(x, probs, ...) - tau <- field(x, "quantile_levels") - qvals <- field(x, "values") - new_quantiles(values = c(qvals, q), quantile_levels = c(tau, probs)) + new_values <- quantile(x, probs, ...) + quantile_levels <- field(x, "quantile_levels") + values <- field(x, "values") + new_quantiles( + values = c(values, new_values), + quantile_levels = c(quantile_levels, probs) + ) } is_dist_quantiles <- function(x) { @@ -152,10 +154,10 @@ is_dist_quantiles <- function(x) { #' #' @examples #' edf <- case_death_rate_subset[1:3, ] -#' edf$q <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11)) +#' edf$dstn <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11)) #' -#' edf_nested <- edf %>% dplyr::mutate(q = nested_quantiles(q)) -#' edf_nested %>% tidyr::unnest(q) +#' edf_nested <- edf %>% dplyr::mutate(dstn = nested_quantiles(dstn)) +#' edf_nested %>% tidyr::unnest(dstn) nested_quantiles <- function(x) { stopifnot(is_dist_quantiles(x)) map( @@ -236,12 +238,16 @@ pivot_quantiles <- function(.data, ...) { #' @export #' @importFrom stats median qnorm family median.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "linear")) { - tau <- field(x, "quantile_levels") - qvals <- field(x, "values") - if (0.5 %in% tau) return(qvals[match(0.5, tau)]) - if (min(tau) > 0.5 || max(tau) < 0.5 || length(tau) < 2) return(NA) - if (length(tau) < 3 || min(tau) > .25 || max(tau) < .75) { - return(stats::approx(tau, qvals, xout = 0.5)$y) + quantile_levels <- field(x, "quantile_levels") + values <- field(x, "values") + if (0.5 %in% quantile_levels) { + return(values[match(0.5, quantile_levels)]) + } + if (length(quantile_levels) < 2 || min(quantile_levels) > 0.5 || max(quantile_levels) < 0.5) { + return(NA) + } + if (length(quantile_levels) < 3 || min(quantile_levels) > .25 || max(quantile_levels) < .75) { + return(stats::approx(quantile_levels, values, xout = 0.5)$y) } quantile(x, 0.5, ..., middle = middle) } @@ -256,15 +262,15 @@ mean.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "line #' @importFrom stats quantile #' @import distributional quantile.dist_quantiles <- function( - x, probs, ..., + x, p, ..., middle = c("cubic", "linear"), left_tail = c("normal", "exponential"), right_tail = c("normal", "exponential")) { - arg_is_probabilities(probs) + arg_is_probabilities(p) middle <- match.arg(middle) left_tail <- match.arg(left_tail) right_tail <- match.arg(right_tail) - quantile_extrapolate(x, probs, middle, left_tail, right_tail) + quantile_extrapolate(x, p, middle, left_tail, right_tail) } diff --git a/R/layer_quantile_distn.R b/R/layer_quantile_distn.R index 6c848231d..a99eed326 100644 --- a/R/layer_quantile_distn.R +++ b/R/layer_quantile_distn.R @@ -77,8 +77,8 @@ slather.layer_quantile_distn <- dstn <- components$predictions$.pred if (!inherits(dstn, "distribution")) { cli_abort(c( - "`layer_quantile_distn()` requires distributional predictions.", - "These are of class {.cls {class(dstn)}}." + "`layer_quantile_distn()` requires distributional predictions.", + "These are of class {.cls {class(dstn)}}." )) } dstn <- dist_quantiles( diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index 2e7639853..932f73246 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -116,7 +116,7 @@ slather.layer_residual_quantiles <- r <- r %>% dplyr::summarize( - q = list(quantile( + dstn = list(quantile( c(.resid, s * .resid), probs = object$quantile_levels, na.rm = TRUE )) @@ -124,7 +124,7 @@ slather.layer_residual_quantiles <- estimate <- components$predictions$.pred res <- tibble::tibble( - .pred_distn = dist_quantiles(map2(estimate, r$q, "+"), object$quantile_levels) + .pred_distn = dist_quantiles(map2(estimate, r$dstn, "+"), object$quantile_levels) ) res <- check_pname(res, components$predictions, object) components$predictions <- dplyr::mutate(components$predictions, !!!res) diff --git a/R/step_growth_rate.R b/R/step_growth_rate.R index f6ad29a5b..74cfff284 100644 --- a/R/step_growth_rate.R +++ b/R/step_growth_rate.R @@ -42,20 +42,19 @@ #' recipes::prep() %>% #' recipes::bake(case_death_rate_subset) step_growth_rate <- - function( - recipe, - ..., - role = "predictor", - trained = FALSE, - horizon = 7, - method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), - log_scale = FALSE, - replace_Inf = NA, - prefix = "gr_", - columns = NULL, - skip = FALSE, - id = rand_id("growth_rate"), - additional_gr_args_list = list()) { + function(recipe, + ..., + role = "predictor", + trained = FALSE, + horizon = 7, + method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"), + log_scale = FALSE, + replace_Inf = NA, + prefix = "gr_", + columns = NULL, + skip = FALSE, + id = rand_id("growth_rate"), + additional_gr_args_list = list()) { if (!is_epi_recipe(recipe)) { rlang::abort("This recipe step can only operate on an `epi_recipe`.") } diff --git a/R/step_lag_difference.R b/R/step_lag_difference.R index 2482be46a..21878eaa7 100644 --- a/R/step_lag_difference.R +++ b/R/step_lag_difference.R @@ -23,16 +23,15 @@ #' recipes::prep() %>% #' recipes::bake(case_death_rate_subset) step_lag_difference <- - function( - recipe, - ..., - role = "predictor", - trained = FALSE, - horizon = 7, - prefix = "lag_diff_", - columns = NULL, - skip = FALSE, - id = rand_id("lag_diff")) { + function(recipe, + ..., + role = "predictor", + trained = FALSE, + horizon = 7, + prefix = "lag_diff_", + columns = NULL, + skip = FALSE, + id = rand_id("lag_diff")) { if (!is_epi_recipe(recipe)) { rlang::abort("This recipe step can only operate on an `epi_recipe`.") } diff --git a/man/arx_args_list.Rd b/man/arx_args_list.Rd index b4aad6a12..e5d2391c8 100644 --- a/man/arx_args_list.Rd +++ b/man/arx_args_list.Rd @@ -10,7 +10,7 @@ arx_args_list( n_training = Inf, forecast_date = NULL, target_date = NULL, - quantile_level = c(0.05, 0.95), + quantile_levels = c(0.05, 0.95), symmetrize = TRUE, nonneg = TRUE, quantile_by_key = character(0L), @@ -36,7 +36,7 @@ The default \code{NULL} will attempt to determine this automatically.} \item{target_date}{Date. The date for which the forecast is intended. The default \code{NULL} will attempt to determine this automatically.} -\item{quantile_level}{Vector or \code{NULL}. A vector of probabilities to produce +\item{quantile_levels}{Vector or \code{NULL}. A vector of probabilities to produce prediction intervals. These are created by computing the quantiles of training residuals. A \code{NULL} value will result in point forecasts only.} @@ -76,5 +76,5 @@ Constructs a list of arguments for \code{\link[=arx_forecaster]{arx_forecaster() \examples{ arx_args_list() arx_args_list(symmetrize = FALSE) -arx_args_list(quantile_level = c(.1, .3, .7, .9), n_training = 120) +arx_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120) } diff --git a/man/arx_fcast_epi_workflow.Rd b/man/arx_fcast_epi_workflow.Rd index 1c7aac02e..8c76bcdd7 100644 --- a/man/arx_fcast_epi_workflow.Rd +++ b/man/arx_fcast_epi_workflow.Rd @@ -49,7 +49,7 @@ arx_fcast_epi_workflow( arx_fcast_epi_workflow(jhu, "death_rate", c("case_rate", "death_rate"), trainer = quantile_reg(), - args_list = arx_args_list(quantile_level = 1:9 / 10) + args_list = arx_args_list(quantile_levels = 1:9 / 10) ) } \seealso{ diff --git a/man/flatline_args_list.Rd b/man/flatline_args_list.Rd index 669cb7a9f..c5a5d9885 100644 --- a/man/flatline_args_list.Rd +++ b/man/flatline_args_list.Rd @@ -31,6 +31,10 @@ The default \code{NULL} will attempt to determine this automatically.} \item{target_date}{Date. The date for which the forecast is intended. The default \code{NULL} will attempt to determine this automatically.} +\item{quantile_levels}{Vector or \code{NULL}. A vector of probabilities to produce +prediction intervals. These are created by computing the quantiles of +training residuals. A \code{NULL} value will result in point forecasts only.} + \item{symmetrize}{Logical. The default \code{TRUE} calculates symmetric prediction intervals. This argument only applies when residual quantiles are used. It is not applicable with diff --git a/man/nested_quantiles.Rd b/man/nested_quantiles.Rd index c4b578c1a..b1a67cffe 100644 --- a/man/nested_quantiles.Rd +++ b/man/nested_quantiles.Rd @@ -17,8 +17,8 @@ Turn a vector of quantile distributions into a list-col } \examples{ edf <- case_death_rate_subset[1:3, ] -edf$q <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11)) +edf$dstn <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11)) -edf_nested <- edf \%>\% dplyr::mutate(q = nested_quantiles(q)) -edf_nested \%>\% tidyr::unnest(q) +edf_nested <- edf \%>\% dplyr::mutate(dstn = nested_quantiles(dstn)) +edf_nested \%>\% tidyr::unnest(dstn) }