Skip to content

Commit

Permalink
nest min grid by postprocessor parameters
Browse files Browse the repository at this point in the history
an attempt to enable the submodel trick with postprocessors. doesn't quite do the trick--see the newly added but skipped test.
  • Loading branch information
simonpcouch committed Nov 22, 2024
1 parent bfdd585 commit 4c763ed
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 38 deletions.
37 changes: 18 additions & 19 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ compute_grid_info <- function(workflow, grid) {
syms_post <- rlang::syms(parameters_postprocessor$id)

res <- min_grid(extract_spec_parsnip(workflow), grid)
if (any_parameters_postprocessor) {
res <- nest_min_grid(res, parameters_postprocessor$id)
}

# ----------------------------------------------------------------------------
# Create an order of execution to train the preprocessor (if any). This will
Expand Down Expand Up @@ -370,18 +373,6 @@ compute_grid_info <- function(workflow, grid) {
# Now make a similar iterator across models. Conditioning on each unique
# preprocessing candidate set, make an iterator for the model candidate sets
# (if any)
if (any_parameters_postprocessor) {
# Ensure that the submodel trick kicks in by temporarily nesting the
# postprocessor parameters while iterating in the model grid
# TODO: will this introduce issues when there are matching postprocessor
# values across models?
# ... i think we actually want to (temporarily?) situate these as submodels
res <- tidyr::nest(
res,
.data_post = all_of(parameters_postprocessor$id)
)
}

res <-
res %>%
dplyr::group_nest(.iter_preprocessor, keep = TRUE) %>%
Expand Down Expand Up @@ -415,29 +406,36 @@ compute_grid_info <- function(workflow, grid) {
res %>%
dplyr::group_nest(.iter_config, keep = TRUE) %>%
dplyr::mutate(
data = purrr::map(data, make_iter_postprocessor)
data = purrr::map(data, make_iter_postprocessor, parameters_postprocessor$id)
) %>%
tidyr::unnest(cols = data) %>%
dplyr::relocate(dplyr::starts_with(".iter"), dplyr::starts_with(".msg")) %>%
tidyr::unnest(.data_post)
dplyr::relocate(dplyr::starts_with(".iter"), dplyr::starts_with(".msg"))

res
}

make_iter_config <- function(dat) {
# Compute labels for the models *within* each preprocessing loop.
num_submodels <- purrr::map_int(dat$.submodels, ~ length(unlist(.x)))
num_submodels <- purrr::map_int(
dat$.submodels,
function(.x) {if (length(.x) == 0) 0 else length(.x[[1]])}
)
num_models <- sum(num_submodels + 1) # +1 for the model being trained
.mod_label <- recipes::names0(num_models, "Model")
.iter_config <- paste(dat$.lab_pre[1], .mod_label, sep = "_")
.iter_config <- vctrs::vec_chop(.iter_config, sizes = num_submodels + 1)
tibble::tibble(.iter_config = .iter_config)
}

make_iter_postprocessor <- function(data) {
make_iter_postprocessor <- function(data, post_params) {
nested_by_post <- "post" %in% names(data)
if (nested_by_post) {
data <- data %>% unnest(post)
}

data %>%
mutate(
.iter_postprocessor = seq_len(nrow(data)),
.iter_postprocessor = seq_len(nrow(.)),
.msg_postprocessor = new_msgs_postprocessor(
i = .iter_postprocessor,
n = max(.iter_postprocessor),
Expand All @@ -449,7 +447,8 @@ make_iter_postprocessor <- function(data) {
make_iter_config_post
)
) %>%
select(-.iter_config)
select(-.iter_config) %>%
nest(post = c(any_of(post_params), ".iter_postprocessor", ".msg_postprocessor", ".iter_config_post"))
}

make_iter_config_post <- function(iter_config, iter_postprocessor) {
Expand Down
39 changes: 37 additions & 2 deletions R/min_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,41 @@ min_grid.poisson_reg <- fit_max_value
# .submodels to effectively enable the submodel trick.
# See: https://gist.github.com/simonpcouch/28d984cdcc3fc6d22ff776ed8740004e
nest_min_grid <- function(min_grid, post_params) {
# TODO
min_grid
if (!has_submodels(min_grid)) {
return(min_grid)
}
non_post_param_cols <- names(min_grid)[
!names(min_grid) %in% c(post_params, ".submodels")
]
submodel_param_name <- names(min_grid$.submodels[[1]])

res <-
min_grid %>%
# unnest from `list(list())` to `list()`
unnest(.submodels) %>%
# unnest from `list()` to vector
unnest(.submodels)

tibble(
vctrs::vec_unique(res[non_post_param_cols]),
post = list(vctrs::vec_unique(res[post_params])),
.submodels = list(
res[c(post_params, ".submodels")] %>%
rename(!!submodel_param_name := .submodels) %>%
group_by(across(all_of(submodel_param_name))) %>%
group_split()
)
)
}

has_submodels <- function(min_grid) {
if (!".submodels" %in% names(min_grid)) {
return(FALSE)
}

if (length(min_grid$.submodels[[1]]) == 0) {
return(FALSE)
}

TRUE
}
65 changes: 48 additions & 17 deletions tests/testthat/test-grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -367,16 +367,15 @@ test_that("compute_grid_info - model and postprocessor (no submodels)", {
expect_equal(res$.msg_model, paste0("preprocessor 1/1, model ", 1:5, "/5"))
expect_named(
res,
c(".iter_config", ".iter_preprocessor", ".iter_model", ".iter_postprocessor",
".iter_config_post", ".msg_preprocessor", ".msg_model", ".msg_postprocessor",
"tree_depth", "threshold", ".submodels"),
c(".iter_config", ".iter_preprocessor", ".iter_model",
".msg_preprocessor", ".msg_model",
"tree_depth", "post", ".submodels"),
ignore.order = TRUE
)
expect_equal(nrow(res), 5)
})

test_that("compute_grid_info - model and postprocessor (with submodels)", {
skip("not quite ready yet")
# when a workflow has a model with submodels and a postprocessor, we want
# to hook into the submodel trick in the same way we would have before
library(workflows)
Expand All @@ -394,22 +393,54 @@ test_that("compute_grid_info - model and postprocessor (with submodels)", {
grid <- grid_regular(extract_parameter_set_dials(wflow), levels = 3)
res <- compute_grid_info(wflow, grid)

expect_equal(res$.iter_preprocessor, rep(1, 3))
expect_equal(res$.msg_preprocessor, rep("preprocessor 1/1", 3))
expect_equal(res$trees, rep(max(grid$trees), 3))
expect_equal(res$.iter_model, rep(1, 3))
expect_equal(res$.iter_config, rep(list(paste0("Preprocessor1_Model", 1:3)), 3))
expect_equal(res$.msg_model, rep("preprocessor 1/1, model 1/1", 3))
# TODO: the second and third have the max trees in them...
# expect_equal(res$.submodels, list(list(trees = grid$trees[-which.max(grid$trees)])))
expect_equal(nrow(res), 1)
expect_equal(res$.iter_preprocessor, 1)
expect_equal(res$.msg_preprocessor, "preprocessor 1/1")
expect_equal(res$trees, max(grid$trees))
expect_equal(res$.iter_model, 1)
expect_equal(res$.iter_config, list(paste0("Preprocessor1_Model", 1:3)))
expect_equal(res$.msg_model, "preprocessor 1/1, model 1/1")

res_post <- res$post[[1]]
expect_equal(res_post$threshold, unique(grid$threshold))
expect_equal(res_post$.iter_postprocessor, 1:3)
expect_equal(
res_post$.msg_postprocessor,
paste0("preprocessor 1/1, model 1/1, postprocessor ", 1:3, "/3")
)
expect_equal(
res_post$.iter_config_post,
list(
paste0("Preprocessor1_Model", 1:3, "_Postprocessor1"),
paste0("Preprocessor1_Model", 1:3, "_Postprocessor2"),
paste0("Preprocessor1_Model", 1:3, "_Postprocessor3")
)
)
expect_named(
res,
c(".iter_config", ".iter_preprocessor", ".iter_model", ".iter_postprocessor",
".iter_config_post", ".msg_preprocessor", ".msg_model", ".msg_postprocessor",
"trees", ".submodels", "threshold"),
c(".iter_config", ".iter_preprocessor", ".iter_model",
".msg_preprocessor", ".msg_model", "trees", ".submodels", "post"),
ignore.order = TRUE
)
expect_equal(nrow(res), 3)
})

tune_grid(wflow, bootstraps(mtcars), grid = grid)
test_that("compute_grid_info - model and postprocessor (with submodels but irregular)", {
library(workflows)
library(parsnip)
library(dials)

spec <- boost_tree(mode = "regression", trees = tune())
tlr <- tailor() %>% adjust_probability_threshold(threshold = tune())

wflow <- workflow()
wflow <- add_model(wflow, spec)
wflow <- add_formula(wflow, mpg ~ .)
wflow <- add_tailor(wflow, tlr)

grid <- grid_regular(extract_parameter_set_dials(wflow), levels = 3)
grid <- grid[c(1:2, 5:nrow(grid)), ]
res <- compute_grid_info(wflow, grid)

skip("does not work--removing some model fits shouldn't increase the number
of rows in the grid")
})

0 comments on commit 4c763ed

Please sign in to comment.