Skip to content

Commit

Permalink
WIP embedded hooks test
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 29, 2023
1 parent adfe87b commit 6d7c1b3
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 10 deletions.
2 changes: 1 addition & 1 deletion nncf/torch/dynamic_graph/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def wrap_operator(operator, operator_info: PatchedOperatorInfo):
Wraps the input callable object (`operator`) with the functionality that allows the calls to this object
to be tracked by the currently set global TracingContext. The wrapped functions can be then intercepted,
their arguments and return values modified arbitrarily and, for functions that correspond to operations on
tensors in a DNN, their general position and address in the DNN's model control flow graph can be established.
tensors in a DNN, their general position and address in the DNN's model control flow graph can be established.
:param: operator: A callable object to be wrapped.
:param: operator_info (PatchedOperatorInfo): An informational struct containing the specifics of wrapping
Expand Down
20 changes: 11 additions & 9 deletions nncf/torch/return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Any, Optional, Tuple, Type, Union

import torch


def __get_supported_torch_return_types() -> Tuple[Type[object]]:
def __get_supported_torch_return_types() -> Tuple[Type[tuple], ...]:
"""
Collects types from torch.return_type which can be wrapped/unwrapped by nncf.
Collects types from torch.return_type which can be wrapped/unwrapped by NNCF.
NNCF can wrap/unwrap only return types that have two attributes, one of them
should be the `values` attribute.
:return: List of types from torch.return_type which can be wrapped/unwrapped by nncf.
:return: List of types from torch.return_type which can be wrapped/unwrapped by NNCF.
"""
return_type_names = [t for t in dir(torch.return_types) if not t.startswith("_") and not t.startswith("linalg")]
return_types = [getattr(torch.return_types, t_name) for t_name in return_type_names]
return_types = [t for t in return_types if hasattr(t, "values")]
return tuple(return_types)
return tuple(t for _, t in inspect.getmembers(torch.return_types) if inspect.isclass(t) and hasattr(t, "values"))


_TORCH_RETURN_TYPES = __get_supported_torch_return_types()


def maybe_unwrap_from_torch_return_type(tensor: Any) -> torch.Tensor:
"""
Attempts to unwrap the tensor value from one of torch.return_types instantces
Attempts to unwrap the tensor value from one of torch.return_types instances
in case torch operation output is wrapped by a torch return_type.
:param tensor: Torch tensor or torch return type instance to unwrap values from.
Expand All @@ -52,5 +52,7 @@ def maybe_wrap_to_torch_return_type(tensor: torch.Tensor, wrapped_input: Optiona
"""

if isinstance(wrapped_input, _TORCH_RETURN_TYPES):
return wrapped_input.__class__([tensor] + [arg for arg in wrapped_input[1:]])
# We assume that return_type has only two attributes, the first one is `value`.
# This assumption is checked by `test_unwrap_wrap_torch_return_type`.
return wrapped_input.__class__((tensor, wrapped_input[1]))
return tensor
167 changes: 167 additions & 0 deletions tests/torch/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from nncf.torch.dynamic_graph.io_handling import FillerInputElement
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.dynamic_graph.operation_address import OperationAddress
from nncf.torch.dynamic_graph.patch_pytorch import register_operator
from nncf.torch.graph.operator_metatypes import PTConv2dMetatype
from nncf.torch.graph.operator_metatypes import PTInputNoopMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
Expand All @@ -51,6 +52,7 @@
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.layout import PTTransformationLayout
from nncf.torch.layers import NNCFConv2d
from nncf.torch.layers import register_module
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.module_operations import BaseOp
from nncf.torch.module_operations import UpdateWeight
Expand Down Expand Up @@ -596,3 +598,168 @@ def test_successive_insertion_transformation(target_type, node_name, input_port_
assert len(hook_ops) == 2
for hook_op, op in zip(hook_ops, ops):
assert hook_op is op


GLOBAL_LIST = []


def get_dummy_op(op_id):
@register_operator()
def dummy_op(x):
GLOBAL_LIST.append(op_id)
return x

return dummy_op


@register_module()
class DummyModule(torch.nn.Module):
def __init__(self, module_id):
super().__init__()
self.weight = torch.nn.Parameter(torch.zeros((1,)))
self._module_id = module_id

def forward(self, x):
GLOBAL_LIST.append(self._module_id)
return x + self.weight


def get_model_to_test_nested_modules():
class TestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.op1 = get_dummy_op("op1")
self.m = DummyModule("DummyModule")
self.op2 = get_dummy_op("op2")

def forward(self, x):
x = self.op1(x)
x = self.m(x)
x = self.op2(x)

return TestModel()


@pytest.mark.parametrize(
"target_type, node_name, input_port_id, ref_hooks",
(
(
TargetType.OPERATOR_POST_HOOK,
"/nncf_model_input_0",
None,
(
[
"pre_hook_1",
"op1",
"DummyModule",
"op2",
],
[
"pre_hook_0",
"pre_hook_1",
"pre_hook_2",
"op1",
"DummyModule",
"op2",
],
),
),
(
TargetType.OPERATOR_PRE_HOOK,
"TestModel/dummy_op_1",
0,
(
[
"op1",
"DummyModule",
"pre_hook_1",
"op2",
],
[
"op1",
"DummyModule",
"pre_hook_0",
"pre_hook_1",
"pre_hook_2",
"op2",
],
),
),
(
TargetType.OPERATION_WITH_WEIGHTS,
"TestModel/NNCFUserDummyModule[m]/__add___0",
None,
(
[
"op1",
"pre_hook_1",
"DummyModule",
"op2",
],
[
"op1",
"pre_hook_0",
"pre_hook_1",
"pre_hook_2",
"DummyModule",
"op2",
],
),
),
),
)
@pytest.mark.parametrize("pre_hook_op", [DummyModule])
def test_nested_hooks(target_type, node_name, input_port_id, ref_hooks, pre_hook_op):
model = NNCFNetwork(get_model_to_test_nested_modules(), FillerInputInfo([FillerInputElement([10])]))

# Check test model is working as expected
GLOBAL_LIST.clear()
model.nncf.rebuild_graph()
assert GLOBAL_LIST == [
"op1",
"DummyModule",
"op2",
]

target_point = PTTargetPoint(target_type, node_name, input_port_id=input_port_id)
transformed_model = model

command = PTInsertionCommand(target_point, pre_hook_op("pre_hook_1"))

model_transformer = PTModelTransformer(transformed_model)
transformation_layout = PTTransformationLayout()
transformation_layout.register(command)
transformed_model = model_transformer.transform(transformation_layout)

GLOBAL_LIST.clear()
transformed_model.nncf.rebuild_graph()
assert GLOBAL_LIST == ref_hooks[0]

graph = transformed_model.nncf.get_graph()
target_node = graph.get_node_by_name(node_name)
if target_type == TargetType.OPERATOR_POST_HOOK:
target_node = graph.get_next_nodes(target_node)[0]
elif target_type == TargetType.OPERATOR_PRE_HOOK:
target_node = graph.get_previous_nodes(target_node)[0]
else:
target_node = graph.get_previous_nodes(target_node)[1]
transformation_layout = PTTransformationLayout()
for i, target_type_ in enumerate([TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATOR_POST_HOOK]):
target_point_on_hook = PTTargetPoint(target_type_, target_node.node_name, input_port_id=0)
transformation_layout.register(PTInsertionCommand(target_point_on_hook, pre_hook_op(f"pre_hook_{i * 2}")))
model_transformer = PTModelTransformer(transformed_model)
model_with_nested_hooks = model_transformer.transform(transformation_layout)

GLOBAL_LIST.clear()
model_with_nested_hooks.nncf.rebuild_graph()
assert GLOBAL_LIST == ref_hooks[1]

if isinstance(pre_hook_op, torch.nn.Module):
transformation_layout = PTTransformationLayout()
target_point_on_hook = PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS, target_node.node_name, input_port_id=0)
transformation_layout.register(PTInsertionCommand(target_point_on_hook, pre_hook_op("pre_hook_3")))
model_transformer = PTModelTransformer(model_with_nested_hooks)
model_with_nested_hooks = model_transformer.transform(transformation_layout)
GLOBAL_LIST.clear()
model_with_nested_hooks.nncf.rebuild_graph()
assert GLOBAL_LIST == ref_hooks[1]

0 comments on commit 6d7c1b3

Please sign in to comment.