Skip to content

Commit

Permalink
Merge pull request #271 from cmu-delphi/resid-hotfix
Browse files Browse the repository at this point in the history
simplify grab_residuals
  • Loading branch information
dajmcdon authored Dec 20, 2023
2 parents 378577a + 3038433 commit b5121c3
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 6 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.6
Version: 0.0.7
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# epipredict (development)

# epipredict 0.0.7

* simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`

# epipredict 0.0.6

* rename the `dist_quantiles()` to be more descriptive, breaking change)
Expand Down
8 changes: 3 additions & 5 deletions R/layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,10 @@ slather.layer_residual_quantiles <-

grab_residuals <- function(the_fit, components) {
if (the_fit$spec$mode != "regression") {
rlang::abort("For meaningful residuals, the predictor should be a regression model.")
cli::cli_abort("For meaningful residuals, the predictor should be a regression model.")
}
r_generic <- attr(utils::methods(class = class(the_fit$fit)[1]), "info")$generic
if ("residuals" %in% r_generic) { # Try to use the available method.
cl <- class(the_fit$fit)[1]
r <- residuals(the_fit$fit)
r <- stats::residuals(the_fit$fit)
if (!is.null(r)) { # Got something from the method
if (inherits(r, "data.frame")) {
if (".resid" %in% names(r)) { # success
return(r)
Expand Down
19 changes: 19 additions & 0 deletions tests/testthat/test-layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,22 @@ test_that("Returns expected number or rows and columns", {
expect_equal(nrow(unnested), 9L)
expect_equal(unique(unnested$quantile_levels), c(.0275, .8, .95))
})


test_that("Errors when used with a classifier", {
tib <- tibble(
y = factor(rep(c("a", "b"), length.out = 100)),
x1 = rnorm(100),
x2 = rnorm(100),
time_value = 1:100,
geo_value = "ak"
) %>% as_epi_df()

r <- epi_recipe(y ~ x1 + x2, data = tib)
wf <- epi_workflow(r, parsnip::logistic_reg()) %>% fit(tib)
f <- frosting() %>%
layer_predict() %>%
layer_residual_quantiles()
wf <- wf %>% add_frosting(f)
expect_error(predict(wf, tib))
})

0 comments on commit b5121c3

Please sign in to comment.