Skip to content

Commit

Permalink
depthwise and non module conv metatypes are removed from torch sq/ co…
Browse files Browse the repository at this point in the history
…nv model test case
  • Loading branch information
daniil-lyakhov committed Jan 24, 2024
1 parent 546ef06 commit ad5e134
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 27 deletions.
2 changes: 1 addition & 1 deletion nncf/experimental/tensor/functions/torch_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _(
device = a.device
# See https://github.com/pytorch/pytorch/issues/61582
# https://github.com/pytorch/pytorch/issues/64947
if len(a) <= 16_000_000 and isinstance(axis, int):
if a.numel() <= 16_000_000 and isinstance(axis, int):
result = torch.quantile(
a,
torch.tensor(q, dtype=a.dtype, device=a.device),
Expand Down
9 changes: 0 additions & 9 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,9 @@ class PTSmoothQuantAlgoBackend(SmoothQuantAlgoBackend):
@property
def convolution_metatypes(self) -> List[OperatorMetatype]:
return [
om.PTConv1dMetatype,
om.PTConv2dMetatype,
om.PTConv3dMetatype,
om.PTModuleConv1dMetatype,
om.PTModuleConv2dMetatype,
om.PTModuleConv3dMetatype,
om.PTDepthwiseConv1dSubtype,
om.PTDepthwiseConv2dSubtype,
om.PTDepthwiseConv3dSubtype,
om.PTConvTranspose1dMetatype,
om.PTConvTranspose2dMetatype,
om.PTConvTranspose3dMetatype,
]

@property
Expand Down
24 changes: 19 additions & 5 deletions tests/openvino/native/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend
from tests.post_training.test_templates.helpers import ConvTestModel
from tests.post_training.test_templates.helpers import LinearMultiShapeModel
from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm

OV_LINEAR_MODEL_MM_OP_MAP = {
Expand All @@ -40,7 +42,6 @@
"Linear4": "/linear_4/MatMul",
}


OV_LINEAR_MODEL_SQ_OP_MAP = {
"MatMul1": "/Reshape_0_0/nncf_smooth_quant",
"MatMul2": "/Reshape_0_0/nncf_smooth_quant",
Expand All @@ -56,6 +57,14 @@
"Linear4": "/Add_0_0/nncf_smooth_quant",
}

OV_CONV_MODEL_MM_OP_MAP = {
"Conv1": "/conv/Conv/WithoutBiases",
}

OV_CONV_MODEL_SQ_OP_MAP = {
"Conv1": "input.1_0_0/nncf_smooth_quant",
}


class TestOVSQAlgorithm(TemplateTestSQAlgorithm):
@staticmethod
Expand All @@ -66,8 +75,12 @@ def fn_to_type(tensor) -> np.ndarray:
def inplace_statistics(self, request) -> bool:
return request.param

def get_node_name_map(self) -> Dict[str, str]:
return OV_LINEAR_MODEL_MM_OP_MAP
def get_node_name_map(self, model_cls) -> Dict[str, str]:
if model_cls is LinearMultiShapeModel:
return OV_LINEAR_MODEL_MM_OP_MAP
if model_cls is ConvTestModel:
return OV_CONV_MODEL_MM_OP_MAP
raise NotImplementedError

@staticmethod
def get_target_node_name(command: TransformationCommand):
Expand All @@ -94,12 +107,13 @@ def backend_specific_model(model: torch.nn.Module, tmp_dir: str) -> ov.Model:
return ov_model

@staticmethod
def check_scales(model: ov.Model, reference_values: Dict[str, np.ndarray]) -> None:
def check_scales(model: ov.Model, reference_values: Dict[str, np.ndarray], model_cls) -> None:
names_map = OV_LINEAR_MODEL_SQ_OP_MAP if model_cls is LinearMultiShapeModel else OV_CONV_MODEL_SQ_OP_MAP
ops_list = {op.get_friendly_name(): op for op in model.get_ops()}
for ref_names, ref_value in reference_values.items():
const_nodes = []
for ref_name in ref_names:
node = ops_list[OV_LINEAR_MODEL_SQ_OP_MAP[ref_name]]
node = ops_list[names_map[ref_name]]
const_nodes.append(node.input(1).get_source_output().get_node())
# Check unified group acutally shares one constant
assert all(node is const_nodes[0] for node in const_nodes[1:])
Expand Down
20 changes: 14 additions & 6 deletions tests/post_training/test_templates/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant
from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend
from tests.post_training.test_templates.helpers import ConvTestModel
from tests.post_training.test_templates.helpers import LinearMultiShapeModel
from tests.post_training.test_templates.helpers import NonZeroLinearModel
from tests.post_training.test_templates.helpers import get_static_dataset
Expand All @@ -48,9 +49,9 @@ def inplace_statistics(self) -> bool:
"""

@abstractmethod
def get_node_name_map(self) -> Dict[str, str]:
def get_node_name_map(self, model_cls) -> Dict[str, str]:
"""
Return backend specific map from the LinearMultiShapeModel labels
Return backend specific map from the given model class labels
to nncf_grpah nodes names.
"""

Expand All @@ -77,7 +78,7 @@ def backend_specific_model(model: TModel, tmp_dir: str) -> TModel:

@staticmethod
@abstractmethod
def check_scales(model: TModel, reference_values: Dict[str, TTensor]) -> None:
def check_scales(model: TModel, reference_values: Dict[str, TTensor], model_cls) -> None:
"""
Checking scales from model with references.
"""
Expand All @@ -103,7 +104,7 @@ def get_quantization_algorithm():
model_type=ModelType.TRANSFORMER,
advanced_parameters=AdvancedQuantizationParameters(
overflow_fix=OverflowFix.DISABLE,
smooth_quant_alphas=AdvancedSmoothQuantParameters(matmul=0.95),
smooth_quant_alphas=AdvancedSmoothQuantParameters(matmul=0.95, convolution=0.95),
inplace_statistics=False,
),
)
Expand Down Expand Up @@ -141,6 +142,12 @@ def get_quantization_algorithm():
("Linear3", "Linear4"): [[[[0.33630377, 0.3288621, 0.9898262, 0.7217065]]]],
},
),
(
ConvTestModel,
{
("Conv1",): [[[[1.0723]]]],
},
),
),
)
def test_smooth_quant_algo(self, model_cls, reference_values, tmpdir):
Expand All @@ -151,7 +158,7 @@ def test_smooth_quant_algo(self, model_cls, reference_values, tmpdir):
graph = NNCFGraphFactory.create(model)
quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset)

self.check_scales(quantized_model, reference_values)
self.check_scales(quantized_model, reference_values, model_cls)

def test_get_abs_max_channel_collector(self, inplace_statistics: bool):
backend = self.get_backend()
Expand Down Expand Up @@ -195,6 +202,7 @@ def test_get_abs_max_channel_collector(self, inplace_statistics: bool):
("Linear4", 0),
],
),
(ConvTestModel, [("Conv1", 0)]),
),
)
def test__get_nodes_to_smooth_data(self, model_cls, references, tmpdir):
Expand All @@ -207,7 +215,7 @@ def test__get_nodes_to_smooth_data(self, model_cls, references, tmpdir):
smooth_data = algo._get_nodes_to_smooth_data(nncf_graph, alpha_map.keys())
smooth_data = {d["node_to_smooth"].node_name: d["input_act_port"] for d in smooth_data}

name_map = self.get_node_name_map()
name_map = self.get_node_name_map(model_cls)
assert len(name_map) == len(smooth_data)
matched = 0
for ref_node_name, ref_port_id in references:
Expand Down
23 changes: 17 additions & 6 deletions tests/torch/ptq/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.model_creation import wrap_model
from nncf.torch.nncf_network import ExtraCompressionModuleType
from tests.post_training.test_templates.helpers import ConvTestModel
from tests.post_training.test_templates.helpers import LinearMultiShapeModel
from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm

PT_LINEAR_MODEL_SQ_MAP = {
Expand All @@ -39,6 +41,10 @@
"Linear4": "LinearMultiShapeModel/NNCFLinear[linear_4]/linear_0",
}

PT_CONV_MODEL_SQ_MAP = {("Conv1",): "/nncf_model_input_0_0_0/nncf_smooth_quant"}

PT_CONV_MODEL_MM_MAP = {"Conv1": "ConvTestModel/NNCFConv2d[conv]/conv2d_0"}


class TestTorchSQAlgorithm(TemplateTestSQAlgorithm):
@staticmethod
Expand All @@ -49,8 +55,12 @@ def fn_to_type(tensor) -> torch.Tensor:
def inplace_statistics(self, request) -> bool:
return request.param

def get_node_name_map(self) -> Dict[str, str]:
return PT_LINEAR_MODEL_MM_MAP
def get_node_name_map(self, model_cls) -> Dict[str, str]:
if model_cls is LinearMultiShapeModel:
return PT_LINEAR_MODEL_MM_MAP
if model_cls is ConvTestModel:
return PT_CONV_MODEL_MM_MAP
raise NotImplementedError

@staticmethod
def get_target_node_name(command: TransformationCommand):
Expand All @@ -74,14 +84,15 @@ def backend_specific_model(model: torch.nn.Module, tmp_dir: str) -> ov.Model:
return wrap_model(model.eval(), torch.rand(model.INPUT_SIZE))

@staticmethod
def check_scales(model: torch.nn.Module, reference_values: Dict[str, np.ndarray]) -> None:
def check_scales(model: torch.nn.Module, reference_values: Dict[str, np.ndarray], model_cls) -> None:
names_map = PT_LINEAR_MODEL_SQ_MAP if model_cls is LinearMultiShapeModel else PT_CONV_MODEL_SQ_MAP
modules = model.nncf.get_compression_modules_by_type(ExtraCompressionModuleType.EXTERNAL_OP)
for ref_names, ref_value in reference_values.items():
if not all(name.startswith("Linear") for name in ref_names):
# Pytorch SQ algorithm supports only linear modules by far,
if not all(name.startswith("Linear") or name.startswith("Conv") for name in ref_names):
# Pytorch SQ algorithm supports only linear and conv modules by far,
# so other multiplies are skipped
continue
sq_node = modules[PT_LINEAR_MODEL_SQ_MAP[ref_names]]
sq_node = modules[names_map[ref_names]]

assert isinstance(sq_node, SQMultiply)

Expand Down

0 comments on commit ad5e134

Please sign in to comment.