Skip to content

Commit

Permalink
Merge pull request #422 from brookslogan/lcb/key_colnames-downstream
Browse files Browse the repository at this point in the history
Fix population scaling with `other_keys` + supporting fixes/changes
  • Loading branch information
dajmcdon authored Nov 11, 2024
2 parents b73105f + 3685e67 commit ea34700
Show file tree
Hide file tree
Showing 14 changed files with 537 additions and 65 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ importFrom(dplyr,filter)
importFrom(dplyr,full_join)
importFrom(dplyr,group_by)
importFrom(dplyr,group_by_at)
importFrom(dplyr,inner_join)
importFrom(dplyr,join_by)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
Expand Down Expand Up @@ -273,6 +274,7 @@ importFrom(hardhat,extract_recipe)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(magrittr,"%>%")
importFrom(magrittr,extract2)
importFrom(recipes,bake)
importFrom(recipes,detect_step)
importFrom(recipes,prep)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
## Improvements

- Add `step_adjust_latency`, which give several methods to adjust the forecast if the `forecast_date` is after the last day of data.
- Fix `layer_population_scaling` default `by` with `other_keys`.
- Make key column inference more consistent within the package and with current `epiprocess`.
- Fix `quantile_reg()` producing error when asked to output just median-level predictions.
- (temporary) ahead negative is allowed for `step_epi_ahead` until we have `step_epi_shift`

## Bug fixes
Expand Down
5 changes: 2 additions & 3 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,10 @@ autoplot.epi_workflow <- function(
if (!is.null(shift)) {
edf <- mutate(edf, time_value = time_value + shift)
}
extra_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
if (length(extra_keys) == 0L) extra_keys <- NULL
other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value"))
edf <- as_epi_df(edf,
as_of = object$fit$meta$as_of,
other_keys = extra_keys %||% character()
other_keys = other_keys
)
if (is.null(predictions)) {
return(autoplot(
Expand Down
2 changes: 2 additions & 0 deletions R/epipredict-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#' @importFrom cli cli_abort cli_warn
#' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by
#' @importFrom dplyr full_join relocate summarise everything
#' @importFrom dplyr inner_join
#' @importFrom dplyr summarize filter mutate select left_join rename ungroup
#' @importFrom magrittr extract2
#' @importFrom rlang := !! %||% as_function global_env set_names !!! caller_arg
#' @importFrom rlang is_logical is_true inject enquo enquos expr sym arg_match
#' @importFrom stats poly predict lm residuals quantile
Expand Down
21 changes: 13 additions & 8 deletions R/key_colnames.R
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
#' @export
key_colnames.recipe <- function(x, ...) {
key_colnames.recipe <- function(x, ..., exclude = character()) {
geo_key <- x$var_info$variable[x$var_info$role %in% "geo_value"]
time_key <- x$var_info$variable[x$var_info$role %in% "time_value"]
keys <- x$var_info$variable[x$var_info$role %in% "key"]
c(geo_key, keys, time_key) %||% character(0L)
full_key <- c(geo_key, keys, time_key) %||% character(0L)
full_key[!full_key %in% exclude]
}

#' @export
key_colnames.epi_workflow <- function(x, ...) {
key_colnames.epi_workflow <- function(x, ..., exclude = character()) {
# safer to look at the mold than the preprocessor
mold <- hardhat::extract_mold(x)
molded_names <- names(mold$extras$roles)
geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value)
time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value)
keys <- names(mold$extras$roles[molded_names %in% "key"]$key)
c(geo_key, keys, time_key) %||% character(0L)
molded_roles <- mold$extras$roles
extras <- bind_cols(molded_roles$geo_value, molded_roles$key, molded_roles$time_value)
full_key <- names(extras)
if (length(full_key) == 0L) {
# No epikeytime role assignment; infer from all columns:
potential_keys <- c("geo_value", "time_value")
full_key <- potential_keys[potential_keys %in% names(bind_cols(molded_roles))]
}
full_key[!full_key %in% exclude]
}

kill_time_value <- function(v) {
Expand Down
41 changes: 34 additions & 7 deletions R/layer_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@
#' inverting the existing scaling.
#' @param by A (possibly named) character vector of variables to join by.
#'
#' If `NULL`, the default, the function will perform a natural join, using all
#' variables in common across the `epi_df` produced by the `predict()` call
#' and the user-provided dataset.
#' If columns in that `epi_df` and `df` have the same name (and aren't
#' included in `by`), `.df` is added to the one from the user-provided data
#' to disambiguate.
#' If `NULL`, the default, the function will try to infer a reasonable set of
#' columns. First, it will try to join by all variables in the test data with
#' roles `"geo_value"`, `"key"`, or `"time_value"` that also appear in `df`;
#' these roles are automatically set if you are using an `epi_df`, or you can
#' use, e.g., `update_role`. If no such roles are set, it will try to perform a
#' natural join, using variables in common between the training/test data and
#' population data.
#'
#' If columns in the training/testing data and `df` have the same name (and
#' aren't included in `by`), a `.df` suffix is added to the one from the
#' user-provided data to disambiguate.
#'
#' To join by different variables on the `epi_df` and `df`, use a named vector.
#' For example, `by = c("geo_value" = "states")` will match `epi_df$geo_value`
Expand Down Expand Up @@ -135,6 +140,26 @@ slather.layer_population_scaling <-
)
rlang::check_dots_empty()

if (is.null(object$by)) {
# Assume `layer_predict` has calculated the prediction keys and other
# layers don't change the prediction key colnames:
prediction_key_colnames <- names(components$keys)
lhs_potential_keys <- prediction_key_colnames
rhs_potential_keys <- colnames(select(object$df, !object$df_pop_col))
object$by <- intersect(lhs_potential_keys, rhs_potential_keys)
suggested_min_keys <- kill_time_value(lhs_potential_keys)
if (!all(suggested_min_keys %in% object$by)) {
cli_warn(c(
"{setdiff(suggested_min_keys, object$by)} {?was an/were} epikey column{?s} in the predictions,
but {?wasn't/weren't} found in the population `df`.",
"i" = "Defaulting to join by {object$by}",
">" = "Double-check whether column names on the population `df` match those expected in your predictions",
">" = "Consider using population data with breakdowns by {suggested_min_keys}",
">" = "Manually specify `by =` to silence"
), class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys")
}
}

object$by <- object$by %||% intersect(
epi_keys_only(components$predictions),
colnames(select(object$df, !object$df_pop_col))
Expand All @@ -152,10 +177,12 @@ slather.layer_population_scaling <-
suffix <- ifelse(object$create_new, object$suffix, "")
col_to_remove <- setdiff(colnames(object$df), colnames(components$predictions))

components$predictions <- left_join(
components$predictions <- inner_join(
components$predictions,
object$df,
by = object$by,
relationship = "many-to-one",
unmatched = c("error", "drop"),
suffix = c("", ".df")
) %>%
mutate(across(
Expand Down
2 changes: 1 addition & 1 deletion R/make_quantile_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ make_quantile_reg <- function() {

# can't make a method because object is second
out <- switch(type,
rq = dist_quantiles(unname(as.list(x)), object$quantile_levels), # one quantile
rq = dist_quantiles(unname(as.list(x)), object$tau), # one quantile
rqs = {
x <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
dist_quantiles(x, list(object$tau))
Expand Down
90 changes: 75 additions & 15 deletions R/step_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@
#' inverting the existing scaling.
#' @param by A (possibly named) character vector of variables to join by.
#'
#' If `NULL`, the default, the function will perform a natural join, using all
#' variables in common across the `epi_df` produced by the `predict()` call
#' and the user-provided dataset.
#' If columns in that `epi_df` and `df` have the same name (and aren't
#' included in `by`), `.df` is added to the one from the user-provided data
#' to disambiguate.
#' If `NULL`, the default, the function will try to infer a reasonable set of
#' columns. First, it will try to join by all variables in the training/test
#' data with roles `"geo_value"`, `"key"`, or `"time_value"` that also appear in
#' `df`; these roles are automatically set if you are using an `epi_df`, or you
#' can use, e.g., `update_role`. If no such roles are set, it will try to
#' perform a natural join, using variables in common between the training/test
#' data and population data.
#'
#' If columns in the training/testing data and `df` have the same name (and
#' aren't included in `by`), a `.df` suffix is added to the one from the
#' user-provided data to disambiguate.
#'
#' To join by different variables on the `epi_df` and `df`, use a named vector.
#' For example, `by = c("geo_value" = "states")` will match `epi_df$geo_value`
#' to `df$states`. To join by multiple variables, use a vector with length > 1.
#' For example, `by = c("geo_value" = "states", "county" = "county")` will match
#' `epi_df$geo_value` to `df$states` and `epi_df$county` to `df$county`.
#'
#' See [dplyr::left_join()] for more details.
#' See [dplyr::inner_join()] for more details.
#' @param df_pop_col the name of the column in the data frame `df` that
#' contains the population data and will be used for scaling.
#' This should be one column.
Expand Down Expand Up @@ -89,13 +94,25 @@ step_population_scaling <-
suffix = "_scaled",
skip = FALSE,
id = rand_id("population_scaling")) {
arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, id)
arg_is_lgl(create_new, skip)
arg_is_chr(df_pop_col, suffix, id)
if (rlang::dots_n(...) == 0L) {
cli_abort(c(
"`...` must not be empty.",
">" = "Please provide one or more tidyselect expressions in `...`
specifying the columns to which scaling should be applied.",
">" = "If you really want to list `step_population_scaling` in your
recipe but not have it do anything, you can use a tidyselection
that selects zero variables, such as `c()`."
))
}
arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, skip, id)
arg_is_chr(role, df_pop_col, suffix, id)
hardhat::validate_column_names(df, df_pop_col)
arg_is_chr(by, allow_null = TRUE)
arg_is_numeric(rate_rescaling)
if (rate_rescaling <= 0) {
cli_abort("`rate_rescaling` must be a positive number.")
}
arg_is_lgl(create_new, skip)

recipes::add_step(
recipe,
Expand Down Expand Up @@ -138,6 +155,42 @@ step_population_scaling_new <-

#' @export
prep.step_population_scaling <- function(x, training, info = NULL, ...) {
if (is.null(x$by)) {
rhs_potential_keys <- setdiff(colnames(x$df), x$df_pop_col)
lhs_potential_keys <- info %>%
filter(role %in% c("geo_value", "key", "time_value")) %>%
extract2("variable") %>%
unique() # in case of weird var with multiple of above roles
if (length(lhs_potential_keys) == 0L) {
# We're working with a recipe and tibble, and *_role hasn't set up any of
# the above roles. Let's say any column could actually act as a key, and
# lean on `intersect` below to make this something reasonable.
lhs_potential_keys <- names(training)
}
suggested_min_keys <- info %>%
filter(role %in% c("geo_value", "key")) %>%
extract2("variable") %>%
unique()
# (0 suggested keys if we weren't given any epikeytime var info.)
x$by <- intersect(lhs_potential_keys, rhs_potential_keys)
if (length(x$by) == 0L) {
cli_stop(c(
"Couldn't guess a default for `by`",
">" = "Please rename columns in your population data to match those in your training data,
or manually specify `by =` in `step_population_scaling()`."
), class = "epipredict__step_population_scaling__default_by_no_intersection")
}
if (!all(suggested_min_keys %in% x$by)) {
cli_warn(c(
"{setdiff(suggested_min_keys, x$by)} {?was an/were} epikey column{?s} in the training data,
but {?wasn't/weren't} found in the population `df`.",
"i" = "Defaulting to join by {x$by}.",
">" = "Double-check whether column names on the population `df` match those for your training data.",
">" = "Consider using population data with breakdowns by {suggested_min_keys}.",
">" = "Manually specify `by =` to silence."
), class = "epipredict__step_population_scaling__default_by_missing_suggested_keys")
}
}
step_population_scaling_new(
terms = x$terms,
role = x$role,
Expand All @@ -156,10 +209,14 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) {

#' @export
bake.step_population_scaling <- function(object, new_data, ...) {
object$by <- object$by %||% intersect(
epi_keys_only(new_data),
colnames(select(object$df, !object$df_pop_col))
)
if (is.null(object$by)) {
cli::cli_abort(c(
"`by` was not set and no default was filled in",
">" = "If this was a fit recipe generated from an older version
of epipredict that you loaded in from a file,
please regenerate with the current version of epipredict."
))
}
joinby <- list(x = names(object$by) %||% object$by, y = object$by)
hardhat::validate_column_names(new_data, joinby$x)
hardhat::validate_column_names(object$df, joinby$y)
Expand All @@ -177,7 +234,10 @@ bake.step_population_scaling <- function(object, new_data, ...) {
suffix <- ifelse(object$create_new, object$suffix, "")
col_to_remove <- setdiff(colnames(object$df), colnames(new_data))

left_join(new_data, object$df, by = object$by, suffix = c("", ".df")) %>%
inner_join(new_data, object$df,
by = object$by, relationship = "many-to-one", unmatched = c("error", "drop"),
suffix = c("", ".df")
) %>%
mutate(
across(
all_of(object$columns),
Expand Down
82 changes: 64 additions & 18 deletions R/utils-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,76 @@ check_pname <- function(res, preds, object, newname = NULL) {
res
}

# Copied from `epiprocess`:

#' "Format" a character vector of column/variable names for cli interpolation
#'
#' Designed to give good output if interpolated with cli. Main purpose is to add
#' backticks around variable names when necessary, and something other than an
#' empty string if length 0.
#'
#' @param x `chr`; e.g., `colnames` of some data frame
#' @param empty string; what should be output if `x` is of length 0?
#' @return `chr`
#' @keywords internal
format_varnames <- function(x, empty = "*none*") {
if (length(x) == 0L) {
empty
} else {
as.character(syms(x))
}
}

grab_forged_keys <- function(forged, workflow, new_data) {
forged_roles <- names(forged$extras$roles)
extras <- bind_cols(forged$extras$roles[forged_roles %in% c("geo_value", "time_value", "key")])
# 1. these are the keys in the test data after prep/bake
new_keys <- names(extras)
# 2. these are the keys in the training data
# 1. keys in the training data post-prep, based on roles:
old_keys <- key_colnames(workflow)
# 3. these are the keys in the test data as input
new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, c("geo_value", "time_value")))
if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) {
cli_warn(paste(
"Not all epi keys that were present in the training data are available",
"in `new_data`. Predictions will have only the available keys."
# 2. keys in the test data post-bake, based on roles & structure:
forged_roles <- forged$extras$roles
new_key_tbl <- bind_cols(forged_roles$geo_value, forged_roles$key, forged_roles$time_value)
new_keys <- names(new_key_tbl)
if (length(new_keys) == 0L) {
# No epikeytime role assignment; infer from all columns:
potential_new_keys <- c("geo_value", "time_value")
forged_tbl <- bind_cols(forged$extras$roles)
new_keys <- potential_new_keys[potential_new_keys %in% names(forged_tbl)]
new_key_tbl <- forged_tbl[new_keys]
}
# Softly validate:
if (!(setequal(old_keys, new_keys))) {
cli_warn(c(
"Inconsistent epikeytime identifier columns specified/inferred in training vs. in testing data.",
"i" = "training epikeytime columns, based on roles post-mold/prep: {format_varnames(old_keys)}",
"i" = "testing epikeytime columns, based on roles post-forge/bake: {format_varnames(new_keys)}",
"*" = "",
">" = 'Some mismatches can be addressed by using `epi_df`s instead of tibbles, or by using `update_role`
to assign pre-`prep` columns the "geo_value", "key", and "time_value" roles.'
))
}
if (is_epi_df(new_data)) {
meta <- attr(new_data, "metadata")
extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys %||% character())
} else if (all(c("geo_value", "time_value") %in% new_keys)) {
if (length(new_keys) > 2) other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")]
extras <- as_epi_df(extras, other_keys = other_keys %||% character())
# Convert `new_key_tbl` to `epi_df` if not renaming columns nor violating
# `epi_df` invariants. Require that our key is a unique key in any case.
if (all(c("geo_value", "time_value") %in% new_keys)) {
maybe_as_of <- attr(new_data, "metadata")$as_of # NULL if wasn't epi_df
try(return(as_epi_df(new_key_tbl, other_keys = new_keys, as_of = maybe_as_of)),
silent = TRUE
)
}
if (anyDuplicated(new_key_tbl)) {
duplicate_key_tbl <- new_key_tbl %>% filter(.by = everything(), dplyr::n() > 1L)
error_part1 <- cli::format_error(
c(
"Specified/inferred key columns had repeated combinations in the forged/baked test data.",
"i" = "Key columns: {format_varnames(new_keys)}",
"Duplicated keys:"
)
)
error_part2 <- capture.output(print(duplicate_key_tbl))
rlang::abort(
paste(collapse = "\n", c(error_part1, error_part2)),
class = "epipredict__grab_forged_keys__nonunique_key"
)
} else {
return(new_key_tbl)
}
extras
}

get_parsnip_mode <- function(trainer) {
Expand Down
Loading

0 comments on commit ea34700

Please sign in to comment.