Skip to content

Commit

Permalink
use closure to fetch min_ref_time_values from starts instead of
Browse files Browse the repository at this point in the history
recalculating
  • Loading branch information
nmdefries committed Jan 17, 2024
1 parent 71e11f7 commit 94a4d27
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 42 deletions.
59 changes: 17 additions & 42 deletions R/slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -227,37 +227,15 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
after <- time_step(after)
}

# Do set up to let us recover `ref_time_value`s later.
min_ref_time_values = ref_time_values - before
min_ref_time_values_not_in_x <- min_ref_time_values[!(min_ref_time_values %in% unique(x$time_value))]

# Do set up to let us recover `ref_time_value`s later.
# A helper column marking real observations.
x$.real = TRUE

# Create df containing phony data. Df has the same columns and attributes as
# `x`, but filled with `NA`s aside from grouping columns. Number of rows is
# equal to the number of `min_ref_time_values_not_in_x` we have * the
# number of unique levels seen in the grouping columns.
before_time_values_df = data.frame(time_value=min_ref_time_values_not_in_x)
if (length(group_vars(x)) != 0) {
before_time_values_df = dplyr::cross_join(
# Get unique combinations of grouping columns seen in real data.
unique(x[, group_vars(x)]),
before_time_values_df
)
}
# Automatically fill in all other columns from `x` with `NA`s, and carry
# attributes over to new df.
before_time_values_df <- bind_rows(x[0,], before_time_values_df)
before_time_values_df$.real <- FALSE

x <- bind_rows(before_time_values_df, x)

# Arrange by increasing time_value
x = arrange(x, time_value)

# Now set up starts and stops for sliding/hopping
time_range = range(unique(x$time_value))
time_range = range(unique(c(x$time_value, min_ref_time_values_not_in_x)))
starts = in_range(ref_time_values - before, time_range)
stops = in_range(ref_time_values + after, time_range)

Expand All @@ -270,7 +248,7 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,

# Computation for one group, all time values
slide_one_grp = function(.data_group,
f, ...,
f_factory, ...,
starts,
stops,
time_values,
Expand All @@ -285,6 +263,8 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
stops = stops[o]
time_values = time_values[o]

f <- f_factory(starts)

# Compute the slide values
slide_values_list = slider::hop_index(.x = .data_group,
.i = .data_group$time_value,
Expand Down Expand Up @@ -344,7 +324,6 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
# fills with NA equivalent.
vctrs::vec_slice(slide_values, o) = orig_values
} else {
# This implicitly removes phony (`.real` == FALSE) observations.
.data_group = filter(.data_group, o)
}
return(mutate(.data_group, !!new_col := slide_values))
Expand All @@ -367,15 +346,20 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,

f = as_slide_computation(f, ...)
# Create a wrapper that calculates and passes `.ref_time_value` to the
# computation.
f_wrapper = function(.x, .group_key, ...) {
.ref_time_value = min(.x$time_value) + before
.x <- .x[.x$.real,]
.x$.real <- NULL
f(.x, .group_key, .ref_time_value, ...)
# computation. `i` is contained in the `f_wrapper_factory` environment such
# that when called within `slide_one_grp` `i` is reset for every group.
f_wrapper_factory = function(starts) {
# Use `i` to advance through list of start dates.
i <- 1L
f_wrapper = function(.x, .group_key, ...) {
.ref_time_value = starts[[i]] + before
i <<- i + 1L
f(.x, .group_key, .ref_time_value, ...)
}
return(f_wrapper)
}
x = group_modify(x, slide_one_grp,
f = f_wrapper, ...,
f_factory = f_wrapper_factory, ...,
starts = starts,
stops = stops,
time_values = ref_time_values,
Expand All @@ -388,14 +372,5 @@ epi_slide = function(x, f, ..., before, after, ref_time_values,
x = unnest(x, !!new_col, names_sep = names_sep)
}

# Remove any remaining phony observations. When `all_rows` is TRUE, phony
# observations aren't necessarily removed in `slide_one_grp`.
if (all_rows) {
x <- x[x$.real,]
}

# Drop helper column `.real`.
x$.real <- NULL

return(x)
}
44 changes: 44 additions & 0 deletions tests/testthat/test-epi_slide.R
Original file line number Diff line number Diff line change
Expand Up @@ -526,3 +526,47 @@ test_that("`epi_slide` can access objects inside of helper functions", {
NA
)
})

test_that("epi_slide basic behavior is correct when groups have non-overlapping date ranges", {
small_x_misaligned_dates = dplyr::bind_rows(
dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value=11:15),
dplyr::tibble(geo_value = "al", time_value = d + 151:155, value=-(1:5))
) %>%
as_epi_df(as_of = d + 6) %>%
group_by(geo_value)

expected_output = dplyr::bind_rows(
dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value = 11:15, slide_value=cumsum(11:15)),
dplyr::tibble(geo_value = "al", time_value = d + 151:155, value = -(1:5), slide_value=cumsum(-(1:5)))
) %>%
group_by(geo_value) %>%
as_epi_df(as_of = d + 6)

result1 <- epi_slide(small_x_misaligned_dates, f = ~sum(.x$value), before=50)
expect_identical(result1, expected_output)
})


test_that("epi_slide gets correct ref_time_value when groups have non-overlapping date ranges", {
small_x_misaligned_dates = dplyr::bind_rows(
dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value=11:15),
dplyr::tibble(geo_value = "al", time_value = d + 151:155, value=-(1:5))
) %>%
as_epi_df(as_of = d + 6) %>%
group_by(geo_value)

expected_output = dplyr::bind_rows(
dplyr::tibble(geo_value = "ak", time_value = d + 1:5, value = 11:15, slide_value=d + 1:5),
dplyr::tibble(geo_value = "al", time_value = d + 151:155, value = -(1:5), slide_value=d + 151:155)
) %>%
group_by(geo_value) %>%
as_epi_df(as_of = d + 6)

result1 <- small_x_misaligned_dates %>%
epi_slide(before = 50,
slide_value = .ref_time_value)

expect_identical(result1, expected_output)


})

0 comments on commit 94a4d27

Please sign in to comment.