Skip to content

Commit

Permalink
Embedded hooks test
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Dec 15, 2023
1 parent b3d2cf1 commit 56bb77e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 62 deletions.
2 changes: 1 addition & 1 deletion nncf/torch/dynamic_graph/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def wrap_operator(operator, operator_info: PatchedOperatorInfo):
Wraps the input callable object (`operator`) with the functionality that allows the calls to this object
to be tracked by the currently set global TracingContext. The wrapped functions can be then intercepted,
their arguments and return values modified arbitrarily and, for functions that correspond to operations on
tensors in a DNN, their general position and address in the DNN's model control flow graph can be established.
tensors in a DNN, their general position and address in the DNN's model control flow graph can be established.
:param: operator: A callable object to be wrapped.
:param: operator_info (PatchedOperatorInfo): An informational struct containing the specifics of wrapping
Expand Down
17 changes: 0 additions & 17 deletions tests/torch/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,23 +690,6 @@ def test_successive_insertion_transformation(target_type, node_name, input_port_
transformed_model = model_transformer.transform(transformation_layout)
transformed_model.nncf.rebuild_graph()

if target_type == TargetType.OPERATION_WITH_WEIGHTS:
pre_ops = transformed_model.conv1.pre_ops
assert len(pre_ops) == 2
for module, op_ref in zip(pre_ops._modules.values(), ops):
assert isinstance(module, UpdateWeight)
assert module.op is op_ref
else:
if target_type == TargetType.OPERATOR_POST_HOOK:
hooks = transformed_model.nncf._compressed_context._post_hooks
else:
hooks = transformed_model.nncf._compressed_context._pre_hooks
assert len(hooks) == 1
_, hook_ops = hooks.popitem()
assert len(hook_ops) == 2
for hook_op, op in zip(hook_ops, ops):
assert hook_op is op

checker = HookChecker(transformed_model, "conv1")
checker.add_ref(
ref_hooks=ops,
Expand Down
101 changes: 57 additions & 44 deletions tests/torch/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from nncf.quantization.range_estimator import RangeEstimatorParametersSet
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.model_transformer import PTInsertionCommand
from nncf.torch.module_operations import UpdateWeight
from nncf.torch.statistics.aggregator import PTStatisticsAggregator
from tests.common.test_statistics_aggregator import TemplateTestStatisticsAggregator
from tests.torch.helpers import HookChecker
from tests.torch.ptq.helpers import get_nncf_network
from tests.torch.ptq.test_ptq_params import ToNNCFNetworkInterface

Expand Down Expand Up @@ -186,19 +186,37 @@ def test_successive_statistics_aggregation(
pytest.skip("Custom estimators are not supported for this backend yet")

### Register operations before statistic collection
def fn(*args, **kwargs):
return args[0] * 2
def fn(x):
return x * 2

layout = TransformationLayout()
for target_point in [test_parameters.target_type]:
target_point = self.get_target_point(target_point)
command = PTInsertionCommand(target_point, fn)
layout.register(command)
target_point = self.get_target_point(test_parameters.target_type)
command = PTInsertionCommand(target_point, fn)
layout.register(command)
model_transformer = factory.ModelTransformerFactory.create(model)
model = model_transformer.transform(layout)
model.nncf.rebuild_graph()

### Check hook inserted correctly
self.__check_hooks(test_parameters, model, target_point, fn)

### Register and collect statistics after inserted operations
tensor_collector = self.__collect_statistics_get_collector(
test_parameters, model, quantizer_config, dataset_samples, inplace_statistics
)
### Check values are changed because of the inserted operation
self.__check_collector(
test_parameters,
tensor_collector,
is_stat_in_shape_of_scale,
)

### Check the inserted operation is inside the model
self.__check_hooks(test_parameters, model, target_point, fn)

def __collect_statistics_get_collector(
self, test_parameters: MinMaxTestParameters, model, quantizer_config, dataset_samples, inplace_statistics
):
statistics_points = StatisticPointsContainer()
for target_point in [test_parameters.target_type]:
target_point = self.get_target_point(target_point)
Expand Down Expand Up @@ -228,43 +246,38 @@ def filter_func(point):
tensor_collectors = list(
statistics_points.get_algo_statistics_for_node(target_point.target_node_name, filter_func, algorithm_name)
)

### Check values are changed because of the inserted operation
assert len(tensor_collectors) == 1
for tensor_collector in tensor_collectors:
stat = tensor_collector.get_statistics()
# Torch and Openvino backends tensor collectors return values in shape of scale
# in comparison to ONNX backends.
ref_min_val, ref_max_val = test_parameters.ref_min_val, test_parameters.ref_max_val
if isinstance(ref_min_val, np.ndarray) and is_stat_in_shape_of_scale:
shape = (1, 3, 1, 1)
if test_parameters.target_type == TargetType.OPERATION_WITH_WEIGHTS:
shape = (3, 1, 1, 1)
ref_min_val, ref_max_val = map(lambda x: np.reshape(x, shape), (ref_min_val, ref_max_val))

assert np.allclose(stat.min_values, ref_min_val)
assert np.allclose(stat.max_values, ref_max_val)
if isinstance(ref_min_val, np.ndarray):
assert stat.min_values.shape == ref_min_val.shape
assert stat.max_values.shape == ref_max_val.shape
else:
ref_shape = (1, 1, 1, 1) if is_stat_in_shape_of_scale else ()
assert stat.min_values.shape == ref_shape
assert stat.max_values.shape == ref_shape
return tensor_collectors[0]

### Check the inserted operation is inside the model
if test_parameters.target_type == TargetType.OPERATION_WITH_WEIGHTS:
pre_ops = model.conv.pre_ops
assert len(pre_ops) == 1
for module in pre_ops.values():
assert isinstance(module, UpdateWeight)
assert module.op is fn
@staticmethod
def __check_collector(test_parameters, tensor_collector, stat_in_shape_of_scale):
stat = tensor_collector.get_statistics()
# Torch and Openvino backends tensor collectors return values in shape of scale
# in comparison to ONNX backends.
ref_min_val, ref_max_val = test_parameters.ref_min_val, test_parameters.ref_max_val
if isinstance(ref_min_val, np.ndarray) and stat_in_shape_of_scale:
shape = (1, 3, 1, 1)
if test_parameters.target_type == TargetType.OPERATION_WITH_WEIGHTS:
shape = (3, 1, 1, 1)
ref_min_val, ref_max_val = map(lambda x: np.reshape(x, shape), (ref_min_val, ref_max_val))

assert np.allclose(stat.min_values, ref_min_val)
assert np.allclose(stat.max_values, ref_max_val)
if isinstance(ref_min_val, np.ndarray):
assert stat.min_values.shape == ref_min_val.shape
assert stat.max_values.shape == ref_max_val.shape
else:
if test_parameters.target_type == TargetType.OPERATOR_POST_HOOK:
hooks = model.nncf._compressed_context._post_hooks
else:
hooks = model.nncf._compressed_context._pre_hooks
assert len(hooks) == 1
_, hook_ops = hooks.popitem()
assert len(hook_ops) == 1
assert hook_ops[0] is fn
ref_shape = (1, 1, 1, 1) if stat_in_shape_of_scale else ()
assert stat.min_values.shape == ref_shape
assert stat.max_values.shape == ref_shape

@staticmethod
def __check_hooks(test_parameters, model, target_point, fn):
checker = HookChecker(model, "conv")
checker.add_ref(
ref_hooks=[fn],
target_type=test_parameters.target_type,
target_node_name=target_point.target_node_name,
input_port_id=0,
)
checker.check_with_reference()

0 comments on commit 56bb77e

Please sign in to comment.