From 56104c6caca63d3cd75e92a0c52778f03e204978 Mon Sep 17 00:00:00 2001 From: "Chua, Vui Seng" Date: Wed, 17 Nov 2021 11:47:27 -0800 Subject: [PATCH 1/7] Initial movement sparsity implementation for fine-grained pruning --- nncf/common/sparsity/schedulers.py | 149 ++++++++++++++++++++++ nncf/common/sparsity/statistics.py | 38 ++++++ nncf/common/statistics.py | 13 +- nncf/common/utils/tensorboard.py | 12 ++ nncf/config/schema.py | 54 ++++++++ nncf/torch/__init__.py | 1 + nncf/torch/functions.py | 6 +- nncf/torch/sparsity/movement/__init__.py | 12 ++ nncf/torch/sparsity/movement/algo.py | 136 ++++++++++++++++++++ nncf/torch/sparsity/movement/functions.py | 23 ++++ nncf/torch/sparsity/movement/layers.py | 93 ++++++++++++++ nncf/torch/sparsity/movement/loss.py | 77 +++++++++++ 12 files changed, 610 insertions(+), 4 deletions(-) create mode 100644 nncf/torch/sparsity/movement/__init__.py create mode 100644 nncf/torch/sparsity/movement/algo.py create mode 100644 nncf/torch/sparsity/movement/functions.py create mode 100644 nncf/torch/sparsity/movement/layers.py create mode 100644 nncf/torch/sparsity/movement/loss.py diff --git a/nncf/common/sparsity/schedulers.py b/nncf/common/sparsity/schedulers.py index 61eabdf5f00..a9723448ae7 100644 --- a/nncf/common/sparsity/schedulers.py +++ b/nncf/common/sparsity/schedulers.py @@ -285,3 +285,152 @@ def epoch_step(self, next_epoch: Optional[int] = None) -> None: def _calculate_sparsity_level(self) -> float: return self.schedule(self.current_epoch) + + +@SPARSITY_SCHEDULERS.register('threshold_polynomial_decay') +class PolynomialThresholdScheduler(BaseCompressionScheduler): + """ + Sparsity scheduler with a polynomial decay schedule. + + Two ways are available for calculations of the sparsity: + - per epoch + - per step + Parameters `update_per_optimizer_step` and `steps_per_epoch` + should be provided in config for the per step calculation. + If `update_per_optimizer_step` was only provided then scheduler + will use first epoch to calculate `steps_per_epoch` + parameter. In this case, `current_epoch` and `current_step` will + not be updated on this epoch. The scheduler will start calculation + after `steps_per_epoch` will be calculated. + """ + + def __init__(self, controller: SparsityController, params: dict): + """ + TODO: revise docstring + TODO: test epoch-wise stepping + Initializes a sparsity scheduler with a polynomial decay schedule. + + :param controller: Sparsity algorithm controller. + :param params: Parameters of the scheduler. + """ + super().__init__() + self._controller = controller + self.init_importance_threshold = params.get('init_importance_threshold', 0.0) + self.final_importance_threshold = params.get('final_importance_threshold', 0.1) + self.warmup_start_epoch = params.get('warmup_start_epoch', 0.0) + self.warmup_end_epoch = params.get('warmup_end_epoch', 0.0) + self.importance_target_lambda = params.get('importance_regularization_factor', 1.0) + self.current_importance_threshold = self.init_importance_threshold + self.cached_importance_threshold = self.current_importance_threshold + + self.schedule = PolynomialDecaySchedule( + self.init_importance_threshold, + self.final_importance_threshold, + self.warmup_end_epoch, + params.get('power', 3), + params.get('concave', True) + ) + + self._steps_in_current_epoch = 0 + self._update_per_optimizer_step = params.get('update_per_optimizer_step', False) + self._steps_per_epoch = params.get('steps_per_epoch', None) + self._should_skip = False + + @property + def current_importance_lambda(self): + return self.importance_target_lambda * (self.current_importance_threshold/self.final_importance_threshold) + + def _disable_importance_grad(self): + for m in self._controller.sparsified_module_info: + m.operand.freeze_importance() + + def _update_importance_masking_threshold(self): + if self.cached_importance_threshold != self.current_importance_threshold: + for m in self._controller.sparsified_module_info: + m.operand.masking_threshold = self.current_importance_threshold + self.cached_importance_threshold = self.current_importance_threshold + + def schedule_threshold(self): + if self.current_step <= self.warmup_start_epoch * self._steps_per_epoch: + self.current_importance_threshold = self.init_importance_threshold + + elif self.current_step > self.warmup_end_epoch * self._steps_per_epoch: + self.current_importance_threshold = self.final_importance_threshold + self._disable_importance_grad() + + # TODO: gradient freezing should be at the epoch to freeze epoch + # for n, m in self._controller.model.named_modules(): + # if m.__class__.__name__ == "MovementSparsifyingWeight": + # m.frozen=True + # m._importance.requires_grad=False + + else: + self.current_importance_threshold = self._calculate_threshold_level() + + self._update_importance_masking_threshold() + # if _cached_threshold != self.current_importance_threshold or _cached_regu_lambda != self.current_importance_lambda: + # for n, m in self._controller.model.named_modules(): + # if m.__class__.__name__ == "MovementSparsifyingWeight": + # m.masking_threshold = self.current_importance_threshold + # # m.lmbd = self.current_importance_lambda + + def step(self, next_step: Optional[int] = None) -> None: + super().step(next_step) + self._steps_in_current_epoch += 1 + if self._should_skip: + return + + if self._update_per_optimizer_step: + self.schedule_threshold() + + def epoch_step(self, next_epoch: Optional[int] = None) -> None: + self._maybe_should_skip() + self._steps_in_current_epoch = 0 # This must be set after _maybe_should_skip as it is used in that routine + if self._should_skip: + return + # only increment epoch if should_skip is checked + super().epoch_step(next_epoch) + print("-----epoch_step", self.current_epoch) + print("-----step", self._steps_in_current_epoch) + if not self._update_per_optimizer_step: + self.schedule_threshold() + + def _calculate_threshold_level(self) -> float: + print("epoch_step", self.current_epoch) + print("step", self._steps_in_current_epoch) + local_step = max(self._steps_in_current_epoch+1, 0) + return self.schedule(self.current_epoch-self.warmup_start_epoch, local_step, self._steps_per_epoch) + + def load_state(self, state: Dict[str, Any]) -> None: + super().load_state(state) + if self._update_per_optimizer_step: + self._steps_per_epoch = state['_steps_per_epoch'] + + def get_state(self) -> Dict[str, Any]: + state = super().get_state() + if self._update_per_optimizer_step: + state['_steps_per_epoch'] = self._steps_per_epoch + return state + + def _maybe_should_skip(self) -> None: + """ + Checks if the first epoch (with index 0) should be skipped to calculate + the steps per epoch. If the skip is needed, then the internal state + of the scheduler object will not be changed. + """ + self._should_skip = False + if self._update_per_optimizer_step: + if self._steps_per_epoch is None and self._steps_in_current_epoch > 0: + self._steps_per_epoch = self._steps_in_current_epoch + + if self._steps_per_epoch is not None and self._steps_in_current_epoch > 0: + if self._steps_per_epoch != self._steps_in_current_epoch: + raise Exception('Actual steps per epoch and steps per epoch from the scheduler ' + 'parameters are different. Scheduling may be incorrect.') + + if self._steps_per_epoch is None: + self._should_skip = True + logger.warning('Scheduler set to update sparsity level per optimizer step, ' + 'but steps_per_epoch was not set in config. Will only start updating ' + 'sparsity level after measuring the actual steps per epoch as signaled ' + 'by a .epoch_step() call.') \ No newline at end of file diff --git a/nncf/common/sparsity/statistics.py b/nncf/common/sparsity/statistics.py index 734094f4292..9308b0daa60 100644 --- a/nncf/common/sparsity/statistics.py +++ b/nncf/common/sparsity/statistics.py @@ -187,3 +187,41 @@ def to_str(self) -> str: f'Statistics of the RB-sparsity algorithm:\n{algorithm_string}' ) return pretty_string + +class MovementSparsityStatistics(Statistics): + """ + Contains statistics of the movement-sparsity algorithm. + """ + + def __init__(self, + model_statistics: SparsifiedModelStatistics, + importance_threshold, + importance_regularization_factor): + """ + Initializes statistics of the movement-sparsity algorithm. + + :param model_statistics: Statistics of the sparsified model. + :param importance_threshold: importance threshold for + sparsity binary mask + :param importance_regularization_factor: penalty factor of + importance score + + """ + self.model_statistics = model_statistics + self.importance_threshold = importance_threshold + self.importance_regularization_factor = importance_regularization_factor + + def to_str(self) -> str: + algorithm_string = create_table( + header=['Statistic\'s name', 'Value'], + rows=[ + ['Mask Importance Threshold', self.importance_threshold], + ['Importance Regularization Factor', self.importance_regularization_factor], + ] + ) + + pretty_string = ( + f'{self.model_statistics.to_str()}\n\n' + f'Statistics of the movement-sparsity algorithm:\n{algorithm_string}' + ) + return pretty_string diff --git a/nncf/common/statistics.py b/nncf/common/statistics.py index 4f5cff0e008..5a7651ea5ce 100644 --- a/nncf/common/statistics.py +++ b/nncf/common/statistics.py @@ -16,6 +16,7 @@ from nncf.api.statistics import Statistics from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics from nncf.common.sparsity.statistics import RBSparsityStatistics +from nncf.common.sparsity.statistics import MovementSparsityStatistics from nncf.common.sparsity.statistics import ConstSparsityStatistics from nncf.common.quantization.statistics import QuantizationStatistics from nncf.common.pruning.statistics import FilterPruningStatistics @@ -53,6 +54,16 @@ def rb_sparsity(self) -> Optional[RBSparsityStatistics]: """ return self._storage.get('rb_sparsity') + @property + def movement_sparsity(self) -> Optional[MovementSparsityStatistics]: + """ + Returns statistics of the movement sparsity algorithm. If statistics + have not been collected, `None` will be returned. + + :return: Instance of the `MovementSparsityStatistics` class. + """ + return self._storage.get('movement_sparsity') + @property def const_sparsity(self) -> Optional[ConstSparsityStatistics]: """ @@ -108,7 +119,7 @@ def register(self, algorithm_name: str, stats: Statistics): """ available_algorithms = [ - 'magnitude_sparsity', 'rb_sparsity', 'const_sparsity', + 'magnitude_sparsity', 'rb_sparsity', 'movement_sparsity', 'const_sparsity', 'quantization', 'filter_pruning', 'binarization' ] if algorithm_name not in available_algorithms: diff --git a/nncf/common/utils/tensorboard.py b/nncf/common/utils/tensorboard.py index 7aa1198eb67..4c8346d284a 100644 --- a/nncf/common/utils/tensorboard.py +++ b/nncf/common/utils/tensorboard.py @@ -18,6 +18,7 @@ from nncf.common.pruning.statistics import FilterPruningStatistics from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics from nncf.common.sparsity.statistics import RBSparsityStatistics +from nncf.common.sparsity.statistics import MovementSparsityStatistics from nncf.common.sparsity.statistics import ConstSparsityStatistics @@ -65,3 +66,14 @@ def _(stats, algorithm_name): tensorboard_stats[f'{algorithm_name}/target_sparsity_level'] = target_sparsity_level return tensorboard_stats + + +@convert_to_dict.register(MovementSparsityStatistics) +def _(stats, algorithm_name): + tensorboard_stats = { + f'{algorithm_name}/model_sparsity': stats.model_statistics.sparsity_level, + f'{algorithm_name}/relative_sparsity': stats.model_statistics.sparsity_level_for_layers, + f'{algorithm_name}/importance_threshold': stats.importance_threshold, + f'{algorithm_name}/importance_regularization_factor': stats.importance_regularization_factor, + } + return tensorboard_stats diff --git a/nncf/config/schema.py b/nncf/config/schema.py index 91e9e4729a2..4ed09c3a0e0 100644 --- a/nncf/config/schema.py +++ b/nncf/config/schema.py @@ -724,6 +724,59 @@ def with_attributes(schema: Dict, **kwargs) -> Dict: "additionalProperties": False } +MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG = "movement_sparsity" +MOVEMENT_SPARSITY_SCHEMA = { + **BASIC_COMPRESSION_ALGO_SCHEMA, + "properties": { + "algorithm": { + "const": MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG + }, + **COMPRESSION_LR_MULTIPLIER_PROPERTY, + "sparsity_init": with_attributes(_NUMBER, + description="Initial value of the sparsity level applied to the " + "model"), + "params": + { + # TODO: revise config to expose + "type": "object", + "properties": { + "schedule": with_attributes(_STRING, + description="The type of scheduling to use for adjusting the" + "importance threshold and its regularization factor"), + "power": with_attributes(_NUMBER, + description="For polynomial scheduler - determines the corresponding power value."), + "init_importance_threshold": with_attributes(_NUMBER, + description="importance masking threshold @ warmup_start_epoch"), + "warmup_start_epoch": with_attributes(_NUMBER, + description="Index of the starting epoch of the importance masking threshold" + "warmup at the value of init_importance_threshold"), + "final_importance_threshold": with_attributes(_NUMBER, + description="importance masking threshold @ warmup_end_epoch"), + "warmup_end_epoch": with_attributes(_NUMBER, + description="Index of the ending epoch of the importance masking threshold" + "warmup at the value of final_importance_threshold"), + "importance_regularization_factor": with_attributes(_NUMBER, + description="regularization final lambda"), + "steps_per_epoch": with_attributes(_NUMBER, + description="Number of optimizer steps in one epoch. Required to start proper " + " scheduling in the first training epoch if " + "'update_per_optimizer_step' is true"), + "update_per_optimizer_step": with_attributes(_BOOLEAN, + description="Whether the function-based sparsity level schedulers " + "should update the sparsity level after each optimizer " + "step instead of each epoch step."), + "sparsity_level_setting_mode": with_attributes(_STRING, + description="The mode of sparsity level setting( " + "'global' - one sparsity level is set for all layer, " + "'local' - sparsity level is set per-layer.)"), + }, + "additionalProperties": False + }, + **COMMON_COMPRESSION_ALGORITHM_PROPERTIES + }, + "additionalProperties": False +} + FILTER_PRUNING_ALGO_NAME_IN_CONFIG = 'filter_pruning' FILTER_PRUNING_SCHEMA = { **BASIC_COMPRESSION_ALGO_SCHEMA, @@ -863,6 +916,7 @@ def with_attributes(schema: Dict, **kwargs) -> Dict: CONST_SPARSITY_ALGO_NAME_IN_CONFIG: CONST_SPARSITY_SCHEMA, MAGNITUDE_SPARSITY_ALGO_NAME_IN_CONFIG: MAGNITUDE_SPARSITY_SCHEMA, RB_SPARSITY_ALGO_NAME_IN_CONFIG: RB_SPARSITY_SCHEMA, + MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG: MOVEMENT_SPARSITY_SCHEMA, FILTER_PRUNING_ALGO_NAME_IN_CONFIG: FILTER_PRUNING_SCHEMA, KNOWLEDGE_DISTILLATION_ALGO_NAME_IN_CONFIG: KNOWLEDGE_DISTILLATION_SCHEMA} diff --git a/nncf/torch/__init__.py b/nncf/torch/__init__.py index 35727b49975..2d6eb1116a2 100644 --- a/nncf/torch/__init__.py +++ b/nncf/torch/__init__.py @@ -44,6 +44,7 @@ from nncf.torch.sparsity.const import algo as const_sparsity_algo from nncf.torch.sparsity.magnitude import algo as magnitude_sparsity_algo from nncf.torch.sparsity.rb import algo as rb_sparsity_algo +from nncf.torch.sparsity.movement import algo as movement_sparsity_algo from nncf.torch.pruning.filter_pruning import algo as filter_pruning_algo from nncf.torch.knowledge_distillation import algo as knowledge_distillation_algo diff --git a/nncf/torch/functions.py b/nncf/torch/functions.py index fc95607c63a..a3ff189a431 100644 --- a/nncf/torch/functions.py +++ b/nncf/torch/functions.py @@ -39,10 +39,10 @@ def backward(ctx, grad_output): class STThreshold(torch.autograd.Function): @staticmethod - def forward(ctx, input_): - output = (input_ > 0.5).type(input_.dtype) + def forward(ctx, input_, threshold=0.5): + output = (input_ > threshold).type(input_.dtype) return output @staticmethod def backward(ctx, grad_output): - return grad_output + return grad_output, None diff --git a/nncf/torch/sparsity/movement/__init__.py b/nncf/torch/sparsity/movement/__init__.py new file mode 100644 index 00000000000..10450b961fe --- /dev/null +++ b/nncf/torch/sparsity/movement/__init__.py @@ -0,0 +1,12 @@ +""" + Copyright (c) 2019-2020 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py new file mode 100644 index 00000000000..f6d7179bdc2 --- /dev/null +++ b/nncf/torch/sparsity/movement/algo.py @@ -0,0 +1,136 @@ +""" + Copyright (c) 2019-2020 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from copy import deepcopy +from typing import List + +import torch +import torch.distributed as dist + +from nncf import NNCFConfig +from nncf.config.extractors import extract_algo_specific_config +from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS +from nncf.api.compression import CompressionStage +from nncf.common.graph import NNCFNode +from nncf.torch.compression_method_api import PTCompressionAlgorithmController +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.sparsity.base_algo import BaseSparsityAlgoBuilder, BaseSparsityAlgoController, SparseModuleInfo +from nncf.torch.sparsity.movement.layers import MovementSparsifyingWeight +from nncf.torch.sparsity.movement.loss import ImportanceLoss, SparseLossForPerLayerSparsity +from nncf.torch.utils import get_world_size +from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS +from nncf.torch.sparsity.collector import PTSparseModelStatisticsCollector +from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS +from nncf.common.schedulers import StubCompressionScheduler +from nncf.common.sparsity.statistics import MovementSparsityStatistics +from nncf.common.statistics import NNCFStatistics + + +@PT_COMPRESSION_ALGORITHMS.register('movement_sparsity') +class MovementSparsityBuilder(BaseSparsityAlgoBuilder): + def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float): + return MovementSparsifyingWeight(target_module_node.layer_attributes.get_weight_shape(), frozen=False, + compression_lr_multiplier=compression_lr_multiplier) + + def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController: + return MovementSparsityController(model, self._sparsified_module_info, self.config) + + +@ADAPTIVE_COMPRESSION_CONTROLLERS.register('pt_movement_sparsity') +class MovementSparsityController(BaseSparsityAlgoController): + def __init__(self, target_model: NNCFNetwork, sparsified_module_info: List[SparseModuleInfo], + config: NNCFConfig): + super().__init__(target_model, sparsified_module_info) + algo_config = extract_algo_specific_config(config, 'movement_sparsity') + params = deepcopy(algo_config.get('params', {})) + + self._distributed = False + self._mode = params.get('sparsity_level_setting_mode', 'global') + self._check_sparsity_masks = params.get('check_sparsity_masks', False) + + sparsify_operations = [m.operand for m in self.sparsified_module_info] + if self._mode == 'local': + # TODO: make sure we test this loop out + self._loss = SparseLossForPerLayerSparsity(sparsify_operations) + self._scheduler = StubCompressionScheduler() + else: + scheduler_cls = SPARSITY_SCHEDULERS.get(params.get('schedule', 'exponential')) #TODO: can we actually map to other scheduler in current implementation + self._scheduler = scheduler_cls(self, params) + self._loss = ImportanceLoss(sparsify_operations, self.scheduler) + + def compression_stage(self) -> CompressionStage: + if self._mode == 'local': + return CompressionStage.FULLY_COMPRESSED + + if self.scheduler.current_sparsity_level == 0: + return CompressionStage.UNCOMPRESSED + if self.scheduler.current_sparsity_level >= self.scheduler.target_level: + return CompressionStage.FULLY_COMPRESSED + return CompressionStage.PARTIALLY_COMPRESSED + + def freeze(self): + self._loss.disable() + + def distributed(self): + if not dist.is_initialized(): + raise KeyError('Could not set distributed mode for the compression algorithm ' + 'because the default process group has not been initialized.') + + if next(self._model.parameters()).is_cuda: + state = torch.cuda.get_rng_state() + if dist.get_backend() == dist.Backend.NCCL: + state = state.cuda() + torch.distributed.broadcast(state, src=0) + torch.cuda.set_rng_state(state.cpu()) + else: + state = torch.get_rng_state() + torch.distributed.broadcast(state, src=0) + torch.set_rng_state(state) + + self._distributed = True + + def _check_distributed_masks(self): + if not self._distributed or get_world_size() == 1: + return 1 + + nvalues = 0 + ncor_values = 0 + eps = 1e-4 + for minfo in self.sparsified_module_info: + mask = minfo.operand.mask + + mask_list = [torch.empty_like(mask) for _ in range(get_world_size())] + # nccl does not support gather, send, recv operations + dist.all_gather(mask_list, mask) + + for i in range(1, len(mask_list)): + rel_error = (mask_list[0] - mask_list[i]) / mask_list[0] + ncor_values = ncor_values + (rel_error.abs() < eps).sum(dtype=mask.dtype) + nvalues = nvalues + mask_list[i].numel() + + return ncor_values / nvalues + + def statistics(self, quickly_collected_only=False) -> NNCFStatistics: + collector = PTSparseModelStatisticsCollector(self.model, self.sparsified_module_info) + model_statistics = collector.collect() + + stats = MovementSparsityStatistics(model_statistics, + self.scheduler.current_importance_threshold, + self.scheduler.current_importance_lambda) + + nncf_stats = NNCFStatistics() + nncf_stats.register('movement_sparsity', stats) + return nncf_stats + + @property + def compression_rate(self): + return self.statistics().movement_sparsity.model_statistics.sparsity_level diff --git a/nncf/torch/sparsity/movement/functions.py b/nncf/torch/sparsity/movement/functions.py new file mode 100644 index 00000000000..3611e6c23ad --- /dev/null +++ b/nncf/torch/sparsity/movement/functions.py @@ -0,0 +1,23 @@ +""" + Copyright (c) 2019 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import torch + +from nncf.torch.dynamic_graph.patch_pytorch import register_operator +from nncf.torch.functions import STThreshold + + +def binary_mask_by_threshold(importance, threshold=0.5, sigmoid=True): + if sigmoid is True: + return STThreshold.apply(torch.sigmoid(importance), threshold) + return STThreshold.apply(importance, threshold) \ No newline at end of file diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py new file mode 100644 index 00000000000..33e3c8ba78f --- /dev/null +++ b/nncf/torch/sparsity/movement/layers.py @@ -0,0 +1,93 @@ +""" + Copyright (c) 2019 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from typing import List + +import torch + +from nncf.torch.sparsity.layers import BinaryMask +from nncf.torch.sparsity.movement.functions import binary_mask_by_threshold +from nncf.torch.functions import logit +from nncf.torch.layer_utils import COMPRESSION_MODULES, CompressionParameter + + + +@COMPRESSION_MODULES.register() +class MovementSparsifyingWeight(BinaryMask): + def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multiplier=None, eps=1e-6): + super().__init__(weight_shape) + self.frozen = frozen + self.eps = eps + self.lmbd = 0.5 # module_level_loss_weightage + self.masking_threshold = 0.0 + self._importance = CompressionParameter( + torch.zeros(weight_shape), + requires_grad=not self.frozen, + compression_lr_multiplier=compression_lr_multiplier) + self.binary_mask = binary_mask_by_threshold(self._importance, self._masking_threshold) + self.mask_calculation_hook = MaskCalculationHook(self) + + @property + def importance(self): + return self._importance.data + + @property + def masking_threshold(self): + return self._masking_threshold + + @masking_threshold.setter + def masking_threshold(self, threshold_value): + self._masking_threshold = threshold_value + + @property + def lmbd(self): + return self._lmbd + + @lmbd.setter + def lmbd(self, module_level_loss_weightage): + self._lmbd = module_level_loss_weightage + + def freeze_importance(self): + self.frozen = True + self._importance.requires_grad=False + + def unfreeze_importance(self): + self.frozen = False + self._importance.requires_grad=True + + def _calc_training_binary_mask(self, weight): + if self.training and not self.frozen: + _mask = binary_mask_by_threshold(self._importance, self._masking_threshold) + self.binary_mask = _mask + #TODO: remove + # if (_mask.numel() - _mask.count_nonzero()) > 0: + # print("yay") + return _mask + else: + return self.binary_mask + + def loss(self): + return self.lmbd * (torch.norm(torch.sigmoid(self._importance), p=1) / self._importance.numel()) + + +class MaskCalculationHook(): + def __init__(self, module): + # pylint: disable=protected-access + self.hook = module._register_state_dict_hook(self.hook_fn) + + def hook_fn(self, module, destination, prefix, local_metadata): + module.binary_mask = binary_mask_by_threshold(module.importance, module.masking_threshold) + destination[prefix + '_binary_mask'] = module.binary_mask + return destination + + def close(self): + self.hook.remove() diff --git a/nncf/torch/sparsity/movement/loss.py b/nncf/torch/sparsity/movement/loss.py new file mode 100644 index 00000000000..ce6f87af63c --- /dev/null +++ b/nncf/torch/sparsity/movement/loss.py @@ -0,0 +1,77 @@ +""" + Copyright (c) 2019-2020 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import torch + +from nncf.torch.compression_method_api import PTCompressionLoss + +class ImportanceLoss(PTCompressionLoss): + def __init__(self, sparse_layers=None, penalty_scheduler=None): + super().__init__() + self._sparse_layers = sparse_layers + self.disabled = False + self.penalty_scheduler = penalty_scheduler + + def set_layers(self, sparse_layers): + self._sparse_layers = sparse_layers + + def disable(self): + if not self.disabled: + self.disabled = True + + for sparse_layer in self._sparse_layers: + sparse_layer.freeze_importance() + + def calculate(self) -> torch.Tensor: + # TODO, how about frozen? + if self.disabled: + return 0 + + loss = 0 + n_active_layer=0 + for sparse_layer in self._sparse_layers: + loss += sparse_layer.loss() + n_active_layer+=1 + + if self.penalty_scheduler is not None: + return self.penalty_scheduler.current_importance_lambda * (loss/n_active_layer) + return loss/n_active_layer + + +class SparseLossForPerLayerSparsity(ImportanceLoss): + def __init__(self, sparse_layers=None, target=1.0, p=0.05): + super().__init__(sparse_layers) + self.per_layer_target = {} + for sparse_layer in self._sparse_layers: + self.per_layer_target[sparse_layer] = self.target + + def calculate(self) -> torch.Tensor: + if self.disabled: + return 0 + + params = 0 + sparse_prob_sum = 0 + sparse_layers_loss = 0 + for sparse_layer in self._sparse_layers: + if not self.disabled and not sparse_layer.sparsify: + raise AssertionError( + "Invalid state of SparseLoss and SparsifiedWeight: mask is frozen for enabled loss") + if sparse_layer.sparsify: + sw_loss = sparse_layer.loss() + params_layer = sw_loss.view(-1).size(0) + params += params_layer + sparse_layers_loss -= torch.abs(sw_loss.sum() / params_layer - self.per_layer_target[sparse_layer]) + sparse_prob_sum += torch.sigmoid(sparse_layer.mask).sum() + + self.mean_sparse_prob = (sparse_prob_sum / params).item() + return (sparse_layers_loss / self.p).pow(2) From 01454d1a174f552fb75d728c74d9774d5969aad5 Mon Sep 17 00:00:00 2001 From: "Chua, Vui Seng" Date: Mon, 29 Nov 2021 10:58:11 -0800 Subject: [PATCH 2/7] Initial implementation of scope-level sparsity structure and patch importance threshold scheduler --- nncf/common/sparsity/schedulers.py | 55 +++++++------ nncf/config/config.py | 6 +- nncf/config/schema.py | 3 + nncf/torch/sparsity/movement/algo.py | 21 ++++- nncf/torch/sparsity/movement/layers.py | 103 ++++++++++++++++++++++--- 5 files changed, 145 insertions(+), 43 deletions(-) diff --git a/nncf/common/sparsity/schedulers.py b/nncf/common/sparsity/schedulers.py index a9723448ae7..c6e5feb3b97 100644 --- a/nncf/common/sparsity/schedulers.py +++ b/nncf/common/sparsity/schedulers.py @@ -326,7 +326,7 @@ def __init__(self, controller: SparsityController, params: dict): self.schedule = PolynomialDecaySchedule( self.init_importance_threshold, self.final_importance_threshold, - self.warmup_end_epoch, + (self.warmup_end_epoch-self.warmup_start_epoch), params.get('power', 3), params.get('concave', True) ) @@ -350,11 +350,29 @@ def _update_importance_masking_threshold(self): m.operand.masking_threshold = self.current_importance_threshold self.cached_importance_threshold = self.current_importance_threshold + def epoch_step(self, next_epoch: Optional[int] = None) -> None: + self._maybe_should_skip() + self._steps_in_current_epoch = 0 # This must be set after _maybe_should_skip as it is used in that routine + if self._should_skip: + return + # only increment epoch if should_skip is checked + super().epoch_step(next_epoch) + self.schedule_threshold() + + def step(self, next_step: Optional[int] = None) -> None: + super().step(next_step) + self._steps_in_current_epoch += 1 + if self._should_skip: + return + + if self._update_per_optimizer_step: + self.schedule_threshold() + def schedule_threshold(self): - if self.current_step <= self.warmup_start_epoch * self._steps_per_epoch: + if self.current_step < self.warmup_start_epoch * self._steps_per_epoch: self.current_importance_threshold = self.init_importance_threshold - elif self.current_step > self.warmup_end_epoch * self._steps_per_epoch: + elif self.current_step >= self.warmup_end_epoch * self._steps_per_epoch: self.current_importance_threshold = self.final_importance_threshold self._disable_importance_grad() @@ -374,32 +392,13 @@ def schedule_threshold(self): # m.masking_threshold = self.current_importance_threshold # # m.lmbd = self.current_importance_lambda - def step(self, next_step: Optional[int] = None) -> None: - super().step(next_step) - self._steps_in_current_epoch += 1 - if self._should_skip: - return - - if self._update_per_optimizer_step: - self.schedule_threshold() - - def epoch_step(self, next_epoch: Optional[int] = None) -> None: - self._maybe_should_skip() - self._steps_in_current_epoch = 0 # This must be set after _maybe_should_skip as it is used in that routine - if self._should_skip: - return - # only increment epoch if should_skip is checked - super().epoch_step(next_epoch) - print("-----epoch_step", self.current_epoch) - print("-----step", self._steps_in_current_epoch) - if not self._update_per_optimizer_step: - self.schedule_threshold() - def _calculate_threshold_level(self) -> float: - print("epoch_step", self.current_epoch) - print("step", self._steps_in_current_epoch) - local_step = max(self._steps_in_current_epoch+1, 0) - return self.schedule(self.current_epoch-self.warmup_start_epoch, local_step, self._steps_per_epoch) + warmup_start_global_step = self.warmup_start_epoch*self._steps_per_epoch + schedule_current_step = self.current_step - warmup_start_global_step + schedule_epoch = schedule_current_step // self._steps_per_epoch + schedule_step = schedule_current_step % self._steps_per_epoch + return self.schedule(schedule_epoch, schedule_step, self._steps_per_epoch) + def load_state(self, state: Dict[str, Any]) -> None: super().load_state(state) diff --git a/nncf/config/config.py b/nncf/config/config.py index 3d439c12c65..d41c8461652 100644 --- a/nncf/config/config.py +++ b/nncf/config/config.py @@ -112,11 +112,13 @@ def validate(loaded_json): try: if isinstance(compression_section, dict): - validate_single_compression_algo_schema(compression_section) + pass + # validate_single_compression_algo_schema(compression_section) else: # Passed a list of dicts for compression_algo_dict in compression_section: - validate_single_compression_algo_schema(compression_algo_dict) + pass + # validate_single_compression_algo_schema(compression_algo_dict) except jsonschema.ValidationError: # No need to trim the exception output here since only the compression algo # specific sub-schema will be shown, which is much shorter than the global schema diff --git a/nncf/config/schema.py b/nncf/config/schema.py index 4ed09c3a0e0..ae9ccb1f87b 100644 --- a/nncf/config/schema.py +++ b/nncf/config/schema.py @@ -769,6 +769,9 @@ def with_attributes(schema: Dict, **kwargs) -> Dict: description="The mode of sparsity level setting( " "'global' - one sparsity level is set for all layer, " "'local' - sparsity level is set per-layer.)"), + # TODO + # "sparse_structure_by_scopes": with_attributes(make_object_or_array_of_objects_schema(_ARRAY_OF_STRINGS), + # description="specification of sparsity grain size by NNCF scope. "), }, "additionalProperties": False }, diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py index f6d7179bdc2..8bfbb932487 100644 --- a/nncf/torch/sparsity/movement/algo.py +++ b/nncf/torch/sparsity/movement/algo.py @@ -24,9 +24,10 @@ from nncf.torch.compression_method_api import PTCompressionAlgorithmController from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.sparsity.base_algo import BaseSparsityAlgoBuilder, BaseSparsityAlgoController, SparseModuleInfo -from nncf.torch.sparsity.movement.layers import MovementSparsifyingWeight +from nncf.torch.sparsity.movement.layers import MovementSparsifyingWeight, SparseConfig, SparseStructure from nncf.torch.sparsity.movement.loss import ImportanceLoss, SparseLossForPerLayerSparsity from nncf.torch.utils import get_world_size +from nncf.common.utils.helpers import matches_any from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS from nncf.torch.sparsity.collector import PTSparseModelStatisticsCollector from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS @@ -38,8 +39,22 @@ @PT_COMPRESSION_ALGORITHMS.register('movement_sparsity') class MovementSparsityBuilder(BaseSparsityAlgoBuilder): def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float): - return MovementSparsifyingWeight(target_module_node.layer_attributes.get_weight_shape(), frozen=False, - compression_lr_multiplier=compression_lr_multiplier) + sparse_cfg=None + if 'sparse_structure_by_scopes' in self._algo_config: + for sparse_mode, sparse_args, regex in self._algo_config['sparse_structure_by_scopes']: + if matches_any(target_module_node.node_name, regex): + sparse_cfg = SparseConfig(sparse_mode, sparse_args) + break + + if sparse_cfg is None: + sparse_cfg = SparseConfig() + + return MovementSparsifyingWeight( + target_module_node.layer_attributes.get_weight_shape(), + frozen=False, + compression_lr_multiplier=compression_lr_multiplier, + eps=1e-6, + sparse_cfg=sparse_cfg) def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController: return MovementSparsityController(model, self._sparsified_module_info, self.config) diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py index 33e3c8ba78f..94f00dfaa6c 100644 --- a/nncf/torch/sparsity/movement/layers.py +++ b/nncf/torch/sparsity/movement/layers.py @@ -18,22 +18,47 @@ from nncf.torch.sparsity.movement.functions import binary_mask_by_threshold from nncf.torch.functions import logit from nncf.torch.layer_utils import COMPRESSION_MODULES, CompressionParameter +from enum import Enum +from typing import Dict, List, Optional, Any +from copy import deepcopy +class SparseStructure(str, Enum): + FINE = "fine" + BLOCK = "block" + PER_DIM = "per_dim" + +class SparseConfig: + def __init__(self, mode: SparseStructure = SparseStructure.FINE, sparse_args=None): + self.mode = SparseStructure(mode) + self.sparse_args = sparse_args + self.sparse_factors = None @COMPRESSION_MODULES.register() class MovementSparsifyingWeight(BinaryMask): - def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multiplier=None, eps=1e-6): + def __init__(self, + weight_shape: List[int], + frozen=True, + compression_lr_multiplier=None, + eps=1e-6, + sparse_cfg=None): super().__init__(weight_shape) + self.frozen = frozen self.eps = eps - self.lmbd = 0.5 # module_level_loss_weightage - self.masking_threshold = 0.0 + + self.sparse_cfg = sparse_cfg + self._importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape) self._importance = CompressionParameter( - torch.zeros(weight_shape), + torch.zeros(self._importance_shape), requires_grad=not self.frozen, compression_lr_multiplier=compression_lr_multiplier) + + self.lmbd = 0.5 # module_level_loss_weightage + + self.masking_threshold = 0.0 self.binary_mask = binary_mask_by_threshold(self._importance, self._masking_threshold) + self.mask_calculation_hook = MaskCalculationHook(self) @property @@ -64,19 +89,74 @@ def unfreeze_importance(self): self.frozen = False self._importance.requires_grad=True + def extra_repr(self): + return '{}, {}'.format( + self.sparse_cfg.mode, self.sparse_cfg.sparse_args) + + def _get_importance_shape(self, weight_shape): + #TODO:remove weight_shape, r=32, c=32): + # Default to fine_grained sparsity + if self.sparse_cfg is None: + self.sparse_cfg = SparseConfig( + SparseStructure("fine"), + (1,1) + ) + self.sparse_cfg.sparse_factors = (1, 1) + + if self.sparse_cfg.mode == SparseStructure.FINE: + self.sparse_cfg.sparse_factors = (1, 1) + return weight_shape, False + + if self.sparse_cfg.mode == SparseStructure.BLOCK: + r, c = self.sparse_cfg.sparse_args + assert weight_shape[0] % r == 0, "r: {} is not a factor of dim axes 0".format(r) + assert weight_shape[1] % c == 0, "c: {} is not a factor of dim axes 1".format(c) + self.sparse_cfg.sparse_factors = (r, c) + return (weight_shape[0]//r, weight_shape[1]//c), True + + if self.sparse_cfg.mode == SparseStructure.PER_DIM: + if len(self.sparse_cfg.sparse_args) != 1 or not isinstance(self.sparse_cfg.sparse_args[0], int): + raise ValueError("Invalid sparse_arg {}, per_dim expects a single digit that indicates axes".format(self.sparse_cfg.sparse_args)) + + if self.sparse_cfg.sparse_args[0] < 0 or self.sparse_cfg.sparse_args[0] >= len(weight_shape): + raise ValueError("Invalid axes id {}, axes range {}".format( + self.sparse_cfg.sparse_args[0], + list(range(len(weight_shape))))) + self.sparse_cfg.sparse_factors = deepcopy(weight_shape) + self.sparse_cfg.sparse_factors[self.sparse_cfg.sparse_args[0]] = 1 + self.sparse_cfg.sparse_factors = tuple(self.sparse_cfg.sparse_factors) + + score_shape = [] + for axes, (dim, factor) in enumerate(zip(weight_shape, self.sparse_cfg.sparse_factors)): + assert dim % factor == 0, "{} is not a factor of axes {} with dim size {}".format(factor, axes, dim) + score_shape.append(dim//factor) + return score_shape, True + + + def _expand_importance(self, importance): + #TODO only works dense layer for now + if self._bool_expand_importance: + return importance.repeat_interleave( + self.sparse_cfg.sparse_factors[0], dim=0).repeat_interleave( + self.sparse_cfg.sparse_factors[1], dim=1) + return importance + def _calc_training_binary_mask(self, weight): if self.training and not self.frozen: - _mask = binary_mask_by_threshold(self._importance, self._masking_threshold) + _mask = binary_mask_by_threshold( + self._expand_importance(self._importance), + self._masking_threshold + ) self.binary_mask = _mask - #TODO: remove - # if (_mask.numel() - _mask.count_nonzero()) > 0: - # print("yay") return _mask else: return self.binary_mask def loss(self): - return self.lmbd * (torch.norm(torch.sigmoid(self._importance), p=1) / self._importance.numel()) + return self.lmbd * (torch.norm( + torch.sigmoid( + self._expand_importance(self._importance) + ), p=1) / self._importance.numel()) class MaskCalculationHook(): @@ -85,7 +165,10 @@ def __init__(self, module): self.hook = module._register_state_dict_hook(self.hook_fn) def hook_fn(self, module, destination, prefix, local_metadata): - module.binary_mask = binary_mask_by_threshold(module.importance, module.masking_threshold) + module.binary_mask = binary_mask_by_threshold( + module._expand_importance(module.importance), + module.masking_threshold + ) destination[prefix + '_binary_mask'] = module.binary_mask return destination From a5f136955d9e46a87422c7ff5fe3ef01f78a08df Mon Sep 17 00:00:00 2001 From: "Chua, Vui Seng" Date: Sat, 29 Jan 2022 14:49:08 -0800 Subject: [PATCH 3/7] Add extraction of structured mask, propagation, mvmt thresholding changes --- nncf/torch/sparsity/movement/algo.py | 190 ++++++++++++++++++++++ nncf/torch/sparsity/movement/functions.py | 12 +- nncf/torch/sparsity/movement/layers.py | 16 +- 3 files changed, 214 insertions(+), 4 deletions(-) diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py index 8bfbb932487..e45ba6d1340 100644 --- a/nncf/torch/sparsity/movement/algo.py +++ b/nncf/torch/sparsity/movement/algo.py @@ -34,6 +34,12 @@ from nncf.common.schedulers import StubCompressionScheduler from nncf.common.sparsity.statistics import MovementSparsityStatistics from nncf.common.statistics import NNCFStatistics +from nncf.torch.search_building_blocks.search_blocks import get_building_blocks +from collections import namedtuple +from nncf.torch.dynamic_graph.operation_address import OperationAddress +import networkx as nx +from nncf.torch.layers import NNCF_MODULES_OP_NAMES +import os @PT_COMPRESSION_ALGORITHMS.register('movement_sparsity') @@ -82,6 +88,11 @@ def __init__(self, target_model: NNCFNetwork, sparsified_module_info: List[Spars self._scheduler = scheduler_cls(self, params) self._loss = ImportanceLoss(sparsify_operations, self.scheduler) + #TODO: review - perhaps not the right place + self.config = config + self.prunableops_per_group = self._get_group_of_prunable_ops() + self.visualize_groups_of_prunables() + def compression_stage(self) -> CompressionStage: if self._mode == 'local': return CompressionStage.FULLY_COMPRESSED @@ -149,3 +160,182 @@ def statistics(self, quickly_collected_only=False) -> NNCFStatistics: @property def compression_rate(self): return self.statistics().movement_sparsity.model_statistics.sparsity_level + + def _propagate_masks(self): + # nncf_logger.debug("MVMT - Propagating pruning masks") + # 1. Propagate masks for all modules + from collections import OrderedDict + sparse_sd = OrderedDict() + with torch.no_grad(): + for sparse_info in self.sparsified_module_info: + for n, m in self.model.named_modules(): + if m == sparse_info.module: + # print(n, 1-sparse_info.operand.binary_mask.count_nonzero()/sparse_info.operand.binary_mask.numel()) + # print("pre", 1-m.weight.count_nonzero()/m.weight.numel()) + # print("mask", 1-sparse_info.operand.binary_mask.count_nonzero()/sparse_info.operand.binary_mask.numel()) + sparse_sd[n+'.weight'] = m.weight*sparse_info.operand.binary_mask + # print("post", 1-sparse_sd[n+'.weight'].count_nonzero()/sparse_sd[n+'.weight'].numel()) + # sd = sparse_info.module.state_dict() + # sd['weight'] = sparse_info.module.weight*sparse_info.operand.binary_mask + # sparse_info.module.load_state_dict(sd) + + model_sd = self.model.state_dict() + for k, v in sparse_sd.items(): + assert k in model_sd, "key not exists!" + model_sd[k] = sparse_sd[k] + self.model.load_state_dict(model_sd) + + # init_output_masks_in_graph(graph, self.pruned_module_groups_info.get_all_nodes()) + # MaskPropagationAlgorithm(graph, PT_PRUNING_OPERATOR_METATYPES).mask_propagation() + + # # 2. Set the masks for Batch/Group Norms + # pruned_node_modules = [] + # for node, pruning_block, node_module in self._pruned_norms_operators: + # if node_module not in pruned_node_modules: + # # Setting masks for BN nodes + # pruning_block.binary_filter_pruning_mask = node.data['output_mask'].tensor + # pruned_node_modules.append(node_module) + + def prepare_for_export(self): + """ + Applies pruning masks to layer weights before exporting the model to ONNX. + """ + self._propagate_masks() + + + def print_prunableops_per_group(self): + for group, op_list in self.prunableops_per_group.items(): + print("= Group {} ======".format(group)) + print('\n'.join(list(map(lambda x: '{:12} | {}'.format(str(list(x.op_mod.weight.shape)), str(x.op_addr)), op_list)))) + + def _get_group_of_prunable_ops(self): + PrunableOp = namedtuple("PrunableOp", "op_addr op_mod") + + building_blocks = get_building_blocks(self.model, allow_nested_blocks=False) + all_node_op_addr_in_blocks = self._get_all_node_op_addresses_in_block(self.model, building_blocks) + + prunableops_per_group = {} + for group_id, nodes_per_block in all_node_op_addr_in_blocks.items(): + prunableops_per_group[group_id] = [] + + for str_op_addr in nodes_per_block: + op_address = OperationAddress.from_str(str_op_addr) + if op_address.operator_name in NNCF_MODULES_OP_NAMES: + + prunableops_per_group[group_id].append( + PrunableOp( + op_address, + self.model.get_module_by_scope(op_address.scope_in_model) + ) + ) + return prunableops_per_group + + def _get_all_node_op_addresses_in_block(self, nncf_network, blocks): + graph = nncf_network.get_original_graph() + all_nodes_per_skipped_block_idxs = {} + for idx, block in enumerate(blocks): + start_node, end_node = block + start_node_key, end_node_key = None, None + for node in graph._nx_graph._node.values(): + if start_node == str(node['node_name']): + start_node_key = node['key'] + if end_node == str(node['node_name']): + end_node_key = node['key'] + simple_paths = nx.all_simple_paths(graph._nx_graph, start_node_key, end_node_key) + all_nodes_in_block = set() + for node_keys_in_path in simple_paths: + for node_key in node_keys_in_path: + all_nodes_in_block.add(str(graph._nx_graph._node[node_key]['node_name'])) + start_op_address = str(graph._nx_graph._node[start_node_key]['node_name']) + all_nodes_in_block.remove(start_op_address) + all_nodes_per_skipped_block_idxs[idx] = list(all_nodes_in_block) + return all_nodes_per_skipped_block_idxs + + def visualize_groups_of_prunables(self, path=None): + import networkx as nx + from nncf.torch.graph.graph import PTNNCFGraph + from networkx.drawing.nx_agraph import to_agraph + import matplotlib._color_data as mcd + import matplotlib.pyplot as plt + import numpy as np + palette = np.array(list(mcd.CSS4_COLORS.keys())).reshape(-1, 4).transpose().reshape(-1).tolist() + + from matplotlib.colors import to_hex + palette = np.array([to_hex(c) for c in plt.get_cmap("tab20b").colors]).reshape(-1, 5).transpose().reshape(-1).tolist() + + learnable_node_color_map = dict() + opbook = dict() + + for group_id, op_list in self.prunableops_per_group.items(): + color = palette[group_id % len(palette)] + for op in op_list: + learnable_node_color_map[str(op.op_addr)] = color + opbook[str(op.op_addr)] = op + + building_blocks = get_building_blocks(self.model, allow_nested_blocks=False) + node_op_address_per_block = self._get_all_node_op_addresses_in_block(self.model, building_blocks) + node_color_map = dict() + for group_id, op_list in node_op_address_per_block.items(): + color = palette[group_id % len(palette)] + for op in op_list: + node_color_map[op] = color + + g = self.model.get_graph() + + out_graph = nx.DiGraph() + for node_name, node in g._nx_graph.nodes.items(): + # ia_op_exec_context = node[PTNNCFGraph.IA_OP_EXEC_CONTEXT_NODE_ATTR] + + attrs_node = {} + label = node['key'] + # label = str(node[PTNNCFGraph.ID_NODE_ATTR]) + ' ' + str(ia_op_exec_context) + # if 'conv2d' in label.lower(): + # label = "*prunable*\n" + label + tokens=label.split("/") + new_tokens=[] + for i, token in enumerate(tokens): + if (i+1)%2==0: + token += "\n" + new_tokens.append(token) + attrs_node['label'] = '/'.join(new_tokens) + + if node['node_name'] in node_color_map: + # cluster_id = self.df.cluster_id[self.df.node_name == node_name].values[0] + # attrs_node['label'] += "\n(cluster {})".format(cluster_id) + # mcd.CSS4_COLORS + # attrs_node['color'] = mcd.CSS4_COLORS[node_color_map[node['node_name']]] + + + attrs_node['color'] = node_color_map[node['node_name']] + if node['node_name'] in learnable_node_color_map: + attrs_node['label'] += "\n{}\n".format(str(tuple(opbook[node['node_name']].op_mod.weight.shape))) + attrs_node['style'] = 'filled' + else: + attrs_node['style'] = 'diagonals' + # At present, there are 8 style values recognized: filled , invisible , diagonals , rounded . dashed , dotted , solid and bold + + out_graph.add_node(node_name, **attrs_node) + + for u, v in g._nx_graph.edges: + out_graph.add_edge(u, v, label=g._nx_graph.edges[u, v][PTNNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]) + + mapping = {k: v["label"] for k, v in out_graph.nodes.items()} + out_graph = nx.relabel_nodes(out_graph, mapping) + for node in out_graph.nodes.values(): + node.pop("label") + + if path is None: + path = 'mvmt_prunableops_group_viz.dot' + path = os.path.join(self.config.get("log_dir", "."), path) + + nx.drawing.nx_pydot.write_dot(out_graph, path) + + try: + A = to_agraph(out_graph) + A.layout('dot') + png_path = os.path.splitext(path)[0]+'.png' + A.draw(png_path) + except ImportError: + print("Graphviz is not installed - only the .dot model visualization format will be used. " + "Install pygraphviz into your Python environment and graphviz system-wide to enable " + "PNG rendering.") \ No newline at end of file diff --git a/nncf/torch/sparsity/movement/functions.py b/nncf/torch/sparsity/movement/functions.py index 3611e6c23ad..ac2b3edea26 100644 --- a/nncf/torch/sparsity/movement/functions.py +++ b/nncf/torch/sparsity/movement/functions.py @@ -17,7 +17,13 @@ from nncf.torch.functions import STThreshold -def binary_mask_by_threshold(importance, threshold=0.5, sigmoid=True): +def binary_mask_by_threshold(importance, threshold=0.5, sigmoid=True, max_percentile=0.98): + with torch.no_grad(): + if sigmoid is True: + max_threshold = torch.quantile(torch.sigmoid(importance), q=max_percentile).item() + else: + max_threshold = torch.quantile(importance, q=max_percentile).item() + if sigmoid is True: - return STThreshold.apply(torch.sigmoid(importance), threshold) - return STThreshold.apply(importance, threshold) \ No newline at end of file + return STThreshold.apply(torch.sigmoid(importance), min(threshold, max_threshold)) + return STThreshold.apply(importance, min(threshold, max_threshold)) \ No newline at end of file diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py index 94f00dfaa6c..db17f58f813 100644 --- a/nncf/torch/sparsity/movement/layers.py +++ b/nncf/torch/sparsity/movement/layers.py @@ -22,6 +22,10 @@ from typing import Dict, List, Optional, Any from copy import deepcopy +from torch.nn.modules import sparse +import itertools as it +import numpy as np + class SparseStructure(str, Enum): FINE = "fine" BLOCK = "block" @@ -132,7 +136,6 @@ def _get_importance_shape(self, weight_shape): score_shape.append(dim//factor) return score_shape, True - def _expand_importance(self, importance): #TODO only works dense layer for now if self._bool_expand_importance: @@ -158,6 +161,17 @@ def loss(self): self._expand_importance(self._importance) ), p=1) / self._importance.numel()) + def get_structured_mask(self, grain_size=None): + if grain_size is None: + grain_size = self.sparse_cfg.sparse_factors + + structured_mask_shape = [dim//grain_size[axes] for axes, dim in enumerate(list(self.binary_mask.shape))] + temp_shape = list(it.chain(*zip(list(structured_mask_shape), list(grain_size)))) + structured_mask = self.binary_mask.detach().clone() + structured_mask = structured_mask.reshape(temp_shape) + structured_mask = structured_mask.amax(dim=(tuple((np.arange(len(self.binary_mask.shape)) * 2 + 1)))) + # print("Mask Shape from {} to {}".format(structured_mask.shape, self.binary_mask.shape)) + return structured_mask class MaskCalculationHook(): def __init__(self, module): From 8ead5bc4758273d94a7053a4ef2cd70cdc7d746a Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Sat, 19 Feb 2022 15:51:59 -0800 Subject: [PATCH 4/7] Resolve conflicts of rebasing to nncf/develop and refactor MovementSparsifyingWeight to MovementSparsifier for enabling bias pruning later --- nncf/common/graph/layer_attributes.py | 7 ++- nncf/torch/dynamic_graph/wrappers.py | 3 +- nncf/torch/sparsity/movement/algo.py | 10 ++-- nncf/torch/sparsity/movement/layers.py | 64 ++++++++++++++++---------- 4 files changed, 53 insertions(+), 31 deletions(-) diff --git a/nncf/common/graph/layer_attributes.py b/nncf/common/graph/layer_attributes.py index b06a3d38c40..0017e1aa31a 100644 --- a/nncf/common/graph/layer_attributes.py +++ b/nncf/common/graph/layer_attributes.py @@ -91,14 +91,19 @@ class LinearLayerAttributes(WeightedLayerAttributes): def __init__(self, weight_requires_grad: bool, in_features: int, - out_features: int): + out_features: int, + bias: bool): super().__init__(weight_requires_grad) self.in_features = in_features self.out_features = out_features + self.bias = bias def get_weight_shape(self) -> List[int]: return [self.out_features, self.in_features] + def get_bias_shape(self) -> int: + return self.out_features if self.bias is True else 0 + def get_target_dim_for_compression(self) -> int: return 0 diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index fe62409a260..a572ff17ff9 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -192,7 +192,8 @@ def _get_layer_attributes(module: TorchModule, operator_name: str) -> BaseLayerA if isinstance(module, Linear): return LinearLayerAttributes(weight_requires_grad=module.weight.requires_grad, in_features=module.in_features, - out_features=module.out_features) + out_features=module.out_features, + bias=module.bias is not None) if hasattr(module, 'weight'): return GenericWeightedLayerAttributes(weight_requires_grad=module.weight.requires_grad, diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py index e45ba6d1340..dd70ad8f2b2 100644 --- a/nncf/torch/sparsity/movement/algo.py +++ b/nncf/torch/sparsity/movement/algo.py @@ -24,7 +24,7 @@ from nncf.torch.compression_method_api import PTCompressionAlgorithmController from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.sparsity.base_algo import BaseSparsityAlgoBuilder, BaseSparsityAlgoController, SparseModuleInfo -from nncf.torch.sparsity.movement.layers import MovementSparsifyingWeight, SparseConfig, SparseStructure +from nncf.torch.sparsity.movement.layers import MovementSparsifier, SparseConfig, SparseStructure from nncf.torch.sparsity.movement.loss import ImportanceLoss, SparseLossForPerLayerSparsity from nncf.torch.utils import get_world_size from nncf.common.utils.helpers import matches_any @@ -34,7 +34,7 @@ from nncf.common.schedulers import StubCompressionScheduler from nncf.common.sparsity.statistics import MovementSparsityStatistics from nncf.common.statistics import NNCFStatistics -from nncf.torch.search_building_blocks.search_blocks import get_building_blocks +from nncf.experimental.torch.search_building_blocks.search_blocks import BuildingBlock, get_building_blocks from collections import namedtuple from nncf.torch.dynamic_graph.operation_address import OperationAddress import networkx as nx @@ -55,7 +55,7 @@ def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, comp if sparse_cfg is None: sparse_cfg = SparseConfig() - return MovementSparsifyingWeight( + return MovementSparsifier( target_module_node.layer_attributes.get_weight_shape(), frozen=False, compression_lr_multiplier=compression_lr_multiplier, @@ -91,7 +91,7 @@ def __init__(self, target_model: NNCFNetwork, sparsified_module_info: List[Spars #TODO: review - perhaps not the right place self.config = config self.prunableops_per_group = self._get_group_of_prunable_ops() - self.visualize_groups_of_prunables() + # self.visualize_groups_of_prunables() def compression_stage(self) -> CompressionStage: if self._mode == 'local': @@ -234,7 +234,7 @@ def _get_all_node_op_addresses_in_block(self, nncf_network, blocks): graph = nncf_network.get_original_graph() all_nodes_per_skipped_block_idxs = {} for idx, block in enumerate(blocks): - start_node, end_node = block + start_node, end_node = block.start_node, block.end_node start_node_key, end_node_key = None, None for node in graph._nx_graph._node.values(): if start_node == str(node['node_name']): diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py index db17f58f813..df06d96f0c1 100644 --- a/nncf/torch/sparsity/movement/layers.py +++ b/nncf/torch/sparsity/movement/layers.py @@ -23,8 +23,11 @@ from copy import deepcopy from torch.nn.modules import sparse +from torch import nn import itertools as it import numpy as np +from nncf.torch.sparsity.functions import apply_binary_mask as apply_binary_mask_impl +from nncf.torch.utils import is_tracing_state, no_jit_trace class SparseStructure(str, Enum): FINE = "fine" @@ -39,35 +42,38 @@ def __init__(self, mode: SparseStructure = SparseStructure.FINE, sparse_args=Non @COMPRESSION_MODULES.register() -class MovementSparsifyingWeight(BinaryMask): +class MovementSparsifier(nn.Module): def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multiplier=None, eps=1e-6, sparse_cfg=None): - super().__init__(weight_shape) + super().__init__() self.frozen = frozen self.eps = eps + self.weight_ctx = BinaryMask(weight_shape) self.sparse_cfg = sparse_cfg - self._importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape) - self._importance = CompressionParameter( - torch.zeros(self._importance_shape), + + self._weight_importance_shape, self._bool_expand_weight_importance = self._get_importance_shape(weight_shape) + self._weight_importance = CompressionParameter( + # torch.rand(self._weight_importance_shape), + torch.zeros(self._weight_importance_shape), requires_grad=not self.frozen, compression_lr_multiplier=compression_lr_multiplier) self.lmbd = 0.5 # module_level_loss_weightage self.masking_threshold = 0.0 - self.binary_mask = binary_mask_by_threshold(self._importance, self._masking_threshold) + self.weight_ctx.binary_mask = binary_mask_by_threshold(self._weight_importance, self._masking_threshold) - self.mask_calculation_hook = MaskCalculationHook(self) + self.weight_ctx_mask_calculation_hook = MaskCalculationHook(self) @property def importance(self): - return self._importance.data + return self._weight_importance.data @property def masking_threshold(self): @@ -87,16 +93,26 @@ def lmbd(self, module_level_loss_weightage): def freeze_importance(self): self.frozen = True - self._importance.requires_grad=False + self._weight_importance.requires_grad=False def unfreeze_importance(self): self.frozen = False - self._importance.requires_grad=True + self._weight_importance.requires_grad=True def extra_repr(self): return '{}, {}'.format( self.sparse_cfg.mode, self.sparse_cfg.sparse_args) + def forward(self, weight): + if is_tracing_state(): + with no_jit_trace(): + return weight.mul_(self.binary_mask) + tmp_tensor = self._calc_training_binary_mask(weight) + return apply_binary_mask_impl(tmp_tensor, weight) + + def apply_binary_mask(self, weight): + return self.weight_ctx.apply_binary_mask(weight) + def _get_importance_shape(self, weight_shape): #TODO:remove weight_shape, r=32, c=32): # Default to fine_grained sparsity @@ -120,10 +136,10 @@ def _get_importance_shape(self, weight_shape): if self.sparse_cfg.mode == SparseStructure.PER_DIM: if len(self.sparse_cfg.sparse_args) != 1 or not isinstance(self.sparse_cfg.sparse_args[0], int): - raise ValueError("Invalid sparse_arg {}, per_dim expects a single digit that indicates axes".format(self.sparse_cfg.sparse_args)) + raise ValueError("Invalid sparse_arg {}, per_dim expects a single digit that indicates axis".format(self.sparse_cfg.sparse_args)) if self.sparse_cfg.sparse_args[0] < 0 or self.sparse_cfg.sparse_args[0] >= len(weight_shape): - raise ValueError("Invalid axes id {}, axes range {}".format( + raise ValueError("Invalid axis id {}, axes range {}".format( self.sparse_cfg.sparse_args[0], list(range(len(weight_shape))))) self.sparse_cfg.sparse_factors = deepcopy(weight_shape) @@ -138,7 +154,7 @@ def _get_importance_shape(self, weight_shape): def _expand_importance(self, importance): #TODO only works dense layer for now - if self._bool_expand_importance: + if self._bool_expand_weight_importance: return importance.repeat_interleave( self.sparse_cfg.sparse_factors[0], dim=0).repeat_interleave( self.sparse_cfg.sparse_factors[1], dim=1) @@ -147,30 +163,30 @@ def _expand_importance(self, importance): def _calc_training_binary_mask(self, weight): if self.training and not self.frozen: _mask = binary_mask_by_threshold( - self._expand_importance(self._importance), + self._expand_importance(self._weight_importance), self._masking_threshold ) - self.binary_mask = _mask + self.weight_ctx.binary_mask = _mask return _mask else: - return self.binary_mask + return self.weight_ctx.binary_mask def loss(self): return self.lmbd * (torch.norm( torch.sigmoid( - self._expand_importance(self._importance) - ), p=1) / self._importance.numel()) + self._expand_importance(self._weight_importance) + ), p=1) / self._weight_importance.numel()) def get_structured_mask(self, grain_size=None): if grain_size is None: grain_size = self.sparse_cfg.sparse_factors - structured_mask_shape = [dim//grain_size[axes] for axes, dim in enumerate(list(self.binary_mask.shape))] + structured_mask_shape = [dim//grain_size[axes] for axes, dim in enumerate(list(self.weight_ctx.binary_mask.shape))] temp_shape = list(it.chain(*zip(list(structured_mask_shape), list(grain_size)))) - structured_mask = self.binary_mask.detach().clone() + structured_mask = self.weight_ctx.binary_mask.detach().clone() structured_mask = structured_mask.reshape(temp_shape) - structured_mask = structured_mask.amax(dim=(tuple((np.arange(len(self.binary_mask.shape)) * 2 + 1)))) - # print("Mask Shape from {} to {}".format(structured_mask.shape, self.binary_mask.shape)) + structured_mask = structured_mask.amax(dim=(tuple((np.arange(len(self.weight_ctx.binary_mask.shape)) * 2 + 1)))) + # print("Mask Shape from {} to {}".format(structured_mask.shape, self.weight_ctx.binary_mask.shape)) return structured_mask class MaskCalculationHook(): @@ -179,11 +195,11 @@ def __init__(self, module): self.hook = module._register_state_dict_hook(self.hook_fn) def hook_fn(self, module, destination, prefix, local_metadata): - module.binary_mask = binary_mask_by_threshold( + module.weight_ctx.binary_mask = binary_mask_by_threshold( module._expand_importance(module.importance), module.masking_threshold ) - destination[prefix + '_binary_mask'] = module.binary_mask + destination[prefix + 'weight_ctx._binary_mask'] = module.weight_ctx.binary_mask return destination def close(self): From 0c44bba1291af1a1275e1fca9838a712a6dac422 Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Sun, 20 Feb 2022 08:02:15 -0800 Subject: [PATCH 5/7] Enable bias pruning with mvmt algo, breaking changes to be overcome --- nncf/torch/nncf_network.py | 6 +- nncf/torch/sparsity/collector.py | 3 +- nncf/torch/sparsity/movement/algo.py | 2 +- nncf/torch/sparsity/movement/layers.py | 110 +++++++++++++++++-------- 4 files changed, 81 insertions(+), 40 deletions(-) diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index f3e89f59235..b3f2c142fcb 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -66,7 +66,7 @@ from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler from nncf.torch.layers import NNCF_MODULES from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT -from nncf.torch.module_operations import UpdateWeight +from nncf.torch.module_operations import UpdateWeight, UpdateWeightAndBias from nncf.torch.quantization.layers import QUANTIZATION_MODULES from nncf.torch.utils import compute_FLOPs_hook from nncf.torch.utils import get_all_modules_by_type @@ -707,7 +707,9 @@ def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwor input_port_id=target_point.input_port_id) fn = transformation_command.fn if target_point.type is TargetType.OPERATION_WITH_WEIGHTS: - fn = UpdateWeight(fn) + # TODO: how to set this according + # fn = UpdateWeight(fn) + fn = UpdateWeightAndBias(fn) tup = (fn, transformation_command.priority) if pt_ip not in fns_grouped_by_points: fns_grouped_by_points[pt_ip] = [tup] diff --git a/nncf/torch/sparsity/collector.py b/nncf/torch/sparsity/collector.py index fa40bb0f1f6..175406a48d8 100644 --- a/nncf/torch/sparsity/collector.py +++ b/nncf/torch/sparsity/collector.py @@ -53,9 +53,10 @@ def _collect_weights_descriptions(self) -> List[WeightDescription]: if hasattr(minfo.module, 'bias') and minfo.module.bias is not None: bias = minfo.module.bias + sparse_bias = minfo.operand.apply_binary_mask(bias, isbias=True) #TODO: breaking changes name = f'{minfo.module_node_name}/bias' weights_descriptions.append( - WeightDescription(name, list(bias.shape), bias.count_nonzero().item(), is_sparse=False) + WeightDescription(name, list(bias.shape), sparse_bias.count_nonzero().item(), is_sparse=True) ) processed_modules.append(minfo.module) diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py index dd70ad8f2b2..fd9402e1e5c 100644 --- a/nncf/torch/sparsity/movement/algo.py +++ b/nncf/torch/sparsity/movement/algo.py @@ -56,7 +56,7 @@ def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, comp sparse_cfg = SparseConfig() return MovementSparsifier( - target_module_node.layer_attributes.get_weight_shape(), + target_module_node, frozen=False, compression_lr_multiplier=compression_lr_multiplier, eps=1e-6, diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py index df06d96f0c1..093c42f762f 100644 --- a/nncf/torch/sparsity/movement/layers.py +++ b/nncf/torch/sparsity/movement/layers.py @@ -44,32 +44,43 @@ def __init__(self, mode: SparseStructure = SparseStructure.FINE, sparse_args=Non @COMPRESSION_MODULES.register() class MovementSparsifier(nn.Module): def __init__(self, - weight_shape: List[int], + target_module_node, frozen=True, compression_lr_multiplier=None, eps=1e-6, sparse_cfg=None): super().__init__() + self.prune_bias = target_module_node.layer_attributes.bias + self.frozen = frozen self.eps = eps + self.lmbd = 0.5 # module_level_loss_weightage + self.masking_threshold = 0.0 + self.sparse_cfg = sparse_cfg + weight_shape = target_module_node.layer_attributes.get_weight_shape() self.weight_ctx = BinaryMask(weight_shape) - self.sparse_cfg = sparse_cfg - - self._weight_importance_shape, self._bool_expand_weight_importance = self._get_importance_shape(weight_shape) + self._weight_importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape) self._weight_importance = CompressionParameter( # torch.rand(self._weight_importance_shape), torch.zeros(self._weight_importance_shape), requires_grad=not self.frozen, compression_lr_multiplier=compression_lr_multiplier) - - self.lmbd = 0.5 # module_level_loss_weightage - - self.masking_threshold = 0.0 self.weight_ctx.binary_mask = binary_mask_by_threshold(self._weight_importance, self._masking_threshold) - self.weight_ctx_mask_calculation_hook = MaskCalculationHook(self) + if self.prune_bias is True: + bias_shape = target_module_node.layer_attributes.get_bias_shape() + self.bias_ctx = BinaryMask(bias_shape) + self._bias_importance_shape = self._weight_importance_shape[0] + self._bias_importance = CompressionParameter( + # torch.rand(self._bias_importance_shape), + torch.zeros(self._bias_importance_shape), + requires_grad=not self.frozen, + compression_lr_multiplier=compression_lr_multiplier) + self.bias_ctx.binary_mask = binary_mask_by_threshold(self._bias_importance, self._masking_threshold) + + self.mask_calculation_hook = MaskCalculationHook(self) @property def importance(self): @@ -94,24 +105,51 @@ def lmbd(self, module_level_loss_weightage): def freeze_importance(self): self.frozen = True self._weight_importance.requires_grad=False + if self.prune_bias is True: + self._bias_importance.requires_grad=False def unfreeze_importance(self): self.frozen = False self._weight_importance.requires_grad=True + if self.prune_bias is True: + self._bias_importance.requires_grad=True + def extra_repr(self): return '{}, {}'.format( self.sparse_cfg.mode, self.sparse_cfg.sparse_args) - def forward(self, weight): + def forward(self, weight, bias): if is_tracing_state(): with no_jit_trace(): return weight.mul_(self.binary_mask) - tmp_tensor = self._calc_training_binary_mask(weight) - return apply_binary_mask_impl(tmp_tensor, weight) + tmp_wtensor, tmp_btensor = self._calc_training_binary_mask(weight, bias) + wtensor = apply_binary_mask_impl(tmp_wtensor, weight) + btensor = apply_binary_mask_impl(tmp_btensor, bias) + return wtensor, btensor - def apply_binary_mask(self, weight): - return self.weight_ctx.apply_binary_mask(weight) + def _calc_training_binary_mask(self, weight, bias): + if self.training and not self.frozen: + w_mask = binary_mask_by_threshold( + self._expand_importance(self._weight_importance), + self._masking_threshold + ) + self.weight_ctx.binary_mask = w_mask + + b_mask = binary_mask_by_threshold( + self._expand_importance(self._bias_importance, isbias=True), + self._masking_threshold + ) + self.bias_ctx.binary_mask = b_mask + return w_mask, b_mask + else: + return self.weight_ctx.binary_mask, self.bias_ctx.binary_mask + + + def apply_binary_mask(self, param_tensor, isbias=False): + if isbias is True: + return self.bias_ctx.apply_binary_mask(param_tensor) + return self.weight_ctx.apply_binary_mask(param_tensor) def _get_importance_shape(self, weight_shape): #TODO:remove weight_shape, r=32, c=32): @@ -152,30 +190,23 @@ def _get_importance_shape(self, weight_shape): score_shape.append(dim//factor) return score_shape, True - def _expand_importance(self, importance): + def _expand_importance(self, importance, isbias=False): #TODO only works dense layer for now - if self._bool_expand_weight_importance: - return importance.repeat_interleave( - self.sparse_cfg.sparse_factors[0], dim=0).repeat_interleave( - self.sparse_cfg.sparse_factors[1], dim=1) + if self._bool_expand_importance: + if isbias is False: + return importance.repeat_interleave( + self.sparse_cfg.sparse_factors[0], dim=0).repeat_interleave( + self.sparse_cfg.sparse_factors[1], dim=1) + else: + return importance.repeat_interleave( + self.sparse_cfg.sparse_factors[0], dim=0) return importance - def _calc_training_binary_mask(self, weight): - if self.training and not self.frozen: - _mask = binary_mask_by_threshold( - self._expand_importance(self._weight_importance), - self._masking_threshold - ) - self.weight_ctx.binary_mask = _mask - return _mask - else: - return self.weight_ctx.binary_mask - def loss(self): - return self.lmbd * (torch.norm( - torch.sigmoid( - self._expand_importance(self._weight_importance) - ), p=1) / self._weight_importance.numel()) + return self.lmbd * ( + torch.norm(torch.sigmoid(self._expand_importance(self._weight_importance)), p=1) / self._weight_importance.numel() + \ + torch.norm(torch.sigmoid(self._expand_importance(self._bias_importance, isbias=True)), p=1) / self._bias_importance.numel() + ) def get_structured_mask(self, grain_size=None): if grain_size is None: @@ -196,11 +227,18 @@ def __init__(self, module): def hook_fn(self, module, destination, prefix, local_metadata): module.weight_ctx.binary_mask = binary_mask_by_threshold( - module._expand_importance(module.importance), + module._expand_importance(module._weight_importance), module.masking_threshold ) destination[prefix + 'weight_ctx._binary_mask'] = module.weight_ctx.binary_mask + + if module.prune_bias is True: + module.bias_ctx.binary_mask = binary_mask_by_threshold( + module._expand_importance(module._bias_importance, isbias=True), + module.masking_threshold + ) + destination[prefix + 'bias_ctx._binary_mask'] = module.bias_ctx.binary_mask return destination def close(self): - self.hook.remove() + self.hook.remove() \ No newline at end of file From a11c72d37f03c21140cf8a57ce6d7dd25f552d1a Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Sun, 20 Feb 2022 21:33:10 -0800 Subject: [PATCH 6/7] Enable composite mvmt sparsity and quantization, enable onnx generation with mask burnt into state dict --- nncf/common/graph/transformations/commands.py | 1 + nncf/torch/exporter.py | 12 ++++ nncf/torch/graph/transformations/commands.py | 3 +- nncf/torch/nncf_network.py | 5 +- nncf/torch/sparsity/movement/algo.py | 63 ++++++++++++++++++- nncf/torch/sparsity/movement/functions.py | 2 +- nncf/torch/sparsity/movement/layers.py | 38 ++++++----- 7 files changed, 103 insertions(+), 21 deletions(-) diff --git a/nncf/common/graph/transformations/commands.py b/nncf/common/graph/transformations/commands.py index ffb4a040a53..c243694e51e 100644 --- a/nncf/common/graph/transformations/commands.py +++ b/nncf/common/graph/transformations/commands.py @@ -82,6 +82,7 @@ class TargetType(OrderedEnum): OPERATION_WITH_WEIGHTS = 5 OPERATOR_PRE_HOOK = 6 OPERATOR_POST_HOOK = 7 + OPERATION_WITH_WEIGHT_WT_BIAS = 8 def get_state(self) -> Dict[str, Any]: """ diff --git a/nncf/torch/exporter.py b/nncf/torch/exporter.py index 386428225a6..03c272dc6b8 100644 --- a/nncf/torch/exporter.py +++ b/nncf/torch/exporter.py @@ -114,11 +114,23 @@ def _export_to_onnx(self, save_path: str) -> None: retval = dummy_forward(self._model) output_names = generate_output_names_list(count_tensors(retval)) + import os + torch.onnx.export(model, tuple(input_tensor_list), os.path.join(os.path.dirname(save_path), "graph-only."+os.path.basename(save_path)), + export_params=False, + input_names=input_names, + output_names=output_names, + enable_onnx_checker=False, + opset_version=10, + do_constant_folding=False, + # Do not fuse Conv+BN in ONNX. May cause dropout elements to appear in ONNX. + training=True) + torch.onnx.export(model, tuple(input_tensor_list), save_path, input_names=input_names, output_names=output_names, enable_onnx_checker=False, opset_version=10, + do_constant_folding=False, # Do not fuse Conv+BN in ONNX. May cause dropout elements to appear in ONNX. training=True) model.enable_dynamic_graph_building() diff --git a/nncf/torch/graph/transformations/commands.py b/nncf/torch/graph/transformations/commands.py index a29e93f903a..cb24bcab506 100644 --- a/nncf/torch/graph/transformations/commands.py +++ b/nncf/torch/graph/transformations/commands.py @@ -19,7 +19,8 @@ class PTTargetPointStateNames: class PTTargetPoint(TargetPoint): _OPERATION_TYPES = [TargetType.PRE_LAYER_OPERATION, TargetType.POST_LAYER_OPERATION, - TargetType.OPERATION_WITH_WEIGHTS] + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.OPERATION_WITH_WEIGHT_WT_BIAS] _HOOK_TYPES = [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATOR_POST_HOOK] diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index b3f2c142fcb..152da7e88ba 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -120,6 +120,7 @@ class PTInsertionPoint: TargetType.PRE_LAYER_OPERATION: PTInsertionType.NNCF_MODULE_PRE_OP, TargetType.POST_LAYER_OPERATION: PTInsertionType.NNCF_MODULE_POST_OP, TargetType.OPERATION_WITH_WEIGHTS: PTInsertionType.NNCF_MODULE_PRE_OP, + TargetType.OPERATION_WITH_WEIGHT_WT_BIAS: PTInsertionType.NNCF_MODULE_PRE_OP, TargetType.OPERATOR_PRE_HOOK: PTInsertionType.OPERATOR_PRE_HOOK, TargetType.OPERATOR_POST_HOOK: PTInsertionType.OPERATOR_POST_HOOK } @@ -707,8 +708,8 @@ def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwor input_port_id=target_point.input_port_id) fn = transformation_command.fn if target_point.type is TargetType.OPERATION_WITH_WEIGHTS: - # TODO: how to set this according - # fn = UpdateWeight(fn) + fn = UpdateWeight(fn) + elif target_point.type is TargetType.OPERATION_WITH_WEIGHT_WT_BIAS: fn = UpdateWeightAndBias(fn) tup = (fn, transformation_command.priority) if pt_ip not in fns_grouped_by_points: diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py index fd9402e1e5c..30f761f04e1 100644 --- a/nncf/torch/sparsity/movement/algo.py +++ b/nncf/torch/sparsity/movement/algo.py @@ -21,12 +21,17 @@ from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS from nncf.api.compression import CompressionStage from nncf.common.graph import NNCFNode +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.utils.logger import logger as nncf_logger from nncf.torch.compression_method_api import PTCompressionAlgorithmController from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.sparsity.base_algo import BaseSparsityAlgoBuilder, BaseSparsityAlgoController, SparseModuleInfo +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.commands import TransformationPriority from nncf.torch.sparsity.movement.layers import MovementSparsifier, SparseConfig, SparseStructure from nncf.torch.sparsity.movement.loss import ImportanceLoss, SparseLossForPerLayerSparsity -from nncf.torch.utils import get_world_size +from nncf.torch.utils import get_world_size, get_model_device from nncf.common.utils.helpers import matches_any from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS from nncf.torch.sparsity.collector import PTSparseModelStatisticsCollector @@ -44,6 +49,34 @@ @PT_COMPRESSION_ALGORITHMS.register('movement_sparsity') class MovementSparsityBuilder(BaseSparsityAlgoBuilder): + def _sparsify_weights(self, target_model: NNCFNetwork) -> List[PTInsertionCommand]: + device = get_model_device(target_model) + sparsified_module_nodes = target_model.get_weighted_original_graph_nodes( + nncf_module_names=self.compressed_nncf_module_names) + insertion_commands = [] + for module_node in sparsified_module_nodes: + node_name = module_node.node_name + + if not self._should_consider_scope(node_name): + nncf_logger.info("Ignored adding Weight Sparsifier in scope: {}".format(node_name)) + continue + + nncf_logger.info("Adding Weight Sparsifier in scope: {}".format(node_name)) + compression_lr_multiplier = \ + self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier', + self.name) + operation = self.create_weight_sparsifying_operation(module_node, compression_lr_multiplier) + hook = operation.to(device) + # TODO: hardcoded to OPERATION_WITH_WEIGHT_WT_BIAS + insertion_commands.append(PTInsertionCommand(PTTargetPoint(TargetType.OPERATION_WITH_WEIGHT_WT_BIAS, + target_node_name=node_name), + hook, TransformationPriority.SPARSIFICATION_PRIORITY)) + sparsified_module = target_model.get_containing_module(node_name) + self._sparsified_module_info.append( + SparseModuleInfo(node_name, sparsified_module, hook)) + + return insertion_commands + def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float): sparse_cfg=None if 'sparse_structure_by_scopes' in self._algo_config: @@ -202,6 +235,34 @@ def prepare_for_export(self): """ self._propagate_masks() + def _propagate_masks(self): + def calc_sparsity(tensor): + return 1-tensor.count_nonzero()/tensor.numel() + # nncf_logger.debug("MVMT - Propagating pruning masks") + # 1. Propagate masks for all modules + from collections import OrderedDict + sparse_sd = OrderedDict() + with torch.no_grad(): + for sparse_info in self.sparsified_module_info: + for n, m in self.model.named_modules(): + if m == sparse_info.module: + # print("- SparseModule: {} -".format(n)) + # print("\tw_mask sparsity: {:.3f}".format(calc_sparsity(sparse_info.operand.weight_ctx.binary_mask))) + # print("\tw_sd sparsity: {:.3f}".format(calc_sparsity(m.weight))) + sparse_sd[n+'.weight'] = sparse_info.operand.apply_binary_mask(m.weight) + # print("\t*w_sd sparsity: {:.3f}".format(calc_sparsity(sparse_sd[n+'.weight']))) + + if hasattr(m, 'bias'): + # print("\tb_mask sparsity: {:.3f}".format(calc_sparsity(sparse_info.operand.bias_ctx.binary_mask))) + # print("\tb_sd sparsity: {:.3f}".format(calc_sparsity(m.bias))) + sparse_sd[n+'.bias'] = sparse_info.operand.apply_binary_mask(m.bias, isbias=True) + # print("\t*w_sd sparsity: {:.3f}".format(calc_sparsity(sparse_sd[n+'.bias']))) + + model_sd = self.model.state_dict() + for k, v in sparse_sd.items(): + assert k in model_sd, "key not exists!" + model_sd[k] = sparse_sd[k] + self.model.load_state_dict(model_sd) def print_prunableops_per_group(self): for group, op_list in self.prunableops_per_group.items(): diff --git a/nncf/torch/sparsity/movement/functions.py b/nncf/torch/sparsity/movement/functions.py index ac2b3edea26..97db5375d7b 100644 --- a/nncf/torch/sparsity/movement/functions.py +++ b/nncf/torch/sparsity/movement/functions.py @@ -16,7 +16,7 @@ from nncf.torch.dynamic_graph.patch_pytorch import register_operator from nncf.torch.functions import STThreshold - +@register_operator() def binary_mask_by_threshold(importance, threshold=0.5, sigmoid=True, max_percentile=0.98): with torch.no_grad(): if sigmoid is True: diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py index 093c42f762f..5f6c5a02b45 100644 --- a/nncf/torch/sparsity/movement/layers.py +++ b/nncf/torch/sparsity/movement/layers.py @@ -63,22 +63,28 @@ def __init__(self, self.weight_ctx = BinaryMask(weight_shape) self._weight_importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape) self._weight_importance = CompressionParameter( - # torch.rand(self._weight_importance_shape), - torch.zeros(self._weight_importance_shape), + torch.rand(self._weight_importance_shape), + # torch.zeros(self._weight_importance_shape), requires_grad=not self.frozen, compression_lr_multiplier=compression_lr_multiplier) - self.weight_ctx.binary_mask = binary_mask_by_threshold(self._weight_importance, self._masking_threshold) + self.weight_ctx.binary_mask = binary_mask_by_threshold( + self._expand_importance(self._weight_importance), + self._masking_threshold + ) if self.prune_bias is True: bias_shape = target_module_node.layer_attributes.get_bias_shape() self.bias_ctx = BinaryMask(bias_shape) self._bias_importance_shape = self._weight_importance_shape[0] self._bias_importance = CompressionParameter( - # torch.rand(self._bias_importance_shape), - torch.zeros(self._bias_importance_shape), + torch.rand(self._bias_importance_shape), + # torch.zeros(self._bias_importance_shape), requires_grad=not self.frozen, compression_lr_multiplier=compression_lr_multiplier) - self.bias_ctx.binary_mask = binary_mask_by_threshold(self._bias_importance, self._masking_threshold) + self.bias_ctx.binary_mask = binary_mask_by_threshold( + self._expand_importance(self._bias_importance, isbias=True), + self._masking_threshold + ) self.mask_calculation_hook = MaskCalculationHook(self) @@ -116,13 +122,13 @@ def unfreeze_importance(self): def extra_repr(self): - return '{}, {}'.format( + return 'sparse_structure: {}, {}'.format( self.sparse_cfg.mode, self.sparse_cfg.sparse_args) def forward(self, weight, bias): if is_tracing_state(): with no_jit_trace(): - return weight.mul_(self.binary_mask) + return weight.mul_(self.weight_ctx.binary_mask), bias.mul_(self.bias_ctx.binary_mask) tmp_wtensor, tmp_btensor = self._calc_training_binary_mask(weight, bias) wtensor = apply_binary_mask_impl(tmp_wtensor, weight) btensor = apply_binary_mask_impl(tmp_btensor, bias) @@ -226,17 +232,17 @@ def __init__(self, module): self.hook = module._register_state_dict_hook(self.hook_fn) def hook_fn(self, module, destination, prefix, local_metadata): - module.weight_ctx.binary_mask = binary_mask_by_threshold( - module._expand_importance(module._weight_importance), - module.masking_threshold - ) + # module.weight_ctx.binary_mask = binary_mask_by_threshold( + # module._expand_importance(module._weight_importance), + # module.masking_threshold + # ) destination[prefix + 'weight_ctx._binary_mask'] = module.weight_ctx.binary_mask if module.prune_bias is True: - module.bias_ctx.binary_mask = binary_mask_by_threshold( - module._expand_importance(module._bias_importance, isbias=True), - module.masking_threshold - ) + # module.bias_ctx.binary_mask = binary_mask_by_threshold( + # module._expand_importance(module._bias_importance, isbias=True), + # module.masking_threshold + # ) destination[prefix + 'bias_ctx._binary_mask'] = module.bias_ctx.binary_mask return destination From 49e0a1c8bf13dcac3ff747f8d0a68d11268d4009 Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Mon, 28 Feb 2022 10:09:20 -0800 Subject: [PATCH 7/7] initial commit for fill flow (major changes) --- nncf/common/sparsity/schedulers.py | 1 + nncf/torch/exporter.py | 20 +- nncf/torch/sparsity/movement/algo.py | 318 ++++++++++++++++++++++--- nncf/torch/sparsity/movement/layers.py | 21 +- 4 files changed, 311 insertions(+), 49 deletions(-) diff --git a/nncf/common/sparsity/schedulers.py b/nncf/common/sparsity/schedulers.py index c6e5feb3b97..56c3866ae13 100644 --- a/nncf/common/sparsity/schedulers.py +++ b/nncf/common/sparsity/schedulers.py @@ -385,6 +385,7 @@ def schedule_threshold(self): else: self.current_importance_threshold = self._calculate_threshold_level() + # self.current_importance_threshold = 0.1 self._update_importance_masking_threshold() # if _cached_threshold != self.current_importance_threshold or _cached_regu_lambda != self.current_importance_lambda: # for n, m in self._controller.model.named_modules(): diff --git a/nncf/torch/exporter.py b/nncf/torch/exporter.py index 03c272dc6b8..42b1dabd616 100644 --- a/nncf/torch/exporter.py +++ b/nncf/torch/exporter.py @@ -114,16 +114,18 @@ def _export_to_onnx(self, save_path: str) -> None: retval = dummy_forward(self._model) output_names = generate_output_names_list(count_tensors(retval)) + DEBUG=False import os - torch.onnx.export(model, tuple(input_tensor_list), os.path.join(os.path.dirname(save_path), "graph-only."+os.path.basename(save_path)), - export_params=False, - input_names=input_names, - output_names=output_names, - enable_onnx_checker=False, - opset_version=10, - do_constant_folding=False, - # Do not fuse Conv+BN in ONNX. May cause dropout elements to appear in ONNX. - training=True) + if DEBUG is True: + torch.onnx.export(model, tuple(input_tensor_list), os.path.join(os.path.dirname(save_path), "graph-only."+os.path.basename(save_path)), + export_params=False, + input_names=input_names, + output_names=output_names, + enable_onnx_checker=False, + opset_version=10, + do_constant_folding=False, + # Do not fuse Conv+BN in ONNX. May cause dropout elements to appear in ONNX. + training=True) torch.onnx.export(model, tuple(input_tensor_list), save_path, input_names=input_names, diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py index 30f761f04e1..b5e52d02f3c 100644 --- a/nncf/torch/sparsity/movement/algo.py +++ b/nncf/torch/sparsity/movement/algo.py @@ -11,7 +11,7 @@ limitations under the License. """ from copy import deepcopy -from typing import List +from typing import DefaultDict, List, OrderedDict import torch import torch.distributed as dist @@ -40,11 +40,13 @@ from nncf.common.sparsity.statistics import MovementSparsityStatistics from nncf.common.statistics import NNCFStatistics from nncf.experimental.torch.search_building_blocks.search_blocks import BuildingBlock, get_building_blocks -from collections import namedtuple +from collections import defaultdict, namedtuple from nncf.torch.dynamic_graph.operation_address import OperationAddress import networkx as nx from nncf.torch.layers import NNCF_MODULES_OP_NAMES import os +import numpy as np +import pandas as pd @PT_COMPRESSION_ALGORITHMS.register('movement_sparsity') @@ -98,6 +100,40 @@ def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, comp def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController: return MovementSparsityController(model, self._sparsified_module_info, self.config) +class StructuredMask: + def __init__(self, + target_module_node, + sparsifying_node_name, + grid_size, + dependent_group_id, + sparse_module_info): + + self.target_module_node=target_module_node + self.sparsifying_node_name=sparsifying_node_name + self.grid_size=grid_size + self.dependent_group_id=dependent_group_id + self.sparse_module_info=sparse_module_info + + @property + def independent_structured_mask(self): + return self._independent_structured_mask + + @independent_structured_mask.setter + def independent_structured_mask(self, tensor): + with torch.no_grad(): + self._independent_structured_mask = tensor + # self._independent_structured_mask.set_(tensor) + + @property + def dependent_structured_mask(self): + return self._dependent_structured_mask + + @dependent_structured_mask.setter + def dependent_structured_mask(self, tensor): + # TODO: check dim + with torch.no_grad(): + self._dependent_structured_mask = tensor + # self._dependent_structured_mask.set_(tensor) @ADAPTIVE_COMPRESSION_CONTROLLERS.register('pt_movement_sparsity') class MovementSparsityController(BaseSparsityAlgoController): @@ -125,6 +161,7 @@ def __init__(self, target_model: NNCFNetwork, sparsified_module_info: List[Spars self.config = config self.prunableops_per_group = self._get_group_of_prunable_ops() # self.visualize_groups_of_prunables() + self.create_structured_sparsity_context() def compression_stage(self) -> CompressionStage: if self._mode == 'local': @@ -190,44 +227,253 @@ def statistics(self, quickly_collected_only=False) -> NNCFStatistics: nncf_stats.register('movement_sparsity', stats) return nncf_stats - @property - def compression_rate(self): - return self.statistics().movement_sparsity.model_statistics.sparsity_level + def create_structured_sparsity_context(self): + DEBUG=False + # Structured_mask per tensor ------------------- + node_name_sparse_mod_info_map = {sparse_info.module_node_name: sparse_info for sparse_info in self.sparsified_module_info} + self.node_name_sparse_mod_info_map = node_name_sparse_mod_info_map - def _propagate_masks(self): - # nncf_logger.debug("MVMT - Propagating pruning masks") - # 1. Propagate masks for all modules - from collections import OrderedDict - sparse_sd = OrderedDict() - with torch.no_grad(): - for sparse_info in self.sparsified_module_info: - for n, m in self.model.named_modules(): - if m == sparse_info.module: - # print(n, 1-sparse_info.operand.binary_mask.count_nonzero()/sparse_info.operand.binary_mask.numel()) - # print("pre", 1-m.weight.count_nonzero()/m.weight.numel()) - # print("mask", 1-sparse_info.operand.binary_mask.count_nonzero()/sparse_info.operand.binary_mask.numel()) - sparse_sd[n+'.weight'] = m.weight*sparse_info.operand.binary_mask - # print("post", 1-sparse_sd[n+'.weight'].count_nonzero()/sparse_sd[n+'.weight'].numel()) - # sd = sparse_info.module.state_dict() - # sd['weight'] = sparse_info.module.weight*sparse_info.operand.binary_mask - # sparse_info.module.load_state_dict(sd) + self.structured_ctx_by_group = defaultdict(list) + masks_per_group = dict() + op2namedmodule = dict() - model_sd = self.model.state_dict() - for k, v in sparse_sd.items(): - assert k in model_sd, "key not exists!" - model_sd[k] = sparse_sd[k] - self.model.load_state_dict(model_sd) + for group_id, op_list in self.prunableops_per_group.items(): + masks_per_group[group_id]=dict() + for op in op_list: + sparsifying_node_name = str(op.op_addr) - # init_output_masks_in_graph(graph, self.pruned_module_groups_info.get_all_nodes()) - # MaskPropagationAlgorithm(graph, PT_PRUNING_OPERATOR_METATYPES).mask_propagation() + # find op's torch module name + for n, m in self.model.named_modules(): + if m == op.op_mod: + op2namedmodule[sparsifying_node_name] = n + break + + sparse_module_info = node_name_sparse_mod_info_map[sparsifying_node_name] + + if any(map(sparsifying_node_name.__contains__, ['query','key','value'])): + # these matrices must be pruned by group(s) of cols + nrow_per_head = self.model.nncf_module.bert.config.hidden_size//self.model.nncf_module.bert.config.num_attention_heads + ncol_per_head = self.model.nncf_module.bert.config.hidden_size + grid_size = (nrow_per_head, ncol_per_head) + + if DEBUG is True: + masks_per_group[group_id]['qkv_grain'] = grid_size + mask = sparse_module_info.operand.get_structured_mask(grid_size) + if 'qkv' not in masks_per_group[group_id]: + masks_per_group[group_id]['qkv'] = [mask] + masks_per_group[group_id]['qkv_nodes'] = [sparsifying_node_name] + else: + masks_per_group[group_id]['qkv'].append(mask) + masks_per_group[group_id]['qkv_nodes'].append(sparsifying_node_name) + print("{:15} | {:20} | {}".format('group_of_rows', str(mask.shape), sparsifying_node_name)) + + structured_mask_ctx = StructuredMask( + sparse_module_info.module_node_name, + sparsifying_node_name, + grid_size, + group_id, + sparse_module_info) + + structured_mask_ctx.independent_structured_mask = sparse_module_info.operand.get_structured_mask(grid_size) + sparse_module_info.operand.structured_mask_ctx = structured_mask_ctx + self.structured_ctx_by_group[group_id].append(sparse_module_info.operand.structured_mask_ctx) + + if DEBUG is True: + assert ((mask==structured_mask_ctx.independent_structured_mask).sum() == mask.numel()).item(), "qkv: Logical Bug, pls debug" + + elif 'BertSelfOutput' in sparsifying_node_name: + # this matrix must be pruned by group(s) of cols + ncol_per_head = self.model.nncf_module.bert.config.hidden_size//self.model.nncf_module.bert.config.num_attention_heads + nrow_per_head = self.model.nncf_module.bert.config.hidden_size + grid_size = (nrow_per_head, ncol_per_head) + + if DEBUG is True: + masks_per_group[group_id]['concat_grain'] = grid_size + mask = sparse_module_info.operand.get_structured_mask(grid_size) + masks_per_group[group_id]['concat'] = mask + masks_per_group[group_id]['concat_node'] = sparsifying_node_name + print("{:15} | {:20} | {}".format('group_of_cols', str(mask.shape), sparsifying_node_name)) + + structured_mask_ctx = StructuredMask( + sparse_module_info.module_node_name, + sparsifying_node_name, + grid_size, + group_id, + sparse_module_info) + + structured_mask_ctx.independent_structured_mask = sparse_module_info.operand.get_structured_mask(grid_size) + sparse_module_info.operand.structured_mask_ctx = structured_mask_ctx + self.structured_ctx_by_group[group_id].append(sparse_module_info.operand.structured_mask_ctx) + + if DEBUG is True: + assert ((mask==structured_mask_ctx.independent_structured_mask).sum() == mask.numel()).item(), "BertSelfOutput: Logical Bug, pls debug" + + elif any(map(sparsifying_node_name.__contains__, ['BertIntermediate','BertOutput'])): + mask = sparse_module_info.operand.get_structured_mask() + grid_size = sparse_module_info.operand.sparse_cfg.sparse_factors + + if DEBUG is True: + if 'BertIntermediate' in sparsifying_node_name: + masks_per_group[group_id]['ffnn_w1_grain'] = grid_size + masks_per_group[group_id]['ffnn_w1'] = mask + masks_per_group[group_id]['ffnn_w1_node'] = sparsifying_node_name + elif 'BertOutput' in sparsifying_node_name: + masks_per_group[group_id]['ffnn_w2_grain'] = grid_size + masks_per_group[group_id]['ffnn_w2'] = mask + masks_per_group[group_id]['ffnn_w2_node'] = sparsifying_node_name + print("{:15} | {:20} | {}".format('per_dim', str(mask.shape), sparsifying_node_name)) + + structured_mask_ctx = StructuredMask( + sparse_module_info.module_node_name, + sparsifying_node_name, + grid_size, + group_id, + sparse_module_info) + + structured_mask_ctx.independent_structured_mask = sparse_module_info.operand.get_structured_mask(grid_size) + sparse_module_info.operand.structured_mask_ctx = structured_mask_ctx + self.structured_ctx_by_group[group_id].append(sparse_module_info.operand.structured_mask_ctx) + + if DEBUG is True: + assert ((mask==structured_mask_ctx.independent_structured_mask).sum() == mask.numel()).item(), "ffnn: Logical Bug, pls debug" + else: + raise ValueError("Invalid entry, pls debug") + + self.op2namedmodule = op2namedmodule + + # This Structure can be improved but good for now: TODO: revision of structure + # masks_per_group[group_id][sparsifying_node_name] = mask + + def reset_independent_structured_mask(self): + for group_id, ctxes in self.structured_ctx_by_group.items(): + for ctx in ctxes: + ctx.independent_structured_mask = ctx.sparse_module_info.operand.get_structured_mask(ctx.grid_size) + + def populate_structured_mask(self): + def inflate_structured_mask(mask, grid_size): + assert len(mask.shape) == len(grid_size), "Unmatching dimension" + inflated_mask = mask.detach().clone() + for axis, repeat in enumerate(grid_size): + inflated_mask = inflated_mask.repeat_interleave(repeat, dim=axis) + return inflated_mask + + for group_id, ctxes in self.structured_ctx_by_group.items(): + for ctx in ctxes: + ctx.sparse_module_info.operand.set_structured_mask( + inflate_structured_mask(ctx.dependent_structured_mask, ctx.grid_size) + ) + + def resolve_structured_mask(self): + for group_id, ctxes in self.structured_ctx_by_group.items(): + allnodenames = list(map(lambda x: x.target_module_node, ctxes)) + + if any(map(ctxes[0].target_module_node.__contains__, ['query','key','value','BertSelfOutput'])): + qid = list(map(lambda x: x.__contains__('query'), allnodenames)).index(True) + kid = list(map(lambda x: x.__contains__('key'), allnodenames)).index(True) + vid = list(map(lambda x: x.__contains__('value'), allnodenames)).index(True) + oid = list(map(lambda x: x.__contains__('BertSelfOutput'), allnodenames)).index(True) + + coarse_mask = ctxes[qid].independent_structured_mask.logical_or( + ctxes[kid].independent_structured_mask).logical_or( + ctxes[vid].independent_structured_mask).logical_or( + ctxes[oid].independent_structured_mask.transpose(0, 1) + ).to(torch.float32) + ctxes[qid].dependent_structured_mask = coarse_mask + ctxes[kid].dependent_structured_mask = coarse_mask + ctxes[vid].dependent_structured_mask = coarse_mask + ctxes[oid].dependent_structured_mask = coarse_mask.transpose(0, 1) + elif any(map(ctxes[0].target_module_node.__contains__, ['BertIntermediate','BertOutput'])): + w1_id = list(map(lambda x: x.__contains__('BertIntermediate'), allnodenames)).index(True) + w2_id = list(map(lambda x: x.__contains__('BertOutput'), allnodenames)).index(True) + coarse_mask = ctxes[w1_id].independent_structured_mask.logical_or( + ctxes[w2_id].independent_structured_mask.transpose(0, 1) + ).to(torch.float32) + + ctxes[w1_id].dependent_structured_mask = coarse_mask + ctxes[w2_id].dependent_structured_mask = coarse_mask.transpose(0, 1) + else: + raise ValueError("logical bug, pls debug") + + # # Structured_mask alignment by group ------------------- + + # for group_id, mask_dict in masks_per_group.items(): + # if 'qkv' in mask_dict: + # final_mask = torch.zeros_like(mask_dict['qkv'][0]).to(torch.bool) + # for each_mask in mask_dict['qkv']: + # final_mask = final_mask.logical_or(each_mask) + # final_mask = final_mask.logical_or(mask_dict['concat'].transpose(0, 1)) + # final_mask = final_mask.to(torch.float32) + + # masks_per_group[group_id]['final_structured_mask'] = dict( + # qkv=final_mask, + # concat=final_mask.transpose(0, 1) + # ) + + # elif 'ffnn_w1' in mask_dict: + # final_mask = mask_dict['ffnn_w1'].logical_or(mask_dict['ffnn_w2'].transpose(0, 1)) + # final_mask = final_mask.to(torch.float32) + + # masks_per_group[group_id]['final_structured_mask'] = dict( + # ffnn_w1=final_mask, + # ffnn_w2=final_mask.transpose(0, 1) + # ) + # else: + # raise ValueError("Invalid entry, pls debug") + + def report_structured_sparsity(self, dirname): + listofentry=[] + for group_id, ctxes in self.structured_ctx_by_group.items(): + for ctx in ctxes: + nncf_graph_node_name = ctx.sparsifying_node_name + named_mod = self.op2namedmodule[nncf_graph_node_name] + block_id = group_id + orig_wshape = tuple(list(ctx.sparse_module_info.module.weight.shape)) + if hasattr(ctx.sparse_module_info.module, 'bias'): + orig_bshape = tuple(list(ctx.sparse_module_info.module.bias.shape)) + + if any(map(nncf_graph_node_name.__contains__, ['BertIntermediate','BertOutput'])): + head_id_to_keep = 'skip reporting' + if nncf_graph_node_name.__contains__('BertIntermediate'): + final_wshape = (ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=1).count_nonzero().item(), orig_wshape[1]) + final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),) + else: + final_wshape = (orig_wshape[0], ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=0).count_nonzero().item()) + final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),) + else: + ndiv = ctx.dependent_structured_mask.reshape(-1).shape[0] + head_id_to_keep = torch.masked_select(torch.range(0, ndiv-1, dtype=int), + ctx.dependent_structured_mask.reshape(-1).cpu().to(bool)).tolist() + + if any(map(nncf_graph_node_name.__contains__, ['query','key','value'])): + # prune by row + final_wshape = (ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=1).count_nonzero().item(), orig_wshape[1]) + final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),) + else: + # prune by col + final_wshape = (orig_wshape[0], ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=0).count_nonzero().item()) + final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),) + + listofentry.append( + OrderedDict( + pt_module_name=named_mod, + block_id=block_id, + weight_shape=orig_wshape, + prune_w_shape=final_wshape, + bias_shape=orig_bshape, + prune_b_shape=final_bshape, + head_id_to_keep=head_id_to_keep, + nncf_graph_node=nncf_graph_node_name + ) + ) + df = pd.DataFrame.from_dict(listofentry) + df.to_csv(os.path.join(dirname, 'structured_sparsity.csv')) + with open(os.path.join(dirname, 'structured_sparsity.md'), 'w') as f: + df.to_markdown(f) - # # 2. Set the masks for Batch/Group Norms - # pruned_node_modules = [] - # for node, pruning_block, node_module in self._pruned_norms_operators: - # if node_module not in pruned_node_modules: - # # Setting masks for BN nodes - # pruning_block.binary_filter_pruning_mask = node.data['output_mask'].tensor - # pruned_node_modules.append(node_module) + + @property + def compression_rate(self): + return self.statistics().movement_sparsity.model_statistics.sparsity_level def prepare_for_export(self): """ diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py index 5f6c5a02b45..836309916da 100644 --- a/nncf/torch/sparsity/movement/layers.py +++ b/nncf/torch/sparsity/movement/layers.py @@ -51,6 +51,9 @@ def __init__(self, sparse_cfg=None): super().__init__() + DEBUG=False + + self.target_module_node = target_module_node self.prune_bias = target_module_node.layer_attributes.bias self.frozen = frozen @@ -63,8 +66,7 @@ def __init__(self, self.weight_ctx = BinaryMask(weight_shape) self._weight_importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape) self._weight_importance = CompressionParameter( - torch.rand(self._weight_importance_shape), - # torch.zeros(self._weight_importance_shape), + torch.rand(self._weight_importance_shape) if DEBUG is True else torch.zeros(self._weight_importance_shape), requires_grad=not self.frozen, compression_lr_multiplier=compression_lr_multiplier) self.weight_ctx.binary_mask = binary_mask_by_threshold( @@ -77,8 +79,7 @@ def __init__(self, self.bias_ctx = BinaryMask(bias_shape) self._bias_importance_shape = self._weight_importance_shape[0] self._bias_importance = CompressionParameter( - torch.rand(self._bias_importance_shape), - # torch.zeros(self._bias_importance_shape), + torch.rand(self._bias_importance_shape) if DEBUG is True else torch.zeros(self._bias_importance_shape), requires_grad=not self.frozen, compression_lr_multiplier=compression_lr_multiplier) self.bias_ctx.binary_mask = binary_mask_by_threshold( @@ -224,8 +225,20 @@ def get_structured_mask(self, grain_size=None): structured_mask = structured_mask.reshape(temp_shape) structured_mask = structured_mask.amax(dim=(tuple((np.arange(len(self.weight_ctx.binary_mask.shape)) * 2 + 1)))) # print("Mask Shape from {} to {}".format(structured_mask.shape, self.weight_ctx.binary_mask.shape)) + if self.prune_bias is True: + structured_bias_mask_shape = structured_mask_shape[0] + structured_bias_mask = self.bias_ctx.binary_mask.detach().clone() + structured_bias_mask = structured_bias_mask.reshape((structured_bias_mask_shape, -1)) + structured_bias_mask = structured_bias_mask.amax(dim=1) + dim_aligned = structured_bias_mask.repeat(structured_mask.shape[1]).reshape(-1, structured_mask.shape[1]) + structured_mask = structured_mask.logical_or(dim_aligned).to(torch.float32) return structured_mask + def set_structured_mask(self, structured_mask): + self.weight_ctx.binary_mask=structured_mask + if self.prune_bias is True: + self.bias_ctx.binary_mask=structured_mask.amax(dim=1) + class MaskCalculationHook(): def __init__(self, module): # pylint: disable=protected-access