Skip to content

Commit

Permalink
final requests
Browse files Browse the repository at this point in the history
  • Loading branch information
dsweber2 committed Oct 1, 2024
1 parent 86c46a4 commit 90edb46
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
9 changes: 7 additions & 2 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,14 @@ arx_class_epi_workflow <- function(
}
}
# regex that will match any amount of adjustment for the ahead
ahead_out_name <- glue::glue("ahead_[0-9]*_{pre_out_name}")
ahead_out_name_regex <- glue::glue("ahead_[0-9]*_{pre_out_name}")
method_adjust_latency <- args_list$adjust_latency
if (method_adjust_latency != "none") {
if (method_adjust_latency != "extend_ahead") {
cli_abort("only extend_ahead is currently supported",
class = "epipredict__arx_classifier__adjust_latency_unsupported_method"
)
}
r <- r %>% step_adjust_latency(!!pre_out_name,
fixed_forecast_date = forecast_date,
method = method_adjust_latency
Expand All @@ -204,7 +209,7 @@ arx_class_epi_workflow <- function(
r <- r %>%
step_mutate(
across(
matches(ahead_out_name),
matches(ahead_out_name_regex),
~ cut(.x, breaks = args_list$breaks),
.names = "outcome_class",
.unpack = TRUE
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/test-snapshots.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,24 @@ test_that("arx_classifier snapshots", {
args_list = arx_class_args_list(adjust_latency = "extend_ahead", forecast_date = max_date + 2)
)
expect_snapshot_tibble(arc2$predictions)
expect_error(
arc3 <- arx_classifier(
case_death_rate_subset %>%
dplyr::filter(time_value >= as.Date("2021-11-01")),
"death_rate",
c("case_rate", "death_rate"),
args_list = arx_class_args_list(adjust_latency = "extend_lags", forecast_date = max_date + 2)
),
class = "epipredict__arx_classifier__adjust_latency_unsupported_method"
)
expect_error(
arc4 <- arx_classifier(
case_death_rate_subset %>%
dplyr::filter(time_value >= as.Date("2021-11-01")),
"death_rate",
c("case_rate", "death_rate"),
args_list = arx_class_args_list(adjust_latency = "locf", forecast_date = max_date + 2)
),
class = "epipredict__arx_classifier__adjust_latency_unsupported_method"
)
})

0 comments on commit 90edb46

Please sign in to comment.