From 026cbd68c831773dd97b5efcf813d6e588c02877 Mon Sep 17 00:00:00 2001 From: JanPalasek Date: Sun, 26 Apr 2020 21:42:22 +0200 Subject: [PATCH] Fixed loading from pretrained model --- src/captcha_detection/captcha_network.py | 4 ++-- src/test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/captcha_detection/captcha_network.py b/src/captcha_detection/captcha_network.py index 82f8c00..b197aa3 100644 --- a/src/captcha_detection/captcha_network.py +++ b/src/captcha_detection/captcha_network.py @@ -69,12 +69,12 @@ def __init__(self, image_shape, classes: int, image_preprocess_pipeline, label_p self._model = tf.keras.Model(inputs=input, outputs=output) else: - self._model = tf.saved_model.load(args.pretrained_model) + self._model = tf.keras.models.load_model(args.pretrained_model) if args.weights_file is not None: self._model.load_weights(args.weights_file) - print(f"Total layers: {len(self._model.layers)}") + # print(f"Total layers: {len(self._model.layers)}") if args.remove_layers: # remove classification header and add new one input = self._model.layers[0].input diff --git a/src/test.py b/src/test.py index b717755..cfcbe5f 100644 --- a/src/test.py +++ b/src/test.py @@ -25,7 +25,7 @@ if __name__ == "__main__": # Parse arguments parser = argparse.ArgumentParser() - parser.add_argument("--weights_file", default="src/captcha_detection/model.h5", type=str, + parser.add_argument("--weights_file", default=None, type=str, help="Path to file that contains pre-trained weights.") parser.add_argument("--pretrained_model", default=None, type=str) parser.add_argument("--freeze_layers", default=0, type=int,