Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 committed Nov 28, 2024
1 parent 2fcc8b2 commit 44cfccc
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/testthat/test_CallbackSetUnfreeze.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ check_frozen = torch_callback("check_frozen",
},
on_epoch_end = function() {
if ("epoch" %in% names(self$unfreeze)) {
if ((self$ctx$epoch + 1) %in% self$unfreeze$epoch) {
weights = (self$unfreeze[epoch == (self$ctx$epoch + 1)]$unfreeze)[[1]](names(self$ctx$network$parameters))
walk(self$ctx$network$parameters[weights], function(param) print(!param$requires_grad))
}
if (self$ctx$epoch %in% self$unfreeze$epoch) {
weights = (self$unfreeze[epoch == self$ctx$epoch]$unfreeze)[[1]](names(self$ctx$network$parameters))
walk(self$ctx$network$parameters[weights], function(param) print(param$requires_grad))
Expand All @@ -15,6 +19,10 @@ check_frozen = torch_callback("check_frozen",
on_batch_end = function() {
if ("batch" %in% names(self$unfreeze)) {
batch_num = (self$ctx$epoch - 1) * length(self$ctx$loader_train) + self$ctx$step
if ((batch_num + 1) %in% self$unfreeze$batch) {
weights = (self$unfreeze[batch == (batch_num + 1)]$unfreeze)[[1]](names(self$ctx$network$parameters))
walk(self$ctx$network$parameters[weights], function(param) print(!param$requires_grad))
}
if (batch_num %in% self$unfreeze$batch) {
weights = (self$unfreeze[batch == batch_num]$unfreeze)[[1]](names(self$ctx$network$parameters))
walk(self$ctx$network$parameters[weights], function(param) print(param$requires_grad))
Expand Down

0 comments on commit 44cfccc

Please sign in to comment.