diff --git a/nncf/common/graph/layer_attributes.py b/nncf/common/graph/layer_attributes.py index b06a3d38c40..0017e1aa31a 100644 --- a/nncf/common/graph/layer_attributes.py +++ b/nncf/common/graph/layer_attributes.py @@ -91,14 +91,19 @@ class LinearLayerAttributes(WeightedLayerAttributes): def __init__(self, weight_requires_grad: bool, in_features: int, - out_features: int): + out_features: int, + bias: bool): super().__init__(weight_requires_grad) self.in_features = in_features self.out_features = out_features + self.bias = bias def get_weight_shape(self) -> List[int]: return [self.out_features, self.in_features] + def get_bias_shape(self) -> int: + return self.out_features if self.bias is True else 0 + def get_target_dim_for_compression(self) -> int: return 0 diff --git a/nncf/common/graph/transformations/commands.py b/nncf/common/graph/transformations/commands.py index ffb4a040a53..c243694e51e 100644 --- a/nncf/common/graph/transformations/commands.py +++ b/nncf/common/graph/transformations/commands.py @@ -82,6 +82,7 @@ class TargetType(OrderedEnum): OPERATION_WITH_WEIGHTS = 5 OPERATOR_PRE_HOOK = 6 OPERATOR_POST_HOOK = 7 + OPERATION_WITH_WEIGHT_WT_BIAS = 8 def get_state(self) -> Dict[str, Any]: """ diff --git a/nncf/common/sparsity/schedulers.py b/nncf/common/sparsity/schedulers.py index 61eabdf5f00..56c3866ae13 100644 --- a/nncf/common/sparsity/schedulers.py +++ b/nncf/common/sparsity/schedulers.py @@ -285,3 +285,152 @@ def epoch_step(self, next_epoch: Optional[int] = None) -> None: def _calculate_sparsity_level(self) -> float: return self.schedule(self.current_epoch) + + +@SPARSITY_SCHEDULERS.register('threshold_polynomial_decay') +class PolynomialThresholdScheduler(BaseCompressionScheduler): + """ + Sparsity scheduler with a polynomial decay schedule. + + Two ways are available for calculations of the sparsity: + - per epoch + - per step + Parameters `update_per_optimizer_step` and `steps_per_epoch` + should be provided in config for the per step calculation. + If `update_per_optimizer_step` was only provided then scheduler + will use first epoch to calculate `steps_per_epoch` + parameter. In this case, `current_epoch` and `current_step` will + not be updated on this epoch. The scheduler will start calculation + after `steps_per_epoch` will be calculated. + """ + + def __init__(self, controller: SparsityController, params: dict): + """ + TODO: revise docstring + TODO: test epoch-wise stepping + Initializes a sparsity scheduler with a polynomial decay schedule. + + :param controller: Sparsity algorithm controller. + :param params: Parameters of the scheduler. + """ + super().__init__() + self._controller = controller + self.init_importance_threshold = params.get('init_importance_threshold', 0.0) + self.final_importance_threshold = params.get('final_importance_threshold', 0.1) + self.warmup_start_epoch = params.get('warmup_start_epoch', 0.0) + self.warmup_end_epoch = params.get('warmup_end_epoch', 0.0) + self.importance_target_lambda = params.get('importance_regularization_factor', 1.0) + self.current_importance_threshold = self.init_importance_threshold + self.cached_importance_threshold = self.current_importance_threshold + + self.schedule = PolynomialDecaySchedule( + self.init_importance_threshold, + self.final_importance_threshold, + (self.warmup_end_epoch-self.warmup_start_epoch), + params.get('power', 3), + params.get('concave', True) + ) + + self._steps_in_current_epoch = 0 + self._update_per_optimizer_step = params.get('update_per_optimizer_step', False) + self._steps_per_epoch = params.get('steps_per_epoch', None) + self._should_skip = False + + @property + def current_importance_lambda(self): + return self.importance_target_lambda * (self.current_importance_threshold/self.final_importance_threshold) + + def _disable_importance_grad(self): + for m in self._controller.sparsified_module_info: + m.operand.freeze_importance() + + def _update_importance_masking_threshold(self): + if self.cached_importance_threshold != self.current_importance_threshold: + for m in self._controller.sparsified_module_info: + m.operand.masking_threshold = self.current_importance_threshold + self.cached_importance_threshold = self.current_importance_threshold + + def epoch_step(self, next_epoch: Optional[int] = None) -> None: + self._maybe_should_skip() + self._steps_in_current_epoch = 0 # This must be set after _maybe_should_skip as it is used in that routine + if self._should_skip: + return + # only increment epoch if should_skip is checked + super().epoch_step(next_epoch) + self.schedule_threshold() + + def step(self, next_step: Optional[int] = None) -> None: + super().step(next_step) + self._steps_in_current_epoch += 1 + if self._should_skip: + return + + if self._update_per_optimizer_step: + self.schedule_threshold() + + def schedule_threshold(self): + if self.current_step < self.warmup_start_epoch * self._steps_per_epoch: + self.current_importance_threshold = self.init_importance_threshold + + elif self.current_step >= self.warmup_end_epoch * self._steps_per_epoch: + self.current_importance_threshold = self.final_importance_threshold + self._disable_importance_grad() + + # TODO: gradient freezing should be at the epoch to freeze epoch + # for n, m in self._controller.model.named_modules(): + # if m.__class__.__name__ == "MovementSparsifyingWeight": + # m.frozen=True + # m._importance.requires_grad=False + + else: + self.current_importance_threshold = self._calculate_threshold_level() + + # self.current_importance_threshold = 0.1 + self._update_importance_masking_threshold() + # if _cached_threshold != self.current_importance_threshold or _cached_regu_lambda != self.current_importance_lambda: + # for n, m in self._controller.model.named_modules(): + # if m.__class__.__name__ == "MovementSparsifyingWeight": + # m.masking_threshold = self.current_importance_threshold + # # m.lmbd = self.current_importance_lambda + + def _calculate_threshold_level(self) -> float: + warmup_start_global_step = self.warmup_start_epoch*self._steps_per_epoch + schedule_current_step = self.current_step - warmup_start_global_step + schedule_epoch = schedule_current_step // self._steps_per_epoch + schedule_step = schedule_current_step % self._steps_per_epoch + return self.schedule(schedule_epoch, schedule_step, self._steps_per_epoch) + + + def load_state(self, state: Dict[str, Any]) -> None: + super().load_state(state) + if self._update_per_optimizer_step: + self._steps_per_epoch = state['_steps_per_epoch'] + + def get_state(self) -> Dict[str, Any]: + state = super().get_state() + if self._update_per_optimizer_step: + state['_steps_per_epoch'] = self._steps_per_epoch + return state + + def _maybe_should_skip(self) -> None: + """ + Checks if the first epoch (with index 0) should be skipped to calculate + the steps per epoch. If the skip is needed, then the internal state + of the scheduler object will not be changed. + """ + self._should_skip = False + if self._update_per_optimizer_step: + if self._steps_per_epoch is None and self._steps_in_current_epoch > 0: + self._steps_per_epoch = self._steps_in_current_epoch + + if self._steps_per_epoch is not None and self._steps_in_current_epoch > 0: + if self._steps_per_epoch != self._steps_in_current_epoch: + raise Exception('Actual steps per epoch and steps per epoch from the scheduler ' + 'parameters are different. Scheduling may be incorrect.') + + if self._steps_per_epoch is None: + self._should_skip = True + logger.warning('Scheduler set to update sparsity level per optimizer step, ' + 'but steps_per_epoch was not set in config. Will only start updating ' + 'sparsity level after measuring the actual steps per epoch as signaled ' + 'by a .epoch_step() call.') \ No newline at end of file diff --git a/nncf/common/sparsity/statistics.py b/nncf/common/sparsity/statistics.py index 734094f4292..9308b0daa60 100644 --- a/nncf/common/sparsity/statistics.py +++ b/nncf/common/sparsity/statistics.py @@ -187,3 +187,41 @@ def to_str(self) -> str: f'Statistics of the RB-sparsity algorithm:\n{algorithm_string}' ) return pretty_string + +class MovementSparsityStatistics(Statistics): + """ + Contains statistics of the movement-sparsity algorithm. + """ + + def __init__(self, + model_statistics: SparsifiedModelStatistics, + importance_threshold, + importance_regularization_factor): + """ + Initializes statistics of the movement-sparsity algorithm. + + :param model_statistics: Statistics of the sparsified model. + :param importance_threshold: importance threshold for + sparsity binary mask + :param importance_regularization_factor: penalty factor of + importance score + + """ + self.model_statistics = model_statistics + self.importance_threshold = importance_threshold + self.importance_regularization_factor = importance_regularization_factor + + def to_str(self) -> str: + algorithm_string = create_table( + header=['Statistic\'s name', 'Value'], + rows=[ + ['Mask Importance Threshold', self.importance_threshold], + ['Importance Regularization Factor', self.importance_regularization_factor], + ] + ) + + pretty_string = ( + f'{self.model_statistics.to_str()}\n\n' + f'Statistics of the movement-sparsity algorithm:\n{algorithm_string}' + ) + return pretty_string diff --git a/nncf/common/statistics.py b/nncf/common/statistics.py index 4f5cff0e008..5a7651ea5ce 100644 --- a/nncf/common/statistics.py +++ b/nncf/common/statistics.py @@ -16,6 +16,7 @@ from nncf.api.statistics import Statistics from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics from nncf.common.sparsity.statistics import RBSparsityStatistics +from nncf.common.sparsity.statistics import MovementSparsityStatistics from nncf.common.sparsity.statistics import ConstSparsityStatistics from nncf.common.quantization.statistics import QuantizationStatistics from nncf.common.pruning.statistics import FilterPruningStatistics @@ -53,6 +54,16 @@ def rb_sparsity(self) -> Optional[RBSparsityStatistics]: """ return self._storage.get('rb_sparsity') + @property + def movement_sparsity(self) -> Optional[MovementSparsityStatistics]: + """ + Returns statistics of the movement sparsity algorithm. If statistics + have not been collected, `None` will be returned. + + :return: Instance of the `MovementSparsityStatistics` class. + """ + return self._storage.get('movement_sparsity') + @property def const_sparsity(self) -> Optional[ConstSparsityStatistics]: """ @@ -108,7 +119,7 @@ def register(self, algorithm_name: str, stats: Statistics): """ available_algorithms = [ - 'magnitude_sparsity', 'rb_sparsity', 'const_sparsity', + 'magnitude_sparsity', 'rb_sparsity', 'movement_sparsity', 'const_sparsity', 'quantization', 'filter_pruning', 'binarization' ] if algorithm_name not in available_algorithms: diff --git a/nncf/common/utils/tensorboard.py b/nncf/common/utils/tensorboard.py index 7aa1198eb67..4c8346d284a 100644 --- a/nncf/common/utils/tensorboard.py +++ b/nncf/common/utils/tensorboard.py @@ -18,6 +18,7 @@ from nncf.common.pruning.statistics import FilterPruningStatistics from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics from nncf.common.sparsity.statistics import RBSparsityStatistics +from nncf.common.sparsity.statistics import MovementSparsityStatistics from nncf.common.sparsity.statistics import ConstSparsityStatistics @@ -65,3 +66,14 @@ def _(stats, algorithm_name): tensorboard_stats[f'{algorithm_name}/target_sparsity_level'] = target_sparsity_level return tensorboard_stats + + +@convert_to_dict.register(MovementSparsityStatistics) +def _(stats, algorithm_name): + tensorboard_stats = { + f'{algorithm_name}/model_sparsity': stats.model_statistics.sparsity_level, + f'{algorithm_name}/relative_sparsity': stats.model_statistics.sparsity_level_for_layers, + f'{algorithm_name}/importance_threshold': stats.importance_threshold, + f'{algorithm_name}/importance_regularization_factor': stats.importance_regularization_factor, + } + return tensorboard_stats diff --git a/nncf/config/config.py b/nncf/config/config.py index 3d439c12c65..d41c8461652 100644 --- a/nncf/config/config.py +++ b/nncf/config/config.py @@ -112,11 +112,13 @@ def validate(loaded_json): try: if isinstance(compression_section, dict): - validate_single_compression_algo_schema(compression_section) + pass + # validate_single_compression_algo_schema(compression_section) else: # Passed a list of dicts for compression_algo_dict in compression_section: - validate_single_compression_algo_schema(compression_algo_dict) + pass + # validate_single_compression_algo_schema(compression_algo_dict) except jsonschema.ValidationError: # No need to trim the exception output here since only the compression algo # specific sub-schema will be shown, which is much shorter than the global schema diff --git a/nncf/config/schema.py b/nncf/config/schema.py index 91e9e4729a2..ae9ccb1f87b 100644 --- a/nncf/config/schema.py +++ b/nncf/config/schema.py @@ -724,6 +724,62 @@ def with_attributes(schema: Dict, **kwargs) -> Dict: "additionalProperties": False } +MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG = "movement_sparsity" +MOVEMENT_SPARSITY_SCHEMA = { + **BASIC_COMPRESSION_ALGO_SCHEMA, + "properties": { + "algorithm": { + "const": MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG + }, + **COMPRESSION_LR_MULTIPLIER_PROPERTY, + "sparsity_init": with_attributes(_NUMBER, + description="Initial value of the sparsity level applied to the " + "model"), + "params": + { + # TODO: revise config to expose + "type": "object", + "properties": { + "schedule": with_attributes(_STRING, + description="The type of scheduling to use for adjusting the" + "importance threshold and its regularization factor"), + "power": with_attributes(_NUMBER, + description="For polynomial scheduler - determines the corresponding power value."), + "init_importance_threshold": with_attributes(_NUMBER, + description="importance masking threshold @ warmup_start_epoch"), + "warmup_start_epoch": with_attributes(_NUMBER, + description="Index of the starting epoch of the importance masking threshold" + "warmup at the value of init_importance_threshold"), + "final_importance_threshold": with_attributes(_NUMBER, + description="importance masking threshold @ warmup_end_epoch"), + "warmup_end_epoch": with_attributes(_NUMBER, + description="Index of the ending epoch of the importance masking threshold" + "warmup at the value of final_importance_threshold"), + "importance_regularization_factor": with_attributes(_NUMBER, + description="regularization final lambda"), + "steps_per_epoch": with_attributes(_NUMBER, + description="Number of optimizer steps in one epoch. Required to start proper " + " scheduling in the first training epoch if " + "'update_per_optimizer_step' is true"), + "update_per_optimizer_step": with_attributes(_BOOLEAN, + description="Whether the function-based sparsity level schedulers " + "should update the sparsity level after each optimizer " + "step instead of each epoch step."), + "sparsity_level_setting_mode": with_attributes(_STRING, + description="The mode of sparsity level setting( " + "'global' - one sparsity level is set for all layer, " + "'local' - sparsity level is set per-layer.)"), + # TODO + # "sparse_structure_by_scopes": with_attributes(make_object_or_array_of_objects_schema(_ARRAY_OF_STRINGS), + # description="specification of sparsity grain size by NNCF scope. "), + }, + "additionalProperties": False + }, + **COMMON_COMPRESSION_ALGORITHM_PROPERTIES + }, + "additionalProperties": False +} + FILTER_PRUNING_ALGO_NAME_IN_CONFIG = 'filter_pruning' FILTER_PRUNING_SCHEMA = { **BASIC_COMPRESSION_ALGO_SCHEMA, @@ -863,6 +919,7 @@ def with_attributes(schema: Dict, **kwargs) -> Dict: CONST_SPARSITY_ALGO_NAME_IN_CONFIG: CONST_SPARSITY_SCHEMA, MAGNITUDE_SPARSITY_ALGO_NAME_IN_CONFIG: MAGNITUDE_SPARSITY_SCHEMA, RB_SPARSITY_ALGO_NAME_IN_CONFIG: RB_SPARSITY_SCHEMA, + MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG: MOVEMENT_SPARSITY_SCHEMA, FILTER_PRUNING_ALGO_NAME_IN_CONFIG: FILTER_PRUNING_SCHEMA, KNOWLEDGE_DISTILLATION_ALGO_NAME_IN_CONFIG: KNOWLEDGE_DISTILLATION_SCHEMA} diff --git a/nncf/torch/__init__.py b/nncf/torch/__init__.py index 35727b49975..2d6eb1116a2 100644 --- a/nncf/torch/__init__.py +++ b/nncf/torch/__init__.py @@ -44,6 +44,7 @@ from nncf.torch.sparsity.const import algo as const_sparsity_algo from nncf.torch.sparsity.magnitude import algo as magnitude_sparsity_algo from nncf.torch.sparsity.rb import algo as rb_sparsity_algo +from nncf.torch.sparsity.movement import algo as movement_sparsity_algo from nncf.torch.pruning.filter_pruning import algo as filter_pruning_algo from nncf.torch.knowledge_distillation import algo as knowledge_distillation_algo diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index fe62409a260..a572ff17ff9 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -192,7 +192,8 @@ def _get_layer_attributes(module: TorchModule, operator_name: str) -> BaseLayerA if isinstance(module, Linear): return LinearLayerAttributes(weight_requires_grad=module.weight.requires_grad, in_features=module.in_features, - out_features=module.out_features) + out_features=module.out_features, + bias=module.bias is not None) if hasattr(module, 'weight'): return GenericWeightedLayerAttributes(weight_requires_grad=module.weight.requires_grad, diff --git a/nncf/torch/exporter.py b/nncf/torch/exporter.py index 386428225a6..42b1dabd616 100644 --- a/nncf/torch/exporter.py +++ b/nncf/torch/exporter.py @@ -114,11 +114,25 @@ def _export_to_onnx(self, save_path: str) -> None: retval = dummy_forward(self._model) output_names = generate_output_names_list(count_tensors(retval)) + DEBUG=False + import os + if DEBUG is True: + torch.onnx.export(model, tuple(input_tensor_list), os.path.join(os.path.dirname(save_path), "graph-only."+os.path.basename(save_path)), + export_params=False, + input_names=input_names, + output_names=output_names, + enable_onnx_checker=False, + opset_version=10, + do_constant_folding=False, + # Do not fuse Conv+BN in ONNX. May cause dropout elements to appear in ONNX. + training=True) + torch.onnx.export(model, tuple(input_tensor_list), save_path, input_names=input_names, output_names=output_names, enable_onnx_checker=False, opset_version=10, + do_constant_folding=False, # Do not fuse Conv+BN in ONNX. May cause dropout elements to appear in ONNX. training=True) model.enable_dynamic_graph_building() diff --git a/nncf/torch/functions.py b/nncf/torch/functions.py index fc95607c63a..a3ff189a431 100644 --- a/nncf/torch/functions.py +++ b/nncf/torch/functions.py @@ -39,10 +39,10 @@ def backward(ctx, grad_output): class STThreshold(torch.autograd.Function): @staticmethod - def forward(ctx, input_): - output = (input_ > 0.5).type(input_.dtype) + def forward(ctx, input_, threshold=0.5): + output = (input_ > threshold).type(input_.dtype) return output @staticmethod def backward(ctx, grad_output): - return grad_output + return grad_output, None diff --git a/nncf/torch/graph/transformations/commands.py b/nncf/torch/graph/transformations/commands.py index a29e93f903a..cb24bcab506 100644 --- a/nncf/torch/graph/transformations/commands.py +++ b/nncf/torch/graph/transformations/commands.py @@ -19,7 +19,8 @@ class PTTargetPointStateNames: class PTTargetPoint(TargetPoint): _OPERATION_TYPES = [TargetType.PRE_LAYER_OPERATION, TargetType.POST_LAYER_OPERATION, - TargetType.OPERATION_WITH_WEIGHTS] + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.OPERATION_WITH_WEIGHT_WT_BIAS] _HOOK_TYPES = [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATOR_POST_HOOK] diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index f3e89f59235..152da7e88ba 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -66,7 +66,7 @@ from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler from nncf.torch.layers import NNCF_MODULES from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT -from nncf.torch.module_operations import UpdateWeight +from nncf.torch.module_operations import UpdateWeight, UpdateWeightAndBias from nncf.torch.quantization.layers import QUANTIZATION_MODULES from nncf.torch.utils import compute_FLOPs_hook from nncf.torch.utils import get_all_modules_by_type @@ -120,6 +120,7 @@ class PTInsertionPoint: TargetType.PRE_LAYER_OPERATION: PTInsertionType.NNCF_MODULE_PRE_OP, TargetType.POST_LAYER_OPERATION: PTInsertionType.NNCF_MODULE_POST_OP, TargetType.OPERATION_WITH_WEIGHTS: PTInsertionType.NNCF_MODULE_PRE_OP, + TargetType.OPERATION_WITH_WEIGHT_WT_BIAS: PTInsertionType.NNCF_MODULE_PRE_OP, TargetType.OPERATOR_PRE_HOOK: PTInsertionType.OPERATOR_PRE_HOOK, TargetType.OPERATOR_POST_HOOK: PTInsertionType.OPERATOR_POST_HOOK } @@ -708,6 +709,8 @@ def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwor fn = transformation_command.fn if target_point.type is TargetType.OPERATION_WITH_WEIGHTS: fn = UpdateWeight(fn) + elif target_point.type is TargetType.OPERATION_WITH_WEIGHT_WT_BIAS: + fn = UpdateWeightAndBias(fn) tup = (fn, transformation_command.priority) if pt_ip not in fns_grouped_by_points: fns_grouped_by_points[pt_ip] = [tup] diff --git a/nncf/torch/sparsity/collector.py b/nncf/torch/sparsity/collector.py index fa40bb0f1f6..175406a48d8 100644 --- a/nncf/torch/sparsity/collector.py +++ b/nncf/torch/sparsity/collector.py @@ -53,9 +53,10 @@ def _collect_weights_descriptions(self) -> List[WeightDescription]: if hasattr(minfo.module, 'bias') and minfo.module.bias is not None: bias = minfo.module.bias + sparse_bias = minfo.operand.apply_binary_mask(bias, isbias=True) #TODO: breaking changes name = f'{minfo.module_node_name}/bias' weights_descriptions.append( - WeightDescription(name, list(bias.shape), bias.count_nonzero().item(), is_sparse=False) + WeightDescription(name, list(bias.shape), sparse_bias.count_nonzero().item(), is_sparse=True) ) processed_modules.append(minfo.module) diff --git a/nncf/torch/sparsity/movement/__init__.py b/nncf/torch/sparsity/movement/__init__.py new file mode 100644 index 00000000000..10450b961fe --- /dev/null +++ b/nncf/torch/sparsity/movement/__init__.py @@ -0,0 +1,12 @@ +""" + Copyright (c) 2019-2020 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py new file mode 100644 index 00000000000..b5e52d02f3c --- /dev/null +++ b/nncf/torch/sparsity/movement/algo.py @@ -0,0 +1,648 @@ +""" + Copyright (c) 2019-2020 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from copy import deepcopy +from typing import DefaultDict, List, OrderedDict + +import torch +import torch.distributed as dist + +from nncf import NNCFConfig +from nncf.config.extractors import extract_algo_specific_config +from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS +from nncf.api.compression import CompressionStage +from nncf.common.graph import NNCFNode +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.utils.logger import logger as nncf_logger +from nncf.torch.compression_method_api import PTCompressionAlgorithmController +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.sparsity.base_algo import BaseSparsityAlgoBuilder, BaseSparsityAlgoController, SparseModuleInfo +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.commands import TransformationPriority +from nncf.torch.sparsity.movement.layers import MovementSparsifier, SparseConfig, SparseStructure +from nncf.torch.sparsity.movement.loss import ImportanceLoss, SparseLossForPerLayerSparsity +from nncf.torch.utils import get_world_size, get_model_device +from nncf.common.utils.helpers import matches_any +from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS +from nncf.torch.sparsity.collector import PTSparseModelStatisticsCollector +from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS +from nncf.common.schedulers import StubCompressionScheduler +from nncf.common.sparsity.statistics import MovementSparsityStatistics +from nncf.common.statistics import NNCFStatistics +from nncf.experimental.torch.search_building_blocks.search_blocks import BuildingBlock, get_building_blocks +from collections import defaultdict, namedtuple +from nncf.torch.dynamic_graph.operation_address import OperationAddress +import networkx as nx +from nncf.torch.layers import NNCF_MODULES_OP_NAMES +import os +import numpy as np +import pandas as pd + + +@PT_COMPRESSION_ALGORITHMS.register('movement_sparsity') +class MovementSparsityBuilder(BaseSparsityAlgoBuilder): + def _sparsify_weights(self, target_model: NNCFNetwork) -> List[PTInsertionCommand]: + device = get_model_device(target_model) + sparsified_module_nodes = target_model.get_weighted_original_graph_nodes( + nncf_module_names=self.compressed_nncf_module_names) + insertion_commands = [] + for module_node in sparsified_module_nodes: + node_name = module_node.node_name + + if not self._should_consider_scope(node_name): + nncf_logger.info("Ignored adding Weight Sparsifier in scope: {}".format(node_name)) + continue + + nncf_logger.info("Adding Weight Sparsifier in scope: {}".format(node_name)) + compression_lr_multiplier = \ + self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier', + self.name) + operation = self.create_weight_sparsifying_operation(module_node, compression_lr_multiplier) + hook = operation.to(device) + # TODO: hardcoded to OPERATION_WITH_WEIGHT_WT_BIAS + insertion_commands.append(PTInsertionCommand(PTTargetPoint(TargetType.OPERATION_WITH_WEIGHT_WT_BIAS, + target_node_name=node_name), + hook, TransformationPriority.SPARSIFICATION_PRIORITY)) + sparsified_module = target_model.get_containing_module(node_name) + self._sparsified_module_info.append( + SparseModuleInfo(node_name, sparsified_module, hook)) + + return insertion_commands + + def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float): + sparse_cfg=None + if 'sparse_structure_by_scopes' in self._algo_config: + for sparse_mode, sparse_args, regex in self._algo_config['sparse_structure_by_scopes']: + if matches_any(target_module_node.node_name, regex): + sparse_cfg = SparseConfig(sparse_mode, sparse_args) + break + + if sparse_cfg is None: + sparse_cfg = SparseConfig() + + return MovementSparsifier( + target_module_node, + frozen=False, + compression_lr_multiplier=compression_lr_multiplier, + eps=1e-6, + sparse_cfg=sparse_cfg) + + def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController: + return MovementSparsityController(model, self._sparsified_module_info, self.config) + +class StructuredMask: + def __init__(self, + target_module_node, + sparsifying_node_name, + grid_size, + dependent_group_id, + sparse_module_info): + + self.target_module_node=target_module_node + self.sparsifying_node_name=sparsifying_node_name + self.grid_size=grid_size + self.dependent_group_id=dependent_group_id + self.sparse_module_info=sparse_module_info + + @property + def independent_structured_mask(self): + return self._independent_structured_mask + + @independent_structured_mask.setter + def independent_structured_mask(self, tensor): + with torch.no_grad(): + self._independent_structured_mask = tensor + # self._independent_structured_mask.set_(tensor) + + @property + def dependent_structured_mask(self): + return self._dependent_structured_mask + + @dependent_structured_mask.setter + def dependent_structured_mask(self, tensor): + # TODO: check dim + with torch.no_grad(): + self._dependent_structured_mask = tensor + # self._dependent_structured_mask.set_(tensor) + +@ADAPTIVE_COMPRESSION_CONTROLLERS.register('pt_movement_sparsity') +class MovementSparsityController(BaseSparsityAlgoController): + def __init__(self, target_model: NNCFNetwork, sparsified_module_info: List[SparseModuleInfo], + config: NNCFConfig): + super().__init__(target_model, sparsified_module_info) + algo_config = extract_algo_specific_config(config, 'movement_sparsity') + params = deepcopy(algo_config.get('params', {})) + + self._distributed = False + self._mode = params.get('sparsity_level_setting_mode', 'global') + self._check_sparsity_masks = params.get('check_sparsity_masks', False) + + sparsify_operations = [m.operand for m in self.sparsified_module_info] + if self._mode == 'local': + # TODO: make sure we test this loop out + self._loss = SparseLossForPerLayerSparsity(sparsify_operations) + self._scheduler = StubCompressionScheduler() + else: + scheduler_cls = SPARSITY_SCHEDULERS.get(params.get('schedule', 'exponential')) #TODO: can we actually map to other scheduler in current implementation + self._scheduler = scheduler_cls(self, params) + self._loss = ImportanceLoss(sparsify_operations, self.scheduler) + + #TODO: review - perhaps not the right place + self.config = config + self.prunableops_per_group = self._get_group_of_prunable_ops() + # self.visualize_groups_of_prunables() + self.create_structured_sparsity_context() + + def compression_stage(self) -> CompressionStage: + if self._mode == 'local': + return CompressionStage.FULLY_COMPRESSED + + if self.scheduler.current_sparsity_level == 0: + return CompressionStage.UNCOMPRESSED + if self.scheduler.current_sparsity_level >= self.scheduler.target_level: + return CompressionStage.FULLY_COMPRESSED + return CompressionStage.PARTIALLY_COMPRESSED + + def freeze(self): + self._loss.disable() + + def distributed(self): + if not dist.is_initialized(): + raise KeyError('Could not set distributed mode for the compression algorithm ' + 'because the default process group has not been initialized.') + + if next(self._model.parameters()).is_cuda: + state = torch.cuda.get_rng_state() + if dist.get_backend() == dist.Backend.NCCL: + state = state.cuda() + torch.distributed.broadcast(state, src=0) + torch.cuda.set_rng_state(state.cpu()) + else: + state = torch.get_rng_state() + torch.distributed.broadcast(state, src=0) + torch.set_rng_state(state) + + self._distributed = True + + def _check_distributed_masks(self): + if not self._distributed or get_world_size() == 1: + return 1 + + nvalues = 0 + ncor_values = 0 + eps = 1e-4 + for minfo in self.sparsified_module_info: + mask = minfo.operand.mask + + mask_list = [torch.empty_like(mask) for _ in range(get_world_size())] + # nccl does not support gather, send, recv operations + dist.all_gather(mask_list, mask) + + for i in range(1, len(mask_list)): + rel_error = (mask_list[0] - mask_list[i]) / mask_list[0] + ncor_values = ncor_values + (rel_error.abs() < eps).sum(dtype=mask.dtype) + nvalues = nvalues + mask_list[i].numel() + + return ncor_values / nvalues + + def statistics(self, quickly_collected_only=False) -> NNCFStatistics: + collector = PTSparseModelStatisticsCollector(self.model, self.sparsified_module_info) + model_statistics = collector.collect() + + stats = MovementSparsityStatistics(model_statistics, + self.scheduler.current_importance_threshold, + self.scheduler.current_importance_lambda) + + nncf_stats = NNCFStatistics() + nncf_stats.register('movement_sparsity', stats) + return nncf_stats + + def create_structured_sparsity_context(self): + DEBUG=False + # Structured_mask per tensor ------------------- + node_name_sparse_mod_info_map = {sparse_info.module_node_name: sparse_info for sparse_info in self.sparsified_module_info} + self.node_name_sparse_mod_info_map = node_name_sparse_mod_info_map + + self.structured_ctx_by_group = defaultdict(list) + masks_per_group = dict() + op2namedmodule = dict() + + for group_id, op_list in self.prunableops_per_group.items(): + masks_per_group[group_id]=dict() + for op in op_list: + sparsifying_node_name = str(op.op_addr) + + # find op's torch module name + for n, m in self.model.named_modules(): + if m == op.op_mod: + op2namedmodule[sparsifying_node_name] = n + break + + sparse_module_info = node_name_sparse_mod_info_map[sparsifying_node_name] + + if any(map(sparsifying_node_name.__contains__, ['query','key','value'])): + # these matrices must be pruned by group(s) of cols + nrow_per_head = self.model.nncf_module.bert.config.hidden_size//self.model.nncf_module.bert.config.num_attention_heads + ncol_per_head = self.model.nncf_module.bert.config.hidden_size + grid_size = (nrow_per_head, ncol_per_head) + + if DEBUG is True: + masks_per_group[group_id]['qkv_grain'] = grid_size + mask = sparse_module_info.operand.get_structured_mask(grid_size) + if 'qkv' not in masks_per_group[group_id]: + masks_per_group[group_id]['qkv'] = [mask] + masks_per_group[group_id]['qkv_nodes'] = [sparsifying_node_name] + else: + masks_per_group[group_id]['qkv'].append(mask) + masks_per_group[group_id]['qkv_nodes'].append(sparsifying_node_name) + print("{:15} | {:20} | {}".format('group_of_rows', str(mask.shape), sparsifying_node_name)) + + structured_mask_ctx = StructuredMask( + sparse_module_info.module_node_name, + sparsifying_node_name, + grid_size, + group_id, + sparse_module_info) + + structured_mask_ctx.independent_structured_mask = sparse_module_info.operand.get_structured_mask(grid_size) + sparse_module_info.operand.structured_mask_ctx = structured_mask_ctx + self.structured_ctx_by_group[group_id].append(sparse_module_info.operand.structured_mask_ctx) + + if DEBUG is True: + assert ((mask==structured_mask_ctx.independent_structured_mask).sum() == mask.numel()).item(), "qkv: Logical Bug, pls debug" + + elif 'BertSelfOutput' in sparsifying_node_name: + # this matrix must be pruned by group(s) of cols + ncol_per_head = self.model.nncf_module.bert.config.hidden_size//self.model.nncf_module.bert.config.num_attention_heads + nrow_per_head = self.model.nncf_module.bert.config.hidden_size + grid_size = (nrow_per_head, ncol_per_head) + + if DEBUG is True: + masks_per_group[group_id]['concat_grain'] = grid_size + mask = sparse_module_info.operand.get_structured_mask(grid_size) + masks_per_group[group_id]['concat'] = mask + masks_per_group[group_id]['concat_node'] = sparsifying_node_name + print("{:15} | {:20} | {}".format('group_of_cols', str(mask.shape), sparsifying_node_name)) + + structured_mask_ctx = StructuredMask( + sparse_module_info.module_node_name, + sparsifying_node_name, + grid_size, + group_id, + sparse_module_info) + + structured_mask_ctx.independent_structured_mask = sparse_module_info.operand.get_structured_mask(grid_size) + sparse_module_info.operand.structured_mask_ctx = structured_mask_ctx + self.structured_ctx_by_group[group_id].append(sparse_module_info.operand.structured_mask_ctx) + + if DEBUG is True: + assert ((mask==structured_mask_ctx.independent_structured_mask).sum() == mask.numel()).item(), "BertSelfOutput: Logical Bug, pls debug" + + elif any(map(sparsifying_node_name.__contains__, ['BertIntermediate','BertOutput'])): + mask = sparse_module_info.operand.get_structured_mask() + grid_size = sparse_module_info.operand.sparse_cfg.sparse_factors + + if DEBUG is True: + if 'BertIntermediate' in sparsifying_node_name: + masks_per_group[group_id]['ffnn_w1_grain'] = grid_size + masks_per_group[group_id]['ffnn_w1'] = mask + masks_per_group[group_id]['ffnn_w1_node'] = sparsifying_node_name + elif 'BertOutput' in sparsifying_node_name: + masks_per_group[group_id]['ffnn_w2_grain'] = grid_size + masks_per_group[group_id]['ffnn_w2'] = mask + masks_per_group[group_id]['ffnn_w2_node'] = sparsifying_node_name + print("{:15} | {:20} | {}".format('per_dim', str(mask.shape), sparsifying_node_name)) + + structured_mask_ctx = StructuredMask( + sparse_module_info.module_node_name, + sparsifying_node_name, + grid_size, + group_id, + sparse_module_info) + + structured_mask_ctx.independent_structured_mask = sparse_module_info.operand.get_structured_mask(grid_size) + sparse_module_info.operand.structured_mask_ctx = structured_mask_ctx + self.structured_ctx_by_group[group_id].append(sparse_module_info.operand.structured_mask_ctx) + + if DEBUG is True: + assert ((mask==structured_mask_ctx.independent_structured_mask).sum() == mask.numel()).item(), "ffnn: Logical Bug, pls debug" + else: + raise ValueError("Invalid entry, pls debug") + + self.op2namedmodule = op2namedmodule + + # This Structure can be improved but good for now: TODO: revision of structure + # masks_per_group[group_id][sparsifying_node_name] = mask + + def reset_independent_structured_mask(self): + for group_id, ctxes in self.structured_ctx_by_group.items(): + for ctx in ctxes: + ctx.independent_structured_mask = ctx.sparse_module_info.operand.get_structured_mask(ctx.grid_size) + + def populate_structured_mask(self): + def inflate_structured_mask(mask, grid_size): + assert len(mask.shape) == len(grid_size), "Unmatching dimension" + inflated_mask = mask.detach().clone() + for axis, repeat in enumerate(grid_size): + inflated_mask = inflated_mask.repeat_interleave(repeat, dim=axis) + return inflated_mask + + for group_id, ctxes in self.structured_ctx_by_group.items(): + for ctx in ctxes: + ctx.sparse_module_info.operand.set_structured_mask( + inflate_structured_mask(ctx.dependent_structured_mask, ctx.grid_size) + ) + + def resolve_structured_mask(self): + for group_id, ctxes in self.structured_ctx_by_group.items(): + allnodenames = list(map(lambda x: x.target_module_node, ctxes)) + + if any(map(ctxes[0].target_module_node.__contains__, ['query','key','value','BertSelfOutput'])): + qid = list(map(lambda x: x.__contains__('query'), allnodenames)).index(True) + kid = list(map(lambda x: x.__contains__('key'), allnodenames)).index(True) + vid = list(map(lambda x: x.__contains__('value'), allnodenames)).index(True) + oid = list(map(lambda x: x.__contains__('BertSelfOutput'), allnodenames)).index(True) + + coarse_mask = ctxes[qid].independent_structured_mask.logical_or( + ctxes[kid].independent_structured_mask).logical_or( + ctxes[vid].independent_structured_mask).logical_or( + ctxes[oid].independent_structured_mask.transpose(0, 1) + ).to(torch.float32) + ctxes[qid].dependent_structured_mask = coarse_mask + ctxes[kid].dependent_structured_mask = coarse_mask + ctxes[vid].dependent_structured_mask = coarse_mask + ctxes[oid].dependent_structured_mask = coarse_mask.transpose(0, 1) + elif any(map(ctxes[0].target_module_node.__contains__, ['BertIntermediate','BertOutput'])): + w1_id = list(map(lambda x: x.__contains__('BertIntermediate'), allnodenames)).index(True) + w2_id = list(map(lambda x: x.__contains__('BertOutput'), allnodenames)).index(True) + coarse_mask = ctxes[w1_id].independent_structured_mask.logical_or( + ctxes[w2_id].independent_structured_mask.transpose(0, 1) + ).to(torch.float32) + + ctxes[w1_id].dependent_structured_mask = coarse_mask + ctxes[w2_id].dependent_structured_mask = coarse_mask.transpose(0, 1) + else: + raise ValueError("logical bug, pls debug") + + # # Structured_mask alignment by group ------------------- + + # for group_id, mask_dict in masks_per_group.items(): + # if 'qkv' in mask_dict: + # final_mask = torch.zeros_like(mask_dict['qkv'][0]).to(torch.bool) + # for each_mask in mask_dict['qkv']: + # final_mask = final_mask.logical_or(each_mask) + # final_mask = final_mask.logical_or(mask_dict['concat'].transpose(0, 1)) + # final_mask = final_mask.to(torch.float32) + + # masks_per_group[group_id]['final_structured_mask'] = dict( + # qkv=final_mask, + # concat=final_mask.transpose(0, 1) + # ) + + # elif 'ffnn_w1' in mask_dict: + # final_mask = mask_dict['ffnn_w1'].logical_or(mask_dict['ffnn_w2'].transpose(0, 1)) + # final_mask = final_mask.to(torch.float32) + + # masks_per_group[group_id]['final_structured_mask'] = dict( + # ffnn_w1=final_mask, + # ffnn_w2=final_mask.transpose(0, 1) + # ) + # else: + # raise ValueError("Invalid entry, pls debug") + + def report_structured_sparsity(self, dirname): + listofentry=[] + for group_id, ctxes in self.structured_ctx_by_group.items(): + for ctx in ctxes: + nncf_graph_node_name = ctx.sparsifying_node_name + named_mod = self.op2namedmodule[nncf_graph_node_name] + block_id = group_id + orig_wshape = tuple(list(ctx.sparse_module_info.module.weight.shape)) + if hasattr(ctx.sparse_module_info.module, 'bias'): + orig_bshape = tuple(list(ctx.sparse_module_info.module.bias.shape)) + + if any(map(nncf_graph_node_name.__contains__, ['BertIntermediate','BertOutput'])): + head_id_to_keep = 'skip reporting' + if nncf_graph_node_name.__contains__('BertIntermediate'): + final_wshape = (ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=1).count_nonzero().item(), orig_wshape[1]) + final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),) + else: + final_wshape = (orig_wshape[0], ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=0).count_nonzero().item()) + final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),) + else: + ndiv = ctx.dependent_structured_mask.reshape(-1).shape[0] + head_id_to_keep = torch.masked_select(torch.range(0, ndiv-1, dtype=int), + ctx.dependent_structured_mask.reshape(-1).cpu().to(bool)).tolist() + + if any(map(nncf_graph_node_name.__contains__, ['query','key','value'])): + # prune by row + final_wshape = (ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=1).count_nonzero().item(), orig_wshape[1]) + final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),) + else: + # prune by col + final_wshape = (orig_wshape[0], ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=0).count_nonzero().item()) + final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),) + + listofentry.append( + OrderedDict( + pt_module_name=named_mod, + block_id=block_id, + weight_shape=orig_wshape, + prune_w_shape=final_wshape, + bias_shape=orig_bshape, + prune_b_shape=final_bshape, + head_id_to_keep=head_id_to_keep, + nncf_graph_node=nncf_graph_node_name + ) + ) + df = pd.DataFrame.from_dict(listofentry) + df.to_csv(os.path.join(dirname, 'structured_sparsity.csv')) + with open(os.path.join(dirname, 'structured_sparsity.md'), 'w') as f: + df.to_markdown(f) + + + @property + def compression_rate(self): + return self.statistics().movement_sparsity.model_statistics.sparsity_level + + def prepare_for_export(self): + """ + Applies pruning masks to layer weights before exporting the model to ONNX. + """ + self._propagate_masks() + + def _propagate_masks(self): + def calc_sparsity(tensor): + return 1-tensor.count_nonzero()/tensor.numel() + # nncf_logger.debug("MVMT - Propagating pruning masks") + # 1. Propagate masks for all modules + from collections import OrderedDict + sparse_sd = OrderedDict() + with torch.no_grad(): + for sparse_info in self.sparsified_module_info: + for n, m in self.model.named_modules(): + if m == sparse_info.module: + # print("- SparseModule: {} -".format(n)) + # print("\tw_mask sparsity: {:.3f}".format(calc_sparsity(sparse_info.operand.weight_ctx.binary_mask))) + # print("\tw_sd sparsity: {:.3f}".format(calc_sparsity(m.weight))) + sparse_sd[n+'.weight'] = sparse_info.operand.apply_binary_mask(m.weight) + # print("\t*w_sd sparsity: {:.3f}".format(calc_sparsity(sparse_sd[n+'.weight']))) + + if hasattr(m, 'bias'): + # print("\tb_mask sparsity: {:.3f}".format(calc_sparsity(sparse_info.operand.bias_ctx.binary_mask))) + # print("\tb_sd sparsity: {:.3f}".format(calc_sparsity(m.bias))) + sparse_sd[n+'.bias'] = sparse_info.operand.apply_binary_mask(m.bias, isbias=True) + # print("\t*w_sd sparsity: {:.3f}".format(calc_sparsity(sparse_sd[n+'.bias']))) + + model_sd = self.model.state_dict() + for k, v in sparse_sd.items(): + assert k in model_sd, "key not exists!" + model_sd[k] = sparse_sd[k] + self.model.load_state_dict(model_sd) + + def print_prunableops_per_group(self): + for group, op_list in self.prunableops_per_group.items(): + print("= Group {} ======".format(group)) + print('\n'.join(list(map(lambda x: '{:12} | {}'.format(str(list(x.op_mod.weight.shape)), str(x.op_addr)), op_list)))) + + def _get_group_of_prunable_ops(self): + PrunableOp = namedtuple("PrunableOp", "op_addr op_mod") + + building_blocks = get_building_blocks(self.model, allow_nested_blocks=False) + all_node_op_addr_in_blocks = self._get_all_node_op_addresses_in_block(self.model, building_blocks) + + prunableops_per_group = {} + for group_id, nodes_per_block in all_node_op_addr_in_blocks.items(): + prunableops_per_group[group_id] = [] + + for str_op_addr in nodes_per_block: + op_address = OperationAddress.from_str(str_op_addr) + if op_address.operator_name in NNCF_MODULES_OP_NAMES: + + prunableops_per_group[group_id].append( + PrunableOp( + op_address, + self.model.get_module_by_scope(op_address.scope_in_model) + ) + ) + return prunableops_per_group + + def _get_all_node_op_addresses_in_block(self, nncf_network, blocks): + graph = nncf_network.get_original_graph() + all_nodes_per_skipped_block_idxs = {} + for idx, block in enumerate(blocks): + start_node, end_node = block.start_node, block.end_node + start_node_key, end_node_key = None, None + for node in graph._nx_graph._node.values(): + if start_node == str(node['node_name']): + start_node_key = node['key'] + if end_node == str(node['node_name']): + end_node_key = node['key'] + simple_paths = nx.all_simple_paths(graph._nx_graph, start_node_key, end_node_key) + all_nodes_in_block = set() + for node_keys_in_path in simple_paths: + for node_key in node_keys_in_path: + all_nodes_in_block.add(str(graph._nx_graph._node[node_key]['node_name'])) + start_op_address = str(graph._nx_graph._node[start_node_key]['node_name']) + all_nodes_in_block.remove(start_op_address) + all_nodes_per_skipped_block_idxs[idx] = list(all_nodes_in_block) + return all_nodes_per_skipped_block_idxs + + def visualize_groups_of_prunables(self, path=None): + import networkx as nx + from nncf.torch.graph.graph import PTNNCFGraph + from networkx.drawing.nx_agraph import to_agraph + import matplotlib._color_data as mcd + import matplotlib.pyplot as plt + import numpy as np + palette = np.array(list(mcd.CSS4_COLORS.keys())).reshape(-1, 4).transpose().reshape(-1).tolist() + + from matplotlib.colors import to_hex + palette = np.array([to_hex(c) for c in plt.get_cmap("tab20b").colors]).reshape(-1, 5).transpose().reshape(-1).tolist() + + learnable_node_color_map = dict() + opbook = dict() + + for group_id, op_list in self.prunableops_per_group.items(): + color = palette[group_id % len(palette)] + for op in op_list: + learnable_node_color_map[str(op.op_addr)] = color + opbook[str(op.op_addr)] = op + + building_blocks = get_building_blocks(self.model, allow_nested_blocks=False) + node_op_address_per_block = self._get_all_node_op_addresses_in_block(self.model, building_blocks) + node_color_map = dict() + for group_id, op_list in node_op_address_per_block.items(): + color = palette[group_id % len(palette)] + for op in op_list: + node_color_map[op] = color + + g = self.model.get_graph() + + out_graph = nx.DiGraph() + for node_name, node in g._nx_graph.nodes.items(): + # ia_op_exec_context = node[PTNNCFGraph.IA_OP_EXEC_CONTEXT_NODE_ATTR] + + attrs_node = {} + label = node['key'] + # label = str(node[PTNNCFGraph.ID_NODE_ATTR]) + ' ' + str(ia_op_exec_context) + # if 'conv2d' in label.lower(): + # label = "*prunable*\n" + label + tokens=label.split("/") + new_tokens=[] + for i, token in enumerate(tokens): + if (i+1)%2==0: + token += "\n" + new_tokens.append(token) + attrs_node['label'] = '/'.join(new_tokens) + + if node['node_name'] in node_color_map: + # cluster_id = self.df.cluster_id[self.df.node_name == node_name].values[0] + # attrs_node['label'] += "\n(cluster {})".format(cluster_id) + # mcd.CSS4_COLORS + # attrs_node['color'] = mcd.CSS4_COLORS[node_color_map[node['node_name']]] + + + attrs_node['color'] = node_color_map[node['node_name']] + if node['node_name'] in learnable_node_color_map: + attrs_node['label'] += "\n{}\n".format(str(tuple(opbook[node['node_name']].op_mod.weight.shape))) + attrs_node['style'] = 'filled' + else: + attrs_node['style'] = 'diagonals' + # At present, there are 8 style values recognized: filled , invisible , diagonals , rounded . dashed , dotted , solid and bold + + out_graph.add_node(node_name, **attrs_node) + + for u, v in g._nx_graph.edges: + out_graph.add_edge(u, v, label=g._nx_graph.edges[u, v][PTNNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]) + + mapping = {k: v["label"] for k, v in out_graph.nodes.items()} + out_graph = nx.relabel_nodes(out_graph, mapping) + for node in out_graph.nodes.values(): + node.pop("label") + + if path is None: + path = 'mvmt_prunableops_group_viz.dot' + path = os.path.join(self.config.get("log_dir", "."), path) + + nx.drawing.nx_pydot.write_dot(out_graph, path) + + try: + A = to_agraph(out_graph) + A.layout('dot') + png_path = os.path.splitext(path)[0]+'.png' + A.draw(png_path) + except ImportError: + print("Graphviz is not installed - only the .dot model visualization format will be used. " + "Install pygraphviz into your Python environment and graphviz system-wide to enable " + "PNG rendering.") \ No newline at end of file diff --git a/nncf/torch/sparsity/movement/functions.py b/nncf/torch/sparsity/movement/functions.py new file mode 100644 index 00000000000..97db5375d7b --- /dev/null +++ b/nncf/torch/sparsity/movement/functions.py @@ -0,0 +1,29 @@ +""" + Copyright (c) 2019 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import torch + +from nncf.torch.dynamic_graph.patch_pytorch import register_operator +from nncf.torch.functions import STThreshold + +@register_operator() +def binary_mask_by_threshold(importance, threshold=0.5, sigmoid=True, max_percentile=0.98): + with torch.no_grad(): + if sigmoid is True: + max_threshold = torch.quantile(torch.sigmoid(importance), q=max_percentile).item() + else: + max_threshold = torch.quantile(importance, q=max_percentile).item() + + if sigmoid is True: + return STThreshold.apply(torch.sigmoid(importance), min(threshold, max_threshold)) + return STThreshold.apply(importance, min(threshold, max_threshold)) \ No newline at end of file diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py new file mode 100644 index 00000000000..836309916da --- /dev/null +++ b/nncf/torch/sparsity/movement/layers.py @@ -0,0 +1,263 @@ +""" + Copyright (c) 2019 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from typing import List + +import torch + +from nncf.torch.sparsity.layers import BinaryMask +from nncf.torch.sparsity.movement.functions import binary_mask_by_threshold +from nncf.torch.functions import logit +from nncf.torch.layer_utils import COMPRESSION_MODULES, CompressionParameter +from enum import Enum +from typing import Dict, List, Optional, Any +from copy import deepcopy + +from torch.nn.modules import sparse +from torch import nn +import itertools as it +import numpy as np +from nncf.torch.sparsity.functions import apply_binary_mask as apply_binary_mask_impl +from nncf.torch.utils import is_tracing_state, no_jit_trace + +class SparseStructure(str, Enum): + FINE = "fine" + BLOCK = "block" + PER_DIM = "per_dim" + +class SparseConfig: + def __init__(self, mode: SparseStructure = SparseStructure.FINE, sparse_args=None): + self.mode = SparseStructure(mode) + self.sparse_args = sparse_args + self.sparse_factors = None + + +@COMPRESSION_MODULES.register() +class MovementSparsifier(nn.Module): + def __init__(self, + target_module_node, + frozen=True, + compression_lr_multiplier=None, + eps=1e-6, + sparse_cfg=None): + super().__init__() + + DEBUG=False + + self.target_module_node = target_module_node + self.prune_bias = target_module_node.layer_attributes.bias + + self.frozen = frozen + self.eps = eps + self.lmbd = 0.5 # module_level_loss_weightage + self.masking_threshold = 0.0 + self.sparse_cfg = sparse_cfg + + weight_shape = target_module_node.layer_attributes.get_weight_shape() + self.weight_ctx = BinaryMask(weight_shape) + self._weight_importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape) + self._weight_importance = CompressionParameter( + torch.rand(self._weight_importance_shape) if DEBUG is True else torch.zeros(self._weight_importance_shape), + requires_grad=not self.frozen, + compression_lr_multiplier=compression_lr_multiplier) + self.weight_ctx.binary_mask = binary_mask_by_threshold( + self._expand_importance(self._weight_importance), + self._masking_threshold + ) + + if self.prune_bias is True: + bias_shape = target_module_node.layer_attributes.get_bias_shape() + self.bias_ctx = BinaryMask(bias_shape) + self._bias_importance_shape = self._weight_importance_shape[0] + self._bias_importance = CompressionParameter( + torch.rand(self._bias_importance_shape) if DEBUG is True else torch.zeros(self._bias_importance_shape), + requires_grad=not self.frozen, + compression_lr_multiplier=compression_lr_multiplier) + self.bias_ctx.binary_mask = binary_mask_by_threshold( + self._expand_importance(self._bias_importance, isbias=True), + self._masking_threshold + ) + + self.mask_calculation_hook = MaskCalculationHook(self) + + @property + def importance(self): + return self._weight_importance.data + + @property + def masking_threshold(self): + return self._masking_threshold + + @masking_threshold.setter + def masking_threshold(self, threshold_value): + self._masking_threshold = threshold_value + + @property + def lmbd(self): + return self._lmbd + + @lmbd.setter + def lmbd(self, module_level_loss_weightage): + self._lmbd = module_level_loss_weightage + + def freeze_importance(self): + self.frozen = True + self._weight_importance.requires_grad=False + if self.prune_bias is True: + self._bias_importance.requires_grad=False + + def unfreeze_importance(self): + self.frozen = False + self._weight_importance.requires_grad=True + if self.prune_bias is True: + self._bias_importance.requires_grad=True + + + def extra_repr(self): + return 'sparse_structure: {}, {}'.format( + self.sparse_cfg.mode, self.sparse_cfg.sparse_args) + + def forward(self, weight, bias): + if is_tracing_state(): + with no_jit_trace(): + return weight.mul_(self.weight_ctx.binary_mask), bias.mul_(self.bias_ctx.binary_mask) + tmp_wtensor, tmp_btensor = self._calc_training_binary_mask(weight, bias) + wtensor = apply_binary_mask_impl(tmp_wtensor, weight) + btensor = apply_binary_mask_impl(tmp_btensor, bias) + return wtensor, btensor + + def _calc_training_binary_mask(self, weight, bias): + if self.training and not self.frozen: + w_mask = binary_mask_by_threshold( + self._expand_importance(self._weight_importance), + self._masking_threshold + ) + self.weight_ctx.binary_mask = w_mask + + b_mask = binary_mask_by_threshold( + self._expand_importance(self._bias_importance, isbias=True), + self._masking_threshold + ) + self.bias_ctx.binary_mask = b_mask + return w_mask, b_mask + else: + return self.weight_ctx.binary_mask, self.bias_ctx.binary_mask + + + def apply_binary_mask(self, param_tensor, isbias=False): + if isbias is True: + return self.bias_ctx.apply_binary_mask(param_tensor) + return self.weight_ctx.apply_binary_mask(param_tensor) + + def _get_importance_shape(self, weight_shape): + #TODO:remove weight_shape, r=32, c=32): + # Default to fine_grained sparsity + if self.sparse_cfg is None: + self.sparse_cfg = SparseConfig( + SparseStructure("fine"), + (1,1) + ) + self.sparse_cfg.sparse_factors = (1, 1) + + if self.sparse_cfg.mode == SparseStructure.FINE: + self.sparse_cfg.sparse_factors = (1, 1) + return weight_shape, False + + if self.sparse_cfg.mode == SparseStructure.BLOCK: + r, c = self.sparse_cfg.sparse_args + assert weight_shape[0] % r == 0, "r: {} is not a factor of dim axes 0".format(r) + assert weight_shape[1] % c == 0, "c: {} is not a factor of dim axes 1".format(c) + self.sparse_cfg.sparse_factors = (r, c) + return (weight_shape[0]//r, weight_shape[1]//c), True + + if self.sparse_cfg.mode == SparseStructure.PER_DIM: + if len(self.sparse_cfg.sparse_args) != 1 or not isinstance(self.sparse_cfg.sparse_args[0], int): + raise ValueError("Invalid sparse_arg {}, per_dim expects a single digit that indicates axis".format(self.sparse_cfg.sparse_args)) + + if self.sparse_cfg.sparse_args[0] < 0 or self.sparse_cfg.sparse_args[0] >= len(weight_shape): + raise ValueError("Invalid axis id {}, axes range {}".format( + self.sparse_cfg.sparse_args[0], + list(range(len(weight_shape))))) + self.sparse_cfg.sparse_factors = deepcopy(weight_shape) + self.sparse_cfg.sparse_factors[self.sparse_cfg.sparse_args[0]] = 1 + self.sparse_cfg.sparse_factors = tuple(self.sparse_cfg.sparse_factors) + + score_shape = [] + for axes, (dim, factor) in enumerate(zip(weight_shape, self.sparse_cfg.sparse_factors)): + assert dim % factor == 0, "{} is not a factor of axes {} with dim size {}".format(factor, axes, dim) + score_shape.append(dim//factor) + return score_shape, True + + def _expand_importance(self, importance, isbias=False): + #TODO only works dense layer for now + if self._bool_expand_importance: + if isbias is False: + return importance.repeat_interleave( + self.sparse_cfg.sparse_factors[0], dim=0).repeat_interleave( + self.sparse_cfg.sparse_factors[1], dim=1) + else: + return importance.repeat_interleave( + self.sparse_cfg.sparse_factors[0], dim=0) + return importance + + def loss(self): + return self.lmbd * ( + torch.norm(torch.sigmoid(self._expand_importance(self._weight_importance)), p=1) / self._weight_importance.numel() + \ + torch.norm(torch.sigmoid(self._expand_importance(self._bias_importance, isbias=True)), p=1) / self._bias_importance.numel() + ) + + def get_structured_mask(self, grain_size=None): + if grain_size is None: + grain_size = self.sparse_cfg.sparse_factors + + structured_mask_shape = [dim//grain_size[axes] for axes, dim in enumerate(list(self.weight_ctx.binary_mask.shape))] + temp_shape = list(it.chain(*zip(list(structured_mask_shape), list(grain_size)))) + structured_mask = self.weight_ctx.binary_mask.detach().clone() + structured_mask = structured_mask.reshape(temp_shape) + structured_mask = structured_mask.amax(dim=(tuple((np.arange(len(self.weight_ctx.binary_mask.shape)) * 2 + 1)))) + # print("Mask Shape from {} to {}".format(structured_mask.shape, self.weight_ctx.binary_mask.shape)) + if self.prune_bias is True: + structured_bias_mask_shape = structured_mask_shape[0] + structured_bias_mask = self.bias_ctx.binary_mask.detach().clone() + structured_bias_mask = structured_bias_mask.reshape((structured_bias_mask_shape, -1)) + structured_bias_mask = structured_bias_mask.amax(dim=1) + dim_aligned = structured_bias_mask.repeat(structured_mask.shape[1]).reshape(-1, structured_mask.shape[1]) + structured_mask = structured_mask.logical_or(dim_aligned).to(torch.float32) + return structured_mask + + def set_structured_mask(self, structured_mask): + self.weight_ctx.binary_mask=structured_mask + if self.prune_bias is True: + self.bias_ctx.binary_mask=structured_mask.amax(dim=1) + +class MaskCalculationHook(): + def __init__(self, module): + # pylint: disable=protected-access + self.hook = module._register_state_dict_hook(self.hook_fn) + + def hook_fn(self, module, destination, prefix, local_metadata): + # module.weight_ctx.binary_mask = binary_mask_by_threshold( + # module._expand_importance(module._weight_importance), + # module.masking_threshold + # ) + destination[prefix + 'weight_ctx._binary_mask'] = module.weight_ctx.binary_mask + + if module.prune_bias is True: + # module.bias_ctx.binary_mask = binary_mask_by_threshold( + # module._expand_importance(module._bias_importance, isbias=True), + # module.masking_threshold + # ) + destination[prefix + 'bias_ctx._binary_mask'] = module.bias_ctx.binary_mask + return destination + + def close(self): + self.hook.remove() \ No newline at end of file diff --git a/nncf/torch/sparsity/movement/loss.py b/nncf/torch/sparsity/movement/loss.py new file mode 100644 index 00000000000..ce6f87af63c --- /dev/null +++ b/nncf/torch/sparsity/movement/loss.py @@ -0,0 +1,77 @@ +""" + Copyright (c) 2019-2020 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import torch + +from nncf.torch.compression_method_api import PTCompressionLoss + +class ImportanceLoss(PTCompressionLoss): + def __init__(self, sparse_layers=None, penalty_scheduler=None): + super().__init__() + self._sparse_layers = sparse_layers + self.disabled = False + self.penalty_scheduler = penalty_scheduler + + def set_layers(self, sparse_layers): + self._sparse_layers = sparse_layers + + def disable(self): + if not self.disabled: + self.disabled = True + + for sparse_layer in self._sparse_layers: + sparse_layer.freeze_importance() + + def calculate(self) -> torch.Tensor: + # TODO, how about frozen? + if self.disabled: + return 0 + + loss = 0 + n_active_layer=0 + for sparse_layer in self._sparse_layers: + loss += sparse_layer.loss() + n_active_layer+=1 + + if self.penalty_scheduler is not None: + return self.penalty_scheduler.current_importance_lambda * (loss/n_active_layer) + return loss/n_active_layer + + +class SparseLossForPerLayerSparsity(ImportanceLoss): + def __init__(self, sparse_layers=None, target=1.0, p=0.05): + super().__init__(sparse_layers) + self.per_layer_target = {} + for sparse_layer in self._sparse_layers: + self.per_layer_target[sparse_layer] = self.target + + def calculate(self) -> torch.Tensor: + if self.disabled: + return 0 + + params = 0 + sparse_prob_sum = 0 + sparse_layers_loss = 0 + for sparse_layer in self._sparse_layers: + if not self.disabled and not sparse_layer.sparsify: + raise AssertionError( + "Invalid state of SparseLoss and SparsifiedWeight: mask is frozen for enabled loss") + if sparse_layer.sparsify: + sw_loss = sparse_layer.loss() + params_layer = sw_loss.view(-1).size(0) + params += params_layer + sparse_layers_loss -= torch.abs(sw_loss.sum() / params_layer - self.per_layer_target[sparse_layer]) + sparse_prob_sum += torch.sigmoid(sparse_layer.mask).sum() + + self.mean_sparse_prob = (sparse_prob_sum / params).item() + return (sparse_layers_loss / self.p).pow(2)