Skip to content

Commit

Permalink
feat: regr.featureless quantile prediction (#1125)
Browse files Browse the repository at this point in the history
* feat: regr.featuresless quantile prediction

* docs: update

* fix: pass quantiles to fallback

* feat: set default fallback with set_fallback

* chore: news

* tests: set_fallback quantiles

* refactor: switch to default_fallback

* fix: arguments

* chore: rename

---------

Co-authored-by: Bernd Bischl <[email protected]>
  • Loading branch information
be-marc and berndbischl authored Aug 31, 2024
1 parent 1f51d3c commit 5ffcfee
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 3 deletions.
25 changes: 23 additions & 2 deletions R/LearnerRegrFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ LearnerRegrFeatureless = R6Class("LearnerRegrFeatureless", inherit = LearnerRegr
super$initialize(
id = "regr.featureless",
feature_types = unname(mlr_reflections$task_feature_types),
predict_types = c("response", "se"),
predict_types = c("response", "se", "quantiles"),
param_set = ps,
properties = c("featureless", "missings", "importance", "selected_features"),
packages = "stats",
Expand Down Expand Up @@ -61,18 +61,39 @@ LearnerRegrFeatureless = R6Class("LearnerRegrFeatureless", inherit = LearnerRegr
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
x = task$data(cols = task$target_names)[[1L]]

quantiles = if (self$predict_type == "quantiles") {
if (is.null(private$.quantiles) || is.null(private$.quantile_response)) {
stop("Quantiles '$quantiles' and response quantile '$quantile_response' must be set")
}
quantile(x, probs = private$.quantiles)
}

if (isFALSE(pv$robust)) {
location = mean(x)
dispersion = sd(x)
} else {
location = stats::median(x)
dispersion = stats::mad(x, center = location)
}
set_class(list(location = location, dispersion = dispersion, features = task$feature_names), "regr.featureless_model")

set_class(list(
location = location,
dispersion = dispersion,
quantiles = quantiles,
features = task$feature_names), "regr.featureless_model")
},

.predict = function(task) {
n = task$nrow

if (self$predict_type == "quantiles") {
quantiles = matrix(rep(self$model$quantiles, n), nrow = n, byrow = TRUE)
attr(quantiles, "probs") = private$.quantiles
attr(quantiles, "response") = private$.quantile_response
return(list(quantiles = quantiles))
}

response = rep(self$model$location, n)
se = if (self$predict_type == "se") rep(self$model$dispersion, n) else NULL
list(response = response, se = se)
Expand Down
2 changes: 1 addition & 1 deletion man/mlr_learners_regr.featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,15 @@ test_that("fallback = default_fallback() works", {

expect_class(fallback, "LearnerRegrFeatureless")
expect_equal(fallback$predict_type, "se")

learner = lrn("regr.debug",
predict_type = "quantiles",
quantiles = c(0.1, 0.9),
quantile_response = 0.1)
fallback = default_fallback(learner)

expect_class(fallback, "LearnerRegrFeatureless")
expect_equal(fallback$predict_type, "quantiles")
expect_equal(fallback$quantiles, c(0.1, 0.9))
expect_equal(fallback$quantile_response, 0.1)
})
24 changes: 24 additions & 0 deletions tests/testthat/test_mlr_learners_regr_featureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,27 @@ test_that("regr.featureless works on featureless task", {
expect_resample_result(rr)
expect_number(rr$aggregate())
})

test_that("regr.featureless quantile prediction works", {
task = tsk("mtcars")

learner = lrn("regr.featureless",
predict_type = "quantiles",
quantiles = c(0.1, 0.5, 0.9),
quantile_response = 0.5)

learner$train(task)
expect_numeric(learner$model$quantiles, len = 3L)

pred = learner$predict(task)
expect_prediction(pred)
expect_subset("quantiles", pred$predict_types)
expect_matrix(pred$quantiles, ncols = 3L, nrows = task$nrow, any.missing = FALSE)
expect_names(colnames(pred$quantiles), identical.to = c("q0.1", "q0.5", "q0.9"))
expect_equal(pred$response, pred$quantiles[, 2L])

learner = lrn("regr.featureless",
predict_type = "quantiles")

expect_error(learner$train(task), "Quantiles")
})

0 comments on commit 5ffcfee

Please sign in to comment.