Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Nov 27, 2024
1 parent 05a985c commit 86e64b9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ LearnerTorch = R6Class("LearnerTorch",
if (self$state$param_vals$patience == 0) {
named_list()
} else {
list(epochs = self$model$epochs - self$state$param_vals$patience)
list(epochs = self$model$epochs - self$state$param_vals$patience * self$state$param_vals$eval_freq)
}
},
.extract_internal_valid_scores = function() {
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/test_LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,14 @@ test_that("early stopping works", {

learner$train(task)
# the first evaluation can do no comparison, i.e. the second eval with no improvement is the third epoch
expect_equal(learner$internal_tuned_values, list(epochs = 9))
expect_equal(learner$internal_tuned_values, list(epochs = 3))

# in this scenario early stopping should definitely not trigger yet
learner$param_set$set_values(
min_delta = 0, patience = 5, opt.lr = 0.01, eval_freq = 1
)
learner$train(task)
expect_equal(learner$internal_tuned_values, list(epochs = 10))
expect_equal(learner$internal_tuned_values, list(epochs = 1))
})

test_that("validation works", {
Expand Down Expand Up @@ -579,9 +579,9 @@ test_that("internal tuning", {
term_evals = 2
)
expect_equal(
ti$archive$data$internal_tuned_values, replicate(list(list(epochs = 9L)), n = 2L)
ti$archive$data$internal_tuned_values, replicate(list(list(epochs = 3L)), n = 2L)
)
expect_equal(ti$result_learner_param_vals$epochs, 9L)
expect_equal(ti$result_learner_param_vals$epochs, 3L)
})


Expand Down

0 comments on commit 86e64b9

Please sign in to comment.