Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Oct 15, 2024
1 parent 238c8f4 commit 6907072
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ LearnerTorch = R6Class("LearnerTorch",
if (identical(param_vals$seed, "random")) param_vals$seed = sample.int(.Machine$integer.max, 1)

model = with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads,
num_interop_threads = param_vals$num_threads_interop, expr = {
num_interop_threads = param_vals$num_interop_threads, expr = {
learner_torch_train(self, private, super, task, param_vals)
})
model$task_col_info = copy(task$col_info[c(task$feature_names, task$target_names), c("id", "type", "levels")])
Expand All @@ -455,7 +455,7 @@ LearnerTorch = R6Class("LearnerTorch",
private$.verify_predict_task(task, param_vals)

with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads,
num_interop_threads = param_vals$num_threads_interop, expr = {
num_interop_threads = param_vals$num_interop_threads, expr = {
learner_torch_predict(self, private, super, task, param_vals)
})
},
Expand Down
2 changes: 1 addition & 1 deletion R/paramset_torchlearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ paramset_torchlearner = function(task_type) {
aggr = epochs_aggr, in_tune_fn = epochs_tune_fn, disable_in_tune = list(patience = 0)),
device = p_fct(tags = c("train", "predict", "required"), levels = mlr_reflections$torch$devices, init = "auto"),
num_threads = p_int(lower = 1L, tags = c("train", "predict", "required", "threads"), init = 1L),
num_interop_threads = p_int(lower = 1L, tags = c("train", "predict", "required", "threads"), init = 1L),
num_interop_threads = p_int(lower = 1L, tags = c("train", "predict", "required"), init = 1L),
seed = p_int(tags = c("train", "predict", "required"), special_vals = list("random", NULL), init = "random"),
# evaluation
eval_freq = p_int(lower = 1L, tags = c("train", "required"), init = 1L),
Expand Down

0 comments on commit 6907072

Please sign in to comment.