Skip to content

Commit

Permalink
Fix retinanet training
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 30, 2021
1 parent 7c8489b commit 96a094b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/tensorflow/common/object_detection/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, params):
# One can use 'RESNET_FROZEN_VAR_PREFIX' to speed up ResNet training when loading from the checkpoint
# RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
self._frozen_variable_prefix = ""
params_train_regularization_variable_regex = r'.*(kernel|weight):0$'
params_train_regularization_variable_regex = r'.*(kernel|weight|kernel_mirrored|weight_mirrored):0$'
self._regularization_var_regex = params_train_regularization_variable_regex
self._l2_weight_decay = params.weight_decay

Expand Down
2 changes: 1 addition & 1 deletion examples/tensorflow/object_detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def model_eval_fn(model):
args = [model]
inputs = tf.keras.layers.Input(shape=model.inputs[0].shape[1:], name=model.inputs[0].name.split(':')[0])
outputs = NNCFWrapperCustom(*args, caliblration_dataset=train_dataset,
enable_mirrored_vars_split=False)(inputs)
enable_mirrored_vars_split=True)(inputs)
compress_model = tf.keras.Model(inputs=inputs, outputs=outputs)

scheduler = build_scheduler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,14 @@ def build_outputs(self, inputs, is_training):

return model_outputs

@staticmethod
def get_zero_replica_from_mirrored_var(var):
return var._get_replica(0)

def build_loss_fn(self, keras_model, compression_loss_fn):
#filter_fn = self.make_filter_trainable_variables_fn()
#trainable_variables = filter_fn(keras_model.trainable_variables)
trainable_variables = [v for v in keras_model.layers[1].trainable_model.mirrored_variables if v.trainable]
trainable_variables = [self.get_zero_replica_from_mirrored_var(v) for v in keras_model.layers[1].trainable_model.mirrored_variables if v.trainable]

def _total_loss_fn(labels, outputs):
cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
Expand Down
14 changes: 13 additions & 1 deletion op_insertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def build(self, input_shape=None):
concrete = tf_f.get_concrete_function(input_signature)
structured_outputs = concrete.structured_outputs
sorted_vars = get_sorted_on_captured_vars(concrete)
model.mirrored_variables = model.orig_model.variables
if isinstance(model.orig_model.variables[0], MirroredVariable):
model.mirrored_variables = model.orig_model.variables
else:
# Case when model build before replica context
model.mirrored_variables = self.create_mirrored_variables(sorted_vars)

else:
concrete = make_new_func(model.graph_def,
Expand Down Expand Up @@ -350,6 +354,14 @@ def call(self, inputs, training=None):
model_obj.fn_train.inputs,
model_obj.output_tensor)

if model_obj.fn_train.structured_outputs is not None:
# The order should be the same because
# we use concrete.outputs when building new concrete function
#outputs_list = nest.flatten(structured_outputs, expand_composites=True)
fn_train._func_graph.structured_outputs = \
nest.pack_sequence_as(model_obj.fn_train.structured_outputs,
fn_train.outputs,
expand_composites=True)
return fn_train(inputs)

def initialize_trainsformations(self, concrete, trainsformations):
Expand Down

0 comments on commit 96a094b

Please sign in to comment.