Skip to content

Commit

Permalink
delta method and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jul 17, 2024
1 parent 4954825 commit 17a99a8
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 110 deletions.
20 changes: 17 additions & 3 deletions R/MeasureAbstractCi.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@
#' The resample result.
#' @section Inheriting:
#' To define a new CI method, inherit from the abstract base class and implement the private method:
#' `ci(tbl: data.table, rr: ResampleResult, param_vals: named `list()`) -> numeric(3)`
#' `ci: function(tbl: data.table, rr: ResampleResult, param_vals: named `list()`) -> numeric(3)`
#' Here, `tbl` contains the columns `loss`, `row_id` and `iteration`, which are the pointwise loss,
#' the identifier of the observation and the resampling iteration.
#' It should return a vector containing the `estimate`, `lower` and `upper` boundary in that order.
#' In case the confidence interval is not of the form `(estimate, estimate - z * se, estimate + z * se)`
#' it is also necessary to implement
#' `trafo: function(ci: numeric(3), measure: Measure) -> numeric(3)`
#' Which receives a confidence interval for a pointwise loss (e.g. squared-error) and transforms it according
#' to the transformation `measure$trafo` (e.g. sqrt to go from mse to rmse).
#'
#' @export
MeasureAbstractCi = R6Class("MeasureAbstractCi",
inherit = Measure,
Expand Down Expand Up @@ -97,13 +103,12 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
stopf("CI for Measure '%s' requires one of: %s", self$measure$id, paste0(self$resamplings, sep = ", "))
}


param_vals = self$param_set$get_values()
tbl = rr$obs_loss(self$measure)
names(tbl)[names(tbl) == self$measure$id] = "loss"
ci = private$.ci(tbl, rr, param_vals)
if (!is.null(self$measure$trafo)) {
ci = self$measure$trafo(ci)
ci = private$.trafo(ci)
}
if (param_vals$within_range) {
ci = pmin(pmax(ci, self$range[1L]), self$range[2L])
Expand All @@ -112,6 +117,15 @@ MeasureAbstractCi = R6Class("MeasureAbstractCi",
}
),
private = list(
.trafo = function(ci) {
measure = self$measure
ci[[1]] = measure$trafo$fn(ci[[1]])
halfwidth = (ci[[3]] - ci[[1]])
multiplier = measure$trafo$deriv(ci[[1]])
est_t = measure$trafo$fn(ci[[1]])
ci_t = c(est_t, est_t - halfwidth * multiplier, est_t + halfwidth * multiplier)
set_names(ci_t, names(ci))
},
.score = function(prediction, ...) {
stopf("CI measures must be passed to $aggregate(), not $score()")
},
Expand Down
7 changes: 6 additions & 1 deletion man/mlr_measures_abstract_ci.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 34 additions & 23 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
expect_ci_measure = function(m, rr, symmetric = TRUE) {
m = m$clone(deep = TRUE)
get("expect_measure", envir = .GlobalEnv)(m)
testthat::expect_s3_class(m, "MeasureAbstractCi")
testthat::expect_error(rr$score(m), "$aggregate", fixed = TRUE)
ci = rr$aggregate(m)
checkmate::expect_numeric(ci[[m$id]])
checkmate::expect_numeric(ci[[paste0(m$id, ".lower")]])
checkmate::expect_numeric(ci[[paste0(m$id, ".upper")]])
testthat::expect_true(ci[[m$id]] > ci[[paste0(m$id, ".lower")]])
testthat::expect_true(ci[[m$id]] < ci[[paste0(m$id, ".upper")]])
if (symmetric) {
d1 = ci[[m$id]] - ci[[paste0(m$id, ".lower")]]
d2 = ci[[paste0(m$id, ".upper")]] - ci[[m$id]]
testthat::expect_equal(d1, d2)
}
expect_ci_measure = function(id, resampling, task = tsk("boston_housing"),
symmetric = TRUE, stratum = "chas", ...) {
check = function(m, rr) {
m = m$clone(deep = TRUE)
get("expect_measure", envir = .GlobalEnv)(m)
testthat::expect_s3_class(m, "MeasureAbstractCi")
testthat::expect_error(rr$score(m), "$aggregate", fixed = TRUE)
ci = rr$aggregate(m)
checkmate::expect_numeric(ci[[m$id]])
checkmate::expect_numeric(ci[[paste0(m$id, ".lower")]])
checkmate::expect_numeric(ci[[paste0(m$id, ".upper")]])
testthat::expect_true(ci[[m$id]] > ci[[paste0(m$id, ".lower")]])
testthat::expect_true(ci[[m$id]] < ci[[paste0(m$id, ".upper")]])
if (symmetric && ci[[2]] != m$range[[1L]] && ci[[3]] != m$range[2L]) {
d1 = ci[[m$id]] - ci[[paste0(m$id, ".lower")]]
d2 = ci[[paste0(m$id, ".upper")]] - ci[[m$id]]
testthat::expect_equal(d1, d2)
}

m$param_set$values$alpha = 0.05
ci1 = rr$aggregate(m)
m$param_set$values$alpha = 0.5
ci2 = rr$aggregate(m)

m$param_set$values$alpha = 0.05
ci1 = rr$aggregate(m)
m$param_set$values$alpha = 0.5
ci2 = rr$aggregate(m)
expect_equal(ci1[1L], ci2[1L])
expect_true(ci2[2L] >= ci1[2L])
expect_true(ci2[3L] <= ci1[3L])
}
rr = resample(task, lrn("regr.featureless"), resampling)
check(msr(id, measure = "regr.rmse", within_range = FALSE), rr)
check(msr(id, measure = "regr.mse", within_range = FALSE), rr)

expect_equal(ci1[1L], ci2[1L])
expect_true(ci2[2L] >= ci1[2L])
expect_true(ci2[3L] <= ci1[3L])
task$col_roles$stratum = "chas"
rr_strat = resample(task, lrn("regr.featureless"), resampling)
check(msr(id, measure = "regr.rmse", within_range = FALSE), rr)
check(msr(id, measure = "regr.mse", within_range = FALSE), rr)
}

13 changes: 1 addition & 12 deletions tests/testthat/test_MeasureCIConZ.R
Original file line number Diff line number Diff line change
@@ -1,15 +1,4 @@
test_that("basic", {
withr::local_seed(1)
mci = msr("ci.con_z", "regr.mae")
rr = resample(tsk("boston_housing"), lrn("regr.featureless"), rsmp("paired_subsampling", repeats_in = 5, repeats_out = 10))
expect_ci_measure(mci, rr)
})

test_that("stratification", {
withr::local_seed(1)
mci = msr("ci.con_z", "regr.mae")
task = tsk("boston_housing")
task$col_roles$stratum = "chas"
rr = resample(task, lrn("regr.featureless"), rsmp("paired_subsampling", repeats_in = 5, repeats_out = 10))
expect_ci_measure(mci, rr)
expect_ci_measure("ci.con_z", rsmp("paired_subsampling", repeats_in = 5, repeats_out = 10))
})
12 changes: 1 addition & 11 deletions tests/testthat/test_MeasureCIHoldout.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
test_that("simple", {
mci = msr("ci.holdout", "regr.mse")
rr = resample(tsk("boston_housing"), lrn("regr.featureless"), rsmp("holdout"))
expect_ci_measure(mci, rr)
})

test_that("stratification", {
mci = msr("ci.holdout", "regr.mse")
task = tsk("boston_housing")
task$col_roles$stratum = "chas"
rr = resample(task, lrn("regr.featureless"), rsmp("holdout"))
expect_ci_measure(mci, rr)
expect_ci_measure("ci.holdout", rsmp("holdout"))
})
30 changes: 4 additions & 26 deletions tests/testthat/test_MeasureCINaiveCV.R
Original file line number Diff line number Diff line change
@@ -1,28 +1,6 @@
test_that("basic", {
task = tsk("mtcars")
learner = lrn("regr.featureless")

mci = msr("ci.naive_cv", "regr.mse", variance = "all-pairs")
rr = resample(task, learner, rsmp("cv"))
expect_ci_measure(mci, rr)

mci = msr("ci.naive_cv", "regr.mse", variance = "within-fold")
rr = resample(task, learner, rsmp("cv"))
expect_ci_measure(mci, rr)

mci = msr("ci.naive_cv", "regr.mse")
rr = resample(task, learner, rsmp("loo"))
expect_ci_measure(mci, rr)

mci = msr("ci.naive_cv", "regr.mse", variance = "within-fold")
rr = resample(task, learner, rsmp("loo"))
expect_error(rr$aggregate(mci), "LOO")
})

test_that("stratification", {
mci = msr("ci.naive_cv", "regr.mse")
task = tsk("boston_housing")
task$col_roles$stratum = "chas"
rr = resample(task, lrn("regr.featureless"), rsmp("cv"))
expect_ci_measure(mci, rr)
expect_ci_measure("ci.naive_cv", rsmp("cv"), variance = "all-pairs")
expect_ci_measure("ci.naive_cv", rsmp("cv"), variance = "within-fold")
expect_ci_measure("ci.naive_cv", rsmp("loo"), variance = "all-pairs")
expect_ci_measure("ci.naive_cv", rsmp("loo"), variance = "within-fold")
})
11 changes: 1 addition & 10 deletions tests/testthat/test_MeasureCi.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,5 @@ test_that("basic", {
ci1 = rr$aggregate(msr("ci", "classif.acc"))
ci2 = rr$aggregate(msr("ci.holdout", "classif.acc"))
expect_equal(ci1, ci2)
mci = msr("ci", "classif.acc")
expect_ci_measure(mci, rr)
})

test_that("obs_loss with trafo", {
withr::local_seed(1)
rr = resample(tsk("boston_housing"), lrn("regr.featureless"), rsmp("cv"))
ci = rr$aggregate(msr("ci.naive_cv", "regr.rmse"))
expect_ci_measure(msr("ci.naive_cv", "regr.rmse"), rr, symmetric = FALSE)
expect_ci_measure(msr("ci", "regr.rmse"), rr, symmetric = FALSE)
expect_ci_measure("ci", rsmp("holdout"))
})
12 changes: 1 addition & 11 deletions tests/testthat/test_MeasureCiCorT.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
test_that("simple", {
mci = msr("ci.cor_t", "regr.mse")
rr = resample(tsk("boston_housing"), lrn("regr.featureless"), rsmp("subsampling", repeats = 10))
expect_ci_measure(mci, rr)
})

test_that("simple", {
mci = msr("ci.cor_t", "regr.mse")
task = tsk("boston_housing")
task$col_roles$stratum = "chas"
rr = resample(task, lrn("regr.featureless"), rsmp("subsampling", repeats = 10))
expect_ci_measure(mci, rr)
expect_ci_measure("ci.cor_t", rsmp("subsampling", repeats = 10L))
})
19 changes: 6 additions & 13 deletions tests/testthat/test_MeasureCiNestedCV.R
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
test_that("basic", {
withr::local_seed(1)
mci = msr("ci.ncv", "classif.acc")
rr = resample(tsk("iris"), lrn("classif.featureless"), rsmp("nested_cv", repeats = 20, folds = 5))
expect_ci_measure(mci, rr)
})

test_that("stratification", {
withr::local_seed(1)
mci = msr("ci.ncv", "classif.acc")
task = tsk("iris")
task$col_roles$stratum = "Species"
rr = resample(task, lrn("classif.featureless"), rsmp("nested_cv", repeats = 20, folds = 5))
expect_ci_measure(mci, rr)
task = tsk("mtcars")$cbind(data.frame(chas = rep(c("a", "b"), times = 16)))
expect_ci_measure(
"ci.ncv",
rsmp("nested_cv", folds = 3L, repeats = 5L),
task = task
)
})

0 comments on commit 17a99a8

Please sign in to comment.