Skip to content

Commit

Permalink
redocument, run styler
Browse files Browse the repository at this point in the history
  • Loading branch information
dajmcdon committed Oct 2, 2023
1 parent cdfd0a8 commit 8463a42
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 73 deletions.
26 changes: 13 additions & 13 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ arx_forecaster <- function(epi_data,
#' arx_fcast_epi_workflow(jhu, "death_rate",
#' c("case_rate", "death_rate"),
#' trainer = quantile_reg(),
#' args_list = arx_args_list(quantile_level = 1:9 / 10)
#' args_list = arx_args_list(quantile_levels = 1:9 / 10)
#' )
arx_fcast_epi_workflow <- function(
epi_data,
Expand Down Expand Up @@ -135,18 +135,18 @@ arx_fcast_epi_workflow <- function(
f <- frosting() %>% layer_predict() # %>% layer_naomit()
if (inherits(trainer, "quantile_reg")) {
# add all quantile_level to the forecaster and update postprocessor
quantile_level <- sort(compare_quantile_args(
args_list$quantile_level,
rlang::eval_tidy(trainer$args$quantile_level)
quantile_levels <- sort(compare_quantile_args(
args_list$quantile_levels,
rlang::eval_tidy(trainer$args$quantile_levels)
))
args_list$quantile_level <- quantile_level
trainer$args$quantile_level <- rlang::enquo(quantile_level)
f <- layer_quantile_distn(f, quantile_level = quantile_level) %>%
args_list$quantile_levels <- quantile_levels
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
f <- layer_quantile_distn(f, quantile_levels = quantile_levels) %>%
layer_point_from_distn()
} else {
f <- layer_residual_quantiles(
f,
quantile_level = args_list$quantile_level,
quantile_levels = args_list$quantile_levels,
symmetrize = args_list$symmetrize,
by_key = args_list$quantile_by_key
)
Expand Down Expand Up @@ -175,7 +175,7 @@ arx_fcast_epi_workflow <- function(
#' The default `NULL` will attempt to determine this automatically.
#' @param target_date Date. The date for which the forecast is intended.
#' The default `NULL` will attempt to determine this automatically.
#' @param quantile_level Vector or `NULL`. A vector of probabilities to produce
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
#' prediction intervals. These are created by computing the quantiles of
#' training residuals. A `NULL` value will result in point forecasts only.
#' @param symmetrize Logical. The default `TRUE` calculates
Expand Down Expand Up @@ -208,14 +208,14 @@ arx_fcast_epi_workflow <- function(
#' @examples
#' arx_args_list()
#' arx_args_list(symmetrize = FALSE)
#' arx_args_list(quantile_level = c(.1, .3, .7, .9), n_training = 120)
#' arx_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
arx_args_list <- function(
lags = c(0L, 7L, 14L),
ahead = 7L,
n_training = Inf,
forecast_date = NULL,
target_date = NULL,
quantile_level = c(0.05, 0.95),
quantile_levels = c(0.05, 0.95),
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = character(0L),
Expand All @@ -231,7 +231,7 @@ arx_args_list <- function(
arg_is_date(forecast_date, target_date, allow_null = TRUE)
arg_is_nonneg_int(ahead, lags)
arg_is_lgl(symmetrize, nonneg)
arg_is_probabilities(quantile_level, allow_null = TRUE)
arg_is_probabilities(quantile_levels, allow_null = TRUE)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
Expand All @@ -242,7 +242,7 @@ arx_args_list <- function(
lags = .lags,
ahead,
n_training,
quantile_level,
quantile_levels,
forecast_date,
target_date,
symmetrize,
Expand Down
56 changes: 31 additions & 25 deletions R/dist_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ new_quantiles <- function(values = double(), quantile_levels = double()) {
}

new_rcrd(list(values = values, quantile_levels = quantile_levels),
class = c("dist_quantiles", "dist_default")
class = c("dist_quantiles", "dist_default")
)
}

Expand All @@ -30,9 +30,8 @@ vec_ptype_full.dist_quantiles <- function(x, ...) "dist_quantiles"

#' @export
format.dist_quantiles <- function(x, digits = 2, ...) {
q <- field(x, "values")
m <- suppressWarnings(median(x))
paste0("quantiles(", round(m, digits), ")[", vctrs::vec_size(q), "]")
paste0("quantiles(", round(m, digits), ")[", vctrs::vec_size(x), "]")
}


Expand Down Expand Up @@ -78,11 +77,11 @@ validate_dist_quantiles <- function(values, quantile_levels) {
i = "Mismatches found at position(s): {.val {which(length_diff)}}."
))
}
tau_duplication <- map_lgl(quantile_levels, vctrs::vec_duplicate_any)
if (any(tau_duplication)) {
level_duplication <- map_lgl(quantile_levels, vctrs::vec_duplicate_any)
if (any(level_duplication)) {
cli::cli_abort(c(
"`quantile_levels` must not be duplicated.",
i = "Duplicates found at position(s): {.val {which(tau_duplication)}}."
i = "Duplicates found at position(s): {.val {which(level_duplication)}}."
))
}
}
Expand Down Expand Up @@ -120,22 +119,25 @@ extrapolate_quantiles <- function(x, probs, ...) {
#' @importFrom vctrs vec_data
extrapolate_quantiles.distribution <- function(x, probs, ...) {
arg_is_probabilities(probs)
dstn <- lapply(vec_data(x), extrapolate_quantiles, p = probs, ...)
dstn <- lapply(vec_data(x), extrapolate_quantiles, probs = probs, ...)
new_vctr(dstn, vars = NULL, class = "distribution")
}

#' @export
extrapolate_quantiles.dist_default <- function(x, probs, ...) {
q <- quantile(x, probs, ...)
new_quantiles(values = q, quantile_levels = probs)
values <- quantile(x, probs, ...)
new_quantiles(values = values, quantile_levels = probs)
}

#' @export
extrapolate_quantiles.dist_quantiles <- function(x, probs, ...) {
q <- quantile(x, probs, ...)
tau <- field(x, "quantile_levels")
qvals <- field(x, "values")
new_quantiles(values = c(qvals, q), quantile_levels = c(tau, probs))
new_values <- quantile(x, probs, ...)
quantile_levels <- field(x, "quantile_levels")
values <- field(x, "values")
new_quantiles(
values = c(values, new_values),
quantile_levels = c(quantile_levels, probs)
)
}

is_dist_quantiles <- function(x) {
Expand All @@ -152,10 +154,10 @@ is_dist_quantiles <- function(x) {
#'
#' @examples
#' edf <- case_death_rate_subset[1:3, ]
#' edf$q <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11))
#' edf$dstn <- dist_quantiles(list(1:5, 2:4, 3:10), list(1:5 / 6, 2:4 / 5, 3:10 / 11))
#'
#' edf_nested <- edf %>% dplyr::mutate(q = nested_quantiles(q))
#' edf_nested %>% tidyr::unnest(q)
#' edf_nested <- edf %>% dplyr::mutate(dstn = nested_quantiles(dstn))
#' edf_nested %>% tidyr::unnest(dstn)
nested_quantiles <- function(x) {
stopifnot(is_dist_quantiles(x))
map(
Expand Down Expand Up @@ -236,12 +238,16 @@ pivot_quantiles <- function(.data, ...) {
#' @export
#' @importFrom stats median qnorm family
median.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "linear")) {
tau <- field(x, "quantile_levels")
qvals <- field(x, "values")
if (0.5 %in% tau) return(qvals[match(0.5, tau)])
if (min(tau) > 0.5 || max(tau) < 0.5 || length(tau) < 2) return(NA)
if (length(tau) < 3 || min(tau) > .25 || max(tau) < .75) {
return(stats::approx(tau, qvals, xout = 0.5)$y)
quantile_levels <- field(x, "quantile_levels")
values <- field(x, "values")
if (0.5 %in% quantile_levels) {
return(values[match(0.5, quantile_levels)])
}
if (length(quantile_levels) < 2 || min(quantile_levels) > 0.5 || max(quantile_levels) < 0.5) {
return(NA)
}
if (length(quantile_levels) < 3 || min(quantile_levels) > .25 || max(quantile_levels) < .75) {
return(stats::approx(quantile_levels, values, xout = 0.5)$y)
}
quantile(x, 0.5, ..., middle = middle)
}
Expand All @@ -256,15 +262,15 @@ mean.dist_quantiles <- function(x, na.rm = FALSE, ..., middle = c("cubic", "line
#' @importFrom stats quantile
#' @import distributional
quantile.dist_quantiles <- function(
x, probs, ...,
x, p, ...,
middle = c("cubic", "linear"),
left_tail = c("normal", "exponential"),
right_tail = c("normal", "exponential")) {
arg_is_probabilities(probs)
arg_is_probabilities(p)
middle <- match.arg(middle)
left_tail <- match.arg(left_tail)
right_tail <- match.arg(right_tail)
quantile_extrapolate(x, probs, middle, left_tail, right_tail)
quantile_extrapolate(x, p, middle, left_tail, right_tail)
}


Expand Down
4 changes: 2 additions & 2 deletions R/layer_quantile_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ slather.layer_quantile_distn <-
dstn <- components$predictions$.pred
if (!inherits(dstn, "distribution")) {
cli_abort(c(
"`layer_quantile_distn()` requires distributional predictions.",
"These are of class {.cls {class(dstn)}}."
"`layer_quantile_distn()` requires distributional predictions.",
"These are of class {.cls {class(dstn)}}."
))
}
dstn <- dist_quantiles(
Expand Down
4 changes: 2 additions & 2 deletions R/layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ slather.layer_residual_quantiles <-

r <- r %>%
dplyr::summarize(
q = list(quantile(
dstn = list(quantile(
c(.resid, s * .resid),
probs = object$quantile_levels, na.rm = TRUE
))
)

estimate <- components$predictions$.pred
res <- tibble::tibble(
.pred_distn = dist_quantiles(map2(estimate, r$q, "+"), object$quantile_levels)
.pred_distn = dist_quantiles(map2(estimate, r$dstn, "+"), object$quantile_levels)
)
res <- check_pname(res, components$predictions, object)
components$predictions <- dplyr::mutate(components$predictions, !!!res)
Expand Down
27 changes: 13 additions & 14 deletions R/step_growth_rate.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,19 @@
#' recipes::prep() %>%
#' recipes::bake(case_death_rate_subset)
step_growth_rate <-
function(
recipe,
...,
role = "predictor",
trained = FALSE,
horizon = 7,
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
log_scale = FALSE,
replace_Inf = NA,
prefix = "gr_",
columns = NULL,
skip = FALSE,
id = rand_id("growth_rate"),
additional_gr_args_list = list()) {
function(recipe,
...,
role = "predictor",
trained = FALSE,
horizon = 7,
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
log_scale = FALSE,
replace_Inf = NA,
prefix = "gr_",
columns = NULL,
skip = FALSE,
id = rand_id("growth_rate"),
additional_gr_args_list = list()) {
if (!is_epi_recipe(recipe)) {
rlang::abort("This recipe step can only operate on an `epi_recipe`.")
}
Expand Down
19 changes: 9 additions & 10 deletions R/step_lag_difference.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@
#' recipes::prep() %>%
#' recipes::bake(case_death_rate_subset)
step_lag_difference <-
function(
recipe,
...,
role = "predictor",
trained = FALSE,
horizon = 7,
prefix = "lag_diff_",
columns = NULL,
skip = FALSE,
id = rand_id("lag_diff")) {
function(recipe,
...,
role = "predictor",
trained = FALSE,
horizon = 7,
prefix = "lag_diff_",
columns = NULL,
skip = FALSE,
id = rand_id("lag_diff")) {
if (!is_epi_recipe(recipe)) {
rlang::abort("This recipe step can only operate on an `epi_recipe`.")
}
Expand Down
6 changes: 3 additions & 3 deletions man/arx_args_list.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/arx_fcast_epi_workflow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/flatline_args_list.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/nested_quantiles.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8463a42

Please sign in to comment.