From 96c83d664406f749cf357d01061e479bf0346189 Mon Sep 17 00:00:00 2001 From: Nat DeFries <42820733+nmdefries@users.noreply.github.com> Date: Wed, 17 Jan 2024 11:50:14 -0500 Subject: [PATCH] use closure to fetch min_ref_time_values from `starts` instead of recalculating --- R/slide.R | 59 ++++++++++----------------------- tests/testthat/test-epi_slide.R | 44 ++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 42 deletions(-) diff --git a/R/slide.R b/R/slide.R index 2a10efce..9f2cccad 100644 --- a/R/slide.R +++ b/R/slide.R @@ -230,37 +230,15 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values, after <- time_step(after) } + # Do set up to let us recover `ref_time_value`s later. min_ref_time_values <- ref_time_values - before min_ref_time_values_not_in_x <- min_ref_time_values[!(min_ref_time_values %in% unique(x$time_value))] - # Do set up to let us recover `ref_time_value`s later. - # A helper column marking real observations. - x$.real <- TRUE - - # Create df containing phony data. Df has the same columns and attributes as - # `x`, but filled with `NA`s aside from grouping columns. Number of rows is - # equal to the number of `min_ref_time_values_not_in_x` we have * the - # number of unique levels seen in the grouping columns. - before_time_values_df <- data.frame(time_value = min_ref_time_values_not_in_x) - if (length(group_vars(x)) != 0) { - before_time_values_df <- dplyr::cross_join( - # Get unique combinations of grouping columns seen in real data. - unique(x[, group_vars(x)]), - before_time_values_df - ) - } - # Automatically fill in all other columns from `x` with `NA`s, and carry - # attributes over to new df. - before_time_values_df <- bind_rows(x[0, ], before_time_values_df) - before_time_values_df$.real <- FALSE - - x <- bind_rows(before_time_values_df, x) - # Arrange by increasing time_value x <- arrange(x, time_value) # Now set up starts and stops for sliding/hopping - time_range <- range(unique(x$time_value)) + time_range <- range(unique(c(x$time_value, min_ref_time_values_not_in_x))) starts <- in_range(ref_time_values - before, time_range) stops <- in_range(ref_time_values + after, time_range) @@ -273,7 +251,7 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values, # Computation for one group, all time values slide_one_grp <- function(.data_group, - f, ..., + f_factory, ..., starts, stops, time_values, @@ -288,6 +266,8 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values, stops <- stops[o] time_values <- time_values[o] + f <- f_factory(starts) + # Compute the slide values slide_values_list <- slider::hop_index( .x = .data_group, @@ -349,7 +329,6 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values, # fills with NA equivalent. vctrs::vec_slice(slide_values, o) <- orig_values } else { - # This implicitly removes phony (`.real` == FALSE) observations. .data_group <- filter(.data_group, o) } return(mutate(.data_group, !!new_col := slide_values)) @@ -372,15 +351,20 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values, f <- as_slide_computation(f, ...) # Create a wrapper that calculates and passes `.ref_time_value` to the - # computation. - f_wrapper <- function(.x, .group_key, ...) { - .ref_time_value <- min(.x$time_value) + before - .x <- .x[.x$.real, ] - .x$.real <- NULL - f(.x, .group_key, .ref_time_value, ...) + # computation. `i` is contained in the `f_wrapper_factory` environment such + # that when called within `slide_one_grp` `i` is reset for every group. + f_wrapper_factory <- function(starts) { + # Use `i` to advance through list of start dates. + i <- 1L + f_wrapper <- function(.x, .group_key, ...) { + .ref_time_value <- starts[[i]] + before + i <<- i + 1L + f(.x, .group_key, .ref_time_value, ...) + } + return(f_wrapper) } x <- group_modify(x, slide_one_grp, - f = f_wrapper, ..., + f_factory = f_wrapper_factory, ..., starts = starts, stops = stops, time_values = ref_time_values, @@ -394,14 +378,5 @@ epi_slide <- function(x, f, ..., before, after, ref_time_values, x <- unnest(x, !!new_col, names_sep = names_sep) } - # Remove any remaining phony observations. When `all_rows` is TRUE, phony - # observations aren't necessarily removed in `slide_one_grp`. - if (all_rows) { - x <- x[x$.real, ] - } - - # Drop helper column `.real`. - x$.real <- NULL - return(x) } diff --git a/tests/testthat/test-epi_slide.R b/tests/testthat/test-epi_slide.R index e2bbc040..cd38dc97 100644 --- a/tests/testthat/test-epi_slide.R +++ b/tests/testthat/test-epi_slide.R @@ -626,3 +626,47 @@ test_that("`epi_slide` can access objects inside of helper functions", { NA ) }) + +test_that("epi_slide basic behavior is correct when groups have non-overlapping date ranges", { + small_x_misaligned_dates <- dplyr::bind_rows( + dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value = 11:15), + dplyr::tibble(geo_value = "al", time_value = d + 151:155, value = -(1:5)) + ) %>% + as_epi_df(as_of = d + 6) %>% + group_by(geo_value) + + expected_output <- dplyr::bind_rows( + dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value = 11:15, slide_value = cumsum(11:15)), + dplyr::tibble(geo_value = "al", time_value = d + 151:155, value = -(1:5), slide_value = cumsum(-(1:5))) + ) %>% + group_by(geo_value) %>% + as_epi_df(as_of = d + 6) + + result1 <- epi_slide(small_x_misaligned_dates, f = ~ sum(.x$value), before = 50) + expect_identical(result1, expected_output) +}) + + +test_that("epi_slide gets correct ref_time_value when groups have non-overlapping date ranges", { + small_x_misaligned_dates <- dplyr::bind_rows( + dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value = 11:15), + dplyr::tibble(geo_value = "al", time_value = d + 151:155, value = -(1:5)) + ) %>% + as_epi_df(as_of = d + 6) %>% + group_by(geo_value) + + expected_output <- dplyr::bind_rows( + dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value = 11:15, slide_value = d + 1:5), + dplyr::tibble(geo_value = "al", time_value = d + 151:155, value = -(1:5), slide_value = d + 151:155) + ) %>% + group_by(geo_value) %>% + as_epi_df(as_of = d + 6) + + result1 <- small_x_misaligned_dates %>% + epi_slide( + before = 50, + slide_value = .ref_time_value + ) + + expect_identical(result1, expected_output) +})