Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 11, 2023
1 parent f97299f commit e9aaa44
Show file tree
Hide file tree
Showing 13 changed files with 31 additions and 14 deletions.
2 changes: 1 addition & 1 deletion nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def sub(cls, a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:

@classmethod
@abstractmethod
def non_zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
def zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
""" """


Expand Down
2 changes: 1 addition & 1 deletion nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:
def _aggregate_impl(self) -> Any:
stacked_val = self._tensor_processor.stack(self._container)
median_fn = partial(self._tensor_processor.masked_median, axis=self._aggregation_axes, keepdims=True)
filter_fn = self._tensor_processor.non_zero_elements
filter_fn = self._tensor_processor.zero_elements
median_per_ch = self._tensor_processor.masked_map(stacked_val, median_fn, filter_fn)

mad_values = self._tensor_processor.median(
Expand Down
2 changes: 1 addition & 1 deletion nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def sub(cls, a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:
raise NotImplementedError()

@classmethod
def non_zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
def zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
raise NotImplementedError()


Expand Down
8 changes: 4 additions & 4 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def quantile(

@classmethod
def masked_map(cls, x: NNCFTensor, fn: MaskedReduceFN, filter_fn) -> NNCFTensor:
raise NotImplemented()
raise NotImplementedError()

@classmethod
def sub(cls, a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:
raise NotImplemented()
raise NotImplementedError()

@classmethod
def non_zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
raise NotImplemented()
def zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
raise NotImplementedError()


class OVNoopReducer(NoopReducer):
Expand Down
12 changes: 11 additions & 1 deletion nncf/tensorflow/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ def quantile(
) -> List[NNCFTensor]:
raise NotImplementedError()

@classmethod
def precentile(
cls,
tensor: NNCFTensor,
precentile: Union[float, List[float]],
axis: Union[int, tuple, list],
keepdims: bool = False,
) -> List[TensorElementsType]:
raise NotImplementedError()

@staticmethod
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
raise NotImplementedError()
Expand All @@ -119,7 +129,7 @@ def sub(cls, a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:
raise NotImplementedError()

@classmethod
def non_zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
def zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
raise NotImplementedError()


Expand Down
4 changes: 2 additions & 2 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ def sub(cls, a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:
return NNCFTensor(a.tensor - b.tensor)

@classmethod
def non_zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
def zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
pt_tensor = x.tensor
eps = torch.finfo(pt_tensor.dtype).eps
return NNCFTensor(pt_tensor.abs() > eps)
return NNCFTensor(pt_tensor.abs() < eps)


class PTReducerMixIn:
Expand Down
3 changes: 2 additions & 1 deletion tests/experimental/common/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
from nncf.experimental.common.tensor_statistics.collectors import TensorType


# pylint: disable=(protected-access)


class DummyTensorReducer(TensorReducerBase):
def __init__(self, output_name: str, inplace: bool = False, inplace_mock=None):
super().__init__(inplace=inplace)
Expand Down
3 changes: 3 additions & 0 deletions tests/openvino/native/quantization/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
ParamsCls = TemplateTestQuantizerConfig.TestGetStatisticsCollectorParameters


# pylint: disable=protected-access


class TestQuantizerConfig(TemplateTestQuantizerConfig):
def get_algo_backend(self):
return OVMinMaxAlgoBackend()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_calculate_quantizer_parameters(self, case_to_test):
max_values = np.amax(data, axis=axes, keepdims=q_config.per_channel)

statistics = self.tensor_statistic(
{MinMaxTensorStatistic.MIN_STAT: max_values, MinMaxTensorStatistic.MAX_STAT: min_values}
{MinMaxTensorStatistic.MIN_STAT: min_values, MinMaxTensorStatistic.MAX_STAT: max_values}
)

if not case_to_test.should_fail:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def test_num_samples(self, collector_for_num_samples_test: OfflineTensorStatisti


class TestCollectorTensorProcessor:
tensor_processor = TFNNCFCollectorTensorProcessor()
tensor_processor = TFNNCFCollectorTensorProcessor

def test_unstack(self):
# Unstack tensor with dimensions
Expand Down
3 changes: 3 additions & 0 deletions tests/torch/ptq/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
ParamsCls = TemplateTestQuantizerConfig.TestGetStatisticsCollectorParameters


# pylint: disable=protected-access


class TestQuantizerConfig(TemplateTestQuantizerConfig):
def get_algo_backend(self):
return PTMinMaxAlgoBackend()
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/tensor_statistics/test_tensor_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_collected_statistics(
collector: Type[TensorStatisticCollectorBase],
reduction_shapes_vs_ref_statistic: Dict[ReductionShape, TensorStatistic],
):
for reduction_shape in reduction_shapes_vs_ref_statistic.keys():
for reduction_shape in reduction_shapes_vs_ref_statistic:
if len(reduction_shape) > 1:
reducer_axes = ([dim for dim, val in enumerate(reduction_shape) if val == 1][0],)
aggregator_keep_dims = False
Expand Down

0 comments on commit e9aaa44

Please sign in to comment.