Skip to content

Commit

Permalink
testing the step
Browse files Browse the repository at this point in the history
  • Loading branch information
dsweber2 committed Sep 4, 2024
1 parent 3194dfc commit 9c4f465
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
2 changes: 1 addition & 1 deletion R/step_adjust_latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ bake.step_adjust_latency <- function(object, new_data, ...) {
# locf doesn't need to mess with the metadata at all, it just forward-fills the requested columns
rel_keys <- setdiff(key_colnames(new_data), "time_value")
unnamed_columns <- object$columns %>% unname()
new_data %>%
new_data <- new_data %>%
pad_to_end(rel_keys, object$forecast_date) %>%
# group_by_at(rel_keys) %>%
arrange(time_value) %>%
Expand Down
52 changes: 52 additions & 0 deletions tests/testthat/test-step_adjust_latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,58 @@ test_that("epi_adjust_latency correctly extends the ahead", {
expect_equal(length(fit2$fit$fit$fit$coefficients), 6)
})

test_that("epi_adjust_latency correctly locfs", {
r1 <- epi_recipe(x) %>%
step_adjust_latency(method = "locf") %>%
step_epi_lag(death_rate, lag = c(0, 6, 11)) %>%
step_epi_lag(case_rate, lag = c(1, 5)) %>%
step_epi_ahead(death_rate, ahead = ahead)

# directly checking the shifts
baked_x <- r1 %>%
prep(real_x) %>%
bake(real_x)
# map each column to its last non-NA value
last_dates <- baked_x %>%
tidyr::pivot_longer(cols = contains("rate"), values_drop_na = TRUE) %>%
group_by(name) %>%
summarise(last_date = max(time_value)) %>%
arrange(desc(last_date))
expect_equal(
last_dates,
tribble(
~name, ~last_date,
"lag_11_death_rate", max_time + 11,
"lag_6_death_rate", max_time + 6,
"lag_5_case_rate", max_time + 5,
"lag_1_case_rate", max_time + 1,
"case_rate", max_time,
"death_rate", max_time,
"lag_0_death_rate", max_time + 0,
"ahead_7_death_rate", max_time - 7,
)
)
# we expect a 5-fold repetition of the last values found in the original
# epi_df
last_real <- real_x %>%
group_by(geo_value) %>%
arrange(time_value) %>%
slice_tail() %>%
ungroup() %>%
select(case_rate, death_rate) %>%
uncount(5)
# pulling just the region between the last day and the prediction day
filled_values <-
baked_x %>%
filter(
time_value > max(real_x$time_value),
time_value <= attributes(real_x)$metadata$as_of
) %>%
ungroup() %>%
select(case_rate, death_rate)
expect_equal(last_real, filled_values)
})

test_that("epi_adjust_latency extends multiple aheads", {
aheads <- 1:3
r3 <- epi_recipe(x) %>%
Expand Down

0 comments on commit 9c4f465

Please sign in to comment.