Skip to content

Commit

Permalink
OpenVINOQuantizer as torch.ao Quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Dec 5, 2024
1 parent 0827e21 commit 908deae
Show file tree
Hide file tree
Showing 9 changed files with 11,051 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import Optional, Union

import torch.fx
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.observer import PerChannelMinMaxObserver
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation as InductorQAnotation
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as InductorQuantizationSpec
from torch.ao.quantization.quantizer.quantizer import Quantizer as InductorQuantizer

from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationRule
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.common.quantization.structs import QuantizationPreset
from nncf.common.quantization.structs import QuantizationScheme
from nncf.common.quantization.structs import QuantizerConfig as NNCFQuantizerConfig
from nncf.experimental.common.quantization.algorithms.quantizer.base_quantizer import NNCFQuantizer
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.parameters import ModelType
from nncf.parameters import QuantizationMode
from nncf.parameters import TargetDevice
Expand All @@ -27,8 +38,10 @@
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.scopes import IgnoredScope

QUANT_ANNOTATION_KEY = "quantization_annotation"

class OpenVINOQuantizer(NNCFQuantizer):

class OpenVINOQuantizer(InductorQuantizer, NNCFQuantizer):
def __init__(
self,
mode: Optional[QuantizationMode] = None,
Expand Down Expand Up @@ -80,12 +93,74 @@ def __init__(
)

def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
"""
Builds SingleConfigQuantizerSetup for the given model.
:param model: Backend-specific model, for which Quantization Target Points are being seek.
:param nncf_graph: NNCFGraph instance.
:return: SingleConfigQuantizerSetup for the given model.
"""
self._min_max_algo._set_backend_entity(model)
return self._min_max_algo._find_quantization_setup(model, nncf_graph)
return self._min_max_algo.find_quantization_setup(model, nncf_graph)

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
nncf_grpah = GraphConverter.create_nncf_graph(model)
quantization_setup = self.get_quantization_setup(model, nncf_grpah)
target_node_vs_qp = defaultdict(list)
graph = model.graph
for qp in quantization_setup.quantization_points.values():
target_node_vs_qp[qp.insertion_point.target_node_name].append(qp)

for target_node_name, qps in target_node_vs_qp.items():
input_qspec_map = dict()
output_qspec = None
target_node = get_graph_node_by_name(graph, target_node_name)
for qp in qps:
ip = qp.insertion_point
if isinstance(ip, ActivationQuantizationInsertionPoint):
inductor_qspec = self._convert_nncf_qspec_to_inductor_qspec(qp.qconfig, is_weight=False)
if ip.input_port_id is None:
output_qspec = inductor_qspec
else:
node = target_node.all_input_nodes[ip.input_port_id]
input_qspec_map[node] = inductor_qspec
else:
inductor_qspec = self._convert_nncf_qspec_to_inductor_qspec(qp.qconfig, is_weight=True)
weight_node = target_node.all_input_nodes[1]
input_qspec_map[weight_node] = inductor_qspec

annotation = InductorQAnotation(input_qspec_map=input_qspec_map, output_qspec=output_qspec)
assert QUANT_ANNOTATION_KEY not in target_node.meta
target_node.meta[QUANT_ANNOTATION_KEY] = annotation

def _convert_nncf_qspec_to_inductor_qspec(self, qspec: NNCFQuantizerConfig, is_weight: bool):
extra_args = {"eps": 2**-12}
if qspec.per_channel:
torch_qscheme = (
torch.per_channel_symmetric if qspec.mode is QuantizationScheme.SYMMETRIC else torch.per_channel_affine
)
else:
torch_qscheme = (
torch.per_tensor_symmetric if qspec.mode is QuantizationScheme.SYMMETRIC else torch.per_tensor_affine
)
if is_weight:
observer = PerChannelMinMaxObserver
quant_min = -128
quant_max = 127
dtype = torch.int8
channel_axis = 0
else:
observer = (
HistogramObserver
if torch_qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
else PerChannelMinMaxObserver
)
quant_min = 0
quant_max = 255
dtype = torch.int8 if qspec.signedness_to_force else torch.uint8
channel_axis = 1 # channel dim for activations
return InductorQuantizationSpec(
dtype=dtype,
observer_or_fake_quant_ctr=observer.with_args(**extra_args),
quant_min=quant_min,
quant_max=quant_max,
qscheme=torch_qscheme,
ch_axis=channel_axis,
is_dynamic=False,
)

def validate(self, model: torch.fx.GraphModule) -> None:
pass
Loading

0 comments on commit 908deae

Please sign in to comment.