Skip to content

Commit

Permalink
some changes from PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 committed Nov 29, 2024
1 parent 7a2ae78 commit 5663df4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 22 deletions.
14 changes: 6 additions & 8 deletions R/CallbackSetUnfreeze.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
#' @name mlr_callback_set.unfreeze
#'
#' @description
#' Unfreeze some weights after some number of steps or epochs. Select either a given module or a parameter.
#'
#' @details
#' TODO: add
#' Unfreeze some weights (parameters of the network) after some number of steps or epochs.
#'
#' @param starting_weights (`Select`)\cr
#' A `Select` denoting the weights that are trainable from the start.
Expand All @@ -30,8 +27,10 @@ CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze",
#' @description
#' Sets the starting weights
on_begin = function() {
weights = select_invert(self$starting_weights)(names(self$ctx$network$parameters))
walk(self$ctx$network$parameters[weights], function(param) param$requires_grad_(FALSE))
trainable_weights = self$starting_weights(names(self$ctx$network$parameters))
walk(self$ctx$network$parameters[trainable_weights], function(param) param$requires_grad_(TRUE))
frozen_weights = select_invert(self$starting_weights)(names(self$ctx$network$parameters))
walk(self$ctx$network$parameters[frozen_weights], function(param) param$requires_grad_(FALSE))
},
#' @description
#' Unfreezes weights if the training is at the correct epoch
Expand All @@ -54,8 +53,7 @@ CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze",
}
}
}
),
private = list()
)
)

#' @include TorchCallback.R
Expand Down
6 changes: 2 additions & 4 deletions R/Select.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
#' @name Select
#'
#' @description
#' A [`Select`] function is used by the callback `CallbackSetUnfreeze` to determine a subset of parameters to freeze or unfreeze during training.
#'
#' @section Details:
#' A [`Select`] function subsets a character vector. They are used by the callback `CallbackSetUnfreeze` to select parameters to freeze or unfreeze during training.
#' ...
NULL

Expand Down Expand Up @@ -49,7 +47,7 @@ select_grep = function(pattern, ignore.case = FALSE, perl = FALSE, fixed = FALSE

#' @describeIn Select `select_name` selects parameters with names matching the given names
#' @export
select_name = function(param_names, assert_present = FALSE) {
select_name = function(param_names, assert_present = TRUE) {
assert_character(param_names, any.missing = FALSE)
assert_flag(assert_present)
str_assert_present = if (assert_present) ", assert_present = TRUE" else ""
Expand Down
11 changes: 1 addition & 10 deletions tests/testthat/test_Select.R
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
test_that("selectors work", {
n_epochs = 1

task = tsk("iris")

mlp = lrn("classif.mlp",
epochs = 10, batch_size = 150, neurons = c(100, 200, 300)
)
mlp$train(task)

all_params = names(mlp$network$parameters)
all_params = c("0.weight", "0.bias", "3.weight", "3.bias", "6.weight", "6.bias", "9.weight", "9.bias")

expect_equal(selectorparam_none()(all_params), character(0))
expect_equal(selectorparam_all()(all_params), all_params)
Expand Down

0 comments on commit 5663df4

Please sign in to comment.