Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dl/fx/experimental quantization conformance #33

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from typing import Callable, List, Optional, TypeVar

from nncf import Dataset
from nncf.common.graph.graph import NNCFGraph
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.quantization.algorithms.post_training.pipeline import create_ptq_pipeline
from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer
from nncf.parameters import ModelType
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.algorithms.algorithm import Algorithm

TModel = TypeVar("TModel")
TPass = Callable[[TModel], TModel]


class PostTrainingQuantization(Algorithm):
"""
Implements Post-Training Quantization algorithm, which basically includes:
1) ChannelAlignment
2) MinMaxQuantization
3) FastBiasCorrection or BiasCorrection
"""

def __init__(
self,
quantizer: NNCFQuantizer,
subset_size: int = 300,
fast_bias_correction: bool = True,
model_type: Optional[ModelType] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
):
"""
:param mode: Special quantization mode that specify different ways of the optimization.
:param preset: A preset controls the quantization mode (symmetric and asymmetric).
It can take the following values:
- `performance`: Symmetric quantization of weights and activations.
- `mixed`: Symmetric quantization of weights and asymmetric quantization of activations.
Default value is None. In this case, `mixed` preset is used for `transformer`
model type otherwise `performace`.
:param target_device: A target device the specificity of which will be taken
into account while compressing in order to obtain the best performance
for this type of device.
:param subset_size: Size of a subset to calculate activations
statistics used for quantization.
:param fast_bias_correction: Setting this option to `False` enables a different
bias correction method which is more accurate, in general, and takes
more time but requires less memory.
:param model_type: Model type is needed to specify additional patterns
in the model. Supported only `transformer` now.
:param ignored_scope: An ignored scope that defined the list of model control
flow graph nodes to be ignored during quantization.
:param advanced_parameters: Advanced quantization parameters for
fine-tuning the quantization algorithm
"""
self._pipeline = create_ptq_pipeline(
quantizer=quantizer,
subset_size=subset_size,
fast_bias_correction=fast_bias_correction,
model_type=model_type,
advanced_parameters=advanced_parameters,
)

@property
def available_backends(self) -> List[BackendType]:
backends = set(BackendType)
for algorithm in itertools.chain.from_iterable(self._pipeline.pipeline_steps):
backends = backends.intersection(algorithm.available_backends)
return list(backends)

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
return self._pipeline.get_statistic_points_for_step(0, model, graph)

def apply(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
if dataset is None and len(self._pipeline.pipeline_steps) > 1:
raise ValueError(
"A dataset is required for the post-training quantization "
"algorithm to collect statistics for intermediate models."
)

step_index_to_statistics = None
if statistic_points:
step_index_to_statistics = {0: statistic_points}

return self._pipeline.run_from_step(model, dataset, graph, 0, step_index_to_statistics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, TypeVar

from nncf.common.deprecation import warning_deprecated
from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer
from nncf.experimental.common.quantization.algorithms.range_estimator.range_estimator import MinMaxRangeEstimator
from nncf.parameters import ModelType
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.algorithms.bias_correction.algorithm import BIAS_CORRECTION_THRESHOLD
from nncf.quantization.algorithms.bias_correction.algorithm import BiasCorrection
from nncf.quantization.algorithms.channel_alignment.algorithm import ChannelAlignment
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FAST_BIAS_CORRECTION_THRESHOLD
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection
from nncf.quantization.algorithms.pipeline import Pipeline
from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant

TModel = TypeVar("TModel")


def create_ptq_pipeline(
quantizer: NNCFQuantizer,
subset_size: int = 300,
fast_bias_correction: bool = True,
model_type: Optional[ModelType] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
) -> Pipeline:
"""
Creates a post-training quantization pipeline.
The post-training quantization pipeline includes the following steps:
1) SmoothQuant
2) ChannelAlignment
3) MinMaxQuantization
4) FastBiasCorrection or BiasCorrection
:param mode: Special quantization mode that specify different ways of the optimization.
:param preset: A preset controls the quantization mode (symmetric and asymmetric).
It can take the following values:
- `performance`: Symmetric quantization of weights and activations.
- `mixed`: Symmetric quantization of weights and asymmetric quantization of activations.
Default value is None. In this case, `mixed` preset is used for `transformer`
model type otherwise `performace`.
:param target_device: A target device the specificity of which will be taken
into account while compressing in order to obtain the best performance
for this type of device.
:param subset_size: Size of a subset to calculate activations
statistics used for quantization.
:param fast_bias_correction: Setting this option to `False` enables a different
bias correction method which is more accurate, in general, and takes
more time but requires less memory.
:param model_type: Model type is needed to specify additional patterns
in the model. Supported only `transformer` now.
:param advanced_parameters: Advanced quantization parameters for
fine-tuning the quantization algorithm
:return: A post-training quantization pipeline.
"""

if advanced_parameters is None:
advanced_parameters = AdvancedQuantizationParameters()

# Build the post-training quantization pipeline.
pipeline_steps = []

# Add the `SmoothQuant` algorithm as the first step of the pipeline.
# It is added only for `ModelType.TRANSFORMER`.
sq_params = advanced_parameters.smooth_quant_alphas
sq_alpha = advanced_parameters.smooth_quant_alpha
if sq_alpha is not None:
warning_deprecated(
"`AdvancedQuantizationParameters(smooth_quant_alpha=..)` is deprecated."
"Please, use `AdvancedQuantizationParameters(smooth_quant_alphas)` option "
"with AdvancedSmoothQuantParameters(convolution=.., matmul=..) as value instead."
)
if sq_alpha < 0:
sq_params.convolution = -1
sq_params.matmul = -1
else:
sq_params.matmul = sq_alpha

if model_type == ModelType.TRANSFORMER and (sq_params.convolution >= 0 or sq_params.matmul >= 0):
alpha_map = {"convolution": sq_params.convolution, "matmul": sq_params.matmul}
pipeline_steps.append([SmoothQuant(subset_size, advanced_parameters.inplace_statistics, alpha_map=alpha_map)])

# Add the `ChannelAlignment` algorithm as the second step of the pipeline.
if not advanced_parameters.disable_channel_alignment:
pipeline_steps.append([ChannelAlignment(subset_size, advanced_parameters.inplace_statistics)])

# Add the `MinMaxQuantization` algorithm as the third step of the pipeline.
pipeline_steps.append(
[
MinMaxRangeEstimator(
quantizer=quantizer,
subset_size=subset_size,
inplace_statistics=advanced_parameters.inplace_statistics,
batchwise_statistics=advanced_parameters.batchwise_statistics,
activations_range_estimator_params=advanced_parameters.activations_range_estimator_params,
weights_range_estimator_params=advanced_parameters.weights_range_estimator_params,
)
]
)

if not advanced_parameters.disable_bias_correction:
# Add the `FastBiasCorrection` or `BiasCorrection` as additional algorithm
# inside the third step of the pipeline. It is added after `MinMaxQuantization`
# algorithm.
bias_correction_params = advanced_parameters.bias_correction_params
if fast_bias_correction:
threshold = FAST_BIAS_CORRECTION_THRESHOLD
bias_correction_subset_size = subset_size
bias_correction_cls = FastBiasCorrection
else:
threshold = BIAS_CORRECTION_THRESHOLD
bias_correction_subset_size = max(int(subset_size * 0.2), 1)
bias_correction_cls = BiasCorrection

if bias_correction_params.threshold is not None:
threshold = bias_correction_params.threshold

pipeline_steps[-1].append(
bias_correction_cls(
bias_correction_subset_size,
threshold,
bias_correction_params.apply_for_all_nodes,
advanced_parameters.inplace_statistics,
advanced_parameters.backend_params,
)
)

return Pipeline(pipeline_steps)
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from collections import defaultdict
from copy import deepcopy

import torch
import torch.fx
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
from torch.ao.quantization.pt2e.prepare import _get_obs_or_fq_map
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec

import nncf
from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer


class NNCFFXQuantizer(NNCFQuantizer):
def __init__(self, quantizer: Quantizer):
self._quantizer = quantizer

def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
anotated_model = deepcopy(model)

self._quantizer.transform_for_annotation(anotated_model)
self._quantizer.annotate(anotated_model)
self._quantizer.validate(anotated_model)
return self.get_quantizer_config_from_anotated_model(anotated_model)

@staticmethod
def get_quantizer_config_from_anotated_model(anotated_model: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
is_qat = False
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(anotated_model)
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat)
if obs_or_fq_map:
pass

q_map = defaultdict(list)
for edge, qspec in edge_or_node_to_qspec.items():
if not isinstance(edge, tuple):
continue
from_n, to_n = edge
q_map[from_n].append(to_n)

q_setup = SingleConfigQuantizerSetup()
for from_n, to_nodes in q_map.items():
to_n = to_nodes[0]
qspec = edge_or_node_to_qspec[(from_n, to_n)]
if qspec is None:
continue
if isinstance(qspec, QuantizationSpec):
if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
per_channel = True
elif qspec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
per_channel = False
else:
raise nncf.InternalError(f"Unknown qscheme: {qspec.qscheme}")
signed = qspec.dtype is torch.uint8
mode = (
QuantizationMode.SYMMETRIC
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
else QuantizationMode.ASYMMETRIC
)
qconfig = QuantizerConfig(mode=mode, signedness_to_force=signed, per_channel=per_channel)
qps = []
# If input node is a constant and placed not at activations port (0)
if from_n.op == "get_attr" and to_n.args.index(from_n) != 0:
qip = WeightQuantizationInsertionPoint(to_n.name)
qp = SingleConfigQuantizationPoint(qip, qconfig, [x.name for x in to_nodes])
qps.append(qp)
else:
if len(from_n.users) == len(to_nodes):
qip = ActivationQuantizationInsertionPoint(from_n.name)
qp = SingleConfigQuantizationPoint(qip, qconfig, [x.name for x in to_nodes])
qps.append(qp)
else:
for to_n_ in to_nodes:
input_port_id = to_n_.args.index(from_n)
qip = ActivationQuantizationInsertionPoint(to_n_.name, input_port_id)
qp = SingleConfigQuantizationPoint(qip, qconfig, [to_n_.name])
qps.append(qp)

for qp in qps:
q_setup.add_independent_quantization_point(qp)

elif isinstance(qspec, SharedQuantizationSpec):
pass
else:
raise nncf.InternalError(f"Unknown torch.ao quantization spec: {qspec}")

return q_setup
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import TypeVar

from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup

TModel = TypeVar("TModel")


class NNCFQuantizer:
@abstractmethod
def get_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
"""
Return quantization setup.
"""
Loading
Loading