From 47ae6e0276330f2f344704585becc32c8576fd1b Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 1 Dec 2023 17:52:46 +0100 Subject: [PATCH] WIP temporary insertion command --- nncf/torch/dynamic_graph/context.py | 33 ++++++++++---- nncf/torch/graph/transformations/commands.py | 29 ++++++++++++ nncf/torch/model_transformer.py | 34 ++++++++++++-- nncf/torch/nncf_network.py | 48 ++++++++++++++++++-- nncf/torch/statistics/aggregator.py | 10 ++-- tests/torch/helpers.py | 2 +- 6 files changed, 133 insertions(+), 23 deletions(-) diff --git a/nncf/torch/dynamic_graph/context.py b/nncf/torch/dynamic_graph/context.py index f279fea49c9..a5d41d4c272 100644 --- a/nncf/torch/dynamic_graph/context.py +++ b/nncf/torch/dynamic_graph/context.py @@ -11,10 +11,11 @@ import threading import weakref +from collections import OrderedDict from collections import defaultdict from collections import deque from contextlib import contextmanager -from typing import Callable, DefaultDict, List, Optional +from typing import Callable, Dict, List, Optional import torch @@ -92,9 +93,9 @@ class TracingContext: def __init__(self): self.graph = DynamicGraph() - self._post_hooks: DefaultDict[OperationAddress, List[Callable]] = defaultdict(list) - self._pre_hooks: DefaultDict[PreHookId, List[Callable]] = defaultdict(list) - self._num_nested_hooks = 0 + self._post_hooks = defaultdict(OrderedDict) + self._pre_hooks: Dict[PreHookId, List[Callable]] = defaultdict(OrderedDict) + self._hooks_counter = 0 self._threading = CopySafeThreadingVars() @@ -260,9 +261,16 @@ def pop_scope(self): self.relative_scopes_stack.pop() self.module_call_stack.pop() - def register_pre_hooks(self, fn_list: List[Callable], op_address: OperationAddress, input_port_id: int): + def register_pre_hooks( + self, fn_list: List[Callable], op_address: OperationAddress, input_port_id: int + ) -> List[int]: pre_hook_id = PreHookId(op_address, input_port_id) - self._pre_hooks[pre_hook_id].extend(fn_list) + hooks_ids = [] + for fn in fn_list: + self._hooks_counter += 1 + self._pre_hooks[pre_hook_id][self._hooks_counter] = fn + hooks_ids.append(self._hooks_counter) + return pre_hook_id, hooks_ids def execute_pre_hooks(self, op_address: OperationAddress, op_inputs: OperatorInput) -> OperatorInput: in_op = getattr(self, "in_operator", False) @@ -274,21 +282,26 @@ def execute_pre_hooks(self, op_address: OperationAddress, op_inputs: OperatorInp for pre_hook_id in pre_hook_ids_for_curr_op: hook_list_for_current_input_port = self._pre_hooks[pre_hook_id] input_arg_to_process = pre_hook_id.input_port_id - for hook in hook_list_for_current_input_port: + for hook in hook_list_for_current_input_port.values(): op_inputs[input_arg_to_process] = hook(op_inputs[input_arg_to_process]) self._threading.thread_local.num_nested_hooks -= 1 self.in_operator = in_op return op_inputs - def register_post_hooks(self, fn_list: List[Callable], op_address: OperationAddress): - self._post_hooks[op_address].extend(fn_list) + def register_post_hooks(self, fn_list: List[Callable], op_address: OperationAddress) -> List[int]: + hooks_ids = [] + for fn in fn_list: + self._hooks_counter += 1 + self._post_hooks[op_address][self._hooks_counter] = fn + hooks_ids.append(self._hooks_counter) + return op_address, hooks_ids def execute_post_hooks(self, op_address: OperationAddress, outputs): in_op = getattr(self, "in_operator", False) self.in_operator = False self._threading.thread_local.num_nested_hooks += 1 if op_address in self._post_hooks: - for hook in self._post_hooks[op_address]: + for hook in self._post_hooks[op_address].values(): outputs = hook(outputs) self._threading.thread_local.num_nested_hooks -= 1 self.in_operator = in_op diff --git a/nncf/torch/graph/transformations/commands.py b/nncf/torch/graph/transformations/commands.py index b74804e417c..3d59b2e16eb 100644 --- a/nncf/torch/graph/transformations/commands.py +++ b/nncf/torch/graph/transformations/commands.py @@ -153,6 +153,35 @@ def requires_graph_rebuild(self): return self.priority == TransformationPriority.QUANTIZATION_PRIORITY +class PTInsertionTemporaryCommand(PTTransformationCommand): + """ + Insertion operation to the models. + """ + + def __init__( + self, + point: PTTargetPoint, + fn: Callable, + priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY, + ): + super().__init__(TransformationType.INSERT, point) + self.fn: Callable = fn + self.priority: TransformationPriority = priority + + def union(self, other: "PTTransformationCommand") -> "PTTransformationCommand": + # TODO: keep all TransformationCommands atomic, refactor TransformationLayout instead + raise NotImplementedError() + + def requires_graph_rebuild(self): + """ + Return boolean flag to rebuild graph of model. + + :return: Boolean flag. + """ + # Rebuild graph when adding quantization nodes or an op. + return False + + class PTSharedFnInsertionCommand(PTTransformationCommand): def __init__( self, diff --git a/nncf/torch/model_transformer.py b/nncf/torch/model_transformer.py index 1fbb722aa28..92899e21a37 100644 --- a/nncf/torch/model_transformer.py +++ b/nncf/torch/model_transformer.py @@ -11,7 +11,7 @@ import copy from collections import defaultdict -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple from torch import Tensor from torch import nn @@ -24,6 +24,7 @@ from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTInsertionTemporaryCommand from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand @@ -52,6 +53,7 @@ def __init__(self, model: NNCFNetwork): self._command_transformation_ordered_pairs = [ (PTModelExtractionWithFusedBiasCommand, self._apply_extraction_with_fused_bias_transformations), (PTInsertionCommand, self._apply_insertion_transformations), + (PTInsertionTemporaryCommand, self._apply_temporary_insertion_transformation), (PTQuantizerInsertionCommand, self._apply_quantizer_insertion_transformations), (PTBiasCorrectionCommand, self._apply_bias_correction_transformations), (PTSharedFnInsertionCommand, self._apply_shared_nodes_insertion), @@ -82,6 +84,33 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P """ Applies insertion transformations to the model. + :param model: Model to apply transformations. + :param transformations: List of the bias correction transformations. + """ + for insert_args in PTModelTransformer._get_nncf_network_insert_arguments(model, transformations): + model.nncf.insert_at_point(*insert_args) + return model + + @staticmethod + def _apply_temporary_insertion_transformation( + model: NNCFNetwork, transformations: List[PTInsertionCommand] + ) -> NNCFNetwork: + """ + Applies temporary insertion transformations to the model. + + :param model: Model to apply transformations. + :param transformations: List of the bias correction transformations. + """ + for insert_args in PTModelTransformer._get_nncf_network_insert_arguments(model, transformations): + model.nncf.temporary_insert_at_point(*insert_args) + return model + + def _get_nncf_network_insert_arguments( + model: NNCFNetwork, transformations: List[PTInsertionCommand] + ) -> Iterator[Tuple[PTInsertionPoint, List[Callable]]]: + """ + Applies insertion transformations to the model. + :param model: Model to apply transformations. :param transformations: List of the bias correction transformations. """ @@ -107,8 +136,7 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P for pt_ip, fn_list_with_priority in fns_grouped_by_points.items(): fn_list_with_priority = sorted(fn_list_with_priority, key=lambda x: x[1]) - model.nncf.insert_at_point(pt_ip, [x[0] for x in fn_list_with_priority]) - return model + yield (pt_ip, [x[0] for x in fn_list_with_priority]) @staticmethod def _apply_shared_nodes_insertion( diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index 48f45240bef..497a9697c8d 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from enum import Enum from enum import IntEnum -from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar import torch from torch import nn @@ -237,6 +237,7 @@ def __init__( self._target_scopes = target_scopes self._user_dummy_forward_fn = dummy_forward_fn self._kd_loss_handler = None + self._temprorary_hooks_adresses = [] if wrap_inputs_fn is not None: self._wrap_inputs_fn = wrap_inputs_fn @@ -409,11 +410,36 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc def update_model_ref(self, model: torch.nn.Module) -> None: object.__setattr__(self, "__model_ref", model) + def temporary_insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]): + hook_addresses = self.insert_at_point(point, fn_list) + self._temprorary_hooks_adresses.append(hook_addresses) + return hook_addresses + + def remove_temporary_ops(self): + for point, hook_address, hook_ids in self._temprorary_hooks_adresses: + for hook_idx in hook_ids: + if point.insertion_type == PTInsertionType.OPERATOR_PRE_HOOK: + hooks = self._compressed_context._pre_hooks[hook_address] + elif point.insertion_type == PTInsertionType.OPERATOR_POST_HOOK: + hooks = self._compressed_context._post_hooks[hook_address] + else: + nncf_module = self.get_module_by_scope(point.module_scope) + if point.insertion_type == PTInsertionType.NNCF_MODULE_PRE_OP: + hooks = nncf_module.pre_ops + else: + hooks = nncf_module.post_ops + + del hooks[hook_idx] + self._temprorary_hooks_adresses.clear() + def insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]): + hooks_ids = None if point.insertion_type == PTInsertionType.OPERATOR_PRE_HOOK: - self._compressed_context.register_pre_hooks(fn_list, point.op_address, point.input_port_id) + hook_address, hooks_ids = self._compressed_context.register_pre_hooks( + fn_list, point.op_address, point.input_port_id + ) elif point.insertion_type == PTInsertionType.OPERATOR_POST_HOOK: - self._compressed_context.register_post_hooks(fn_list, point.op_address) + hook_address, hooks_ids = self._compressed_context.register_post_hooks(fn_list, point.op_address) elif point.insertion_type in [PTInsertionType.NNCF_MODULE_PRE_OP, PTInsertionType.NNCF_MODULE_POST_OP]: nncf_module = self.get_module_by_scope(point.module_scope) if not isinstance(nncf_module, _NNCFModuleMixin): @@ -431,14 +457,18 @@ def insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]): for scope_list_for_module in self.get_nncf_module_scopes(): norm_nncf_scopes.extend([self._normalize_variable_recurrent_scope(x) for x in scope_list_for_module]) assert norm_target_scope in norm_nncf_scopes # Required for proper Recurrent/VariableRecurrent addressing + hooks_ids = [] if point.insertion_type == PTInsertionType.NNCF_MODULE_PRE_OP: for fn in fn_list: - nncf_module.register_pre_forward_operation(fn) + hooks_ids.append(nncf_module.register_pre_forward_operation(fn)) elif point.insertion_type == PTInsertionType.NNCF_MODULE_POST_OP: for fn in fn_list: - nncf_module.register_post_forward_operation(fn) + hooks_ids.append(nncf_module.register_post_forward_operation(fn)) + hook_address = None else: raise RuntimeError("Unsupported insertion type: {}".format(point.insertion_type)) + hook_addresses = (point, hook_address, hooks_ids) + return hook_addresses def get_graph(self) -> PTNNCFGraph: if self._compressed_context.graph.get_nodes_count() == 0 or self._compressed_graphs_pair.nncf_graph is None: @@ -793,6 +823,14 @@ def strip(self, do_copy: bool = True) -> "NNCFNetwork": return self.compression_controller.strip(do_copy) +class TemporaryOp: + def __init__(self, op: Callable) -> None: + self._op = op + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self._op(*args, **kwargs) + + class NNCFNetworkMeta(type): """ Metaclass for the NNCFNetwork mixin. Has magic methods defined so that the original model object could be diff --git a/nncf/torch/statistics/aggregator.py b/nncf/torch/statistics/aggregator.py index e369e57345c..fa49cf3cca3 100644 --- a/nncf/torch/statistics/aggregator.py +++ b/nncf/torch/statistics/aggregator.py @@ -21,7 +21,7 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer from nncf.common.tensor_statistics.aggregator import StatisticsAggregator -from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTInsertionTemporaryCommand from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.tensor import PTNNCFTensor from nncf.torch.tensor_statistics.algo import create_register_input_hook @@ -60,8 +60,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): class PTStatisticsAggregator(StatisticsAggregator): def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: with torch.no_grad(): - with ModelView(model) as intermediate_model: - super().collect_statistics(intermediate_model, graph) + # with ModelView(model) as intermediate_model: + # super().collect_statistics(intermediate_model, graph) + super().collect_statistics(model, graph) + model.nncf.remove_temporary_ops() def _register_statistics( self, outputs: Dict[str, PTNNCFTensor], statistic_points: StatisticPointsContainer @@ -79,7 +81,7 @@ def _get_transformation_layout_extra_outputs( for collectors in _statistic_point.algorithm_to_tensor_collectors.values(): for collector in collectors: transformation_commands.append( - PTInsertionCommand( + PTInsertionTemporaryCommand( _statistic_point.target_point, create_register_input_hook(collector=collector), TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index d5c8751bc6c..e15f7671627 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -567,7 +567,7 @@ def _check_weight_update_hooks(ref_hooks): def _check_pre_post_hooks(hooks, ref_hooks): assert len(hooks) == len(ref_hooks) for op_address, ref_hooks in ref_hooks.items(): - actual_hooks = hooks[op_address] + actual_hooks = hooks[op_address].values() assert len(actual_hooks) == len(ref_hooks) for actual_hook, ref_hook in zip(actual_hooks, ref_hooks): assert actual_hook is ref_hook