Skip to content

Commit

Permalink
Fixed loading from pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
JanPalasek committed Apr 26, 2020
1 parent c9316fe commit 026cbd6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/captcha_detection/captcha_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ def __init__(self, image_shape, classes: int, image_preprocess_pipeline, label_p

self._model = tf.keras.Model(inputs=input, outputs=output)
else:
self._model = tf.saved_model.load(args.pretrained_model)
self._model = tf.keras.models.load_model(args.pretrained_model)

if args.weights_file is not None:
self._model.load_weights(args.weights_file)

print(f"Total layers: {len(self._model.layers)}")
# print(f"Total layers: {len(self._model.layers)}")
if args.remove_layers:
# remove classification header and add new one
input = self._model.layers[0].input
Expand Down
2 changes: 1 addition & 1 deletion src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
if __name__ == "__main__":
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--weights_file", default="src/captcha_detection/model.h5", type=str,
parser.add_argument("--weights_file", default=None, type=str,
help="Path to file that contains pre-trained weights.")
parser.add_argument("--pretrained_model", default=None, type=str)
parser.add_argument("--freeze_layers", default=0, type=int,
Expand Down

0 comments on commit 026cbd6

Please sign in to comment.