Skip to content

Commit

Permalink
Test finishing
Browse files Browse the repository at this point in the history
Cleanup
  • Loading branch information
daniil-lyakhov committed Dec 12, 2023
1 parent 1662996 commit f3ffaf3
Show file tree
Hide file tree
Showing 18 changed files with 285 additions and 150 deletions.
1 change: 0 additions & 1 deletion nncf/common/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class TransformationPriority(IntEnum):
FP32_TENSOR_STATISTICS_OBSERVATION = 1
PRUNING_PRIORITY = 2
SPARSIFICATION_PRIORITY = 3
OP_INSERTION_PRIORITY = 4
QUANTIZATION_PRIORITY = 11


Expand Down
5 changes: 1 addition & 4 deletions nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,7 @@ def stack(x: List[Tensor], axis: int = 0) -> Tensor:
:return: Stacked Tensor.
"""
if isinstance(x, List):
unwrapped_x = [i.data for i in x]
# singledispatch cannot dispatch function by element in a list
res = stack.dispatch(type(unwrapped_x[0]))(unwrapped_x, axis=axis)
return Tensor(res)
return Tensor(_dispatch_list(stack, x, axis=axis))
raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}")


Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/dynamic_graph/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self):

self._post_hooks = defaultdict(OrderedDict)
self._pre_hooks: Dict[PreHookId, List[Callable]] = defaultdict(OrderedDict)
self._hooks_counter = 0
self._hooks_counter = -1

self._threading = CopySafeThreadingVars()

Expand Down
10 changes: 5 additions & 5 deletions nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
Expand Down Expand Up @@ -60,9 +59,10 @@ def forward(self, x):
def multiply_insertion_command(
target_nodes: List[NNCFNode], scale_value: Tensor, scale_node_name: str, input_port_id: int
) -> PTInsertionCommand:
commands = []
target_points = []
for target_node in target_nodes:
target_point = PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id)
commands.append(PTInsertionCommand(target_point, None, priority=TransformationPriority.OP_INSERTION_PRIORITY))
target_points.append(
PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id)
)

return PTSharedFnInsertionCommand(commands, SQMultiply(scale_value), scale_node_name)
return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name)
32 changes: 1 addition & 31 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from copy import deepcopy
from enum import Enum
from enum import IntEnum
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar

import torch
from torch import nn
Expand Down Expand Up @@ -345,25 +345,6 @@ def reset_nncf_modules(self):
module = self.get_module_by_scope(some_scope)
module.reset()

def get_shallow_copy(self) -> "NNCFNetwork":
from nncf.torch.utils import load_module_state
from nncf.torch.utils import save_module_state

saved_state = save_module_state(self._model_ref)
new_interface = NNCFNetworkInterface(
self._model_ref,
self._input_infos,
self._user_dummy_forward_fn,
self._wrap_inputs_fn,
self._scopes_without_shape_matching,
self._ignored_scopes,
self._target_scopes,
wrap_outputs_fn=self._wrap_outputs_fn,
)
self._model_ref._nncf = new_interface
load_module_state(self._model_ref, saved_state)
return self._model_ref

def get_clean_shallow_copy(self) -> "NNCFNetwork":
# WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions
# and load_nncf_module_additions to preserve these, or temporary_clean_view().
Expand Down Expand Up @@ -395,9 +376,6 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc
retval[nncf_module_scope + relative_scope] = target_module
return retval

def update_model_ref(self, model: torch.nn.Module) -> None:
object.__setattr__(self, "__model_ref", model)

def temporary_insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]):
hook_addresses = self.insert_at_point(point, fn_list)
self._temprorary_hooks_adresses.append(hook_addresses)
Expand Down Expand Up @@ -802,14 +780,6 @@ def strip(self, do_copy: bool = True) -> "NNCFNetwork":
return self.compression_controller.strip(do_copy)


class TemporaryOp:
def __init__(self, op: Callable) -> None:
self._op = op

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._op(*args, **kwargs)


class NNCFNetworkMeta(type):
"""
Metaclass for the NNCFNetwork mixin. Has magic methods defined so that the original model object could be
Expand Down
1 change: 0 additions & 1 deletion nncf/torch/quantization/external_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from nncf.torch.quantization.debug_interface import QuantizationDebugInterface

EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers"
EXTERNAL_OP_STORAGE_NAME = "external_op"
EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME


Expand Down
33 changes: 0 additions & 33 deletions nncf/torch/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Dict

import numpy as np
Expand All @@ -27,41 +26,9 @@
from nncf.torch.tensor_statistics.algo import create_register_input_hook


class ModelView:
def __init__(self, model: NNCFNetwork):
self.model = model
self.nncf_module_additions = self.model.nncf.save_nncf_module_additions()

def __enter__(self):
# Model ref removed to prevent copying
self.model.nncf.update_model_ref(None)

# nncf_replaced_models removed to prevent copying
replaced_modules = self.model.nncf._nncf_replaced_modules
self.model.nncf._nncf_replaced_modules = None

self.nncf_interface = deepcopy(self.model.nncf)

# Model ref is recovering
self.model.nncf.update_model_ref(self.model)
self.nncf_interface.update_model_ref(self.model)

# nncf_replaced_models is recovering
self.model.nncf._nncf_replaced_modules = replaced_modules
self.nncf_interface._nncf_replaced_modules = replaced_modules
return self.model

def __exit__(self, exc_type, exc_val, exc_tb):
self.model._nncf = self.nncf_interface
self.model.nncf.reset_nncf_modules()
self.model.nncf.load_nncf_module_additions(self.nncf_module_additions)


class PTStatisticsAggregator(StatisticsAggregator):
def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None:
with torch.no_grad():
# with ModelView(model) as intermediate_model:
# super().collect_statistics(intermediate_model, graph)
super().collect_statistics(model, graph)
model.nncf.remove_temporary_ops()

Expand Down
8 changes: 4 additions & 4 deletions tests/common/experimental/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ class BadStatContainer:

class TemplateTestStatisticCollector:
@abstractmethod
def get_nncf_tensor_cls(self):
def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor:
pass

@abstractmethod
Expand Down Expand Up @@ -366,10 +366,10 @@ def test_empty_tensors_register(self, inplace, any_not_empty):
collector.register_statistic_branch("A", reducer, aggregator)
input_name = "input_name"
full_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: self.get_nncf_tensor_cls()(np.array([100]))}, [(hash(reducer), [input_name])]
{input_name: self.get_nncf_tensor(np.array([100]))}, [(hash(reducer), [input_name])]
)
empty_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: self.get_nncf_tensor_cls()(np.array([]))}, [(hash(reducer), [input_name])]
{input_name: self.get_nncf_tensor(np.array([]))}, [(hash(reducer), [input_name])]
)

stats = collector.get_statistics()
Expand All @@ -385,7 +385,7 @@ def test_empty_tensors_register(self, inplace, any_not_empty):
assert aggregator._collected_samples == 2
stats = collector.get_statistics()
assert len(stats) == 1
assert stats["A"] == self.get_nncf_tensor_cls()([100])
assert stats["A"] == self.get_nncf_tensor([100])
return

assert len(aggregator._container) == 0
Expand Down
9 changes: 6 additions & 3 deletions tests/common/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ class BCStatsCollectors(Enum):


class TemplateTestStatisticsAggregator:
@classmethod
@abstractmethod
def get_min_max_algo_backend_cls(self) -> Type[MinMaxAlgoBackend]:
def get_min_max_algo_backend_cls(cls) -> Type[MinMaxAlgoBackend]:
pass

@abstractmethod
Expand All @@ -73,6 +74,7 @@ def get_statistics_aggregator(self, dataset):
def get_dataset(self, samples):
pass

@staticmethod
@abstractmethod
def get_target_point(self, target_type: TargetType) -> TargetPoint:
pass
Expand Down Expand Up @@ -631,10 +633,11 @@ def filter_func(point):
assert ref.shape == val.shape
assert np.allclose(val, ref)

@classmethod
def create_statistics_point(
self, model, q_config, target_point, subset_size, algorithm_name, inplace_statistics, range_estimator
cls, model, q_config, target_point, subset_size, algorithm_name, inplace_statistics, range_estimator
):
algo_backend = self.get_min_max_algo_backend_cls()
algo_backend = cls.get_min_max_algo_backend_cls()
nncf_graph = NNCFGraphFactory.create(model)
tensor_collector = algo_backend.get_statistic_collector(
range_estimator,
Expand Down
6 changes: 4 additions & 2 deletions tests/onnx/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@


class TestStatisticsAggregator(TemplateTestStatisticsAggregator):
def get_min_max_algo_backend_cls(self) -> Type[ONNXMinMaxAlgoBackend]:
@classmethod
def get_min_max_algo_backend_cls(cls) -> Type[ONNXMinMaxAlgoBackend]:
return ONNXMinMaxAlgoBackend

def get_bias_correction_algo_backend_cls(self) -> Type[ONNXBiasCorrectionAlgoBackend]:
Expand Down Expand Up @@ -65,7 +66,8 @@ def transform_fn(data_item):

return Dataset(samples, transform_fn)

def get_target_point(self, target_type: TargetType):
@staticmethod
def get_target_point(target_type: TargetType):
target_node_name = IDENTITY_NODE_NAME
port_id = 0
if target_type == TargetType.OPERATION_WITH_WEIGHTS:
Expand Down
6 changes: 4 additions & 2 deletions tests/openvino/native/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

from typing import Type

import numpy as np
import pytest

from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic
from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic
Expand All @@ -26,8 +28,8 @@


class TestOVStatisticCollector(TemplateTestStatisticCollector):
def get_nncf_tensor_cls(self):
return OVNNCFTensor
def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor:
return OVNNCFTensor(value)

@pytest.fixture
def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]:
Expand Down
6 changes: 4 additions & 2 deletions tests/openvino/native/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def get_StatisticAgregatorTestModel(input_shape, kernel):


class TestStatisticsAggregator(TemplateTestStatisticsAggregator):
def get_min_max_algo_backend_cls(self) -> Type[OVMinMaxAlgoBackend]:
@classmethod
def get_min_max_algo_backend_cls(cls) -> Type[OVMinMaxAlgoBackend]:
return OVMinMaxAlgoBackend

def get_bias_correction_algo_backend_cls(self) -> Type[OVBiasCorrectionAlgoBackend]:
Expand Down Expand Up @@ -82,7 +83,8 @@ def get_target_point_cls(self):
def get_dataset(self, samples):
return Dataset(samples, lambda data: {INPUT_NAME: data})

def get_target_point(self, target_type: TargetType) -> TargetPoint:
@staticmethod
def get_target_point(target_type: TargetType) -> TargetPoint:
target_node_name = INPUT_NAME
port_id = 0
if target_type == TargetType.OPERATION_WITH_WEIGHTS:
Expand Down
6 changes: 6 additions & 0 deletions tests/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,12 @@ def check_with_reference(self):
hooks = self._target_model.nncf._compressed_context._post_hooks
self._check_pre_post_hooks(hooks, self._ref_hooks[TargetType.OPERATOR_POST_HOOK])

def clear(self):
"""
Removes all recorded references.
"""
self._ref_hooks.clear()

@staticmethod
def _check_weight_update_hooks(ref_hooks):
for target_module, ref_hooks_per_module in ref_hooks.items():
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/ptq/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_node_name_map(self) -> Dict[str, str]:
@staticmethod
def get_target_node_name(command: TransformationCommand):
if isinstance(command, PTSharedFnInsertionCommand):
return command.target_commands[0].target_point.target_node_name
return command.target_points[0].target_node_name
return command.target_point.target_node_name

@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions tests/torch/ptq/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

from typing import Type

import numpy as np
import pytest
import torch

from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic
from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic
Expand All @@ -26,9 +29,9 @@
from tests.common.experimental.test_statistic_collector import TemplateTestStatisticCollector


class TestOVStatisticCollector(TemplateTestStatisticCollector):
def get_nncf_tensor_cls(self):
return PTNNCFTensor
class TestPTStatisticCollector(TemplateTestStatisticCollector):
def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor:
return PTNNCFTensor(torch.tensor(value))

@pytest.fixture
def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]:
Expand Down
Loading

0 comments on commit f3ffaf3

Please sign in to comment.