-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
47 lines (40 loc) · 1.21 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Imports
from unicodedata import category
import tensorflow as tf
tf.get_logger().setLevel('WARNING')
import tensorflow_hub as hub
import numpy as np
import matplotlib.pyplot as plt
import argparse
import json
import time
from PIL import Image
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)
import utilities as ut
parser = argparse.ArgumentParser(description='Flowers Image Classifier')
parser.add_argument('path')
parser.add_argument('model')
parser.add_argument('--top_k')
parser.add_argument('--category_names')
args = parser.parse_args()
image_path = args.path
image = Image.open(image_path)
top_k = int(args.top_k)
if top_k is None:
top_k = 1
model = tf.keras.models.load_model(args.model, custom_objects={'KerasLayer':hub.KerasLayer})
probs, classes = ut.predict(image_path, model, top_k)
print("\nClasses:", classes)
print("\nProbabilities:", probs)
if args.category_names is not None:
with open(args.category_names, 'r') as f:
class_names = json.load(f)
print("\nFlowernames:")
flower_names = []
for c in classes[0]:
flower_names.append(class_names[str(c+1)])
print(flower_names)
label = flower_names[0]
print("\nMost probably it is a", label)