Skip to content

Commit

Permalink
step locf tests passing, grf pkgdown
Browse files Browse the repository at this point in the history
  • Loading branch information
dsweber2 committed Sep 5, 2024
1 parent 9c4f465 commit 92ea3e3
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 69 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ importFrom(stats,family)
importFrom(stats,lm)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,na.omit)
importFrom(stats,poly)
importFrom(stats,predict)
importFrom(stats,qnorm)
Expand All @@ -315,6 +316,8 @@ importFrom(tidyr,expand_grid)
importFrom(tidyr,fill)
importFrom(tidyr,unnest)
importFrom(tidyselect,all_of)
importFrom(utils,capture.output)
importFrom(utils,head)
importFrom(vctrs,as_list_of)
importFrom(vctrs,field)
importFrom(vctrs,new_rcrd)
Expand Down
2 changes: 1 addition & 1 deletion R/get_test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
#' step_epi_lag(case_rate, lag = c(0, 7, 14))
#' get_test_data(recipe = rec, x = case_death_rate_subset)
#' @importFrom rlang %@%
#' @importFrom stats na.omit
#' @export

get_test_data <- function(recipe, x) {
if (!is_epi_df(x)) cli::cli_abort("`x` must be an `epi_df`.")

Expand Down
2 changes: 1 addition & 1 deletion R/step_adjust_latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ bake.step_adjust_latency <- function(object, new_data, ...) {
rel_keys <- setdiff(key_colnames(new_data), "time_value")
unnamed_columns <- object$columns %>% unname()
new_data <- new_data %>%
pad_to_end(rel_keys, object$forecast_date) %>%
pad_to_end(rel_keys, object$forecast_date, unnamed_columns) %>%
# group_by_at(rel_keys) %>%
arrange(time_value) %>%
as_tibble() %>%
Expand Down
23 changes: 4 additions & 19 deletions R/utils-latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ construct_shift_tibble <- function(terms_used, recipe, rel_step_type, shift_name
#' @keywords internal
#' @importFrom dplyr select
#' @importFrom tidyr drop_na
#' @importFrom utils capture.output
set_forecast_date <- function(new_data, info, epi_keys_checked, latency) {
original_columns <- info %>%
filter(source == "original") %>%
Expand Down Expand Up @@ -161,25 +162,6 @@ get_forecast_date_in_layer <- function(this_recipe, workflow_max_time_value, new
}


fill_locf <- function(x, forecast_date) {
cannot_be_used <- x %>%
dplyr::filter(forecast_date - time_value <= n_recent) %>%
dplyr::mutate(fillers = forecast_date - time_value > keep) %>%
dplyr::summarise(
dplyr::across(
-tidyselect::any_of(key_colnames(recipe)),
~ all(is.na(.x[fillers])) & is.na(head(.x[!fillers], 1))
),
.groups = "drop"
) %>%
dplyr::select(-fillers) %>%
dplyr::summarise(dplyr::across(
-tidyselect::any_of(key_colnames(recipe)), ~ any(.x)
)) %>%
unlist()
x <- tidyr::fill(x, !time_value)
}

#' pad every group at the right interval
#' @description
#' Perform last observation carried forward on a group by group basis. It uses
Expand Down Expand Up @@ -226,6 +208,9 @@ pad_to_end <- function(x, groups, end_date, columns_to_complete = NULL) {
}

#' return the names of the grouped columns, or `NULL`
#' @param x an epi_df
#' @keywords internal
#' @importFrom utils head
get_grouping_columns <- function(x) {
group_names <- names(attributes(x)$groups)
head(group_names, -1)
Expand Down
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ reference:
contents:
- quantile_reg
- smooth_quantile_reg
- grf_quantiles
- title: Custom panel data forecasting workflows
contents:
- epi_recipe
Expand Down
4 changes: 4 additions & 0 deletions man/get_grouping_columns.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions tests/testthat/_snaps/snapshots.md
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,32 @@
18998, 18998, 18998, 18998, 18998), class = "Date")), row.names = c(NA,
-6L), class = c("tbl_df", "tbl", "data.frame"))

---

structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.303244704017743, 0.531332853311082, 0.588827944685979,
0.988690249216229, 0.794801997001639, 0.306895457225321), .pred_distn = structure(list(
structure(list(values = c(0.136509784083987, 0.469979623951498
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.364597933377326, 0.698067773244837), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0.422093024752224,
0.755562864619735), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.821955329282474, 1.15542516914998), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0.628067077067883,
0.961536916935394), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr"))), class = c("distribution", "vctrs_vctr",
"list")), forecast_date = structure(c(18997, 18997, 18997, 18997,
18997, 18997), class = "Date"), target_date = structure(c(18998,
18998, 18998, 18998, 18998, 18998), class = "Date")), row.names = c(NA,
-6L), class = c("tbl_df", "tbl", "data.frame"))

# arx_forecaster output format snapshots

Code
Expand Down
37 changes: 0 additions & 37 deletions tests/testthat/test-pad_to_end.R

This file was deleted.

13 changes: 13 additions & 0 deletions tests/testthat/test-snapshots.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ test_that("arx_forecaster snapshots", {
)
# not the same predictions
expect_false(all(arx2$predictions == arx3$predictions))


arx4 <- arx_forecaster(
train_data,
"death_rate_7d_av",
c("death_rate_7d_av", "case_rate_7d_av"),
args_list = arx_args_list(
ahead = 1L,
adjust_latency = "locf"
)
)
# consistency check
expect_snapshot_tibble(arx3$predictions)
})

test_that("arx_forecaster output format snapshots", {
Expand Down
20 changes: 10 additions & 10 deletions tests/testthat/test-step_adjust_latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,14 @@ test_that("epi_adjust_latency correctly locfs", {
last_dates,
tribble(
~name, ~last_date,
"lag_11_death_rate", max_time + 11,
"lag_6_death_rate", max_time + 6,
"lag_5_case_rate", max_time + 5,
"lag_1_case_rate", max_time + 1,
"case_rate", max_time,
"death_rate", max_time,
"lag_0_death_rate", max_time + 0,
"ahead_7_death_rate", max_time - 7,
"lag_11_death_rate", max_time + 16,
"lag_6_death_rate", max_time + 11,
"lag_5_case_rate", max_time + 10,
"lag_1_case_rate", max_time + 6,
"case_rate", max_time + 5,
"death_rate", max_time + 5,
"lag_0_death_rate", max_time + 5,
"ahead_7_death_rate", max_time - 2,
)
)
# we expect a 5-fold repetition of the last values found in the original
Expand All @@ -204,7 +204,7 @@ test_that("epi_adjust_latency correctly locfs", {
slice_tail() %>%
ungroup() %>%
select(case_rate, death_rate) %>%
uncount(5)
tidyr::uncount(5)
# pulling just the region between the last day and the prediction day
filled_values <-
baked_x %>%
Expand Down Expand Up @@ -450,7 +450,7 @@ test_that("`step_adjust_latency` only uses the columns specified in the `...`",
summarise(last_date = max(time_value)) %>%
arrange(desc(last_date)) %>%
mutate(locf_date = last_date - latency)
# iterate over all columns and make sure the latent time period has the exact same values
# iterate over all columns and make sure the latent time period has the exact same values (so the variance is zero)
for (ii in seq(nrow(last_dates))) {
baked_var <- baked_x %>%
filter(last_dates[[ii, "locf_date"]] <= time_value, time_value <= last_dates[[ii, "last_date"]]) %>%
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-utils_latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ test_that("pad_to_end handles weeks", {
),
a = 3, b = .9
)
) %>% arrange(time_value, geo_value)
) %>% arrange(geo_value, time_value)
)
})
# todo case where somehow columns of different roles are selected

0 comments on commit 92ea3e3

Please sign in to comment.