diff --git a/R/slide.R b/R/slide.R index ba5f3f3a..d94dd01a 100644 --- a/R/slide.R +++ b/R/slide.R @@ -186,15 +186,16 @@ epi_slide <- function( # Validate arguments assert_class(.x, "epi_df") - if (checkmate::test_class(.x, "grouped_df")) { + .x_orig_groups <- groups(.x) + if (inherits(.x, "grouped_df")) { expected_group_keys <- .x %>% key_colnames(exclude = "time_value") %>% sort() if (!identical(.x %>% group_vars() %>% sort(), expected_group_keys)) { cli_abort( - "epi_slide: `.x` must be either grouped by {expected_group_keys}. (Or you can just ungroup - `.x` and we'll do this grouping automatically.) You may need to aggregate your data first, - see aggregate_epi_df().", + "`.x` must be either grouped by {expected_group_keys} or ungrouped; if the latter, + we'll temporarily group by {expected_group_keys} for this operation. You may need + to aggregate your data first, see aggregate_epi_df().", class = "epiprocess__epi_slide__invalid_grouping" ) } @@ -300,7 +301,6 @@ epi_slide <- function( # `epi_slide_one_group`. # - `...` from top of `epi_slide` are forwarded to `.f` here through # group_modify and through the lambda. - .x_groups <- groups(.x) result <- group_map( .x, .f = function(.data_group, .group_key, ...) { @@ -324,7 +324,7 @@ epi_slide <- function( filter(.real) %>% select(-.real) %>% arrange_col_canonical() %>% - group_by(!!!.x_groups) + group_by(!!!.x_orig_groups) # If every group in epi_slide_one_group takes the # length(available_ref_time_values) == 0 branch then we end up here. @@ -691,12 +691,30 @@ epi_slide_opt <- function( ) } + assert_class(.x, "epi_df") + .x_orig_groups <- groups(.x) + if (inherits(.x, "grouped_df")) { + expected_group_keys <- .x %>% + key_colnames(exclude = "time_value") %>% + sort() + if (!identical(.x %>% group_vars() %>% sort(), expected_group_keys)) { + cli_abort( + "`.x` must be either grouped by {expected_group_keys} or ungrouped; if the latter, + we'll temporarily group by {expected_group_keys} for this operation. You may need + to aggregate your data first, see aggregate_epi_df().", + class = "epiprocess__epi_slide__invalid_grouping" + ) + } + } else { + .x <- group_epi_df(.x, exclude = "time_value") + } if (nrow(.x) == 0L) { cli_abort( c( "input data `.x` unexpectedly has 0 rows", "i" = "If this computation is occuring within an `epix_slide` call, - check that `epix_slide` `.versions` argument was set appropriately" + check that `epix_slide` `.versions` argument was set appropriately + so that you don't get any completely-empty snapshots" ), class = "epiprocess__epi_slide_opt__0_row_input", epiprocess__x = .x @@ -857,27 +875,9 @@ epi_slide_opt <- function( arrange(.data$time_value) if (f_from_package == "data.table") { - # If a group contains duplicate time values, `frollmean` will still only - # use the last `k` obs. It isn't looking at dates, it just goes in row - # order. So if the computation is aggregating across multiple obs for the - # same date, `epi_slide_opt` and derivates will produce incorrect results; - # `epi_slide` should be used instead. - if (anyDuplicated(.data_group$time_value) != 0L) { - cli_abort( - c( - "group contains duplicate time values. Using `epi_slide_[opt/mean/sum]` on this - group will result in incorrect results", - "i" = "Please change the grouping structure of the input data so that - each group has non-duplicate time values (e.g. `x %>% group_by(geo_value) - %>% epi_slide_opt(.f = frollmean)`)", - "i" = "Use `epi_slide` to aggregate across groups" - ), - class = "epiprocess__epi_slide_opt__duplicate_time_values", - epiprocess__data_group = .data_group, - epiprocess__group_key = .group_key - ) - } - + # Grouping should ensure that we don't have duplicate time values. + # Completion above should ensure we have at least .window_size rows. Check + # that we don't have more than .window_size rows (or fewer somehow): if (nrow(.data_group) != length(c(all_dates, pad_early_dates, pad_late_dates))) { cli_abort( c( @@ -928,7 +928,8 @@ epi_slide_opt <- function( group_modify(slide_one_grp, ..., .keep = FALSE) %>% filter(.data$.real) %>% select(-.real) %>% - arrange_col_canonical() + arrange_col_canonical() %>% + group_by(!!!.x_orig_groups) if (.all_rows) { result[!(result$time_value %in% ref_time_values), result_col_names] <- NA diff --git a/man-roxygen/basic-slide-params.R b/man-roxygen/basic-slide-params.R index 638307d6..8ccd35f9 100644 --- a/man-roxygen/basic-slide-params.R +++ b/man-roxygen/basic-slide-params.R @@ -1,5 +1,5 @@ -#' @param .x An `epi_df` object. If ungrouped, we group by `geo_value` and any -#' columns in `other_keys`. If grouped, we make sure the grouping is by +#' @param .x An `epi_df` object. If ungrouped, we temporarily group by `geo_value` +#' and any columns in `other_keys`. If grouped, we make sure the grouping is by #' `geo_value` and `other_keys`. #' @param .window_size The size of the sliding window. The accepted values #' depend on the type of the `time_value` column in `.x`: diff --git a/man/epi_slide.Rd b/man/epi_slide.Rd index c497b5d3..9f4abd36 100644 --- a/man/epi_slide.Rd +++ b/man/epi_slide.Rd @@ -16,8 +16,8 @@ epi_slide( ) } \arguments{ -\item{.x}{An \code{epi_df} object. If ungrouped, we group by \code{geo_value} and any -columns in \code{other_keys}. If grouped, we make sure the grouping is by +\item{.x}{An \code{epi_df} object. If ungrouped, we temporarily group by \code{geo_value} +and any columns in \code{other_keys}. If grouped, we make sure the grouping is by \code{geo_value} and \code{other_keys}.} \item{.f}{Function, formula, or missing; together with \code{...} specifies the diff --git a/man/epi_slide_opt.Rd b/man/epi_slide_opt.Rd index e61e94cd..48453641 100644 --- a/man/epi_slide_opt.Rd +++ b/man/epi_slide_opt.Rd @@ -47,8 +47,8 @@ epi_slide_sum( ) } \arguments{ -\item{.x}{An \code{epi_df} object. If ungrouped, we group by \code{geo_value} and any -columns in \code{other_keys}. If grouped, we make sure the grouping is by +\item{.x}{An \code{epi_df} object. If ungrouped, we temporarily group by \code{geo_value} +and any columns in \code{other_keys}. If grouped, we make sure the grouping is by \code{geo_value} and \code{other_keys}.} \item{.col_names}{<\code{\link[=dplyr_tidy_select]{tidy-select}}> An unquoted column diff --git a/tests/testthat/test-epi_slide.R b/tests/testthat/test-epi_slide.R index 2cb04eec..0aa4aca7 100644 --- a/tests/testthat/test-epi_slide.R +++ b/tests/testthat/test-epi_slide.R @@ -899,3 +899,41 @@ test_that("epi_slide_opt output naming features", { class = "epiprocess__epi_slide_opt_new_name_duplicated" ) }) + +test_that("epi_slide* output grouping matches input grouping", { + toy_edf <- as_epi_df(bind_rows(list( + tibble(geo_value = 1, age_group = 1, time_value = as.Date("2020-01-01") + 1:10 - 1, value = 1:10), + tibble(geo_value = 1, age_group = 2, time_value = as.Date("2020-01-01") + 1:10 - 1, value = 20:11), + tibble(geo_value = 2, age_group = 2, time_value = as.Date("2020-01-01") + 1:10 - 1, value = 31:40) + )), other_keys = "age_group", as_of = as.Date("2020-01-01") + 20) + + # Preserving existing grouping: + expect_equal( + toy_edf %>% + group_by(age_group, geo_value) %>% + epi_slide(value_7dsum = sum(value), .window_size = 7) %>% + group_vars(), + c("age_group", "geo_value") + ) + expect_equal( + toy_edf %>% + group_by(age_group, geo_value) %>% + epi_slide_sum(value, .window_size = 7) %>% + group_vars(), + c("age_group", "geo_value") + ) + + # Removing automatic grouping: + expect_equal( + toy_edf %>% + epi_slide(value_7dsum = sum(value), .window_size = 7) %>% + group_vars(), + character(0) + ) + expect_equal( + toy_edf %>% + epi_slide_sum(value, .window_size = 7) %>% + group_vars(), + character(0) + ) +})