Skip to content

Commit

Permalink
Merge pull request #247 from cmu-delphi/v0.0.6
Browse files Browse the repository at this point in the history
v0.0.6 to main
  • Loading branch information
dajmcdon authored Oct 19, 2023
2 parents 96591a1 + c0d9e9e commit 0f4f2f9
Show file tree
Hide file tree
Showing 75 changed files with 1,948 additions and 445 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
^musings$
^data-raw$
^vignettes/articles$
^.git-blame-ignore-revs$
2 changes: 1 addition & 1 deletion .github/workflows/styler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
workflow_dispatch:
pullrequest:
pull_request:
paths:
[
"**.[rR]",
Expand Down
11 changes: 7 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.5
Version: 0.0.6
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
person("Logan", "Brooks", role = "aut"),
person("Rachel", "Lobay", role = "aut"),
person("Maggie", "Liu", role = "aut"),
person("Ken", "Mawer", role = "aut"),
person("Chloe", "You", role = "aut"),
person("Dmitry", "Shemetov", email = "[email protected]", role = "ctb"),
person("David", "Weber", email = "[email protected]", role = "ctb"),
person("Maggie", "Liu", role = "ctb"),
person("Ken", "Mawer", role = "ctb"),
person("Chloe", "You", role = "ctb"),
person("Jacob", "Bien", role = "ctb")
)
Description: A forecasting "framework" for creating epidemiological
Expand All @@ -32,6 +34,7 @@ Imports:
generics,
glue,
hardhat (>= 1.3.0),
lifecycle,
magrittr,
methods,
quantreg,
Expand Down
24 changes: 22 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ S3method(extrapolate_quantiles,dist_default)
S3method(extrapolate_quantiles,dist_quantiles)
S3method(extrapolate_quantiles,distribution)
S3method(fit,epi_workflow)
S3method(flusight_hub_formatter,canned_epipred)
S3method(flusight_hub_formatter,data.frame)
S3method(format,dist_quantiles)
S3method(is.na,dist_quantiles)
S3method(is.na,distribution)
Expand All @@ -52,6 +54,7 @@ S3method(print,alist)
S3method(print,arx_class)
S3method(print,arx_fcast)
S3method(print,canned_epipred)
S3method(print,cdc_baseline_fcast)
S3method(print,epi_recipe)
S3method(print,epi_workflow)
S3method(print,flat_fcast)
Expand Down Expand Up @@ -81,6 +84,7 @@ S3method(residuals,flatline)
S3method(run_mold,default_epi_recipe_blueprint)
S3method(slather,layer_add_forecast_date)
S3method(slather,layer_add_target_date)
S3method(slather,layer_cdc_flatline_quantiles)
S3method(slather,layer_naomit)
S3method(slather,layer_point_from_distn)
S3method(slather,layer_population_scaling)
Expand Down Expand Up @@ -108,6 +112,8 @@ export(arx_classifier)
export(arx_fcast_epi_workflow)
export(arx_forecaster)
export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
export(create_layer)
export(default_epi_recipe_blueprint)
export(detect_layer)
Expand All @@ -124,6 +130,7 @@ export(fit)
export(flatline)
export(flatline_args_list)
export(flatline_forecaster)
export(flusight_hub_formatter)
export(frosting)
export(get_test_data)
export(grab_names)
Expand All @@ -133,6 +140,7 @@ export(is_layer)
export(layer)
export(layer_add_forecast_date)
export(layer_add_target_date)
export(layer_cdc_flatline_quantiles)
export(layer_naomit)
export(layer_point_from_distn)
export(layer_population_scaling)
Expand All @@ -145,7 +153,8 @@ export(layer_unnest)
export(nested_quantiles)
export(new_default_epi_recipe_blueprint)
export(new_epi_recipe_blueprint)
export(pivot_quantiles)
export(pivot_quantiles_longer)
export(pivot_quantiles_wider)
export(prep)
export(quantile_reg)
export(remove_frosting)
Expand All @@ -163,12 +172,13 @@ import(distributional)
import(epiprocess)
import(parsnip)
import(recipes)
import(vctrs)
importFrom(cli,cli_abort)
importFrom(epiprocess,growth_rate)
importFrom(generics,augment)
importFrom(generics,fit)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(lifecycle,deprecated)
importFrom(magrittr,"%>%")
importFrom(methods,is)
importFrom(quantreg,rq)
Expand All @@ -183,6 +193,7 @@ importFrom(rlang,caller_env)
importFrom(rlang,is_empty)
importFrom(rlang,is_null)
importFrom(rlang,quos)
importFrom(smoothqr,smooth_qr)
importFrom(stats,as.formula)
importFrom(stats,family)
importFrom(stats,lm)
Expand All @@ -196,3 +207,12 @@ importFrom(stats,residuals)
importFrom(tibble,as_tibble)
importFrom(tibble,is_tibble)
importFrom(tibble,tibble)
importFrom(vctrs,as_list_of)
importFrom(vctrs,field)
importFrom(vctrs,new_rcrd)
importFrom(vctrs,new_vctr)
importFrom(vctrs,vec_cast)
importFrom(vctrs,vec_data)
importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
importFrom(vctrs,vec_recycle_common)
10 changes: 8 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# epipredict (development)

# epipredict 0.0.6

* rename the `dist_quantiles()` to be more descriptive, breaking change)
* removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
* add `pivot_quantiles_wider()` for easier plotting
* add complement `pivot_quantiles_longer()`
* add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`

# epipredict 0.0.5

* add `smooth_quantile_reg()`
* improved printing of various methods / internals
* canned forecasters get a class
* fixed quantile bug in `flatline_forecaster()`
* add functionality to output the unfit workflow from the canned forecasters
* add `pivot_quantiles()` for easier plotting


# epipredict 0.0.4

Expand Down
6 changes: 5 additions & 1 deletion R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ arx_class_args_list <- function(
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
log_scale = FALSE,
additional_gr_args = list(),
nafill_buffer = Inf) {
nafill_buffer = Inf,
...) {
rlang::check_dots_empty()
.lags <- lags
if (is.list(lags)) lags <- unlist(lags)
method <- match.arg(method)
Expand Down Expand Up @@ -305,3 +307,5 @@ print.arx_class <- function(x, ...) {
name <- "ARX Classifier"
NextMethod(name = name, ...)
}

# this is a trivial change to induce a check
40 changes: 22 additions & 18 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#' out <- arx_forecaster(jhu, "death_rate",
#' c("case_rate", "death_rate"),
#' trainer = quantile_reg(),
#' args_list = arx_args_list(levels = 1:9 / 10)
#' args_list = arx_args_list(quantile_levels = 1:9 / 10)
#' )
arx_forecaster <- function(epi_data,
outcome,
Expand Down Expand Up @@ -99,7 +99,7 @@ arx_forecaster <- function(epi_data,
#' arx_fcast_epi_workflow(jhu, "death_rate",
#' c("case_rate", "death_rate"),
#' trainer = quantile_reg(),
#' args_list = arx_args_list(levels = 1:9 / 10)
#' args_list = arx_args_list(quantile_levels = 1:9 / 10)
#' )
arx_fcast_epi_workflow <- function(
epi_data,
Expand Down Expand Up @@ -134,18 +134,20 @@ arx_fcast_epi_workflow <- function(
# --- postprocessor
f <- frosting() %>% layer_predict() # %>% layer_naomit()
if (inherits(trainer, "quantile_reg")) {
# add all levels to the forecaster and update postprocessor
tau <- sort(compare_quantile_args(
args_list$levels,
rlang::eval_tidy(trainer$args$tau)
# add all quantile_level to the forecaster and update postprocessor
quantile_levels <- sort(compare_quantile_args(
args_list$quantile_levels,
rlang::eval_tidy(trainer$args$quantile_levels)
))
args_list$levels <- tau
trainer$args$tau <- rlang::enquo(tau)
f <- layer_quantile_distn(f, levels = tau) %>% layer_point_from_distn()
args_list$quantile_levels <- quantile_levels
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
f <- layer_quantile_distn(f, quantile_levels = quantile_levels) %>%
layer_point_from_distn()
} else {
f <- layer_residual_quantiles(
f,
probs = args_list$levels, symmetrize = args_list$symmetrize,
quantile_levels = args_list$quantile_levels,
symmetrize = args_list$symmetrize,
by_key = args_list$quantile_by_key
)
}
Expand Down Expand Up @@ -173,7 +175,7 @@ arx_fcast_epi_workflow <- function(
#' The default `NULL` will attempt to determine this automatically.
#' @param target_date Date. The date for which the forecast is intended.
#' The default `NULL` will attempt to determine this automatically.
#' @param levels Vector or `NULL`. A vector of probabilities to produce
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
#' prediction intervals. These are created by computing the quantiles of
#' training residuals. A `NULL` value will result in point forecasts only.
#' @param symmetrize Logical. The default `TRUE` calculates
Expand All @@ -197,6 +199,7 @@ arx_fcast_epi_workflow <- function(
#' create a prediction. For this reason, setting `nafill_buffer < min(lags)`
#' will be treated as _additional_ allowed recent data rather than the
#' total amount of recent data to examine.
#' @param ... Space to handle future expansions (unused).
#'
#'
#' @return A list containing updated parameter choices with class `arx_flist`.
Expand All @@ -205,18 +208,19 @@ arx_fcast_epi_workflow <- function(
#' @examples
#' arx_args_list()
#' arx_args_list(symmetrize = FALSE)
#' arx_args_list(levels = c(.1, .3, .7, .9), n_training = 120)
#' arx_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
arx_args_list <- function(
lags = c(0L, 7L, 14L),
ahead = 7L,
n_training = Inf,
forecast_date = NULL,
target_date = NULL,
levels = c(0.05, 0.95),
quantile_levels = c(0.05, 0.95),
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = character(0L),
nafill_buffer = Inf) {
nafill_buffer = Inf,
...) {
# error checking if lags is a list
.lags <- lags
if (is.list(lags)) lags <- unlist(lags)
Expand All @@ -227,7 +231,7 @@ arx_args_list <- function(
arg_is_date(forecast_date, target_date, allow_null = TRUE)
arg_is_nonneg_int(ahead, lags)
arg_is_lgl(symmetrize, nonneg)
arg_is_probabilities(levels, allow_null = TRUE)
arg_is_probabilities(quantile_levels, allow_null = TRUE)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
Expand All @@ -238,7 +242,7 @@ arx_args_list <- function(
lags = .lags,
ahead,
n_training,
levels,
quantile_levels,
forecast_date,
target_date,
symmetrize,
Expand All @@ -259,8 +263,8 @@ print.arx_fcast <- function(x, ...) {
}

compare_quantile_args <- function(alist, tlist) {
default_alist <- eval(formals(arx_args_list)$levels)
default_tlist <- eval(formals(quantile_reg)$tau)
default_alist <- eval(formals(arx_args_list)$quantile_level)
default_tlist <- eval(formals(quantile_reg)$quantile_level)
if (setequal(alist, default_alist)) {
if (setequal(tlist, default_tlist)) {
return(sort(unique(union(alist, tlist))))
Expand Down
8 changes: 4 additions & 4 deletions R/canned-epipred.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ validate_forecaster_inputs <- function(epi_data, outcome, predictors) {
arg_is_chr(predictors)
arg_is_chr_scalar(outcome)
if (!outcome %in% names(epi_data)) {
cli::cli_abort("{outcome} was not found in the training data.")
cli::cli_abort("{.var {outcome}} was not found in the training data.")
}
check <- hardhat::check_column_names(epi_data, predictors)
if (!check$ok) {
cli::cli_abort(c(
"At least one predictor was not found in the training data.",
"!" = "The following required columns are missing: {check$missing_names}."
"!" = "The following required columns are missing: {.val {check$missing_names}}."
))
}
invisible(TRUE)
Expand All @@ -41,8 +41,8 @@ arx_lags_validator <- function(predictors, lags) {
predictors_miss <- setdiff(predictors, names(lags))
cli::cli_abort(c(
"If lags is a named list, then all predictors must be present.",
i = "The predictors are '{predictors}'.",
i = "So lags is missing '{predictors_miss}'."
i = "The predictors are {.var {predictors}}.",
i = "So lags is missing {.var {predictors_miss}}'."
))
}
}
Expand Down
Loading

0 comments on commit 0f4f2f9

Please sign in to comment.