Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/adaptive avg pool #291

Merged
merged 11 commits into from
Oct 18, 2024
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Collate:
'PipeOpTorch.R'
'PipeOpTaskPreprocTorch.R'
'PipeOpTorchActivation.R'
'PipeOpTorchAdaptiveAvgPool.R'
'PipeOpTorchAvgPool.R'
'PipeOpTorchBatchNorm.R'
'PipeOpTorchBlock.R'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ export(ModelDescriptor)
export(PipeOpModule)
export(PipeOpTaskPreprocTorch)
export(PipeOpTorch)
export(PipeOpTorchAdaptiveAvgPool1D)
export(PipeOpTorchAdaptiveAvgPool2D)
export(PipeOpTorchAdaptiveAvgPool3D)
export(PipeOpTorchAvgPool1D)
export(PipeOpTorchAvgPool2D)
export(PipeOpTorchAvgPool3D)
Expand Down
126 changes: 126 additions & 0 deletions R/PipeOpTorchAdaptiveAvgPool.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
PipeOpTorchAdaptiveAvgPool = R6Class("PipeOpTorchAdaptiveAvgPool",
inherit = PipeOpTorch,
public = list(
initialize = function(id, d, param_vals = list()) {
private$.d = assert_int(d, lower = 1, upper = 3)
module_generator = switch(d, nn_adaptive_avg_pool1d, nn_adaptive_avg_pool2d, nn_adaptive_avg_pool3d)
check_vector = make_check_vector(private$.d)
param_set = ps(
output_size = p_uty(custom_check = check_vector, tags = c("required", "train"))
)

super$initialize(
id = id,
param_set = param_set,
param_vals = param_vals,
module_generator = module_generator
)
}
),
private = list(
.shapes_out = function(shapes_in, param_vals, task) {
list(adaptive_avg_output_shape(
shape_in = shapes_in[[1]],
conv_dim = private$.d,
output_size = param_vals$output_size
))
},
.d = NULL
)
)

adaptive_avg_output_shape = function(shape_in, conv_dim, output_size) {
shape_in = assert_integerish(shape_in, min.len = conv_dim, coerce = TRUE)

if (length(output_size) == 1) output_size = rep(output_size, conv_dim)

shape_head = utils::head(shape_in, -conv_dim)
if (length(shape_head) <= 1) warningf("Input tensor does not have batch dimension.")

shape_tail = output_size

c(shape_head, shape_tail)
}

#' @title 1D Adaptive Average Pooling
#'
#' @templateVar id nn_adaptive_avg_pool1d
#' @template pipeop_torch_channels_default
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool1d description
#'
#' @section Parameters:
#' * `output_size` :: `integer(1)`\cr
#' The target output size. A single number.
#'
#' @section Internals:
#' Calls [`nn_adaptive_avg_pool1d()`][torch::nn_adaptive_avg_pool1d] during training.
#' @export
PipeOpTorchAdaptiveAvgPool1D = R6Class("PipeOpTorchAdaptiveAvgPool1D", inherit = PipeOpTorchAdaptiveAvgPool,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @template params_pipelines
initialize = function(id = "nn_adaptive_avg_pool1d", param_vals = list()) {
super$initialize(id = id, d = 1, param_vals = param_vals)
}
)
)

#' @title 2D Adaptive Average Pooling
#'
#' @templateVar id nn_adaptive_avg_pool2d
#' @template pipeop_torch_channels_default
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool2d description
#'
#' @section Parameters:
#' * `output_size` :: `integer()`\cr
#' The target output size. Can be a single number or a vector.
#'
#' @section Internals:
#' Calls [`nn_adaptive_avg_pool2d()`][torch::nn_adaptive_avg_pool2d] during training.
#' @export
PipeOpTorchAdaptiveAvgPool2D = R6Class("PipeOpTorchAdaptiveAvgPool2D", inherit = PipeOpTorchAdaptiveAvgPool,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @template params_pipelines
initialize = function(id = "nn_adaptive_avg_pool2d", param_vals = list()) {
super$initialize(id = id, d = 2, param_vals = param_vals)
}
)
)

#' @title 3D Adaptive Average Pooling
#'
#' @templateVar id nn_adaptive_avg_pool3d
#' @template pipeop_torch_channels_default
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool3d description
#'
#' @section Parameters:
#' * `output_size` :: `integer()`\cr
#' The target output size. Can be a single number or a vector.
#'
#' @section Internals:
#' Calls [`nn_adaptive_avg_pool3d()`][torch::nn_adaptive_avg_pool3d] during training.
#' @export
PipeOpTorchAdaptiveAvgPool3D = R6Class("PipeOpTorchAdaptiveAvgPool3D", inherit = PipeOpTorchAdaptiveAvgPool,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @template params_pipelines
initialize = function(id = "nn_adaptive_avg_pool3d", param_vals = list()) {
super$initialize(id = id, d = 3, param_vals = param_vals)
}
)
)

#' @include zzz.R
register_po("nn_adaptive_avg_pool1d", PipeOpTorchAdaptiveAvgPool1D)
register_po("nn_adaptive_avg_pool2d", PipeOpTorchAdaptiveAvgPool2D)
register_po("nn_adaptive_avg_pool3d", PipeOpTorchAdaptiveAvgPool3D)
6 changes: 3 additions & 3 deletions R/PipeOpTorchAvgPool.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ avg_output_shape = function(shape_in, conv_dim, padding, stride, kernel_size, ce
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool1d description
#' @inherit torch::nnf_avg_pool1d description
#'
#' @section Parameters:
#' * `kernel_size` :: (`integer()`)\cr
Expand Down Expand Up @@ -104,7 +104,7 @@ PipeOpTorchAvgPool1D = R6Class("PipeOpTorchAvgPool1D", inherit = PipeOpTorchAvgP
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool2d description
#' @inherit torch::nnf_avg_pool2d description
#'
#' @inheritSection mlr_pipeops_nn_avg_pool1d Parameters
#'
Expand All @@ -128,7 +128,7 @@ PipeOpTorchAvgPool2D = R6Class("PipeOpTorchAvgPool2D", inherit = PipeOpTorchAvgP
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool3d description
#' @inherit torch::nnf_avg_pool3d description
#'
#' @inheritSection mlr_pipeops_nn_avg_pool1d Parameters
#'
Expand Down
176 changes: 176 additions & 0 deletions man/mlr_pipeops_nn_adaptive_avg_pool1d.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading