Skip to content

Commit

Permalink
Experimental quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 28, 2024
1 parent e6bf1d5 commit 7f12283
Show file tree
Hide file tree
Showing 13 changed files with 9,600 additions and 169 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@
from nncf.common.graph.graph import NNCFGraph
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.quantization.algorithms.post_training.pipeline import create_ptq_pipeline
from nncf.experimental.common.quantization.algorithms.post_training.pipeline import experimental_create_ptq_pipeline
from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer
from nncf.parameters import ModelType
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters
from nncf.quantization.algorithms.algorithm import Algorithm

TModel = TypeVar("TModel")
TPass = Callable[[TModel], TModel]


class PostTrainingQuantization(Algorithm):
class ExperimentalPostTrainingQuantization(Algorithm):
"""
Implements Post-Training Quantization algorithm, which basically includes:
1) ChannelAlignment
Expand All @@ -38,9 +39,12 @@ def __init__(
self,
quantizer: NNCFQuantizer,
subset_size: int = 300,
fast_bias_correction: bool = True,
model_type: Optional[ModelType] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
fast_bias_correction: Optional[bool] = True,
smooth_quant: bool = False,
bias_correction_params: Optional[AdvancedBiasCorrectionParameters] = None,
smooth_quant_params: Optional[AdvancedSmoothQuantParameters] = None,
activations_range_estimator_params: Optional[RangeEstimatorParameters] = None,
weights_range_estimator_params: Optional[RangeEstimatorParameters] = None,
):
"""
:param mode: Special quantization mode that specify different ways of the optimization.
Expand All @@ -65,12 +69,15 @@ def __init__(
:param advanced_parameters: Advanced quantization parameters for
fine-tuning the quantization algorithm
"""
self._pipeline = create_ptq_pipeline(
self._pipeline = experimental_create_ptq_pipeline(
quantizer=quantizer,
subset_size=subset_size,
fast_bias_correction=fast_bias_correction,
model_type=model_type,
advanced_parameters=advanced_parameters,
smooth_quant=smooth_quant,
bias_correction_params=bias_correction_params,
smooth_quant_params=smooth_quant_params,
activations_range_estimator_params=activations_range_estimator_params,
weights_range_estimator_params=weights_range_estimator_params,
)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@

from typing import Optional, TypeVar

from nncf.common.deprecation import warning_deprecated
from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer
from nncf.experimental.common.quantization.algorithms.range_estimator.range_estimator import MinMaxRangeEstimator
from nncf.parameters import ModelType
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters
from nncf.quantization.algorithms.bias_correction.algorithm import BIAS_CORRECTION_THRESHOLD
from nncf.quantization.algorithms.bias_correction.algorithm import BiasCorrection
from nncf.quantization.algorithms.channel_alignment.algorithm import ChannelAlignment
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FAST_BIAS_CORRECTION_THRESHOLD
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection
from nncf.quantization.algorithms.pipeline import Pipeline
Expand All @@ -27,19 +26,21 @@
TModel = TypeVar("TModel")


def create_ptq_pipeline(
def experimental_create_ptq_pipeline(
quantizer: NNCFQuantizer,
subset_size: int = 300,
fast_bias_correction: bool = True,
model_type: Optional[ModelType] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
fast_bias_correction: Optional[bool] = True,
smooth_quant: bool = False,
bias_correction_params: Optional[AdvancedBiasCorrectionParameters] = None,
smooth_quant_params: Optional[AdvancedSmoothQuantParameters] = None,
activations_range_estimator_params: Optional[RangeEstimatorParameters] = None,
weights_range_estimator_params: Optional[RangeEstimatorParameters] = None,
) -> Pipeline:
"""
Creates a post-training quantization pipeline.
The post-training quantization pipeline includes the following steps:
1) SmoothQuant
2) ChannelAlignment
3) MinMaxQuantization
4) FastBiasCorrection or BiasCorrection
Expand All @@ -60,60 +61,36 @@ def create_ptq_pipeline(
more time but requires less memory.
:param model_type: Model type is needed to specify additional patterns
in the model. Supported only `transformer` now.
:param advanced_parameters: Advanced quantization parameters for
fine-tuning the quantization algorithm
:return: A post-training quantization pipeline.
"""

if advanced_parameters is None:
advanced_parameters = AdvancedQuantizationParameters()

# Build the post-training quantization pipeline.
pipeline_steps = []

# Add the `SmoothQuant` algorithm as the first step of the pipeline.
# It is added only for `ModelType.TRANSFORMER`.
sq_params = advanced_parameters.smooth_quant_alphas
sq_alpha = advanced_parameters.smooth_quant_alpha
if sq_alpha is not None:
warning_deprecated(
"`AdvancedQuantizationParameters(smooth_quant_alpha=..)` is deprecated."
"Please, use `AdvancedQuantizationParameters(smooth_quant_alphas)` option "
"with AdvancedSmoothQuantParameters(convolution=.., matmul=..) as value instead."
)
if sq_alpha < 0:
sq_params.convolution = -1
sq_params.matmul = -1
else:
sq_params.matmul = sq_alpha
if smooth_quant_params is None:
smooth_quant_params = AdvancedSmoothQuantParameters()

if model_type == ModelType.TRANSFORMER and (sq_params.convolution >= 0 or sq_params.matmul >= 0):
alpha_map = {"convolution": sq_params.convolution, "matmul": sq_params.matmul}
pipeline_steps.append([SmoothQuant(subset_size, advanced_parameters.inplace_statistics, alpha_map=alpha_map)])

# Add the `ChannelAlignment` algorithm as the second step of the pipeline.
if not advanced_parameters.disable_channel_alignment:
pipeline_steps.append([ChannelAlignment(subset_size, advanced_parameters.inplace_statistics)])
if smooth_quant and smooth_quant_params.convolution >= 0 or smooth_quant_params.matmul >= 0:
alpha_map = {"convolution": smooth_quant_params.convolution, "matmul": smooth_quant_params.matmul}
pipeline_steps.append([SmoothQuant(subset_size, False, alpha_map=alpha_map)])

# Add the `MinMaxQuantization` algorithm as the third step of the pipeline.
pipeline_steps.append(
[
MinMaxRangeEstimator(
quantizer=quantizer,
subset_size=subset_size,
inplace_statistics=advanced_parameters.inplace_statistics,
batchwise_statistics=advanced_parameters.batchwise_statistics,
activations_range_estimator_params=advanced_parameters.activations_range_estimator_params,
weights_range_estimator_params=advanced_parameters.weights_range_estimator_params,
inplace_statistics=False,
activations_range_estimator_params=activations_range_estimator_params,
weights_range_estimator_params=weights_range_estimator_params,
)
]
)

if not advanced_parameters.disable_bias_correction:
if fast_bias_correction is not None:
# Add the `FastBiasCorrection` or `BiasCorrection` as additional algorithm
# inside the third step of the pipeline. It is added after `MinMaxQuantization`
# algorithm.
bias_correction_params = advanced_parameters.bias_correction_params
if fast_bias_correction:
threshold = FAST_BIAS_CORRECTION_THRESHOLD
bias_correction_subset_size = subset_size
Expand All @@ -123,6 +100,9 @@ def create_ptq_pipeline(
bias_correction_subset_size = max(int(subset_size * 0.2), 1)
bias_correction_cls = BiasCorrection

if bias_correction_params is None:
bias_correction_params = AdvancedBiasCorrectionParameters()

if bias_correction_params.threshold is not None:
threshold = bias_correction_params.threshold

Expand All @@ -131,8 +111,6 @@ def create_ptq_pipeline(
bias_correction_subset_size,
threshold,
bias_correction_params.apply_for_all_nodes,
advanced_parameters.inplace_statistics,
advanced_parameters.backend_params,
)
)

Expand Down
42 changes: 34 additions & 8 deletions nncf/experimental/torch/fx/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@
from nncf.common.factory import NNCFGraphFactory
from nncf.common.logging import nncf_logger
from nncf.data import Dataset
from nncf.experimental.common.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.experimental.common.quantization.algorithms.post_training.algorithm import (
ExperimentalPostTrainingQuantization,
)
from nncf.experimental.common.quantization.algorithms.quantizer.fx_quantizer import NNCFFXQuantizer
from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS
from nncf.experimental.torch.fx.transformations import fuse_conv_bn
from nncf.parameters import ModelType
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters

DEFAULT_RANGE_TYPE = "mean_min_max"

Expand All @@ -40,8 +45,12 @@ def quantize_pt2e(
calibration_dataset: Dataset,
subset_size: int = 300,
fast_bias_correction: bool = True,
model_type: Optional[ModelType] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
smooth_quant: bool = False,
bias_correction_params: Optional[AdvancedBiasCorrectionParameters] = None,
smooth_quant_params: Optional[AdvancedSmoothQuantParameters] = None,
activations_range_estimator_params: Optional[RangeEstimatorParameters] = None,
weights_range_estimator_params: Optional[RangeEstimatorParameters] = None,
fold_quantize: Optional[bool] = False,
) -> torch.fx.GraphModule:
"""
Implementation of the `quantize()` method for the Torch FX backend.
Expand All @@ -56,12 +65,15 @@ def quantize_pt2e(

copied_model = deepcopy(model)

quantization_algorithm = PostTrainingQuantization(
quantization_algorithm = ExperimentalPostTrainingQuantization(
quantizer=NNCFFXQuantizer(quantizer),
subset_size=subset_size,
fast_bias_correction=fast_bias_correction,
model_type=model_type,
advanced_parameters=advanced_parameters,
smooth_quant=smooth_quant,
bias_correction_params=bias_correction_params,
smooth_quant_params=smooth_quant_params,
activations_range_estimator_params=activations_range_estimator_params,
weights_range_estimator_params=weights_range_estimator_params,
)

# To make it easier for bias correction algorithms,
Expand All @@ -76,6 +88,9 @@ def quantize_pt2e(
quantized_model = GraphModule(quantized_model, quantized_model.graph)

quantized_model = _fold_conv_bn_qat(quantized_model)
if fold_quantize:
constant_fold(quantized_model, _quant_node_constraint)

pm = PassManager([DuplicateDQPass()])

quantized_model = pm(quantized_model).graph_module
Expand All @@ -89,3 +104,14 @@ def quantize_pt2e(
quantized_model = GraphModule(quantized_model, quantized_model.graph)

return quantized_model


def _quant_node_constraint(n: torch.fx.Node) -> bool:
"""If there is any pure ops between get_attr and quantize op they will be const propagated
e.g. get_attr(weight) -> transpose -> quantize -> dequantize*
(Note: dequantize op is not going to be constant propagated)
This filter is added because we don't want to constant fold the things that are not
related to quantization
"""
return n.op == "call_function" and n.target in QUANTIZE_NODE_TARGETS
2 changes: 2 additions & 0 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@

QUANTIZE_NODE_TARGETS = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]
DEQUANTIZE_NODE_TARGETS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
# Map quantize_per_tensor to dequantize_per_tensor, the same for per_channel and vice-versa
Expand Down
Loading

0 comments on commit 7f12283

Please sign in to comment.