Skip to content

Commit

Permalink
Embedded hooks test
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Dec 1, 2023
1 parent adfe87b commit 2abd131
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 70 deletions.
2 changes: 1 addition & 1 deletion nncf/torch/dynamic_graph/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def wrap_operator(operator, operator_info: PatchedOperatorInfo):
Wraps the input callable object (`operator`) with the functionality that allows the calls to this object
to be tracked by the currently set global TracingContext. The wrapped functions can be then intercepted,
their arguments and return values modified arbitrarily and, for functions that correspond to operations on
tensors in a DNN, their general position and address in the DNN's model control flow graph can be established.
tensors in a DNN, their general position and address in the DNN's model control flow graph can be established.
:param: operator: A callable object to be wrapped.
:param: operator_info (PatchedOperatorInfo): An informational struct containing the specifics of wrapping
Expand Down
20 changes: 11 additions & 9 deletions nncf/torch/return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Any, Optional, Tuple, Type, Union

import torch


def __get_supported_torch_return_types() -> Tuple[Type[object]]:
def __get_supported_torch_return_types() -> Tuple[Type[tuple], ...]:
"""
Collects types from torch.return_type which can be wrapped/unwrapped by nncf.
Collects types from torch.return_type which can be wrapped/unwrapped by NNCF.
NNCF can wrap/unwrap only return types that have two attributes, one of them
should be the `values` attribute.
:return: List of types from torch.return_type which can be wrapped/unwrapped by nncf.
:return: List of types from torch.return_type which can be wrapped/unwrapped by NNCF.
"""
return_type_names = [t for t in dir(torch.return_types) if not t.startswith("_") and not t.startswith("linalg")]
return_types = [getattr(torch.return_types, t_name) for t_name in return_type_names]
return_types = [t for t in return_types if hasattr(t, "values")]
return tuple(return_types)
return tuple(t for _, t in inspect.getmembers(torch.return_types) if inspect.isclass(t) and hasattr(t, "values"))


_TORCH_RETURN_TYPES = __get_supported_torch_return_types()


def maybe_unwrap_from_torch_return_type(tensor: Any) -> torch.Tensor:
"""
Attempts to unwrap the tensor value from one of torch.return_types instantces
Attempts to unwrap the tensor value from one of torch.return_types instances
in case torch operation output is wrapped by a torch return_type.
:param tensor: Torch tensor or torch return type instance to unwrap values from.
Expand All @@ -52,5 +52,7 @@ def maybe_wrap_to_torch_return_type(tensor: torch.Tensor, wrapped_input: Optiona
"""

if isinstance(wrapped_input, _TORCH_RETURN_TYPES):
return wrapped_input.__class__([tensor] + [arg for arg in wrapped_input[1:]])
# We assume that return_type has only two attributes, the first one is `value`.
# This assumption is checked by `test_unwrap_wrap_torch_return_type`.
return wrapped_input.__class__((tensor, wrapped_input[1]))
return tensor
68 changes: 68 additions & 0 deletions tests/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numbers
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Union
Expand All @@ -26,17 +27,20 @@
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from nncf.common.graph.transformations.commands import TargetType
from nncf.config import NNCFConfig
from nncf.config.extractors import extract_algorithm_names
from nncf.config.structures import BNAdaptationInitArgs
from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.dynamic_graph.context import PreHookId
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.dynamic_graph.scope import Scope
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.initialization import register_default_init_args
from nncf.torch.layers import NNCF_MODULES_MAP
from nncf.torch.model_creation import create_compressed_model
from nncf.torch.module_operations import UpdateWeight
from nncf.torch.nncf_module_replacement import get_original_module_scope_from_nncf_module_scope
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.utils import get_all_modules_by_type
Expand Down Expand Up @@ -497,3 +501,67 @@ def load_exported_onnx_version(
compression_ctrl.export_model(str(onnx_checkpoint_path), save_format=save_format)
model_proto = onnx.load_model(str(onnx_checkpoint_path))
return model_proto


class HookChecker:
"""
Class to check pre/post hooks and pre ops are placed correctly.
Suports check for one wrapped NNCFModule for now.
"""

def __init__(self, target_model: torch.nn.Module, nncf_module_attr_name: str):
"""
:param nncf_module_attr_name: name of the nncf module attribute name in target model.
"""
self._nncf_module_attr_name = nncf_module_attr_name
self._target_model = target_model
self._ref_hooks = defaultdict(dict)

def add_ref(
self,
ref_hooks: List[callable],
target_type: TargetType,
target_node_name: str,
input_port_id: int,
) -> None:
"""
Adds references hooks.
"""
op_address = self._convert_to_op_address(target_type, target_node_name, input_port_id)
self._ref_hooks[target_type].update({op_address: ref_hooks})

def _convert_to_op_address(self, target_type: TargetType, target_node_name: str, input_port_id: int) -> Any:
address_map = self._target_model.nncf.get_node_to_op_address_mapping()
address = address_map[target_node_name]
if target_type == TargetType.OPERATOR_PRE_HOOK:
address = PreHookId(address, input_port_id)
elif target_type == TargetType.OPERATION_WITH_WEIGHTS:
address = getattr(self._target_model, self._nncf_module_attr_name)
return address

def check_with_reference(self):
"""
Check hooks in the target model and reference hooks are matching.
"""
self._check_weight_update_hooks(self._ref_hooks[TargetType.OPERATION_WITH_WEIGHTS])
hooks = self._target_model.nncf._compressed_context._pre_hooks
self._check_pre_post_hooks(hooks, self._ref_hooks[TargetType.OPERATOR_PRE_HOOK])
hooks = self._target_model.nncf._compressed_context._post_hooks
self._check_pre_post_hooks(hooks, self._ref_hooks[TargetType.OPERATOR_POST_HOOK])

@staticmethod
def _check_weight_update_hooks(ref_hooks):
for target_module, ref_hooks_per_module in ref_hooks.items():
assert len(target_module.pre_ops) == len(ref_hooks_per_module)
for actual_op, ref_op in zip(target_module.pre_ops.values(), ref_hooks_per_module):
assert isinstance(actual_op, UpdateWeight)
assert actual_op.op is ref_op

@staticmethod
def _check_pre_post_hooks(hooks, ref_hooks):
assert len(hooks) == len(ref_hooks)
for op_address, ref_hooks in ref_hooks.items():
actual_hooks = hooks[op_address]
assert len(actual_hooks) == len(ref_hooks)
for actual_hook, ref_hook in zip(actual_hooks, ref_hooks):
assert actual_hook is ref_hook
Loading

0 comments on commit 2abd131

Please sign in to comment.