Skip to content

Commit

Permalink
Make key inference more consistent; allow non-epi_df forged data
Browse files Browse the repository at this point in the history
  • Loading branch information
brookslogan committed Oct 29, 2024
1 parent 7090cf0 commit 2e30757
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 45 deletions.
13 changes: 8 additions & 5 deletions R/key_colnames.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ key_colnames.recipe <- function(x, ..., exclude = character()) {
key_colnames.epi_workflow <- function(x, ..., exclude = character()) {
# safer to look at the mold than the preprocessor
mold <- hardhat::extract_mold(x)
molded_names <- names(mold$extras$roles)
geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value)
time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value)
keys <- names(mold$extras$roles[molded_names %in% "key"]$key)
full_key <- c(geo_key, keys, time_key) %||% character(0L)
molded_roles <- mold$extras$roles
extras <- bind_cols(molded_roles$geo_value, molded_roles$key, molded_roles$time_value)
full_key <- names(extras)
if (length(full_key) == 0L) {
# No epikeytime role assignment; infer from all columns:
potential_keys <- c("geo_value", "time_value")
full_key <- potential_keys[potential_keys %in% names(bind_cols(molded_roles))]
}
full_key[!full_key %in% exclude]
}

Expand Down
3 changes: 2 additions & 1 deletion R/layer_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ slather.layer_population_scaling <-
suggested_min_keys <- kill_time_value(lhs_potential_keys)
if (!all(suggested_min_keys %in% object$by)) {
cli_warn(c(
"Couldn't find {setdiff(suggested_min_keys, object$by)} in population `df`",
"{setdiff(suggested_min_keys, object$by)} {?was an/were} epikey column{?s} in the predictions,
but {?wasn't/weren't} found in the population `df`.",
"i" = "Defaulting to join by {object$by}",
">" = "Double-check whether column names on the population `df` match those expected in your predictions",
">" = "Consider using population data with breakdowns by {suggested_min_keys}",
Expand Down
10 changes: 6 additions & 4 deletions R/step_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,10 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) {
}
if (!all(suggested_min_keys %in% x$by)) {
cli_warn(c(
"Couldn't find {setdiff(suggested_min_keys, x$by)} in population `df`.",
"{setdiff(suggested_min_keys, x$by)} {?was an/were} epikey column{?s} in the training data,
but {?wasn't/weren't} found in the population `df`.",
"i" = "Defaulting to join by {x$by}.",
">" = "Double-check whether column names on the population `df` match those for your time series.",
">" = "Double-check whether column names on the population `df` match those for your training data.",
">" = "Consider using population data with breakdowns by {suggested_min_keys}.",
">" = "Manually specify `by =` to silence."
), class = "epipredict__step_population_scaling__default_by_missing_suggested_keys")
Expand Down Expand Up @@ -229,8 +230,9 @@ bake.step_population_scaling <- function(object, new_data, ...) {
col_to_remove <- setdiff(colnames(object$df), colnames(new_data))

inner_join(new_data, object$df,
by = object$by, relationship = "many-to-one", unmatched = c("error", "drop"),
suffix = c("", ".df")) %>%
by = object$by, relationship = "many-to-one", unmatched = c("error", "drop"),
suffix = c("", ".df")
) %>%
mutate(
across(
all_of(object$columns),
Expand Down
61 changes: 37 additions & 24 deletions R/utils-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,40 +54,53 @@ format_varnames <- function(x, empty = "*none*") {
grab_forged_keys <- function(forged, workflow, new_data) {
# 1. keys in the training data post-prep, based on roles:
old_keys <- key_colnames(workflow)
# 3. keys in the test data post-bake, based on roles:
forged_roles <- names(forged$extras$roles)
extras <- bind_cols(forged$extras$roles[forged_roles %in% c("geo_value", "time_value", "key")])
new_keys <- names(extras)
# 2. keys in the test data post-bake, based on roles & structure:
forged_roles <- forged$extras$roles
new_key_tbl <- bind_cols(forged_roles$geo_value, forged_roles$key, forged_roles$time_value)
new_keys <- names(new_key_tbl)
if (length(new_keys) == 0L) {
# No epikeytime role assignment; infer from all columns:
potential_keys <- c("geo_value", "time_value")
new_keys <- potential_keys[potential_keys %in% names(bind_cols(forged$extras$roles))]
potential_new_keys <- c("geo_value", "time_value")
forged_tbl <- bind_cols(forged$extras$roles)
new_keys <- potential_new_keys[potential_new_keys %in% names(forged_tbl)]
new_key_tbl <- forged_tbl[new_keys]
}
# 2. keys in the test data pre-bake based on data structure + post-bake roles:
new_df_keys <- key_colnames(new_data, other_keys = setdiff(new_keys, c("geo_value", "time_value")))
# Softly validate, assuming that no steps change epikeytime role assignments:
if (!(setequal(old_keys, new_df_keys) && setequal(new_df_keys, new_keys))) {
# Softly validate:
if (!(setequal(old_keys, new_keys))) {
cli_warn(c(
"Inconsistent epikeytime identifier columns specified/inferred.",
"Inconsistent epikeytime identifier columns specified/inferred in training vs. in testing data.",
"i" = "training epikeytime columns, based on roles post-mold/prep: {format_varnames(old_keys)}",
"i" = " testing epikeytime columns, based on data structure pre-bake and roles post-forge/bake: {format_varnames(new_df_keys)}",
"i" = " testing epikeytime columns, based on roles post-forge/bake: {format_varnames(new_keys)}",
"*" = "Keys will be set giving preference to test-time `epi_df` metadata followed by test-time
post-bake role settings.",
"i" = "testing epikeytime columns, based on roles post-forge/bake: {format_varnames(new_keys)}",
"*" = "",
">" = 'Some mismatches can be addressed by using `epi_df`s instead of tibbles, or by using `update_role`
to assign pre-`prep` columns the "geo_value", "key", and "time_value" roles.'
))
}
if (is_epi_df(new_data)) {
# Inference based on test data pre-bake data structure "wins":
meta <- attr(new_data, "metadata")
extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys)
} else if (all(c("geo_value", "time_value") %in% new_keys)) {
# Inference based on test data post-bake roles "wins":
other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")]
extras <- as_epi_df(extras, other_keys = other_keys)
# Convert `new_key_tbl` to `epi_df` if not renaming columns nor violating
# `epi_df` invariants. Require that our key is a unique key in any case.
if (all(c("geo_value", "time_value") %in% new_keys)) {
maybe_as_of <- attr(new_data, "metadata")$as_of # NULL if wasn't epi_df
try(return(as_epi_df(new_key_tbl, other_keys = new_keys, as_of = maybe_as_of)),
silent = TRUE
)
}
if (anyDuplicated(new_key_tbl)) {
duplicate_key_tbl <- new_key_tbl %>% filter(.by = everything(), dplyr::n() > 1L)
error_part1 <- cli::format_error(
c(
"Specified/inferred key columns had repeated combinations in the forged/baked test data.",
"i" = "Key columns: {format_varnames(new_keys)}",
"Duplicated keys:"
)
)
error_part2 <- capture.output(print(duplicate_key_tbl))
rlang::abort(
paste(collapse = "\n", c(error_part1, error_part2)),
class = "epipredict__grab_forged_keys__nonunique_key"
)
} else {
return(new_key_tbl)
}
extras
}

get_parsnip_mode <- function(trainer) {
Expand Down
20 changes: 9 additions & 11 deletions tests/testthat/test-population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ test_that("test joining by default columns with less common keys/classes", {
update_role("age_group", new_role = "key") %>%
update_role("time_value", new_role = "time_value") %>%
step_population_scaling(y, df = pop1b2, df_pop_col = "population", role = "outcome") %>%
{.},
{
.
},
model_spec,
frosting() %>%
layer_predict() %>%
Expand Down Expand Up @@ -432,14 +434,16 @@ test_that("test joining by default columns with less common keys/classes", {
class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys"
)

# Same thing but with time series in tibble, but no role hints -> different inference&messaging:
# Same thing but with time series in tibble, but no role hints -> different behavior:
dat1b3 <- dat1b2
pop1b3 <- pop1b2
ewf1b3 <- epi_workflow(
# Can't use epi_recipe or step_epi_ahead; adjust.
recipe(dat1b3) %>%
step_population_scaling(y, df = pop1b3, df_pop_col = "population", role = "outcome") %>%
{.},
{
.
},
model_spec,
frosting() %>%
layer_predict() %>%
Expand All @@ -453,15 +457,10 @@ test_that("test joining by default columns with less common keys/classes", {
# geo 1 scaling used for both:
mutate(y_scaled = c(3e-6, 7 * 11 / 5e6))
)
expect_equal(
expect_error(
predict(fit(ewf1b3, dat1b3), dat1b3) %>%
pivot_quantiles_wider(.pred),
dat1b3 %>%
select(!"y") %>%
as_tibble() %>%
# geo 1 scaling used for both:
mutate(`0.5` = c(2 * 5, 2 * 5)) %>%
select(geo_value, age_group, time_value, `0.5`)
class = "epipredict__grab_forged_keys__nonunique_key"
)

# With geo x age_group breakdown on both:
Expand Down Expand Up @@ -571,7 +570,6 @@ test_that("test joining by default columns with less common keys/classes", {
# TODO non-`epi_df` scaling?

# TODO multikey scaling?

})


Expand Down

0 comments on commit 2e30757

Please sign in to comment.