From 7c8489b7e327ed69cdb27ff8a4819e9b9f1d9b08 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Fri, 30 Jul 2021 16:02:55 +0300 Subject: [PATCH] Setup retinanet quantization --- examples/tensorflow/object_detection/main.py | 25 +++- op_insertion.py | 123 +++++++++++++++---- 2 files changed, 119 insertions(+), 29 deletions(-) diff --git a/examples/tensorflow/object_detection/main.py b/examples/tensorflow/object_detection/main.py index 2b96838a4f6..e2a3172e061 100644 --- a/examples/tensorflow/object_detection/main.py +++ b/examples/tensorflow/object_detection/main.py @@ -18,6 +18,8 @@ import tensorflow as tf import numpy as np +from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 + from nncf.tensorflow import AdaptiveCompressionTrainingLoop from nncf.tensorflow import create_compressed_model from nncf.tensorflow.helpers.model_manager import TFOriginalModelManager @@ -47,6 +49,20 @@ from examples.tensorflow.object_detection.models.model_selector import get_predefined_config from examples.tensorflow.object_detection.models.model_selector import get_model_builder +def keras_model_to_frozen_graph(model): + input_signature = [] + for item in model.inputs: + input_signature.append(tf.TensorSpec(item.shape, item.dtype)) + concrete_function = tf.function(model).get_concrete_function(input_signature) + frozen_func = convert_variables_to_constants_v2(concrete_function, lower_control_flow=False) + return frozen_func.graph.as_graph_def(add_shapes=True) + + +def save_model_as_frozen_graph(model, save_path, as_text=False): + frozen_graph = keras_model_to_frozen_graph(model) + save_dir, name = os.path.split(save_path) + tf.io.write_graph(frozen_graph, save_dir, name, as_text=as_text) + def get_argument_parser(): parser = get_common_argument_parser(precision=False, @@ -311,7 +327,7 @@ def model_eval_fn(model): compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state) from op_insertion import NNCFWrapperCustom args = [model] - inputs = tf.keras.layers.Input(shape=model.inputs[0].shape[1:]) + inputs = tf.keras.layers.Input(shape=model.inputs[0].shape[1:], name=model.inputs[0].name.split(':')[0]) outputs = NNCFWrapperCustom(*args, caliblration_dataset=train_dataset, enable_mirrored_vars_split=False)(inputs) compress_model = tf.keras.Model(inputs=inputs, outputs=outputs) @@ -381,9 +397,10 @@ def validate_fn(model, **kwargs): write_metrics(metric_result['AP'], config.metrics_dump) if 'export' in config.mode: - save_path, save_format = get_saving_parameters(config) - compression_ctrl.export_model(save_path, save_format) - logger.info("Saved to {}".format(save_path)) + save_model_as_frozen_graph(compress_model, config.to_frozen_graph) + #save_path, save_format = get_saving_parameters(config) + #compression_ctrl.export_model(save_path, save_format) + #logger.info("Saved to {}".format(save_path)) def export(config): diff --git a/op_insertion.py b/op_insertion.py index c268e855845..6da5f72cfbc 100644 --- a/op_insertion.py +++ b/op_insertion.py @@ -24,11 +24,28 @@ class InsertionPoint(object): AFTER_LAYER = 'after' BEFORE_LAYER = 'before' + @staticmethod + def from_str(input_str): + if input_str == "AFTER_LAYER": + return InsertionPoint.AFTER_LAYER + if input_str == "BEFORE_LAYER": + return InsertionPoint.BEFORE_LAYER + if input_str == "OPERATION_WITH_WEIGHTS": + return InsertionPoint.WEIGHTS + + raise RuntimeError('Wrong type of insertion point') + class QuantizationSetup(object): - def __init__(self, signed=None, narrow_range=False, init_value=6): + def __init__(self, signed=True, + narrow_range=False, + per_channel=False, + symmetric=True, + init_value=6): self.signed = signed self.narrow_range = narrow_range + self.per_channel = per_channel + self.symmetric = symmetric self.init_value = init_value @@ -76,8 +93,46 @@ def __init__(self, # point_dict['_target_type'].pop('__objclass__') # res.append(point_dict) def get_functional_retinanet_fq_placing_simular_to_nncf2_0(self, g): - path = 'examples/tensorflow/object_detection/configs/quantization/retinanet_quantization_layout.json' - layout = json.load(path) + path = 'configs/quantization/retinanet_quantization_layout.json' + with open(path, 'r') as inp: + layout = json.load(inp) + for l in layout: + l.update({'ops': [op for op in g.get_operations() if op.name.startswith(l['_layer_name'] +'/')]}) + + transformations = [] + for op_layout in layout: + layout_name = op_layout['_layer_name'] + setup = QuantizationSetup(signed=op_layout['signedness_to_force'] in (True, None), + narrow_range=op_layout['narrow_range'] or op_layout['half_range'], + per_channel=op_layout['per_channel']) + + insertion_point = InsertionPoint.from_str(op_layout['_target_type']['_name_']) + if layout_name.startswith('input'): + op = [g.get_operations()[0]] + elif layout_name.startswith('batch_normalization') or layout_name.endswith('bn'): + op = [op for op in op_layout['ops'] if op.type == 'FusedBatchNormV3'] + elif layout_name.startswith('l') or layout_name.startswith("post_hoc"): + op_type = 'BiasAdd' if insertion_point == InsertionPoint.AFTER_LAYER else 'Conv2D' + op = [op for op in op_layout['ops'] if op.type == op_type] + elif layout_name.startswith('class') or layout_name.startswith('box'): + # Skip shared conv by now + continue + elif (layout_name.startswith('p') and not layout_name.startswith('post_hoc')) \ + or layout_name.startswith('conv2d'): + op = [op for op in op_layout['ops'] if op.type == 'Conv2D'] + elif layout_name.startswith('up_sampling'): + op = [op for op in op_layout['ops'] if op.type == 'ResizeNearestNeighbor'] + elif any(any(layout_name.split('_')[-i].endswith(x) for i in [1, 2]) for x in ['Relu', 'add']): + op = op_layout['ops'] + if 'Relu' in layout_name: + setup.signed = False + else: + raise RuntimeError(f'You forgot about operation {layout_name}') + + assert len(op) == 1 + transformations.append((op[0], insertion_point, setup)) + + return transformations def get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(self, g): """Hardcode fq placing for examples.classification.test_models.get_KerasLayer_model""" @@ -108,9 +163,9 @@ def get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(self, g): # transformations = [] # Transformations for blocks - transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True)) for op in depthwise_conv]) - transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True)) for op in project_ops]) - transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True)) for op in expand_ops]) + transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False)) for op in depthwise_conv]) + transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False)) for op in project_ops]) + transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False)) for op in expand_ops]) transformations.extend([(op, InsertionPoint.AFTER_LAYER, QuantizationSetup(signed=False)) for op in depthwise_conv_relu]) transformations.extend([(op, InsertionPoint.AFTER_LAYER, QuantizationSetup(signed=True)) for op in project_bn]) @@ -120,14 +175,14 @@ def get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(self, g): # FQ on inputs transformations.append((first_conv, InsertionPoint.BEFORE_LAYER, QuantizationSetup(signed=True))) # FQ on first conv weights - transformations.append((first_conv, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True))) + transformations.append((first_conv, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False))) # FQ after first conv relu transformations.append((first_conv_relu, InsertionPoint.AFTER_LAYER, QuantizationSetup(signed=False))) # Transformation for net tail - transformations.append((last_conv, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True))) + transformations.append((last_conv, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False))) transformations.append((last_conv_relu, InsertionPoint.AFTER_LAYER, QuantizationSetup(signed=False))) transformations.append((avg_pool, InsertionPoint.AFTER_LAYER, QuantizationSetup(signed=False))) - transformations.append((prediction_mul, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True))) + transformations.append((prediction_mul, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False))) assert len(transformations) == 117 return transformations @@ -136,8 +191,12 @@ def build(self, input_shape=None): for training, model in zip([True, False], [self.trainable_model, self.eval_model]): if self.model_type != ModelType.KerasLayer: tf_f = tf.function(lambda x: model.orig_model.call(x, training=training)) - concrete = tf_f.get_concrete_function(*[tf.TensorSpec(input_shape, tf.float32)]) + input_signature = [] + for item in model.orig_model.inputs: + input_signature.append(tf.TensorSpec(item.shape, item.dtype)) + concrete = tf_f.get_concrete_function(input_signature) + structured_outputs = concrete.structured_outputs sorted_vars = get_sorted_on_captured_vars(concrete) model.mirrored_variables = model.orig_model.variables @@ -150,10 +209,6 @@ def build(self, input_shape=None): sorted_vars = get_sorted_on_captured_vars(concrete) model.mirrored_variables = self.create_mirrored_variables(sorted_vars) - ### - ### Generated weights preprocessing - ### - ### Insert compression operation if not self.initial_model_weights: self.initial_model_weights = self.get_numpy_weights_list(sorted_vars) @@ -173,31 +228,32 @@ def build(self, input_shape=None): # Add new op to layer if not self.ops_vars_created: self.op_vars = [] - enable_quantization = False + enable_quantization = True if enable_quantization: new_vars = [] transformations = self.get_functional_retinanet_fq_placing_simular_to_nncf2_0(concrete.graph) #transformations = self.get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(concrete.graph) if training: - pass - #self.initialize_trainsformations(concrete, transformations) + #pass + self.initialize_trainsformations(concrete, transformations) with concrete.graph.as_default() as g: # Insert given transformations for op, insertion_point, setup in transformations: def fq_creation(input_tensor, name): return create_fq_with_weights(input_tensor=input_tensor, + per_channel=setup.per_channel, name=name, signed=setup.signed, init_value=setup.init_value, narrow_range=setup.narrow_range) if insertion_point == InsertionPoint.AFTER_LAYER: - new_vars.append(insert_op_after(g, op, 0, fq_creation, op.name)) + new_vars.append(insert_op_after(g, op, 0, fq_creation, f'{op.name}_after_layer')) elif insertion_point == InsertionPoint.BEFORE_LAYER: new_vars.append(insert_op_before(g, op, 0, fq_creation, f'{op.name}_before_layer')) elif insertion_point == InsertionPoint.WEIGHTS: - new_vars.append(insert_op_before(g, op, 1, fq_creation, op.name)) + new_vars.append(insert_op_before(g, op, 1, fq_creation, f'{op.name}_weights')) else: raise RuntimeError('Wrong insertion point in quantization algo') @@ -220,12 +276,22 @@ def fq_creation(input_tensor, name): for new_var, (_, placeholder) in zip(new_ops_vars, old_captures[-len(self.op_vars):]): new_captures.append((new_var.handle, placeholder)) new_variables = [v for v in concrete.variables] + new_ops_vars + if len(new_variables) != len(new_captures): + raise RuntimeError('Len of the new vars should be the same as len' + ' of new captures (possible some compression weights missing)') + concrete = make_new_func(concrete.graph.as_graph_def(), new_captures, new_variables, concrete.inputs, concrete.outputs) + if structured_outputs is not None: + # The order should be the same because + # we use concrete.outputs when building new concrete function + #outputs_list = nest.flatten(structured_outputs, expand_composites=True) + concrete._func_graph.structured_outputs = \ + nest.pack_sequence_as(structured_outputs, concrete.outputs, expand_composites=True) model.output_tensor = concrete.graph.outputs model.fn_train = concrete @@ -297,14 +363,14 @@ def initialize_trainsformations(self, concrete, trainsformations): min_val, max_val = self.get_min_max_op_weights(concrete.graph, op, concrete.inputs, self.initial_model_weights) setup.init_value = max(abs(min_val), abs(max_val)) - setup.narrow_range = True + #setup.narrow_range = True if self.calibration_dataset is None: return outputs = [] activation_transformations = [t for t in trainsformations if t[1] != InsertionPoint.WEIGHTS] - for op, insertion_point, setup in activation_transformations: + for op, _, _ in activation_transformations: outputs.append(op.outputs[0]) # Create concrete function with outputs from each activation @@ -324,7 +390,7 @@ def initialize_trainsformations(self, concrete, trainsformations): # Update quantization setup for i, (_, _, setup) in enumerate(activation_transformations): setup.init_value = max(abs(np.mean(mins[i])), abs(np.mean(maxs[i]))) - setup.narrow_range = False + #setup.narrow_range = False def get_min_max_op_weights(self, graph, op, placeholders, np_vars): try: @@ -480,9 +546,11 @@ def insert_op_after(graph, target_parent, output_index, node_creation_fn, name): return node_weights -def create_fq_with_weights(input_tensor, name, signed, init_value, narrow_range): +def create_fq_with_weights(input_tensor, per_channel, name, signed, init_value, narrow_range): """Should be called in graph context""" with variable_scope.variable_scope('new_node'): + # Should check if variable already exist + # if it exist through error scale = variable_scope.get_variable( f'scale_{name}', shape=(), @@ -491,8 +559,13 @@ def create_fq_with_weights(input_tensor, name, signed, init_value, narrow_range) trainable=True) min = -scale if signed else 0. - output_tensor = tf.quantization.fake_quant_with_min_max_vars(input_tensor, min, scale, - narrow_range=narrow_range) + if False:#per_channel: + # Per channel not implemented yet + output_tensor = tf.quantization.fake_quant_with_min_max_vars_per_channel(input_tensor, min, scale, + narrow_range=narrow_range) + else: + output_tensor = tf.quantization.fake_quant_with_min_max_vars(input_tensor, min, scale, + narrow_range=narrow_range) return output_tensor, scale