diff --git a/examples/tensorflow/common/object_detection/base_model.py b/examples/tensorflow/common/object_detection/base_model.py index 0ec1bc39655..3bdf080ce26 100644 --- a/examples/tensorflow/common/object_detection/base_model.py +++ b/examples/tensorflow/common/object_detection/base_model.py @@ -49,7 +49,7 @@ def __init__(self, params): # One can use 'RESNET_FROZEN_VAR_PREFIX' to speed up ResNet training when loading from the checkpoint # RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/' self._frozen_variable_prefix = "" - params_train_regularization_variable_regex = r'.*(kernel|weight):0$' + params_train_regularization_variable_regex = r'.*(kernel|weight|kernel_mirrored|weight_mirrored):0$' self._regularization_var_regex = params_train_regularization_variable_regex self._l2_weight_decay = params.weight_decay diff --git a/examples/tensorflow/object_detection/main.py b/examples/tensorflow/object_detection/main.py index e2a3172e061..189a8de6487 100644 --- a/examples/tensorflow/object_detection/main.py +++ b/examples/tensorflow/object_detection/main.py @@ -329,7 +329,7 @@ def model_eval_fn(model): args = [model] inputs = tf.keras.layers.Input(shape=model.inputs[0].shape[1:], name=model.inputs[0].name.split(':')[0]) outputs = NNCFWrapperCustom(*args, caliblration_dataset=train_dataset, - enable_mirrored_vars_split=False)(inputs) + enable_mirrored_vars_split=True)(inputs) compress_model = tf.keras.Model(inputs=inputs, outputs=outputs) scheduler = build_scheduler( diff --git a/examples/tensorflow/object_detection/models/retinanet_model.py b/examples/tensorflow/object_detection/models/retinanet_model.py index f2593a3e352..3a98f620840 100644 --- a/examples/tensorflow/object_detection/models/retinanet_model.py +++ b/examples/tensorflow/object_detection/models/retinanet_model.py @@ -68,10 +68,14 @@ def build_outputs(self, inputs, is_training): return model_outputs + @staticmethod + def get_zero_replica_from_mirrored_var(var): + return var._get_replica(0) + def build_loss_fn(self, keras_model, compression_loss_fn): #filter_fn = self.make_filter_trainable_variables_fn() #trainable_variables = filter_fn(keras_model.trainable_variables) - trainable_variables = [v for v in keras_model.layers[1].trainable_model.mirrored_variables if v.trainable] + trainable_variables = [self.get_zero_replica_from_mirrored_var(v) for v in keras_model.layers[1].trainable_model.mirrored_variables if v.trainable] def _total_loss_fn(labels, outputs): cls_loss = self._cls_loss_fn(outputs['cls_outputs'], diff --git a/op_insertion.py b/op_insertion.py index 6da5f72cfbc..1f4539fbfd5 100644 --- a/op_insertion.py +++ b/op_insertion.py @@ -198,7 +198,11 @@ def build(self, input_shape=None): concrete = tf_f.get_concrete_function(input_signature) structured_outputs = concrete.structured_outputs sorted_vars = get_sorted_on_captured_vars(concrete) - model.mirrored_variables = model.orig_model.variables + if isinstance(model.orig_model.variables[0], MirroredVariable): + model.mirrored_variables = model.orig_model.variables + else: + # Case when model build before replica context + model.mirrored_variables = self.create_mirrored_variables(sorted_vars) else: concrete = make_new_func(model.graph_def, @@ -350,6 +354,14 @@ def call(self, inputs, training=None): model_obj.fn_train.inputs, model_obj.output_tensor) + if model_obj.fn_train.structured_outputs is not None: + # The order should be the same because + # we use concrete.outputs when building new concrete function + #outputs_list = nest.flatten(structured_outputs, expand_composites=True) + fn_train._func_graph.structured_outputs = \ + nest.pack_sequence_as(model_obj.fn_train.structured_outputs, + fn_train.outputs, + expand_composites=True) return fn_train(inputs) def initialize_trainsformations(self, concrete, trainsformations):