Skip to content

Commit

Permalink
feat: add check_enough_data
Browse files Browse the repository at this point in the history
  • Loading branch information
dshemetov committed Jan 19, 2024
1 parent cf65b1b commit ad74faa
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 31 deletions.
52 changes: 31 additions & 21 deletions R/check_enough_data.R → R/check_enough_train_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,18 @@ check_enough_train_data <-
...,
n,
epi_keys = NULL,
drop_na = TRUE,
role = NA,
trained = FALSE,
columns = NULL,
skip = FALSE,
skip = TRUE,
id = rand_id("enough_train_data")) {
add_check(
recipe,
check_enough_train_data_new(
n = n,
epi_keys = epi_keys,
drop_na = drop_na,
terms = rlang::enquos(...),
role = role,
trained = trained,
Expand All @@ -88,12 +90,13 @@ check_enough_train_data <-
}

check_enough_train_data_new <-
function(n, epi_keys, terms, role, trained, columns, skip, id) {
function(n, epi_keys, drop_na, terms, role, trained, columns, skip, id) {
check(
subclass = "enough_train_data",
prefix = "check_",
n = n,
epi_keys = epi_keys,
drop_na = drop_na,
terms = terms,
role = role,
trained = trained,
Expand All @@ -107,9 +110,33 @@ check_enough_train_data_new <-
prep.check_enough_train_data <- function(x, training, info = NULL, ...) {
col_names <- recipes_eval_select(x$terms, training, info)

cols_not_enough_data <- purrr::map(col_names, function(col) {
groups_below_thresh <- training %>%
dplyr::select(all_of(c(epi_keys(training), col))) %>%
{
if (x$drop_na) {
tidyr::drop_na(.)
} else {
.
}
} %>%
dplyr::count(dplyr::across(dplyr::all_of(x$epi_keys))) %>%
dplyr::filter(n < x$n)
if (nrow(groups_below_thresh) > 0) {
col
}
}) %>% purrr::keep(~ !is.null(.))

if (length(cols_not_enough_data) > 0) {
cli::cli_abort(
"The following columns don't have enough data to predict: {cols_not_enough_data}."
)
}

check_enough_train_data_new(
n = x$n,
epi_keys = x$epi_keys,
drop_na = x$drop_na,
terms = x$terms,
role = x$role,
trained = TRUE,
Expand All @@ -121,25 +148,6 @@ prep.check_enough_train_data <- function(x, training, info = NULL, ...) {

#' @export
bake.check_enough_train_data <- function(object, new_data, ...) {
col_names <- object$columns
check_new_data(col_names, object, new_data)

cols_not_enough_data <- purrr::map(col_names, function(col) {
groups_below_thresh <- new_data %>%
dplyr::select(all_of(c(epi_keys(new_data), col))) %>%
tidyr::drop_na() %>%
dplyr::count(dplyr::across(dplyr::all_of(object$epi_keys))) %>%
dplyr::filter(n < object$n)
if (nrow(groups_below_thresh) > 0) {
col
}
}) %>% purrr::keep(~ !is.null(.))

if (length(cols_not_enough_data) > 0) {
cli::cli_abort(
"The following columns don't have enough data to predict: {cols_not_enough_data}."
)
}
new_data
}

Expand All @@ -160,5 +168,7 @@ tidy.check_enough_train_data <- function(x, ...) {
}
res$id <- x$id
res$n <- x$n
res$epi_keys <- x$epi_keys
res$drop_na <- x$drop_na
res
}
5 changes: 3 additions & 2 deletions man/check_enough_train_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ test_that("check_enough_train_data", {
# Setup toy data
n <- 10
toy_epi_df <- tibble::tibble(
time_value = rep(seq(as.Date("2020-01-01"),
by = 1,
length.out = n
), times = 2),
time_value = rep(
seq(
as.Date("2020-01-01"),
by = 1,
length.out = n
),
times = 2
),
geo_value = rep(c("ca", "hi"), each = n),
x = c(1:n, c(1:(n - 2), NA, NA)),
y = 1:(2 * n)
Expand Down Expand Up @@ -57,12 +61,21 @@ test_that("check_enough_train_data", {
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL)
)
# Check column with NAs counts the NAs if drop_na=TRUE, without geo pooling
expect_no_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(x, n = n - 1, epi_keys = "geo_value", drop_na = FALSE) %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL)
)

# Sanity check the output of a passing recipe
p <- epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(x, y, n = 2 * n - 2) %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL)
expect_no_error(
p <- epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(x, y, n = 2 * n - 2) %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = NULL)
)

expect_equal(nrow(p), 2 * n)
expect_equal(ncol(p), 4L)
Expand All @@ -73,4 +86,49 @@ test_that("check_enough_train_data", {
rep(seq(as.Date("2020-01-01"), by = 1, length.out = n), times = 2)
)
expect_equal(p$geo_value, rep(c("ca", "hi"), each = n))

# Check that the train data has enough data, the test data does not, but
# the check passes anyway (because it should be applied to training data)
n_ <- n - 2
toy_test_data <- tibble::tibble(
time_value = rep(
seq(
as.Date("2020-01-01"),
by = 1,
length.out = n_
),
times = 2
),
geo_value = rep(c("ca", "hi"), each = n_),
x = c(1:n_, c(1:(n_ - 2), NA, NA)),
y = 1:(2 * n_)
) %>% epiprocess::as_epi_df()
expect_no_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(y, n = n - 1, epi_keys = "geo_value") %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = toy_test_data)
)
# Same thing, but skip = FALSE
n_ <- n - 2
toy_test_data <- tibble::tibble(
time_value = rep(
seq(
as.Date("2020-01-01"),
by = 1,
length.out = n_
),
times = 2
),
geo_value = rep(c("ca", "hi"), each = n_),
x = c(1:n_, c(1:(n_ - 2), NA, NA)),
y = 1:(2 * n_)
) %>% epiprocess::as_epi_df()
expect_no_error(
epi_recipe(y ~ x, data = toy_epi_df) %>%
check_enough_train_data(y, n = n - 1, epi_keys = "geo_value", skip = FALSE) %>%
recipes::prep(toy_epi_df) %>%
recipes::bake(new_data = toy_test_data)
)
browser()
})

0 comments on commit ad74faa

Please sign in to comment.