Skip to content

Commit

Permalink
docs+fix: point_and_distn
Browse files Browse the repository at this point in the history
  • Loading branch information
dsweber2 committed Sep 14, 2023
1 parent ef4ea0a commit 6894071
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ export(layer)
export(layer_add_forecast_date)
export(layer_add_target_date)
export(layer_naomit)
export(layer_point_and_distn)
export(layer_point_from_distn)
export(layer_population_scaling)
export(layer_predict)
Expand Down
45 changes: 35 additions & 10 deletions R/layer_point_and_distn.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
#' returns both the point estimate and the quantile distribution, regardless of the underlying trainer
#' f
#' returns both the point estimate and the quantile distribution
#' @description
#' This function adds a frosting layer that produces both a point estimate as
#' well as quantile estimates.
#' @param distn_id a random id string for the layer that creates the quantile
#' estimate
#' @param point_id a random id string for the layer that creates the point
#' estimate. Only present for trainers that produce quantiles
#' @param point_type character. Either `mean` or `median`.
#' @param use_predictive_distribution only usable for `linear_reg` type models
#' @param distn_type character. Only used if `use_predictive_distribution=TRUE`,
#' for `linear_reg` type models. Either gaussian or student_t
#' @param distn_name an alternate name for the distribution column; defaults
#' to `.pred_distn`.
#' @param point_name an alternate name for the point estimate column; defaults
#' to `.pred`.
#' @param symmetrize logical. If `TRUE` then interval will be symmetric.
#' Applies for residual quantiles only
#' @param by_key A character vector of keys to group the residuals by before
#' calculating quantiles. The default, `c()` performs no grouping. Only used
#' by `layer_residual_quantiles`
#' @inheritParams layer_quantile_distn
#' @export
#' @return an updated `frosting postprocessor` with an additional prediction
#' column; if the trainer produces a point estimate, it has added a
#' distribution estimate, and vice versa.
layer_point_and_distn <- function(frosting, trainer, ...,
probs = c(0.05, 0.95),
levels = c(0.25, 0.75),
symmetrize = TRUE,
by_key = character(0L),
distn_name = ".pred_distn",
Expand All @@ -10,26 +34,27 @@ layer_point_and_distn <- function(frosting, trainer, ...,
point_id = NULL,
point_type = c("median", "mean"),
truncate = c(-Inf, Inf),
use_predictive_distribution = TRUE,
use_predictive_distribution = FALSE,
dist_type = "gaussian") {
rlang::check_dots_empty()
stopifnot(inherits(recipe, "recipe"))
# not sure what to do about the dots...
levels <- sort(levels)
if (inherits(trainer, "quantile_reg")) {
# sort the probabilities
tau <- sort(compare_quantile_args(
args_list$levels,
levels,
rlang::eval_tidy(trainer$args$tau)
))
args_list$levels <- tau
levels <- tau
trainer$args$tau <- rlang::enquo(tau)
if (is.null(point_id)) {
point_id <- rand_id("point_from_distn")
}
if (is.null(distn_id)) {
distn_id <- rand_id("quantile_distn")
}
frosting %<>% layer_quantile_distn(...,
levels = tau,
frosting %<>% layer_quantile_distn(
levels = levels,
truncate = trucate,
name = distn_name,
id = distn_id
Expand All @@ -44,7 +69,7 @@ layer_point_and_distn <- function(frosting, trainer, ...,
distn_id <- rand_id("residual_quantiles")
}
if (inherits(trainer, "linear_reg") && use_predictive_distribution) {
frosting %<>% layer_residual_quantiles(
frosting %<>% layer_predictive_distn(
dist_type = dist_type,
name = distn_name,
id = distn_id
Expand Down

0 comments on commit 6894071

Please sign in to comment.