diff --git a/NAMESPACE b/NAMESPACE index 4840c2568..0900d3594 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -302,6 +302,7 @@ importFrom(stats,family) importFrom(stats,lm) importFrom(stats,median) importFrom(stats,model.frame) +importFrom(stats,na.omit) importFrom(stats,poly) importFrom(stats,predict) importFrom(stats,qnorm) @@ -315,6 +316,8 @@ importFrom(tidyr,expand_grid) importFrom(tidyr,fill) importFrom(tidyr,unnest) importFrom(tidyselect,all_of) +importFrom(utils,capture.output) +importFrom(utils,head) importFrom(vctrs,as_list_of) importFrom(vctrs,field) importFrom(vctrs,new_rcrd) diff --git a/R/get_test_data.R b/R/get_test_data.R index f1d83aad0..5e74da5a1 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -25,8 +25,8 @@ #' step_epi_lag(case_rate, lag = c(0, 7, 14)) #' get_test_data(recipe = rec, x = case_death_rate_subset) #' @importFrom rlang %@% +#' @importFrom stats na.omit #' @export - get_test_data <- function(recipe, x) { if (!is_epi_df(x)) cli::cli_abort("`x` must be an `epi_df`.") diff --git a/R/step_adjust_latency.R b/R/step_adjust_latency.R index cdef74565..1275c3948 100644 --- a/R/step_adjust_latency.R +++ b/R/step_adjust_latency.R @@ -302,7 +302,7 @@ bake.step_adjust_latency <- function(object, new_data, ...) { rel_keys <- setdiff(key_colnames(new_data), "time_value") unnamed_columns <- object$columns %>% unname() new_data <- new_data %>% - pad_to_end(rel_keys, object$forecast_date) %>% + pad_to_end(rel_keys, object$forecast_date, unnamed_columns) %>% # group_by_at(rel_keys) %>% arrange(time_value) %>% as_tibble() %>% diff --git a/R/utils-latency.R b/R/utils-latency.R index b32edf7b5..35db0b484 100644 --- a/R/utils-latency.R +++ b/R/utils-latency.R @@ -33,6 +33,7 @@ construct_shift_tibble <- function(terms_used, recipe, rel_step_type, shift_name #' @keywords internal #' @importFrom dplyr select #' @importFrom tidyr drop_na +#' @importFrom utils capture.output set_forecast_date <- function(new_data, info, epi_keys_checked, latency) { original_columns <- info %>% filter(source == "original") %>% @@ -161,25 +162,6 @@ get_forecast_date_in_layer <- function(this_recipe, workflow_max_time_value, new } -fill_locf <- function(x, forecast_date) { - cannot_be_used <- x %>% - dplyr::filter(forecast_date - time_value <= n_recent) %>% - dplyr::mutate(fillers = forecast_date - time_value > keep) %>% - dplyr::summarise( - dplyr::across( - -tidyselect::any_of(key_colnames(recipe)), - ~ all(is.na(.x[fillers])) & is.na(head(.x[!fillers], 1)) - ), - .groups = "drop" - ) %>% - dplyr::select(-fillers) %>% - dplyr::summarise(dplyr::across( - -tidyselect::any_of(key_colnames(recipe)), ~ any(.x) - )) %>% - unlist() - x <- tidyr::fill(x, !time_value) -} - #' pad every group at the right interval #' @description #' Perform last observation carried forward on a group by group basis. It uses @@ -226,6 +208,9 @@ pad_to_end <- function(x, groups, end_date, columns_to_complete = NULL) { } #' return the names of the grouped columns, or `NULL` +#' @param x an epi_df +#' @keywords internal +#' @importFrom utils head get_grouping_columns <- function(x) { group_names <- names(attributes(x)$groups) head(group_names, -1) diff --git a/_pkgdown.yml b/_pkgdown.yml index c6df4c82d..468da62ac 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -76,6 +76,7 @@ reference: contents: - quantile_reg - smooth_quantile_reg + - grf_quantiles - title: Custom panel data forecasting workflows contents: - epi_recipe diff --git a/man/get_grouping_columns.Rd b/man/get_grouping_columns.Rd index 6b653628d..f8b61af42 100644 --- a/man/get_grouping_columns.Rd +++ b/man/get_grouping_columns.Rd @@ -6,6 +6,10 @@ \usage{ get_grouping_columns(x) } +\arguments{ +\item{x}{an epi_df} +} \description{ return the names of the grouped columns, or \code{NULL} } +\keyword{internal} diff --git a/tests/testthat/_snaps/snapshots.md b/tests/testthat/_snaps/snapshots.md index 6c439dbd8..6b337c056 100644 --- a/tests/testthat/_snaps/snapshots.md +++ b/tests/testthat/_snaps/snapshots.md @@ -1057,6 +1057,32 @@ 18998, 18998, 18998, 18998, 18998), class = "Date")), row.names = c(NA, -6L), class = c("tbl_df", "tbl", "data.frame")) +--- + + structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" + ), .pred = c(0.303244704017743, 0.531332853311082, 0.588827944685979, + 0.988690249216229, 0.794801997001639, 0.306895457225321), .pred_distn = structure(list( + structure(list(values = c(0.136509784083987, 0.469979623951498 + ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.364597933377326, 0.698067773244837), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.422093024752224, + 0.755562864619735), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.821955329282474, 1.15542516914998), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0.628067077067883, + 0.961536916935394), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", + "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( + values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05, + 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", + "list")), forecast_date = structure(c(18997, 18997, 18997, 18997, + 18997, 18997), class = "Date"), target_date = structure(c(18998, + 18998, 18998, 18998, 18998, 18998), class = "Date")), row.names = c(NA, + -6L), class = c("tbl_df", "tbl", "data.frame")) + # arx_forecaster output format snapshots Code diff --git a/tests/testthat/test-pad_to_end.R b/tests/testthat/test-pad_to_end.R deleted file mode 100644 index 0ea6244b0..000000000 --- a/tests/testthat/test-pad_to_end.R +++ /dev/null @@ -1,37 +0,0 @@ -test_that("test set padding works", { - dat <- tibble::tibble( - gr1 = rep(c("a", "b"), times = c(3, 4)), - time_value = c(1:3, 1:4), - value = 1:7 - ) %>% arrange(time_value, gr1) - expect_identical(pad_to_end(dat, "gr1", 3), dat) - expect_equal(nrow(pad_to_end(dat, "gr1", 4)), 8L) - p <- pad_to_end(dat, "gr1", 5) - expect_equal(nrow(p), 10L) - expect_identical(p$gr1, rep(c("a", "b"), times = 5)) - expect_identical(p$time_value, rep(1:5, each = 2)) - expect_identical(p$value, as.integer(c(1, 4, 2, 5, 3, 6, NA, 7, NA, NA))) - - dat <- dat %>% arrange(gr1) - dat$gr2 <- c("c", "c", "d", "c", "c", "d", "d") - dat <- dat %>% arrange(time_value) - # don't treat it as a group - p <- pad_to_end(dat, "gr1", 4) - expect_identical(nrow(p), 8L) - expect_identical(p$gr2, c(rep("c", 4), "d", "d", NA, "d")) - - # treat it as a group (needs different time_value) - dat$time_value <- c(1, 1, 2, 2, 1, 1, 2) # double - p <- pad_to_end(dat, c("gr1", "gr2"), 2) - expect_equal(nrow(p), 8L) - expect_identical(p$gr1, rep(c("a", "a", "b", "b"), times = 2)) - expect_identical(p$gr2, rep(c("c", "d"), times = 4)) - expect_identical(p$time_value, rep(c(1, 2), each = 4)) - expect_identical(p$value, as.integer(c(1, 3, 4, 6, 2, NA, 5, 7))) - - # make sure it maintains the epi_df - dat <- dat %>% - dplyr::rename(geo_value = gr1) %>% - as_epi_df() - expect_s3_class(pad_to_end(dat, "geo_value", 2), "epi_df") -}) diff --git a/tests/testthat/test-snapshots.R b/tests/testthat/test-snapshots.R index 7f1d46006..003fc8319 100644 --- a/tests/testthat/test-snapshots.R +++ b/tests/testthat/test-snapshots.R @@ -94,6 +94,19 @@ test_that("arx_forecaster snapshots", { ) # not the same predictions expect_false(all(arx2$predictions == arx3$predictions)) + + + arx4 <- arx_forecaster( + train_data, + "death_rate_7d_av", + c("death_rate_7d_av", "case_rate_7d_av"), + args_list = arx_args_list( + ahead = 1L, + adjust_latency = "locf" + ) + ) + # consistency check + expect_snapshot_tibble(arx3$predictions) }) test_that("arx_forecaster output format snapshots", { diff --git a/tests/testthat/test-step_adjust_latency.R b/tests/testthat/test-step_adjust_latency.R index a557fc146..75656989a 100644 --- a/tests/testthat/test-step_adjust_latency.R +++ b/tests/testthat/test-step_adjust_latency.R @@ -186,14 +186,14 @@ test_that("epi_adjust_latency correctly locfs", { last_dates, tribble( ~name, ~last_date, - "lag_11_death_rate", max_time + 11, - "lag_6_death_rate", max_time + 6, - "lag_5_case_rate", max_time + 5, - "lag_1_case_rate", max_time + 1, - "case_rate", max_time, - "death_rate", max_time, - "lag_0_death_rate", max_time + 0, - "ahead_7_death_rate", max_time - 7, + "lag_11_death_rate", max_time + 16, + "lag_6_death_rate", max_time + 11, + "lag_5_case_rate", max_time + 10, + "lag_1_case_rate", max_time + 6, + "case_rate", max_time + 5, + "death_rate", max_time + 5, + "lag_0_death_rate", max_time + 5, + "ahead_7_death_rate", max_time - 2, ) ) # we expect a 5-fold repetition of the last values found in the original @@ -204,7 +204,7 @@ test_that("epi_adjust_latency correctly locfs", { slice_tail() %>% ungroup() %>% select(case_rate, death_rate) %>% - uncount(5) + tidyr::uncount(5) # pulling just the region between the last day and the prediction day filled_values <- baked_x %>% @@ -450,7 +450,7 @@ test_that("`step_adjust_latency` only uses the columns specified in the `...`", summarise(last_date = max(time_value)) %>% arrange(desc(last_date)) %>% mutate(locf_date = last_date - latency) - # iterate over all columns and make sure the latent time period has the exact same values + # iterate over all columns and make sure the latent time period has the exact same values (so the variance is zero) for (ii in seq(nrow(last_dates))) { baked_var <- baked_x %>% filter(last_dates[[ii, "locf_date"]] <= time_value, time_value <= last_dates[[ii, "last_date"]]) %>% diff --git a/tests/testthat/test-utils_latency.R b/tests/testthat/test-utils_latency.R index a23be628b..78f294564 100644 --- a/tests/testthat/test-utils_latency.R +++ b/tests/testthat/test-utils_latency.R @@ -149,7 +149,7 @@ test_that("pad_to_end handles weeks", { ), a = 3, b = .9 ) - ) %>% arrange(time_value, geo_value) + ) %>% arrange(geo_value, time_value) ) }) # todo case where somehow columns of different roles are selected