From 2e307577fb0c611a22d11806898483fd2795b194 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Mon, 28 Oct 2024 17:08:51 -0700 Subject: [PATCH] Make key inference more consistent; allow non-`epi_df` forged data --- R/key_colnames.R | 13 +++-- R/layer_population_scaling.R | 3 +- R/step_population_scaling.R | 10 ++-- R/utils-misc.R | 61 ++++++++++++++---------- tests/testthat/test-population_scaling.R | 20 ++++---- 5 files changed, 62 insertions(+), 45 deletions(-) diff --git a/R/key_colnames.R b/R/key_colnames.R index b8d07ce8..9e0d44dc 100644 --- a/R/key_colnames.R +++ b/R/key_colnames.R @@ -11,11 +11,14 @@ key_colnames.recipe <- function(x, ..., exclude = character()) { key_colnames.epi_workflow <- function(x, ..., exclude = character()) { # safer to look at the mold than the preprocessor mold <- hardhat::extract_mold(x) - molded_names <- names(mold$extras$roles) - geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value) - time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value) - keys <- names(mold$extras$roles[molded_names %in% "key"]$key) - full_key <- c(geo_key, keys, time_key) %||% character(0L) + molded_roles <- mold$extras$roles + extras <- bind_cols(molded_roles$geo_value, molded_roles$key, molded_roles$time_value) + full_key <- names(extras) + if (length(full_key) == 0L) { + # No epikeytime role assignment; infer from all columns: + potential_keys <- c("geo_value", "time_value") + full_key <- potential_keys[potential_keys %in% names(bind_cols(molded_roles))] + } full_key[!full_key %in% exclude] } diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 5a982d34..f47e3f29 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -145,7 +145,8 @@ slather.layer_population_scaling <- suggested_min_keys <- kill_time_value(lhs_potential_keys) if (!all(suggested_min_keys %in% object$by)) { cli_warn(c( - "Couldn't find {setdiff(suggested_min_keys, object$by)} in population `df`", + "{setdiff(suggested_min_keys, object$by)} {?was an/were} epikey column{?s} in the predictions, + but {?wasn't/weren't} found in the population `df`.", "i" = "Defaulting to join by {object$by}", ">" = "Double-check whether column names on the population `df` match those expected in your predictions", ">" = "Consider using population data with breakdowns by {suggested_min_keys}", diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index d7b7893e..297c3072 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -177,9 +177,10 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) { } if (!all(suggested_min_keys %in% x$by)) { cli_warn(c( - "Couldn't find {setdiff(suggested_min_keys, x$by)} in population `df`.", + "{setdiff(suggested_min_keys, x$by)} {?was an/were} epikey column{?s} in the training data, + but {?wasn't/weren't} found in the population `df`.", "i" = "Defaulting to join by {x$by}.", - ">" = "Double-check whether column names on the population `df` match those for your time series.", + ">" = "Double-check whether column names on the population `df` match those for your training data.", ">" = "Consider using population data with breakdowns by {suggested_min_keys}.", ">" = "Manually specify `by =` to silence." ), class = "epipredict__step_population_scaling__default_by_missing_suggested_keys") @@ -229,8 +230,9 @@ bake.step_population_scaling <- function(object, new_data, ...) { col_to_remove <- setdiff(colnames(object$df), colnames(new_data)) inner_join(new_data, object$df, - by = object$by, relationship = "many-to-one", unmatched = c("error", "drop"), - suffix = c("", ".df")) %>% + by = object$by, relationship = "many-to-one", unmatched = c("error", "drop"), + suffix = c("", ".df") + ) %>% mutate( across( all_of(object$columns), diff --git a/R/utils-misc.R b/R/utils-misc.R index 7ab4ad95..fec70791 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -54,40 +54,53 @@ format_varnames <- function(x, empty = "*none*") { grab_forged_keys <- function(forged, workflow, new_data) { # 1. keys in the training data post-prep, based on roles: old_keys <- key_colnames(workflow) - # 3. keys in the test data post-bake, based on roles: - forged_roles <- names(forged$extras$roles) - extras <- bind_cols(forged$extras$roles[forged_roles %in% c("geo_value", "time_value", "key")]) - new_keys <- names(extras) + # 2. keys in the test data post-bake, based on roles & structure: + forged_roles <- forged$extras$roles + new_key_tbl <- bind_cols(forged_roles$geo_value, forged_roles$key, forged_roles$time_value) + new_keys <- names(new_key_tbl) if (length(new_keys) == 0L) { # No epikeytime role assignment; infer from all columns: - potential_keys <- c("geo_value", "time_value") - new_keys <- potential_keys[potential_keys %in% names(bind_cols(forged$extras$roles))] + potential_new_keys <- c("geo_value", "time_value") + forged_tbl <- bind_cols(forged$extras$roles) + new_keys <- potential_new_keys[potential_new_keys %in% names(forged_tbl)] + new_key_tbl <- forged_tbl[new_keys] } - # 2. keys in the test data pre-bake based on data structure + post-bake roles: - new_df_keys <- key_colnames(new_data, other_keys = setdiff(new_keys, c("geo_value", "time_value"))) - # Softly validate, assuming that no steps change epikeytime role assignments: - if (!(setequal(old_keys, new_df_keys) && setequal(new_df_keys, new_keys))) { + # Softly validate: + if (!(setequal(old_keys, new_keys))) { cli_warn(c( - "Inconsistent epikeytime identifier columns specified/inferred.", + "Inconsistent epikeytime identifier columns specified/inferred in training vs. in testing data.", "i" = "training epikeytime columns, based on roles post-mold/prep: {format_varnames(old_keys)}", - "i" = " testing epikeytime columns, based on data structure pre-bake and roles post-forge/bake: {format_varnames(new_df_keys)}", - "i" = " testing epikeytime columns, based on roles post-forge/bake: {format_varnames(new_keys)}", - "*" = "Keys will be set giving preference to test-time `epi_df` metadata followed by test-time - post-bake role settings.", + "i" = "testing epikeytime columns, based on roles post-forge/bake: {format_varnames(new_keys)}", + "*" = "", ">" = 'Some mismatches can be addressed by using `epi_df`s instead of tibbles, or by using `update_role` to assign pre-`prep` columns the "geo_value", "key", and "time_value" roles.' )) } - if (is_epi_df(new_data)) { - # Inference based on test data pre-bake data structure "wins": - meta <- attr(new_data, "metadata") - extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys) - } else if (all(c("geo_value", "time_value") %in% new_keys)) { - # Inference based on test data post-bake roles "wins": - other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")] - extras <- as_epi_df(extras, other_keys = other_keys) + # Convert `new_key_tbl` to `epi_df` if not renaming columns nor violating + # `epi_df` invariants. Require that our key is a unique key in any case. + if (all(c("geo_value", "time_value") %in% new_keys)) { + maybe_as_of <- attr(new_data, "metadata")$as_of # NULL if wasn't epi_df + try(return(as_epi_df(new_key_tbl, other_keys = new_keys, as_of = maybe_as_of)), + silent = TRUE + ) + } + if (anyDuplicated(new_key_tbl)) { + duplicate_key_tbl <- new_key_tbl %>% filter(.by = everything(), dplyr::n() > 1L) + error_part1 <- cli::format_error( + c( + "Specified/inferred key columns had repeated combinations in the forged/baked test data.", + "i" = "Key columns: {format_varnames(new_keys)}", + "Duplicated keys:" + ) + ) + error_part2 <- capture.output(print(duplicate_key_tbl)) + rlang::abort( + paste(collapse = "\n", c(error_part1, error_part2)), + class = "epipredict__grab_forged_keys__nonunique_key" + ) + } else { + return(new_key_tbl) } - extras } get_parsnip_mode <- function(trainer) { diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index eabca9f7..63c23a38 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -398,7 +398,9 @@ test_that("test joining by default columns with less common keys/classes", { update_role("age_group", new_role = "key") %>% update_role("time_value", new_role = "time_value") %>% step_population_scaling(y, df = pop1b2, df_pop_col = "population", role = "outcome") %>% - {.}, + { + . + }, model_spec, frosting() %>% layer_predict() %>% @@ -432,14 +434,16 @@ test_that("test joining by default columns with less common keys/classes", { class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys" ) - # Same thing but with time series in tibble, but no role hints -> different inference&messaging: + # Same thing but with time series in tibble, but no role hints -> different behavior: dat1b3 <- dat1b2 pop1b3 <- pop1b2 ewf1b3 <- epi_workflow( # Can't use epi_recipe or step_epi_ahead; adjust. recipe(dat1b3) %>% step_population_scaling(y, df = pop1b3, df_pop_col = "population", role = "outcome") %>% - {.}, + { + . + }, model_spec, frosting() %>% layer_predict() %>% @@ -453,15 +457,10 @@ test_that("test joining by default columns with less common keys/classes", { # geo 1 scaling used for both: mutate(y_scaled = c(3e-6, 7 * 11 / 5e6)) ) - expect_equal( + expect_error( predict(fit(ewf1b3, dat1b3), dat1b3) %>% pivot_quantiles_wider(.pred), - dat1b3 %>% - select(!"y") %>% - as_tibble() %>% - # geo 1 scaling used for both: - mutate(`0.5` = c(2 * 5, 2 * 5)) %>% - select(geo_value, age_group, time_value, `0.5`) + class = "epipredict__grab_forged_keys__nonunique_key" ) # With geo x age_group breakdown on both: @@ -571,7 +570,6 @@ test_that("test joining by default columns with less common keys/classes", { # TODO non-`epi_df` scaling? # TODO multikey scaling? - })