diff --git a/R/step_adjust_latency.R b/R/step_adjust_latency.R index 0dd0f3600..cdef74565 100644 --- a/R/step_adjust_latency.R +++ b/R/step_adjust_latency.R @@ -301,7 +301,7 @@ bake.step_adjust_latency <- function(object, new_data, ...) { # locf doesn't need to mess with the metadata at all, it just forward-fills the requested columns rel_keys <- setdiff(key_colnames(new_data), "time_value") unnamed_columns <- object$columns %>% unname() - new_data %>% + new_data <- new_data %>% pad_to_end(rel_keys, object$forecast_date) %>% # group_by_at(rel_keys) %>% arrange(time_value) %>% diff --git a/tests/testthat/test-step_adjust_latency.R b/tests/testthat/test-step_adjust_latency.R index 1f8a2889f..a557fc146 100644 --- a/tests/testthat/test-step_adjust_latency.R +++ b/tests/testthat/test-step_adjust_latency.R @@ -165,6 +165,58 @@ test_that("epi_adjust_latency correctly extends the ahead", { expect_equal(length(fit2$fit$fit$fit$coefficients), 6) }) +test_that("epi_adjust_latency correctly locfs", { + r1 <- epi_recipe(x) %>% + step_adjust_latency(method = "locf") %>% + step_epi_lag(death_rate, lag = c(0, 6, 11)) %>% + step_epi_lag(case_rate, lag = c(1, 5)) %>% + step_epi_ahead(death_rate, ahead = ahead) + + # directly checking the shifts + baked_x <- r1 %>% + prep(real_x) %>% + bake(real_x) + # map each column to its last non-NA value + last_dates <- baked_x %>% + tidyr::pivot_longer(cols = contains("rate"), values_drop_na = TRUE) %>% + group_by(name) %>% + summarise(last_date = max(time_value)) %>% + arrange(desc(last_date)) + expect_equal( + last_dates, + tribble( + ~name, ~last_date, + "lag_11_death_rate", max_time + 11, + "lag_6_death_rate", max_time + 6, + "lag_5_case_rate", max_time + 5, + "lag_1_case_rate", max_time + 1, + "case_rate", max_time, + "death_rate", max_time, + "lag_0_death_rate", max_time + 0, + "ahead_7_death_rate", max_time - 7, + ) + ) + # we expect a 5-fold repetition of the last values found in the original + # epi_df + last_real <- real_x %>% + group_by(geo_value) %>% + arrange(time_value) %>% + slice_tail() %>% + ungroup() %>% + select(case_rate, death_rate) %>% + uncount(5) + # pulling just the region between the last day and the prediction day + filled_values <- + baked_x %>% + filter( + time_value > max(real_x$time_value), + time_value <= attributes(real_x)$metadata$as_of + ) %>% + ungroup() %>% + select(case_rate, death_rate) + expect_equal(last_real, filled_values) +}) + test_that("epi_adjust_latency extends multiple aheads", { aheads <- 1:3 r3 <- epi_recipe(x) %>%