Skip to content

Commit

Permalink
Move SQ to experimental tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 29, 2023
1 parent ba1de09 commit adfe87b
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 170 deletions.
44 changes: 43 additions & 1 deletion nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,10 @@ def stack(x: List[Tensor], axis: int = 0) -> Tensor:
:return: Stacked Tensor.
"""
if isinstance(x, List):
return Tensor(_dispatch_list(stack, x, axis=axis))
unwrapped_x = [i.data for i in x]
# singledispatch cannot dispatch function by element in a list
res = stack.dispatch(type(unwrapped_x[0]))(unwrapped_x, axis=axis)
return Tensor(res)
raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}")


Expand Down Expand Up @@ -400,6 +403,45 @@ def round(a: Tensor, decimals=0) -> Tensor:
return Tensor(round(a.data, decimals))


@functools.singledispatch
@_tensor_guard
def clip(a: Tensor, min_val: float, max_val: Optional[float] = None) -> Tensor:
return Tensor(clip(a.data, min_val, max_val))


@functools.singledispatch
@_tensor_guard
def eps(a: Tensor, dtype: TensorDataType) -> float:
return eps(a.data, dtype)


@functools.singledispatch
@_tensor_guard
def power(a: Tensor, pwr: float) -> Tensor:
return Tensor(power(a.data, pwr))


@functools.singledispatch
@_tensor_guard
def quantile(
a: Tensor,
q: Union[float, List[float]],
axis: Union[int, List[int]] = None,
keepdims: Optional[bool] = None,
) -> Union[float, Tensor]:
retval = quantile(a.data, q, axis, keepdims)

if isinstance(retval, float):
return retval
return Tensor(retval)


@functools.singledispatch
@_tensor_guard
def size(a: Tensor) -> int:
return size(a.data)


@functools.singledispatch
@_tensor_guard
def _binary_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: Callable) -> Tensor:
Expand Down
34 changes: 34 additions & 0 deletions nncf/experimental/tensor/numpy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,40 @@ def _(a: Union[np.ndarray, np.generic], decimals: int = 0) -> np.ndarray:
return np.round(a, decimals=decimals)


@_register_numpy_types(fns.clip)
def _(
a: Union[np.ndarray, np.generic], min_val: float, max_val: Optional[float] = None
) -> Union[np.ndarray, np.generic]:
return np.clip(a, a_min=min_val, a_max=max_val)


@_register_numpy_types(fns.eps)
def _(a: Union[np.ndarray, np.generic], dtype: TensorDataType) -> float:
return np.finfo(DTYPE_MAP[dtype]).eps


@_register_numpy_types(fns.power)
def _(a: Union[np.ndarray, np.generic], pwr: float) -> Union[np.ndarray, np.generic]:
return np.power(a, pwr)


@_register_numpy_types(fns.quantile)
def _(
a: Union[np.ndarray, np.generic],
q: Union[float, List[float]],
axis: Union[int, List[int]] = None,
keepdims: Optional[bool] = None,
) -> Union[float, Union[np.ndarray, np.generic]]:
if keepdims is None:
keepdims = np._NoValue
return np.quantile(a, q=q, axis=axis, keepdims=keepdims)


@_register_numpy_types(fns.size)
def _(a: Union[np.ndarray, np.generic]) -> int:
return a.size


@_register_numpy_types(fns._binary_op_nowarn)
def _(
a: Union[np.ndarray, np.generic], b: Union[np.ndarray, np.generic, float], operator_fn: Callable
Expand Down
4 changes: 4 additions & 0 deletions nncf/experimental/tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def device(self) -> TensorDeviceType:
def dtype(self) -> TensorDeviceType:
return _call_function("dtype", self)

@property
def size(self) -> int:
return _call_function("size", self)

def __bool__(self) -> bool:
return bool(self.data)

Expand Down
39 changes: 39 additions & 0 deletions nncf/experimental/tensor/torch_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch

from nncf.experimental.tensor import TensorDataType
Expand Down Expand Up @@ -192,6 +193,44 @@ def _(a: torch.Tensor, decimals=0) -> torch.Tensor:
return torch.round(a, decimals=decimals)


@fns.clip.register(torch.Tensor)
def _(a: torch.Tensor, min_val: float, max_val: Optional[float] = None) -> torch.Tensor:
return torch.clip(a, min=min_val, max=max_val)


@fns.eps.register(torch.Tensor)
def _(a: torch.Tensor, dtype: TensorDataType) -> float:
return torch.finfo(DTYPE_MAP[dtype]).eps


@fns.power.register(torch.Tensor)
def _(a: torch.Tensor, pwr: float) -> torch.Tensor:
return torch.pow(a, exponent=pwr)


@fns.quantile.register(torch.Tensor)
def _(
a: torch.Tensor,
q: Union[float, List[float]],
axis: Union[int, List[int]] = None,
keepdims: Optional[bool] = None,
) -> Union[float, torch.Tensor]:
# See https://github.com/pytorch/pytorch/issues/61582
# https://github.com/pytorch/pytorch/issues/64947
device = a.device
if keepdims is None:
keepdims = np._NoValue
np_result = np.quantile(a.detach().cpu().numpy(), q=q, axis=axis, keepdims=keepdims)
if isinstance(np_result, np.ndarray):
return torch.tensor(np_result).type(a.dtype).to(device)
return np_result


@fns.size.register(torch.Tensor)
def _(a: torch.Tensor) -> int:
return a.numel()


@fns._binary_op_nowarn.register(torch.Tensor)
def _(a: torch.Tensor, b: Union[torch.Tensor, float], operator_fn: Callable) -> torch.Tensor:
return operator_fn(a, b)
Expand Down
81 changes: 64 additions & 17 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.experimental.tensor import Tensor
from nncf.experimental.tensor import TensorDataType
from nncf.experimental.tensor import functions as fns
from nncf.quantization.algorithms.algorithm import Algorithm

TModel = TypeVar("TModel")
Expand Down Expand Up @@ -123,21 +126,19 @@ def apply(
activations_value = self._get_statistics_for_node(
statistic_points, node_to_smooth.node_name, input_port_id
)
if any(val is None for val in activations_value):
if any(val.data is None for val in activations_value):
empty_statistic = True
break
assert len(activations_value) == 1
activations_value = self._backend_entity.clip_statistics(activations_value[0])
activations_value = self._clip_statistics(activations_value)

weight_value = self._backend_entity.get_weight_value(node_to_smooth, model)
weight_statistics = self._process_weight_statistics(node_to_smooth, weight_value)
weight_statistics = self._backend_entity.clip_statistics(weight_statistics)
weight_statistics = self._clip_statistics([weight_statistics])

alpha = alpha_map[node_to_smooth.metatype]

scales, ratio = self._backend_entity.calculate_scale_and_ratio(
activations_value, weight_statistics, alpha
)
scales, ratio = self._calculate_scale_and_ratio(activations_value, weight_statistics, alpha)

if ratio > best_ratio:
best_ratio = ratio
Expand All @@ -158,24 +159,46 @@ def apply(
for node_to_smooth in nodes:
weight_value = self._backend_entity.get_weight_value(node_to_smooth, model)
weights_scale = self._calculate_weight_scale(best_scale, node_to_smooth, weight_value)
### TODO: DO it as NNCFTensor op
scaled_weight = weight_value * weights_scale
###
weight_update_command = self._backend_entity.weight_update_command(node_to_smooth, scaled_weight)
weight_update_command = self._backend_entity.weight_update_command(node_to_smooth, scaled_weight.data)
transformation_layout.register(weight_update_command)

activations_shape = graph.get_output_edges(source_node)[source_output_port_id].tensor_shape
activation_scale = self._calculate_activation_scale(best_scale, activations_shape, nodes, graph)

scale_node_name = self._create_scale_node_name(source_node.node_name, source_output_port_id)
scale_insertion_command = self._backend_entity.scale_insertion_command(
source_node, activation_scale, source_output_port_id, nodes, scale_node_name
source_node, activation_scale.data, source_output_port_id, nodes, scale_node_name
)
transformation_layout.register(scale_insertion_command)

transformed_model = model_transformer.transform(transformation_layout)
return transformed_model

@staticmethod
def _calculate_scale_and_ratio(
activations: Tensor, weights: Tensor, alpha: float, quantile: Optional[float] = 0.1
) -> Tuple[Tensor, float]:
"""
Calculates base scale value and it's ratio.
:param activations: Activation statistics value.
:param weights: Weights statistics value.
:param alpha: Base value for exponentiation.
:param quantile: Base quantile value.
:return: Calculated base scale value & ratio.
"""

eps = fns.eps(activations, TensorDataType.float32)
scales = fns.power(activations, alpha) / (fns.power(weights, 1 - alpha) + eps)

a_min = fns.quantile(scales, quantile, keepdims=False)
a_max = 1e2

scales = fns.clip(scales, min_val=a_min, max_val=a_max)
ratio = scales.min() / (scales.max() + eps)
return scales, ratio

def _group_nodes_by_source(self, nodes_to_smooth: List[Dict], nncf_graph: NNCFGraph) -> Dict[tuple, List]:
"""
Groups nodes that will be smoothed by source (parent node).
Expand Down Expand Up @@ -216,7 +239,8 @@ def _get_statistics_for_node(
self._backend_entity.get_filter_fn_for_statistics(act_port),
self._algorithm_key,
):
statistics_for_node.append(tensor_collector.get_statistics()[STATISTIC_BRANCH_KEY])
statistic = tensor_collector.get_statistics()[STATISTIC_BRANCH_KEY]
statistics_for_node.append(Tensor(statistic))
return statistics_for_node

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
Expand Down Expand Up @@ -308,9 +332,14 @@ def _calculate_activation_scale(
raise RuntimeError(f"Channel axes for nodes {[n.node_name for n in nodes]} are not identical")

activations_size = len(activations_shape)
return self._backend_entity.calculate_activation_scale(scale_value, activations_size, channel_axis)

def _calculate_weight_scale(self, scale_value: TTensor, node: NNCFNode, weights_value: TTensor) -> TTensor:
activation_scale = scale_value ** (-1)
if activations_size > 1:
reshape_shape = [1 for _ in range(activations_size)]
reshape_shape[channel_axis] = activation_scale.size
activation_scale = activation_scale.reshape(reshape_shape)
return activation_scale

def _calculate_weight_scale(self, scale_value: Tensor, node: NNCFNode, weights_value: Tensor) -> Tensor:
"""
Calculates scale for weight tensor.
Expand All @@ -321,7 +350,12 @@ def _calculate_weight_scale(self, scale_value: TTensor, node: NNCFNode, weights_
weights_size = len(weights_value.shape)
if weights_size > 1:
channel_axis = self._backend_entity.get_weight_channel_axis(node)
return self._backend_entity.calculate_weight_scale(scale_value, weights_size, channel_axis)
weight_scale = scale_value
if weights_size > 1:
reshape_shape = [1 for _ in range(weights_size)]
reshape_shape[channel_axis] = scale_value.size
weight_scale = scale_value.reshape(reshape_shape)
return weight_scale
return scale_value

def _calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode, input_port: int) -> Tuple[int]:
Expand All @@ -340,7 +374,7 @@ def _calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode,
reduction_axes = self._backend_entity.get_channel_agnostic_reduction_axes(channel_axis, shape)
return reduction_axes

def _process_weight_statistics(self, node: NNCFNode, weights: TTensor) -> TTensor:
def _process_weight_statistics(self, node: NNCFNode, weights: Tensor) -> Tensor:
"""
Returns processed weight statistics for node.
Expand All @@ -354,7 +388,7 @@ def _process_weight_statistics(self, node: NNCFNode, weights: TTensor) -> TTenso
channel_axis = self._backend_entity.get_weight_channel_axis(node)
reduction_shape = [i for i, _ in enumerate(weights.shape)]
reduction_shape.pop(channel_axis)
return self._backend_entity.process_weight_statistics(weights, tuple(reduction_shape))
return fns.max(fns.abs(weights), axis=tuple(reduction_shape))

def _create_scale_node_name(self, source_name: str, source_port_id: int) -> str:
"""
Expand Down Expand Up @@ -391,3 +425,16 @@ def _get_alpha_map(self) -> Dict[OperatorMetatype, float]:
for metatype in metatypes:
alpha_by_metatype_map[metatype] = alpha_value
return alpha_by_metatype_map

@staticmethod
def _clip_statistics(statistics: List[Tensor]) -> Tensor:
"""
Clips statistics for further calculation.
:param statistics: Input statistics.
:return: Clipped statistics.
"""
a_min = 1e-5

statistics = fns.stack(statistics)
squeezed = fns.squeeze(statistics)
return fns.clip(squeezed, min_val=a_min, max_val=None)
Loading

0 comments on commit adfe87b

Please sign in to comment.