diff --git a/.github/workflows/pkgdown.yaml b/.github/workflows/pkgdown.yaml index cc940bc8b..d86e1485e 100644 --- a/.github/workflows/pkgdown.yaml +++ b/.github/workflows/pkgdown.yaml @@ -1,7 +1,10 @@ # Workflow derived from https://github.com/r-lib/actions/tree/v2/examples # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help # -# Created with usethis + edited to run on PRs to dev, use API key. +# Modifications: +# * workflow_dispatch added to allow manual triggering of the workflow +# * trigger branches changed +# * API key secrets.SECRET_EPIPREDICT_GHACTIONS_DELPHI_EPIDATA_KEY on: push: branches: [main, dev] @@ -21,8 +24,9 @@ jobs: group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }} env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} + DELPHI_EPIDATA_KEY: ${{ secrets.SECRET_EPIPREDICT_GHACTIONS_DELPHI_EPIDATA_KEY }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: r-lib/actions/setup-pandoc@v2 @@ -32,19 +36,31 @@ jobs: - uses: r-lib/actions/setup-r-dependencies@v2 with: - extra-packages: any::pkgdown, local::. + extra-packages: any::pkgdown, local::., any::cli needs: website - name: Build site - env: - DELPHI_EPIDATA_KEY: ${{ secrets.SECRET_EPIPROCESS_GHACTIONS_DELPHI_EPIDATA_KEY }} + # - target_ref gets the ref from a different variable, depending on the event + # - override allows us to set the pkgdown mode and version_label + # - mode: release is the standard build mode, devel places the site in /dev + # - version_label: 'light' and 'success' are CSS labels for Bootswatch: Cosmo + # https://bootswatch.com/cosmo/ + # - we use pkgdown:::build_github_pages to build the site because of an issue in pkgdown + # https://github.com/r-lib/pkgdown/issues/2257 run: | - if (startsWith("${{ github.event_name }}", "pull_request")) { - mode <- ifelse("${{ github.base_ref }}" == "main", "release", "devel") + target_ref <- "${{ github.event_name == 'pull_request' && github.base_ref || github.ref }}" + override <- if (target_ref == "main" || target_ref == "refs/heads/main") { + list(development = list(mode = "release", version_label = "light")) + } else if (target_ref == "dev" || target_ref == "refs/heads/dev") { + list(development = list(mode = "devel", version_label = "success")) } else { - mode <- ifelse("${{ github.ref_name }}" == "main", "release", "devel") + stop("Unexpected target_ref: ", target_ref) } - pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE, override=list(PKGDOWN_DEV_MODE=mode)) + pkg <- pkgdown::as_pkgdown(".", override = override) + cli::cli_rule("Cleaning files from old site...") + pkgdown::clean_site(pkg) + pkgdown::build_site(pkg, preview = FALSE, install = FALSE, new_process = FALSE) + pkgdown:::build_github_pages(pkg) shell: Rscript {0} - name: Deploy to GitHub pages 🚀 diff --git a/NEWS.md b/NEWS.md index a79fa3a74..db2b43368 100644 --- a/NEWS.md +++ b/NEWS.md @@ -14,9 +14,10 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat ## Improvements - Add `step_adjust_latency`, which give several methods to adjust the forecast if the `forecast_date` is after the last day of data. +- (temporary) ahead negative is allowed for `step_epi_ahead` until we have `step_epi_shift` ## Bug fixes - +- Shifting no columns results in no error for either `step_epi_ahead` and `step_epi_lag` # epipredict 0.1 diff --git a/R/epi_shift.R b/R/epi_shift.R index 367e26285..877f7866c 100644 --- a/R/epi_shift.R +++ b/R/epi_shift.R @@ -42,6 +42,10 @@ get_sign <- function(object) { add_shifted_columns <- function(new_data, object) { grid <- object$shift_grid + if (nrow(object$shift_grid) == 0) { + # we're not shifting any rows, so this is a no-op + return(new_data) + } ## ensure no name clashes new_data_names <- colnames(new_data) intersection <- new_data_names %in% grid$newname diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 5c64dc38f..81b443e7b 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -167,6 +167,7 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), . components$forged <- hardhat::forge(new_data, blueprint = components$mold$blueprint ) + components$keys <- grab_forged_keys(components$forged, object, new_data) components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...) components$predictions diff --git a/R/step_adjust_latency.R b/R/step_adjust_latency.R index 75560e6de..3d9f19891 100644 --- a/R/step_adjust_latency.R +++ b/R/step_adjust_latency.R @@ -152,6 +152,12 @@ #' #' Note that this is a separate concern from different latencies across #' different *data columns*, which is only handled by the choice of `method`. +#' @param keys_to_ignore a list of character vectors. Set this to avoid using +#' specific key values in the `epi_keys_checked` to set latency. For example, +#' say you have two locations `pr` and `gu` which have useful training data, +#' but have stopped providing up-to-date information, and so are no longer +#' part of the test set. Setting `keys_to_ignore = list(geo_value = c("pr", +#' "gu"))` will exclude them from the latency calculation. #' @param fixed_latency either a positive integer, or a labeled positive integer #' vector. Cannot be set at the same time as `fixed_forecast_date`. If #' non-`NULL`, the amount to offset the ahead or lag by. If a single integer, @@ -203,6 +209,7 @@ step_adjust_latency <- "extend_lags" ), epi_keys_checked = NULL, + keys_to_ignore = c(), fixed_latency = NULL, fixed_forecast_date = NULL, check_latency_length = TRUE, @@ -228,6 +235,7 @@ step_adjust_latency <- metadata = NULL, method = method, epi_keys_checked = epi_keys_checked, + keys_to_ignore = keys_to_ignore, check_latency_length = check_latency_length, columns = NULL, skip = FALSE, @@ -239,7 +247,7 @@ step_adjust_latency <- step_adjust_latency_new <- function(terms, role, trained, fixed_forecast_date, forecast_date, latency, latency_table, latency_sign, metadata, method, epi_keys_checked, - check_latency_length, columns, skip, id) { + keys_to_ignore, check_latency_length, columns, skip, id) { step( subclass = "adjust_latency", terms = terms, @@ -253,6 +261,7 @@ step_adjust_latency_new <- metadata = metadata, method = method, epi_keys_checked = epi_keys_checked, + keys_to_ignore = keys_to_ignore, check_latency_length = check_latency_length, columns = columns, skip = skip, @@ -271,7 +280,7 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) { latency_table <- get_latency_table( training, NULL, forecast_date, latency, - get_sign(x), x$epi_keys_checked, info, x$terms + get_sign(x), x$epi_keys_checked, x$keys_to_ignore, info, x$terms ) # get the columns used, even if it's all of them terms_used <- x$terms @@ -293,6 +302,7 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) { metadata = attributes(training)$metadata, method = x$method, epi_keys_checked = x$epi_keys_checked, + keys_to_ignore = x$keys_to_ignore, check_latency_length = x$check_latency_length, columns = recipes_eval_select(latency_table$col_name, training, info), skip = x$skip, @@ -305,10 +315,8 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) { #' @export bake.step_adjust_latency <- function(object, new_data, ...) { if (!inherits(new_data, "epi_df") || is.null(attributes(new_data)$metadata$as_of)) { - new_data <- as_epi_df(new_data) + new_data <- as_epi_df(new_data, as_of = object$forecast_date, other_keys = object$metadata$other_keys %||% character()) attributes(new_data)$metadata <- object$metadata - attributes(new_data)$metadata$as_of <- object$forecast_date - } else { compare_bake_prep_latencies(object, new_data) } if (object$method == "locf") { diff --git a/R/step_epi_shift.R b/R/step_epi_shift.R index 88795ab56..beda182e6 100644 --- a/R/step_epi_shift.R +++ b/R/step_epi_shift.R @@ -111,7 +111,6 @@ step_epi_ahead <- i = "Did you perhaps pass an integer in `...` accidentally?" )) } - arg_is_nonneg_int(ahead) arg_is_chr_scalar(prefix, id) recipes::add_step( diff --git a/R/utils-latency.R b/R/utils-latency.R index 311656ac1..8bcd2b1e4 100644 --- a/R/utils-latency.R +++ b/R/utils-latency.R @@ -50,27 +50,18 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns ) } } + max_time <- get_max_time(new_data, epi_keys_checked, columns) # the source data determines the actual time_values - # these are the non-na time_values; - # get the minimum value across the checked epi_keys' maximum time values - max_time <- new_data %>% - select(all_of(columns)) %>% - drop_na() - # null and "" don't work in `group_by` - if (!is.null(epi_keys_checked) && (epi_keys_checked != "")) { - max_time <- max_time %>% group_by(get(epi_keys_checked)) - } - max_time <- max_time %>% - summarise(time_value = max(time_value)) %>% - pull(time_value) %>% - min() if (is.null(latency)) { forecast_date <- attributes(new_data)$metadata$as_of } else { + if (is.null(max_time)) { + cli_abort("max_time is null. This likely means there is one of {columns} that is all `NA`") + } forecast_date <- max_time + latency } # make sure the as_of is sane - if (!inherits(forecast_date, class(max_time)) & !inherits(forecast_date, "POSIXt")) { + if (!inherits(forecast_date, class(new_data$time_value)) & !inherits(forecast_date, "POSIXt")) { cli_abort( paste( "the data matrix `forecast_date` value is {forecast_date}, ", @@ -84,13 +75,13 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns if (is.null(forecast_date) || is.na(forecast_date)) { cli_warn( paste( - "epi_data's `forecast_date` was {forecast_date}, setting to ", - "the latest time value, {max_time}." + "epi_data's `forecast_date` was `NA`, setting to ", + "the latest non-`NA` time value for these columns, {max_time}." ), class = "epipredict__get_forecast_date__max_time_warning" ) forecast_date <- max_time - } else if (forecast_date < max_time) { + } else if (!is.null(max_time) && (forecast_date < max_time)) { cli_abort( paste( "`forecast_date` ({(forecast_date)}) is before the most ", @@ -101,12 +92,34 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns ) } # TODO cover the rest of the possible types for as_of and max_time... - if (inherits(max_time, "Date")) { + if (inherits(new_data$time_value, "Date")) { forecast_date <- as.Date(forecast_date) } return(forecast_date) } +get_max_time <- function(new_data, epi_keys_checked, columns) { + # these are the non-na time_values; + # get the minimum value across the checked epi_keys' maximum time values + max_time <- new_data %>% + select(all_of(columns)) %>% + drop_na() + if (nrow(max_time) == 0) { + return(NULL) + } + # null and "" don't work in `group_by` + if (!is.null(epi_keys_checked) && all(epi_keys_checked != "")) { + max_time <- max_time %>% group_by(across(all_of(epi_keys_checked))) + } + max_time <- max_time %>% + summarise(time_value = max(time_value)) %>% + pull(time_value) %>% + min() + return(max_time) +} + + + #' the latency is also the amount the shift is off by #' @param sign_shift integer. 1 if lag and -1 if ahead. These represent how you #' need to shift the data to bring the 3 day lagged value to today. @@ -114,9 +127,14 @@ get_forecast_date <- function(new_data, info, epi_keys_checked, latency, columns get_latency <- function(new_data, forecast_date, column, sign_shift, epi_keys_checked) { shift_max_date <- new_data %>% drop_na(all_of(column)) + if (nrow(shift_max_date) == 0) { + # if everything is an NA, there's infinite latency, but shifting by that is + # untenable. May as well not shift at all + return(0) + } # null and "" don't work in `group_by` - if (!is.null(epi_keys_checked) && epi_keys_checked != "") { - shift_max_date <- shift_max_date %>% group_by(get(epi_keys_checked)) + if (!is.null(epi_keys_checked) && all(epi_keys_checked != "")) { + shift_max_date <- shift_max_date %>% group_by(across(all_of(epi_keys_checked))) } shift_max_date <- shift_max_date %>% summarise(time_value = max(time_value)) %>% @@ -290,7 +308,8 @@ check_interminable_latency <- function(dataset, latency_table, target_columns, f #' @keywords internal #' @importFrom dplyr rowwise get_latency_table <- function(training, columns, forecast_date, latency, - sign_shift, epi_keys_checked, info, terms) { + sign_shift, epi_keys_checked, keys_to_ignore, + info, terms) { if (is.null(columns)) { columns <- recipes_eval_select(terms, training, info) } @@ -300,12 +319,17 @@ get_latency_table <- function(training, columns, forecast_date, latency, if (length(columns) > 0) { latency_table <- latency_table %>% filter(col_name %in% columns) } - + training_dropped <- training %>% + drop_ignored_keys(keys_to_ignore) if (is.null(latency)) { latency_table <- latency_table %>% rowwise() %>% mutate(latency = get_latency( - training, forecast_date, col_name, sign_shift, epi_keys_checked + training_dropped, + forecast_date, + col_name, + sign_shift, + epi_keys_checked )) } else if (length(latency) > 1) { # if latency has a length, it must also have named elements. @@ -319,7 +343,7 @@ get_latency_table <- function(training, columns, forecast_date, latency, latency_table <- latency_table %>% rowwise() %>% mutate(latency = get_latency( - training, forecast_date, col_name, sign_shift, epi_keys_checked + training %>% drop_ignored_keys(keys_to_ignore), forecast_date, col_name, sign_shift, epi_keys_checked )) if (latency) { latency_table <- latency_table %>% mutate(latency = latency) @@ -328,6 +352,19 @@ get_latency_table <- function(training, columns, forecast_date, latency, return(latency_table %>% ungroup()) } +#' given a list named by key columns, remove any matching key values +#' keys_to_ignore should have the form list(col_name = c("value_to_ignore", "other_value_to_ignore")) +#' @keywords internal +drop_ignored_keys <- function(training, keys_to_ignore) { + # note that the extra parenthesis black magic is described here: https://github.com/tidyverse/dplyr/issues/6194 + # and is needed to bypass an incomplete port of `across` functions to `if_any` + training %>% + filter((dplyr::if_all( + names(keys_to_ignore), + ~ . %nin% keys_to_ignore[[cur_column()]] + ))) +} + #' checks: the recipe type, whether a previous step is the relevant epi_shift, #' that either `fixed_latency` or `fixed_forecast_date` is non-null, and that @@ -394,7 +431,7 @@ compare_bake_prep_latencies <- function(object, new_data, call = caller_env()) { ) local_latency_table <- get_latency_table( new_data, object$columns, current_forecast_date, latency, - get_sign(object), object$epi_keys_checked, NULL, NULL + get_sign(object), object$epi_keys_checked, object$keys_to_ignore, NULL, NULL ) comparison_table <- local_latency_table %>% ungroup() %>% diff --git a/man/drop_ignored_keys.Rd b/man/drop_ignored_keys.Rd new file mode 100644 index 000000000..6adeb9983 --- /dev/null +++ b/man/drop_ignored_keys.Rd @@ -0,0 +1,14 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils-latency.R +\name{drop_ignored_keys} +\alias{drop_ignored_keys} +\title{given a list named by key columns, remove any matching key values +keys_to_ignore should have the form list(col_name = c("value_to_ignore", "other_value_to_ignore"))} +\usage{ +drop_ignored_keys(training, keys_to_ignore) +} +\description{ +given a list named by key columns, remove any matching key values +keys_to_ignore should have the form list(col_name = c("value_to_ignore", "other_value_to_ignore")) +} +\keyword{internal} diff --git a/man/get_latency_table.Rd b/man/get_latency_table.Rd index ae309c944..853918b23 100644 --- a/man/get_latency_table.Rd +++ b/man/get_latency_table.Rd @@ -12,6 +12,7 @@ get_latency_table( latency, sign_shift, epi_keys_checked, + keys_to_ignore, info, terms ) diff --git a/man/step_adjust_latency.Rd b/man/step_adjust_latency.Rd index 59e09f4ff..1a6770428 100644 --- a/man/step_adjust_latency.Rd +++ b/man/step_adjust_latency.Rd @@ -9,6 +9,7 @@ step_adjust_latency( ..., method = c("extend_ahead", "locf", "extend_lags"), epi_keys_checked = NULL, + keys_to_ignore = c(), fixed_latency = NULL, fixed_forecast_date = NULL, check_latency_length = TRUE, @@ -50,6 +51,12 @@ it will take the maximum across all values, irrespective of any keys. Note that this is a separate concern from different latencies across different \emph{data columns}, which is only handled by the choice of \code{method}.} +\item{keys_to_ignore}{a list of character vectors. Set this to avoid using +specific key values in the \code{epi_keys_checked} to set latency. For example, +say you have two locations \code{pr} and \code{gu} which have useful training data, +but have stopped providing up-to-date information, and so are no longer +part of the test set. Setting \code{keys_to_ignore = list(geo_value = c("pr", "gu"))} will exclude them from the latency calculation.} + \item{fixed_latency}{either a positive integer, or a labeled positive integer vector. Cannot be set at the same time as \code{fixed_forecast_date}. If non-\code{NULL}, the amount to offset the ahead or lag by. If a single integer, diff --git a/tests/testthat/_snaps/population_scaling.md b/tests/testthat/_snaps/population_scaling.md new file mode 100644 index 000000000..9263e8e1e --- /dev/null +++ b/tests/testthat/_snaps/population_scaling.md @@ -0,0 +1,16 @@ +# expect error if `by` selector does not match + + Code + wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) %>% add_frosting(f) + Condition + Error in `hardhat::validate_column_names()`: + ! The following required columns are missing: 'a'. + +--- + + Code + forecast(wf) + Condition + Error in `hardhat::validate_column_names()`: + ! The following required columns are missing: 'nothere'. + diff --git a/tests/testthat/_snaps/step_epi_shift.md b/tests/testthat/_snaps/step_epi_shift.md index eaf495995..4c720792c 100644 --- a/tests/testthat/_snaps/step_epi_shift.md +++ b/tests/testthat/_snaps/step_epi_shift.md @@ -4,8 +4,8 @@ r1 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = 3.6) %>% step_epi_lag(death_rate, lag = 1.9) Condition - Error in `step_epi_ahead()`: - ! `ahead` must be a non-negative integer. + Error in `step_epi_lag()`: + ! `lag` must be a non-negative integer. # A negative lag value should should throw an error @@ -21,9 +21,6 @@ Code r3 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = -7) %>% step_epi_lag( death_rate, lag = 7) - Condition - Error in `step_epi_ahead()`: - ! `ahead` must be a non-negative integer. # Values for ahead and lag cannot be duplicates diff --git a/tests/testthat/_snaps/step_epi_slide.md b/tests/testthat/_snaps/step_epi_slide.md index a4b9d64c8..7493a7fea 100644 --- a/tests/testthat/_snaps/step_epi_slide.md +++ b/tests/testthat/_snaps/step_epi_slide.md @@ -12,7 +12,7 @@ r %>% step_epi_slide(value, .f = mean, .window_size = c(3L, 6L)) Condition Error in `epiprocess:::validate_slide_window_arg()`: - ! Slide function expected `.window_size` to be a non-null, scalar integer >= 1. + ! Slide function expected `.window_size` to be a length-1 difftime with units in days or non-negative integer or Inf. --- @@ -60,7 +60,7 @@ r %>% step_epi_slide(value, .f = mean, .window_size = 1.5) Condition Error in `epiprocess:::validate_slide_window_arg()`: - ! Slide function expected `.window_size` to be a difftime with units in days or non-negative integer or Inf. + ! Slide function expected `.window_size` to be a length-1 difftime with units in days or non-negative integer or Inf. --- diff --git a/tests/testthat/test-dist_quantiles.R b/tests/testthat/test-dist_quantiles.R index 9975213c6..ef65c5c11 100644 --- a/tests/testthat/test-dist_quantiles.R +++ b/tests/testthat/test-dist_quantiles.R @@ -1,4 +1,4 @@ -library(distributional) +suppressPackageStartupMessages(library(distributional)) test_that("constructor returns reasonable quantiles", { expect_snapshot(error = TRUE, new_quantiles(rnorm(5), c(-2, -1, 0, 1, 2))) diff --git a/tests/testthat/test-epi_workflow.R b/tests/testthat/test-epi_workflow.R index 598436aab..cce68a80f 100644 --- a/tests/testthat/test-epi_workflow.R +++ b/tests/testthat/test-epi_workflow.R @@ -103,7 +103,7 @@ test_that("forecast method errors when workflow not fit", { test_that("fit method does not silently drop the class", { # This is issue #363 - library(recipes) + suppressPackageStartupMessages(library(recipes)) tbl <- tibble::tibble( geo_value = 1, time_value = 1:100, diff --git a/tests/testthat/test-grf_quantiles.R b/tests/testthat/test-grf_quantiles.R index 2570c247d..57a9d0f98 100644 --- a/tests/testthat/test-grf_quantiles.R +++ b/tests/testthat/test-grf_quantiles.R @@ -1,5 +1,5 @@ set.seed(12345) -library(grf) +suppressPackageStartupMessages(library(grf)) tib <- tibble( y = rnorm(100), x = rnorm(100), z = rnorm(100), f = factor(sample(letters[1:3], 100, replace = TRUE)) @@ -10,7 +10,7 @@ test_that("quantile_rand_forest defaults work", { expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib)) pars <- parsnip::extract_fit_engine(out) manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, quantiles = c(0.1, 0.5, 0.9)) - expect_identical(pars$quantiles.orig, manual$quantiles) + expect_identical(pars$quantiles.orig, manual$quantiles.orig) expect_identical(pars$`_num_trees`, manual$`_num_trees`) fseed <- 12345 diff --git a/tests/testthat/test-snapshots.R b/tests/testthat/test-snapshots.R index 512303cd1..de28dc9f2 100644 --- a/tests/testthat/test-snapshots.R +++ b/tests/testthat/test-snapshots.R @@ -117,7 +117,7 @@ test_that("arx_forecaster output format snapshots", { jhu, "death_rate", c("case_rate", "death_rate") ) - expect_equal(as.Date(out1$metadata$forecast_created), Sys.Date()) + expect_equal(as.Date(format(out1$metadata$forecast_created, "%Y-%m-%d")), Sys.Date()) out1$metadata$forecast_created <- as.Date("0999-01-01") expect_snapshot(out1) out2 <- arx_forecaster(jhu, "case_rate", @@ -129,7 +129,7 @@ test_that("arx_forecaster output format snapshots", { forecast_date = as.Date("2022-01-03") ) ) - expect_equal(as.Date(out2$metadata$forecast_created), Sys.Date()) + expect_equal(as.Date(format(out2$metadata$forecast_created, "%Y-%m-%d")), Sys.Date()) out2$metadata$forecast_created <- as.Date("0999-01-01") expect_snapshot(out2) out3 <- arx_forecaster(jhu, "death_rate", @@ -140,7 +140,7 @@ test_that("arx_forecaster output format snapshots", { forecast_date = as.Date("2022-01-03") ) ) - expect_equal(as.Date(out3$metadata$forecast_created), Sys.Date()) + expect_equal(as.Date(format(out3$metadata$forecast_created, "%Y-%m-%d")), Sys.Date()) out3$metadata$forecast_created <- as.Date("0999-01-01") expect_snapshot(out3) }) diff --git a/tests/testthat/test-step_adjust_latency.R b/tests/testthat/test-step_adjust_latency.R index 136c45e64..7b1f320e4 100644 --- a/tests/testthat/test-step_adjust_latency.R +++ b/tests/testthat/test-step_adjust_latency.R @@ -398,10 +398,6 @@ test_that("epi_adjust_latency correctly extends the lags when there are differen names(fit5$pre$mold$outcomes), glue::glue("ahead_{ahead}_death_rate") ) - latest <- get_test_data(r5, x) - pred <- predict(fit5, latest) - actual_solutions <- pred %>% filter(!is.na(.pred)) - expect_equal(actual_solutions$time_value, testing_as_of + 1) # should have four predictors, including the intercept expect_equal(length(fit5$fit$fit$fit$coefficients), 6) diff --git a/tests/testthat/test-step_epi_shift.R b/tests/testthat/test-step_epi_shift.R index 1f83120b3..2a313b103 100644 --- a/tests/testthat/test-step_epi_shift.R +++ b/tests/testthat/test-step_epi_shift.R @@ -39,7 +39,7 @@ test_that("A negative lag value should should throw an error", { test_that("A nonpositive ahead value should throw an error", { expect_snapshot( - error = TRUE, + error = FALSE, r3 <- epi_recipe(x) %>% step_epi_ahead(death_rate, ahead = -7) %>% step_epi_lag(death_rate, lag = 7) @@ -66,3 +66,8 @@ test_that("Check that epi_lag shifts applies the shift", { # Should have four predictors, including the intercept expect_equal(length(fit5$fit$fit$fit$coefficients), 4) }) + +test_that("Shifting nothing is a no-op", { + expect_no_error(noop <- epi_recipe(x) %>% step_epi_ahead(ahead = 3) %>% prep(x) %>% bake(x)) + expect_equal(noop, x) +}) diff --git a/tests/testthat/test-utils_latency.R b/tests/testthat/test-utils_latency.R index d0fd429a9..2ac32fc9f 100644 --- a/tests/testthat/test-utils_latency.R +++ b/tests/testthat/test-utils_latency.R @@ -62,6 +62,20 @@ toy_df <- tribble( "ca", as.Date("2015-01-12"), 103, 10, ) %>% as_epi_df(as_of = as.Date("2015-01-14")) +toy_df_src <- tribble( + ~geo_value, ~source, ~time_value, ~a, ~b, + "ma", "new", as.Date("2015-01-11"), 20, 6, + "ma", "new", as.Date("2015-01-12"), 23, NA, + "ma", "new", as.Date("2015-01-13"), 25, NA, + "ca", "new", as.Date("2015-01-11"), 100, 5, + "ca", "new", as.Date("2015-01-12"), 103, 10, + "ma", "old", as.Date("2013-01-01"), 19, 4, + "ma", "old", as.Date("2013-01-02"), 20, 2, + "ca", "old", as.Date("2013-01-03"), 28, 11, + "na", "new", as.Date("2013-01-05"), 28, 11, + "ma", "older", as.Date("2010-01-05"), 28, 11, +) %>% + as_epi_df(as_of = as.Date("2015-01-14"), other_keys = "source") test_that("get_latency works", { expect_equal(get_latency(modified_data, as_of, "case_rate", 1, "geo_value"), 5) @@ -76,6 +90,27 @@ test_that("get_latency works", { expect_equal(get_latency(toy_df, as.Date("2015-01-14"), "b", -1, "geo_value"), -3) }) +test_that("get_latency ignores keys it's supposed to", { + keys_to_ignore <- list(geo_value = c("na"), source = c("old", "older")) + expected_df <- tribble( + ~geo_value, ~source, ~time_value, ~a, ~b, + "ma", "new", as.Date("2015-01-11"), 20, 6, + "ma", "new", as.Date("2015-01-12"), 23, NA, + "ma", "new", as.Date("2015-01-13"), 25, NA, + "ca", "new", as.Date("2015-01-11"), 100, 5, + "ca", "new", as.Date("2015-01-12"), 103, 10, + ) + expect_equal( + toy_df_src %>% drop_ignored_keys(keys_to_ignore) %>% as_tibble(), + expected_df + ) + + expect_equal( + get_latency_table(toy_df_src, c("a", "b"), as.Date("2015-01-14"), NULL, -1, c("geo_value", "source"), keys_to_ignore), + tibble(col_name = c("a", "b"), latency = c(-2, -3)) + ) +}) + test_that("get_latency infers max_time to be the minimum `max time` across grouping the specified keys", { # place 2 is already 1 day less latent than place 1, so decreasing it's # latency it should have no effect @@ -100,6 +135,17 @@ test_that("get_forecast_date works", { expect_equal(get_forecast_date(modified_data, info, "", NULL), as_of) expect_equal(get_forecast_date(modified_data, info, NULL, NULL), as_of) }) +test_that("get_forecast_date works for multiple key columns", { + info <- tribble( + ~variable, ~type, ~role, ~source, + "time_value", "date", "time_value", "original", + "geo_value", "nominal", "geo_value", "original", + "source", "nominal", "other_key", "original", + "a", "numeric", "raw", "original", + "b", "numeric", "raw", "original", + ) + expect_equal(get_forecast_date(toy_df_src, info, c("geo_value", "source"), NULL), attributes(toy_df_src)$metadata$as_of) +}) test_that("pad_to_end works correctly", { single_ex <- tribble(