Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nncf mvmt p3 #14

Open
wants to merge 7 commits into
base: vs_base
Choose a base branch
from
7 changes: 6 additions & 1 deletion nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,19 @@ class LinearLayerAttributes(WeightedLayerAttributes):
def __init__(self,
weight_requires_grad: bool,
in_features: int,
out_features: int):
out_features: int,
bias: bool):
super().__init__(weight_requires_grad)
self.in_features = in_features
self.out_features = out_features
self.bias = bias

def get_weight_shape(self) -> List[int]:
return [self.out_features, self.in_features]

def get_bias_shape(self) -> int:
return self.out_features if self.bias is True else 0

def get_target_dim_for_compression(self) -> int:
return 0

Expand Down
1 change: 1 addition & 0 deletions nncf/common/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class TargetType(OrderedEnum):
OPERATION_WITH_WEIGHTS = 5
OPERATOR_PRE_HOOK = 6
OPERATOR_POST_HOOK = 7
OPERATION_WITH_WEIGHT_WT_BIAS = 8

def get_state(self) -> Dict[str, Any]:
"""
Expand Down
149 changes: 149 additions & 0 deletions nncf/common/sparsity/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,152 @@ def epoch_step(self, next_epoch: Optional[int] = None) -> None:

def _calculate_sparsity_level(self) -> float:
return self.schedule(self.current_epoch)


@SPARSITY_SCHEDULERS.register('threshold_polynomial_decay')
class PolynomialThresholdScheduler(BaseCompressionScheduler):
"""
Sparsity scheduler with a polynomial decay schedule.

Two ways are available for calculations of the sparsity:
- per epoch
- per step
Parameters `update_per_optimizer_step` and `steps_per_epoch`
should be provided in config for the per step calculation.
If `update_per_optimizer_step` was only provided then scheduler
will use first epoch to calculate `steps_per_epoch`
parameter. In this case, `current_epoch` and `current_step` will
not be updated on this epoch. The scheduler will start calculation
after `steps_per_epoch` will be calculated.
"""

def __init__(self, controller: SparsityController, params: dict):
"""
TODO: revise docstring
TODO: test epoch-wise stepping
Initializes a sparsity scheduler with a polynomial decay schedule.

:param controller: Sparsity algorithm controller.
:param params: Parameters of the scheduler.
"""
super().__init__()
self._controller = controller
self.init_importance_threshold = params.get('init_importance_threshold', 0.0)
self.final_importance_threshold = params.get('final_importance_threshold', 0.1)
self.warmup_start_epoch = params.get('warmup_start_epoch', 0.0)
self.warmup_end_epoch = params.get('warmup_end_epoch', 0.0)
self.importance_target_lambda = params.get('importance_regularization_factor', 1.0)
self.current_importance_threshold = self.init_importance_threshold
self.cached_importance_threshold = self.current_importance_threshold

self.schedule = PolynomialDecaySchedule(
self.init_importance_threshold,
self.final_importance_threshold,
(self.warmup_end_epoch-self.warmup_start_epoch),
params.get('power', 3),
params.get('concave', True)
)

self._steps_in_current_epoch = 0
self._update_per_optimizer_step = params.get('update_per_optimizer_step', False)
self._steps_per_epoch = params.get('steps_per_epoch', None)
self._should_skip = False

@property
def current_importance_lambda(self):
return self.importance_target_lambda * (self.current_importance_threshold/self.final_importance_threshold)

def _disable_importance_grad(self):
for m in self._controller.sparsified_module_info:
m.operand.freeze_importance()

def _update_importance_masking_threshold(self):
if self.cached_importance_threshold != self.current_importance_threshold:
for m in self._controller.sparsified_module_info:
m.operand.masking_threshold = self.current_importance_threshold
self.cached_importance_threshold = self.current_importance_threshold

def epoch_step(self, next_epoch: Optional[int] = None) -> None:
self._maybe_should_skip()
self._steps_in_current_epoch = 0 # This must be set after _maybe_should_skip as it is used in that routine
if self._should_skip:
return
# only increment epoch if should_skip is checked
super().epoch_step(next_epoch)
self.schedule_threshold()

def step(self, next_step: Optional[int] = None) -> None:
super().step(next_step)
self._steps_in_current_epoch += 1
if self._should_skip:
return

if self._update_per_optimizer_step:
self.schedule_threshold()

def schedule_threshold(self):
if self.current_step < self.warmup_start_epoch * self._steps_per_epoch:
self.current_importance_threshold = self.init_importance_threshold

elif self.current_step >= self.warmup_end_epoch * self._steps_per_epoch:
self.current_importance_threshold = self.final_importance_threshold
self._disable_importance_grad()

# TODO: gradient freezing should be at the epoch to freeze epoch
# for n, m in self._controller.model.named_modules():
# if m.__class__.__name__ == "MovementSparsifyingWeight":
# m.frozen=True
# m._importance.requires_grad=False

else:
self.current_importance_threshold = self._calculate_threshold_level()

# self.current_importance_threshold = 0.1
self._update_importance_masking_threshold()
# if _cached_threshold != self.current_importance_threshold or _cached_regu_lambda != self.current_importance_lambda:
# for n, m in self._controller.model.named_modules():
# if m.__class__.__name__ == "MovementSparsifyingWeight":
# m.masking_threshold = self.current_importance_threshold
# # m.lmbd = self.current_importance_lambda

def _calculate_threshold_level(self) -> float:
warmup_start_global_step = self.warmup_start_epoch*self._steps_per_epoch
schedule_current_step = self.current_step - warmup_start_global_step
schedule_epoch = schedule_current_step // self._steps_per_epoch
schedule_step = schedule_current_step % self._steps_per_epoch
return self.schedule(schedule_epoch, schedule_step, self._steps_per_epoch)


def load_state(self, state: Dict[str, Any]) -> None:
super().load_state(state)
if self._update_per_optimizer_step:
self._steps_per_epoch = state['_steps_per_epoch']

def get_state(self) -> Dict[str, Any]:
state = super().get_state()
if self._update_per_optimizer_step:
state['_steps_per_epoch'] = self._steps_per_epoch
return state

def _maybe_should_skip(self) -> None:
"""
Checks if the first epoch (with index 0) should be skipped to calculate
the steps per epoch. If the skip is needed, then the internal state
of the scheduler object will not be changed.
"""
self._should_skip = False
if self._update_per_optimizer_step:
if self._steps_per_epoch is None and self._steps_in_current_epoch > 0:
self._steps_per_epoch = self._steps_in_current_epoch

if self._steps_per_epoch is not None and self._steps_in_current_epoch > 0:
if self._steps_per_epoch != self._steps_in_current_epoch:
raise Exception('Actual steps per epoch and steps per epoch from the scheduler '
'parameters are different. Scheduling may be incorrect.')

if self._steps_per_epoch is None:
self._should_skip = True
logger.warning('Scheduler set to update sparsity level per optimizer step, '
'but steps_per_epoch was not set in config. Will only start updating '
'sparsity level after measuring the actual steps per epoch as signaled '
'by a .epoch_step() call.')
38 changes: 38 additions & 0 deletions nncf/common/sparsity/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,41 @@ def to_str(self) -> str:
f'Statistics of the RB-sparsity algorithm:\n{algorithm_string}'
)
return pretty_string

class MovementSparsityStatistics(Statistics):
"""
Contains statistics of the movement-sparsity algorithm.
"""

def __init__(self,
model_statistics: SparsifiedModelStatistics,
importance_threshold,
importance_regularization_factor):
"""
Initializes statistics of the movement-sparsity algorithm.

:param model_statistics: Statistics of the sparsified model.
:param importance_threshold: importance threshold for
sparsity binary mask
:param importance_regularization_factor: penalty factor of
importance score

"""
self.model_statistics = model_statistics
self.importance_threshold = importance_threshold
self.importance_regularization_factor = importance_regularization_factor

def to_str(self) -> str:
algorithm_string = create_table(
header=['Statistic\'s name', 'Value'],
rows=[
['Mask Importance Threshold', self.importance_threshold],
['Importance Regularization Factor', self.importance_regularization_factor],
]
)

pretty_string = (
f'{self.model_statistics.to_str()}\n\n'
f'Statistics of the movement-sparsity algorithm:\n{algorithm_string}'
)
return pretty_string
13 changes: 12 additions & 1 deletion nncf/common/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nncf.api.statistics import Statistics
from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics
from nncf.common.sparsity.statistics import RBSparsityStatistics
from nncf.common.sparsity.statistics import MovementSparsityStatistics
from nncf.common.sparsity.statistics import ConstSparsityStatistics
from nncf.common.quantization.statistics import QuantizationStatistics
from nncf.common.pruning.statistics import FilterPruningStatistics
Expand Down Expand Up @@ -53,6 +54,16 @@ def rb_sparsity(self) -> Optional[RBSparsityStatistics]:
"""
return self._storage.get('rb_sparsity')

@property
def movement_sparsity(self) -> Optional[MovementSparsityStatistics]:
"""
Returns statistics of the movement sparsity algorithm. If statistics
have not been collected, `None` will be returned.

:return: Instance of the `MovementSparsityStatistics` class.
"""
return self._storage.get('movement_sparsity')

@property
def const_sparsity(self) -> Optional[ConstSparsityStatistics]:
"""
Expand Down Expand Up @@ -108,7 +119,7 @@ def register(self, algorithm_name: str, stats: Statistics):
"""

available_algorithms = [
'magnitude_sparsity', 'rb_sparsity', 'const_sparsity',
'magnitude_sparsity', 'rb_sparsity', 'movement_sparsity', 'const_sparsity',
'quantization', 'filter_pruning', 'binarization'
]
if algorithm_name not in available_algorithms:
Expand Down
12 changes: 12 additions & 0 deletions nncf/common/utils/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.common.pruning.statistics import FilterPruningStatistics
from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics
from nncf.common.sparsity.statistics import RBSparsityStatistics
from nncf.common.sparsity.statistics import MovementSparsityStatistics
from nncf.common.sparsity.statistics import ConstSparsityStatistics


Expand Down Expand Up @@ -65,3 +66,14 @@ def _(stats, algorithm_name):
tensorboard_stats[f'{algorithm_name}/target_sparsity_level'] = target_sparsity_level

return tensorboard_stats


@convert_to_dict.register(MovementSparsityStatistics)
def _(stats, algorithm_name):
tensorboard_stats = {
f'{algorithm_name}/model_sparsity': stats.model_statistics.sparsity_level,
f'{algorithm_name}/relative_sparsity': stats.model_statistics.sparsity_level_for_layers,
f'{algorithm_name}/importance_threshold': stats.importance_threshold,
f'{algorithm_name}/importance_regularization_factor': stats.importance_regularization_factor,
}
return tensorboard_stats
6 changes: 4 additions & 2 deletions nncf/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ def validate(loaded_json):

try:
if isinstance(compression_section, dict):
validate_single_compression_algo_schema(compression_section)
pass
# validate_single_compression_algo_schema(compression_section)
else:
# Passed a list of dicts
for compression_algo_dict in compression_section:
validate_single_compression_algo_schema(compression_algo_dict)
pass
# validate_single_compression_algo_schema(compression_algo_dict)
except jsonschema.ValidationError:
# No need to trim the exception output here since only the compression algo
# specific sub-schema will be shown, which is much shorter than the global schema
Expand Down
57 changes: 57 additions & 0 deletions nncf/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,62 @@ def with_attributes(schema: Dict, **kwargs) -> Dict:
"additionalProperties": False
}

MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG = "movement_sparsity"
MOVEMENT_SPARSITY_SCHEMA = {
**BASIC_COMPRESSION_ALGO_SCHEMA,
"properties": {
"algorithm": {
"const": MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG
},
**COMPRESSION_LR_MULTIPLIER_PROPERTY,
"sparsity_init": with_attributes(_NUMBER,
description="Initial value of the sparsity level applied to the "
"model"),
"params":
{
# TODO: revise config to expose
"type": "object",
"properties": {
"schedule": with_attributes(_STRING,
description="The type of scheduling to use for adjusting the"
"importance threshold and its regularization factor"),
"power": with_attributes(_NUMBER,
description="For polynomial scheduler - determines the corresponding power value."),
"init_importance_threshold": with_attributes(_NUMBER,
description="importance masking threshold @ warmup_start_epoch"),
"warmup_start_epoch": with_attributes(_NUMBER,
description="Index of the starting epoch of the importance masking threshold"
"warmup at the value of init_importance_threshold"),
"final_importance_threshold": with_attributes(_NUMBER,
description="importance masking threshold @ warmup_end_epoch"),
"warmup_end_epoch": with_attributes(_NUMBER,
description="Index of the ending epoch of the importance masking threshold"
"warmup at the value of final_importance_threshold"),
"importance_regularization_factor": with_attributes(_NUMBER,
description="regularization final lambda"),
"steps_per_epoch": with_attributes(_NUMBER,
description="Number of optimizer steps in one epoch. Required to start proper "
" scheduling in the first training epoch if "
"'update_per_optimizer_step' is true"),
"update_per_optimizer_step": with_attributes(_BOOLEAN,
description="Whether the function-based sparsity level schedulers "
"should update the sparsity level after each optimizer "
"step instead of each epoch step."),
"sparsity_level_setting_mode": with_attributes(_STRING,
description="The mode of sparsity level setting( "
"'global' - one sparsity level is set for all layer, "
"'local' - sparsity level is set per-layer.)"),
# TODO
# "sparse_structure_by_scopes": with_attributes(make_object_or_array_of_objects_schema(_ARRAY_OF_STRINGS),
# description="specification of sparsity grain size by NNCF scope. "),
},
"additionalProperties": False
},
**COMMON_COMPRESSION_ALGORITHM_PROPERTIES
},
"additionalProperties": False
}

FILTER_PRUNING_ALGO_NAME_IN_CONFIG = 'filter_pruning'
FILTER_PRUNING_SCHEMA = {
**BASIC_COMPRESSION_ALGO_SCHEMA,
Expand Down Expand Up @@ -863,6 +919,7 @@ def with_attributes(schema: Dict, **kwargs) -> Dict:
CONST_SPARSITY_ALGO_NAME_IN_CONFIG: CONST_SPARSITY_SCHEMA,
MAGNITUDE_SPARSITY_ALGO_NAME_IN_CONFIG: MAGNITUDE_SPARSITY_SCHEMA,
RB_SPARSITY_ALGO_NAME_IN_CONFIG: RB_SPARSITY_SCHEMA,
MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG: MOVEMENT_SPARSITY_SCHEMA,
FILTER_PRUNING_ALGO_NAME_IN_CONFIG: FILTER_PRUNING_SCHEMA,
KNOWLEDGE_DISTILLATION_ALGO_NAME_IN_CONFIG: KNOWLEDGE_DISTILLATION_SCHEMA}

Expand Down
1 change: 1 addition & 0 deletions nncf/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from nncf.torch.sparsity.const import algo as const_sparsity_algo
from nncf.torch.sparsity.magnitude import algo as magnitude_sparsity_algo
from nncf.torch.sparsity.rb import algo as rb_sparsity_algo
from nncf.torch.sparsity.movement import algo as movement_sparsity_algo
from nncf.torch.pruning.filter_pruning import algo as filter_pruning_algo
from nncf.torch.knowledge_distillation import algo as knowledge_distillation_algo

Expand Down
3 changes: 2 additions & 1 deletion nncf/torch/dynamic_graph/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def _get_layer_attributes(module: TorchModule, operator_name: str) -> BaseLayerA
if isinstance(module, Linear):
return LinearLayerAttributes(weight_requires_grad=module.weight.requires_grad,
in_features=module.in_features,
out_features=module.out_features)
out_features=module.out_features,
bias=module.bias is not None)

if hasattr(module, 'weight'):
return GenericWeightedLayerAttributes(weight_requires_grad=module.weight.requires_grad,
Expand Down
Loading