From 43956acbeec76de9f4f20cda3963cd76deae5b4c Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 26 Nov 2024 08:40:56 +0100 Subject: [PATCH] feat(mlp): add n_layers parameter --- NEWS.md | 1 + R/LearnerTorchMLP.R | 13 ++++++++++++- man/mlr_learners.mlp.Rd | 2 ++ tests/testthat/test_LearnerTorchMLP.R | 18 +++++++++++++++++- 4 files changed, 32 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index ab97ff93..f45b16c0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,7 @@ * perf: Use a faster image loader * feat: Add parameter `num_interop_threads` to `LearnerTorch` * feat: Add adaptive average pooling +* feat: Added `n_layers` parameter to MLP # mlr3torch 0.1.2 diff --git a/R/LearnerTorchMLP.R b/R/LearnerTorchMLP.R index e1079293..4b2a955f 100644 --- a/R/LearnerTorchMLP.R +++ b/R/LearnerTorchMLP.R @@ -22,6 +22,8 @@ #' * `neurons` :: `integer()`\cr #' The number of neurons per hidden layer. By default there is no hidden layer. #' Setting this to `c(10, 20)` would have a the first hidden layer with 10 neurons and the second with 20. +#' * `n_layers` :: `integer()`\cr +#' The number of layers. This parameter must only be set when `neurons` has length 1. #' * `p` :: `numeric(1)`\cr #' The dropout probability. Is initialized to `0.5`. #' * `shape` :: `integer()` or `NULL`\cr @@ -48,6 +50,7 @@ LearnerTorchMLP = R6Class("LearnerTorchMLP", param_set = ps( neurons = p_uty(tags = c("train", "predict"), custom_check = check_neurons), p = p_dbl(lower = 0, upper = 1, tags = "train"), + n_layers = p_int(lower = 1L, tags = "train"), activation = p_uty(tags = c("required", "train"), custom_check = check_nn_module), activation_args = p_uty(tags = c("required", "train"), custom_check = check_activation_args), shape = p_uty(tags = "train", custom_check = check_shape) @@ -127,8 +130,16 @@ single_lazy_tensor = function(task) { } # shape is (NA, x) if preesnt -make_mlp = function(task, d_in, d_out, activation, neurons = integer(0), p, activation_args, ...) { +make_mlp = function(task, d_in, d_out, activation, neurons = integer(0), p, activation_args, n_layers = NULL, ...) { # This way, dropout_args will have length 0 if p is `NULL` + + if (!is.null(n_layers)) { + if (length(neurons) != 1L) { + stopf("Can only supply `n_layers` when neurons has length 1.") + } + neurons = rep(neurons, n_layers) + } + dropout_args = list() dropout_args$p = p prev_dim = d_in diff --git a/man/mlr_learners.mlp.Rd b/man/mlr_learners.mlp.Rd index 6eb586aa..510d0e97 100644 --- a/man/mlr_learners.mlp.Rd +++ b/man/mlr_learners.mlp.Rd @@ -43,6 +43,8 @@ This is intialized to an empty list. \item \code{neurons} :: \code{integer()}\cr The number of neurons per hidden layer. By default there is no hidden layer. Setting this to \code{c(10, 20)} would have a the first hidden layer with 10 neurons and the second with 20. +\item \code{n_layers} :: \code{integer()}\cr +The number of layers. This parameter must only be set when \code{neurons} has length 1. \item \code{p} :: \code{numeric(1)}\cr The dropout probability. Is initialized to \code{0.5}. \item \code{shape} :: \code{integer()} or \code{NULL}\cr diff --git a/tests/testthat/test_LearnerTorchMLP.R b/tests/testthat/test_LearnerTorchMLP.R index 1869e554..36b34e2f 100644 --- a/tests/testthat/test_LearnerTorchMLP.R +++ b/tests/testthat/test_LearnerTorchMLP.R @@ -52,4 +52,20 @@ test_that("works for lazy tensor", { expect_class(pred, "Prediction") }) -# TODO: More tests +test_that("neurons and n_layers", { + l1 = lrn("classif.mlp", batch_size = 32, epochs = 0L) + l2 = l1$clone(deep = TRUE) + task = tsk("iris") + l1$param_set$set_values(neurons = c(10, 10)) + l2$param_set$set_values(neurons = 10, n_layers = 2) + l1$train(task) + l2$train(task) + expect_equal(l1$network$parameters[[1]]$shape, l2$network$parameters[[1]]$shape) + expect_equal(l1$network$parameters[[3]]$shape, l2$network$parameters[[3]]$shape) + expect_equal(l1$network$parameters[[1]]$shape, c(10, 4)) + expect_equal(l1$network$parameters[[3]]$shape, c(3, 10)) + + l1$param_set$set_values(n_layers = 2) + expect_error(l2$train(task), "Can only supply") +}) +