diff --git a/R/CallbackSetUnfreeze.R b/R/CallbackSetUnfreeze.R index 87e92682..e095f944 100644 --- a/R/CallbackSetUnfreeze.R +++ b/R/CallbackSetUnfreeze.R @@ -3,10 +3,7 @@ #' @name mlr_callback_set.unfreeze #' #' @description -#' Unfreeze some weights after some number of steps or epochs. Select either a given module or a parameter. -#' -#' @details -#' TODO: add +#' Unfreeze some weights (parameters of the network) after some number of steps or epochs. #' #' @param starting_weights (`Select`)\cr #' A `Select` denoting the weights that are trainable from the start. @@ -30,8 +27,10 @@ CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze", #' @description #' Sets the starting weights on_begin = function() { - weights = select_invert(self$starting_weights)(names(self$ctx$network$parameters)) - walk(self$ctx$network$parameters[weights], function(param) param$requires_grad_(FALSE)) + trainable_weights = self$starting_weights(names(self$ctx$network$parameters)) + walk(self$ctx$network$parameters[trainable_weights], function(param) param$requires_grad_(TRUE)) + frozen_weights = select_invert(self$starting_weights)(names(self$ctx$network$parameters)) + walk(self$ctx$network$parameters[frozen_weights], function(param) param$requires_grad_(FALSE)) }, #' @description #' Unfreezes weights if the training is at the correct epoch @@ -54,8 +53,7 @@ CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze", } } } - ), - private = list() + ) ) #' @include TorchCallback.R diff --git a/R/Select.R b/R/Select.R index 9a167a53..766d196a 100644 --- a/R/Select.R +++ b/R/Select.R @@ -3,9 +3,7 @@ #' @name Select #' #' @description -#' A [`Select`] function is used by the callback `CallbackSetUnfreeze` to determine a subset of parameters to freeze or unfreeze during training. -#' -#' @section Details: +#' A [`Select`] function subsets a character vector. They are used by the callback `CallbackSetUnfreeze` to select parameters to freeze or unfreeze during training. #' ... NULL @@ -49,7 +47,7 @@ select_grep = function(pattern, ignore.case = FALSE, perl = FALSE, fixed = FALSE #' @describeIn Select `select_name` selects parameters with names matching the given names #' @export -select_name = function(param_names, assert_present = FALSE) { +select_name = function(param_names, assert_present = TRUE) { assert_character(param_names, any.missing = FALSE) assert_flag(assert_present) str_assert_present = if (assert_present) ", assert_present = TRUE" else "" diff --git a/tests/testthat/test_Select.R b/tests/testthat/test_Select.R index 3f353daa..f5145e80 100644 --- a/tests/testthat/test_Select.R +++ b/tests/testthat/test_Select.R @@ -1,14 +1,5 @@ test_that("selectors work", { - n_epochs = 1 - - task = tsk("iris") - - mlp = lrn("classif.mlp", - epochs = 10, batch_size = 150, neurons = c(100, 200, 300) - ) - mlp$train(task) - - all_params = names(mlp$network$parameters) + all_params = c("0.weight", "0.bias", "3.weight", "3.bias", "6.weight", "6.bias", "9.weight", "9.bias") expect_equal(selectorparam_none()(all_params), character(0)) expect_equal(selectorparam_all()(all_params), all_params)