Skip to content

Commit

Permalink
all tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dajmcdon committed Oct 26, 2024
1 parent d36760b commit b929e47
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 7 deletions.
7 changes: 4 additions & 3 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ arx_fcast_epi_workflow <- function(
quantile_levels <- sort(compare_quantile_args(
args_list$quantile_levels,
rlang::eval_tidy(trainer$args$quantile_levels),
train_type
"qr"
))
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
} else {
Expand Down Expand Up @@ -357,10 +357,11 @@ print.arx_fcast <- function(x, ...) {
NextMethod(name = name, ...)
}

compare_quantile_args <- function(alist, tlist, trainer = "qr") {
compare_quantile_args <- function(alist, tlist, train_method = c("qr", "grf")) {
train_method <- rlang::arg_match(train_method)
default_alist <- eval(formals(arx_args_list)$quantile_levels)
default_tlist <- switch(
trainer,
train_method,
"qr" = eval(formals(quantile_reg)$quantile_levels),
"grf" = c(.1, .5, .9)
)
Expand Down
5 changes: 3 additions & 2 deletions R/canned-epipred.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ print.canned_epipred <- function(x, name, ...) {
fn_meta <- function() {
cli::cli_ul()
cli::cli_li("Geography: {.field {x$metadata$training$geo_type}},")
if (!is.null(x$metadata$training$other_keys)) {
cli::cli_li("Other keys: {.field {x$metadata$training$other_keys}},")
other_keys <- x$metadata$training$other_keys
if (!is.null(other_keys) && length(other_keys) > 0L) {
cli::cli_li("Other keys: {.field {other_keys}},")
}
cli::cli_li("Time type: {.field {x$metadata$training$time_type}},")
cli::cli_li("Using data up-to-date as of: {.field {format(x$metadata$training$as_of)}}.")
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/_snaps/arx_args_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@

# arx forecaster disambiguates quantiles

Code
compare_quantile_args(alist / 10, 1:9 / 10, "grf")
Condition
Error in `compare_quantile_args()`:
! You have specified different, non-default, quantiles in the trainier and `arx_args` options.
i Please only specify quantiles in one location.

---

Code
compare_quantile_args(alist, tlist)
Condition
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-arx_args_list.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ test_that("arx forecaster disambiguates quantiles", {
)
expect_snapshot(
error = TRUE,
compare_quantile_args(1:9 / 10, 1:9 / 10, "grf")
compare_quantile_args(alist / 10, 1:9 / 10, "grf")
)
expect_identical(compare_quantile_args(alist, 1:9 / 10, "grf"), 1:9 / 10)
alist <- c(.5, alist)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-snapshots.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ test_that("arx_forecaster snapshots", {
})

test_that("arx_forecaster output format snapshots", {
jhu <- case_death_rate_subset %>%
jhu <- epidatasets::covid_case_death_rates %>%
dplyr::filter(time_value >= as.Date("2021-12-01"))
attributes(jhu)$metadata$as_of <- as.Date(attributes(jhu)$metadata$as_of)
out1 <- arx_forecaster(
Expand Down

0 comments on commit b929e47

Please sign in to comment.