Skip to content

Commit

Permalink
sperate preproc and trafo class, generate pipeops programatically
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Oct 13, 2023
1 parent e5c5a32 commit ca0c6c9
Show file tree
Hide file tree
Showing 19 changed files with 554 additions and 404 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ Collate:
'PipeOpTorchHead.R'
'PipeOpTorchIngress.R'
'PipeOpTorchLayerNorm.R'
'PipeOpTorchLazyTransform.R'
'PipeOpTorchLazyPreproc.R'
'PipeOpTorchLazyTrafo.R'
'PipeOpTorchLinear.R'
'TorchLoss.R'
'PipeOpTorchLoss.R'
Expand All @@ -128,6 +129,7 @@ Collate:
'nn_graph.R'
'paramset_torchlearner.R'
'rd_info.R'
'register_lazy.R'
'reset_last_layer.R'
'task_dataset.R'
'utils.R'
Expand Down
4 changes: 2 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export(LearnerTorchMLP)
export(LearnerTorchModel)
export(ModelDescriptor)
export(PipeOpModule)
export(PipeOpTaskPreprocLazy)
export(PipeOpTorch)
export(PipeOpTorchAvgPool1D)
export(PipeOpTorchAvgPool2D)
Expand Down Expand Up @@ -94,7 +95,7 @@ export(PipeOpTorchIngressImage)
export(PipeOpTorchIngressLazyTensor)
export(PipeOpTorchIngressNumeric)
export(PipeOpTorchLayerNorm)
export(PipeOpTorchLazyTransform)
export(PipeOpTorchLazyTrafo)
export(PipeOpTorchLeakyReLU)
export(PipeOpTorchLinear)
export(PipeOpTorchLogSigmoid)
Expand Down Expand Up @@ -125,7 +126,6 @@ export(PipeOpTorchSqueeze)
export(PipeOpTorchTanh)
export(PipeOpTorchTanhShrink)
export(PipeOpTorchThreshold)
export(PipeOpTorchTransformResize)
export(PipeOpTorchUnsqueeze)
export(ResamplingRowRoles)
export(TorchCallback)
Expand Down
7 changes: 4 additions & 3 deletions R/PipeOpModule.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ PipeOpModule = R6Class("PipeOpModule",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @template param_id
#' @param module ([`nn_module`])\cr
#' The torch module that is being wrapped.
#' @param module ([`nn_module`] or `function()`)\cr
#' The torch module or function that is being wrapped.
#' @param inname (`character()`)\cr
#' The names of the input channels.
#' @param outname (`character()`)\cr
Expand All @@ -95,7 +95,8 @@ PipeOpModule = R6Class("PipeOpModule",
initialize = function(id = "module", module = nn_identity(), inname = "input", outname = "output",
param_vals = list(), packages = character(0)) {
private$.multi_output = length(outname) > 1L
self$module = assert_class(module, "nn_module")
self$module = assert(check_class(module, "nn_module"), check_class(module, "function"), combine = "or")
self$module = module
assert_names(outname, type = "strict")
assert_character(packages, any.missing = FALSE)

Expand Down
17 changes: 11 additions & 6 deletions R/PipeOpTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,16 @@ PipeOpTorch = R6Class("PipeOpTorch",
#' [`PipeOp`] during training must return a named `list()`, where the names of the list are the
#' names out the output channels. The default is `"output"`.
initialize = function(id, module_generator, param_set = ps(), param_vals = list(),
inname = "input", outname = "output", packages = "torch", tags = NULL, variant_types = FALSE) {
inname = "input", outname = "output", packages = "torch", tags = NULL) {
self$module_generator = assert_class(module_generator, "nn_module_generator", null.ok = TRUE)
assert_character(inname, .var.name = "input channel names")
assert_character(outname, .var.name = "output channel names", min.len = 1L)
assert_character(tags, null.ok = TRUE)
assert_character(packages, any.missing = FALSE)

packages = union(packages, "torch")
data_type = if (variant_types) "*" else "ModelDescriptor"
input = data.table(name = inname, train = data_type, predict = "*")
output = data.table(name = outname, train = data_type, predict = "*")
input = data.table(name = inname, train = "ModelDescriptor", predict = "*")
output = data.table(name = outname, train = "ModelDescriptor", predict = "*")

assert_r6(param_set, "ParamSet")
#walk(param_set$params, function(p) {
Expand Down Expand Up @@ -296,9 +295,15 @@ PipeOpTorch = R6Class("PipeOpTorch",
} else {
assert_list(shapes_in, len = nrow(self$input), types = "numeric")
}
pv = self$param_set$get_values()

set_names(private$.shapes_out(shapes_in, pv, task = task), self$output$name)
s = if (is.null(private$.shapes_out)) {
shapes_in
} else {
pv = self$param_set$get_values()
private$.shapes_out(shapes_in, pv, task = task, self$output$name)
}

set_names(s, self$output$name)
}

# TODO: printer that calls the nn_module's printer
Expand Down
87 changes: 87 additions & 0 deletions R/PipeOpTorchLazyPreproc.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#' @export
PipeOpTaskPreprocLazy = R6Class("PipeOpTaskPreprocLazy",
inherit = PipeOpTaskPreproc,
public = list(
initialize = function(fn, id = "lazy_preproc", param_vals = list(), param_set = ps(), packages = character(0)) {
private$.fn = assert_function(fn, null.ok = FALSE)

if ("augment" %in% param_set$ids()) {
stopf("Parameter name 'augment' is reserved and cannot be used.")
}
if (is.null(private$.shapes_out)) {
param_set$add(ps(
augment = p_lgl(tags = c("predict", "required"))
))
param_set$set_values(augment = FALSE)
}

super$initialize(
id = id,
param_vals = param_vals,
feature_types = "lazy_tensor"
)
}
),
private = list(
.train_dt = function(dt, levels, target) {
if (ncol(dt) != 1L) {
# Check only during train as this will ensure it also holds during predict
stop("Can only use PipeOpTorchLazyTrafo on tasks with exactly one lazy tensor column.")
}
param_vals = self$param_set$get_values(tags = "train")

trafo = private$.fn

fn = if (length(param_vals)) {
crate(function(x) {
invoke(.f = trafo, x, .args = param_vals)
}, param_vals, trafo, .parent = topenv())
} else {
trafo
}

self$state = list()

private$.transform_dt(dt, fn)
},
.predict_dt = function(dt, levels) {
param_vals = self$param_set$get_values(tags = "predict")
augment = param_vals$augment
param_vals$augment = NULL

# augment can be NULL (in case the pipeopf changes the output shape and hence augmentation is not supported)
# or a logical(1)
fn = if (isTRUE(augment)) {
identity
} else {
trafo = private$.fn
crate(function(x) {
invoke(.f = trafo, x, .args = param_vals)
}, param_vals, trafo, .parent = topenv())
}

private$.transform_dt(dt, fn)
},
.transform_dt = function(dt, fn) {
po_fn = PipeOpModule$new(
id = self$id,
module = fn,
inname = self$input$name,
outname = self$output$name,
packages = self$packages
)

lt = dt[[1L]]

shapes_in = attr(lt[[1L]], "data_descriptor")$.pointer_shape

shapes_out = self$shapes_out(shapes_in, intask)
dt[[1L]] = transform_lazy_tensor(lt, po_fn, shapes_out[[1L]])
return(dt)
},
.fn = NULL,
.additional_phash_input = function() {
list(self$param_set$ids(), private$.fn, self$packages)
}
)
)
142 changes: 142 additions & 0 deletions R/PipeOpTorchLazyTrafo.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#' @title Base Class for Lazy Transformations
#' @name mlr_pipeops_torch_lazy_trafo
#'
#' @description
#' This `PipeOp` represents simple preprocessing transformations of torch tensors.
#' These can be used in two situations:
#'
#' 1. To preprocess a task, which works analogous to standard preprocessing PipeOps like [`PipeOpPCA`].
#' Because the [`lazy_tensor()`] does not make any assumptions on how the data is stored, the transformation is
#' applied lazily, i.e. when [`materialize()`] is called.
#' During trainig of a learner, this transformation will be a applied during data-loading on the CPU.
#'
#' 2. To add a preprocessing step in an [`nn_graph()`] that is being built up in a [`ModelDescriptor`].
#' In this case, the transformation is applied during the forward pass of the model, i.e. the tensor is then
#' also on the specified device.
#'
#' Currently the `PipeOp` must have exactly one inut and one output.
#'
#' @section Inheriting:
#' You need to:
#' * Initialize the `fn` argument. This function should take one torch tensor as input and return a torch tensor.
#' Additional parameters that are passed to the function can be specified via the parameter set.
#' This function needs to be a simple, stateless function, see section *Internals* for more information.
#' * In case the transformation changes the tensor shape you must provide a private `.shapes_out()` method like
#' for [`PipeOpTorch`].
#'
#' @section Input and Output Channels:
#' During *training*, all inputs and outputs are of class [`Task`] or [`ModelDescriptor`].
#' During *prediction*, all inputs and outputs are of class [`Task`] or [`ModelDescriptor`].
#'
#' @template pipeop_torch_state_default
#' @section Parameters:
#' * `augment` :: `logical(1)`\cr
#' This parameter is only present when the `PipeOp` does not modify the input shape.
#' Whether the transformation is applied only during training (`TRUE`) or also during prediction (also includes
#' validation; `FALSE`).
#' This parameter is initalized to `FALSE`.
#'
#' Additional parameters can be specified by the class.
#'
#' @section Internals:
#'
#' Applied to a **Task**:
#'
#' When this PipeOp is used for preprocessing, it creates a [`PipeOpModule`] from the function `fn` (additionally
#' passing the `param_vals` if there are any) and then adds it to the preprocessing graph that is part of the
#' [`DataDescriptor`] contained in the [`lazy_tensor`] column that is being preprocessed.
#' When the outpuf of this pipeop is then preprocessed by a different `PipeOpTorchLazyTrafo` a deep clone of the
#' preprocessing graph is done. However, this deep clone does not clone the environment of the
#' function or its attributes in case they have a state (as e.g. in [`nn_module()`]s).
#' When setting the parameter `augment` this meanst that the preprcessing
#'
#' Applied to a **ModelDescriptor**
#'
#'
#'
#'
#' @template param_id
#' @template param_param_vals
#' @template param_param_set
#' @param packages (`character()`)\cr
#' The packages the function depends on.
#' @param fn (`function()`)\cr
#' A function that will be applied to a (lazy) tensor.
#' Additional arguments can be passed as parameters.
#' During actual preprocessing the (lazy) tensor will be passed by position.
#' The transformation is always applied to a whole batch of tensors, i.e. the first dimension is the batch dimension.
#'
#' @export
PipeOpTorchLazyTrafo = R6Class("PipeOpTorchLazyTrafo",
inherit = PipeOpTorch,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(fn, id = "lazy_trafo", param_vals = list(), param_set = ps(), packages = character(0)) {
private$.fn = assert_function(fn, null.ok = FALSE)

if ("augment" %in% param_set$ids()) {
stopf("Parameter name 'augment' is reserved and cannot be used.")
}

if (is.null(private$.shapes_out)) {
param_set$add(ps(
augment = p_lgl(tags = c("train", "required"))
))
param_set$set_values(augment = FALSE)
}

super$initialize(
id = id,
inname = "input",
outname = "output",
param_vals = param_vals,
param_set = param_set,
packages = packages,
module_generator = NULL
)

}
),
private = list(
.fn = NULL,
.make_module = function(shapes_in, param_vals, task) {
# this function is only called when the input is the ModelDescriptor
augment = param_vals$augment
param_vals$augment = NULL
trafo = private$.fn
fn = if (length(param_vals)) {
crate(function(x) {
invoke(.f = trafo, x, .args = param_vals)
}, param_vals, trafo, .parent = topenv())
} else {
trafo
}

# augment can be NULL or logical(1)
if (!isTRUE(augment)) {
return(fn)
}

nn_module(self$id,
initialize = function(fn) {
self$fn = fn
},
forward = function(x) {
if (self$training) {
self$fn(x)
} else {
x
}
}
)(fn)
},
.additional_phash_input = function() {
list(self$param_set$ids(), private$.fn, self$packages)
}
)
)

#' @include zzz.R
register_po("lazy_trafo", PipeOpTorchLazyTrafo)


Loading

0 comments on commit ca0c6c9

Please sign in to comment.