Skip to content

Commit

Permalink
Fixes to work properly with saved model.
Browse files Browse the repository at this point in the history
  • Loading branch information
JanPalasek committed Apr 26, 2020
1 parent e31fb08 commit 27cda1e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/accuracy/correctly_classified_captcha_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import tensorflow as tf


def all_correct_acc(y_true: tf.Tensor, y_pred):
if y_true.shape[0] is None and y_true.shape[1] is None and y_true.shape[2] is None:
return tf.convert_to_tensor(0)
Expand Down
9 changes: 8 additions & 1 deletion src/captcha_detection/captcha_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,13 @@ def __init__(self, image_shape, classes: int, image_preprocess_pipeline, label_p

self._loss = tf.keras.losses.SparseCategoricalCrossentropy()
self._optimizer = tf.keras.optimizers.Adam()

metrics = [tf.keras.metrics.sparse_categorical_accuracy]
if not args.save_model_path:
metrics.append(all_correct_acc)
self._model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.sparse_categorical_accuracy, all_correct_acc])
metrics=metrics)

self._model.summary()
plot_model(self._model, to_file=os.path.join(args.out_dir, "model.png"), show_shapes=True)
Expand All @@ -154,6 +158,9 @@ def __init__(self, image_shape, classes: int, image_preprocess_pipeline, label_p
self._check_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True)

if args.save_model_path:
self.save_model(args.save_model_path)

def train(self, train_x, train_y, val_x, val_y, args):
train_inputs, train_labels = self._image_preprocess_pipeline(train_x), self._label_preprocess_pipeline(train_y)
dev_inputs, dev_labels = self._image_preprocess_pipeline(val_x), self._label_preprocess_pipeline(
Expand Down
3 changes: 1 addition & 2 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"),
",".join(("{}={}".format(re.sub("(.)[^_]*_?", r"\1", key), value) for key, value in sorted(vars(args).items())))
))
args.save_model_path = os.path.join(out_dir, "model")

dataset = CaptchaDataset(annotations_path, len(args.available_chars))
inputs, labels = dataset.get_data()
Expand All @@ -84,8 +85,6 @@
label_preprocess_pipeline=label_preprocess_pipeline,
args=args)

network.save_model(os.path.join(out_dir, "model"))

labels = label_preprocess_pipeline(labels)

pred_labels = network.predict(inputs)
Expand Down

0 comments on commit 27cda1e

Please sign in to comment.