-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
71 lines (55 loc) · 2.62 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Import external libraries
from tensorflow.keras.applications.vgg16 import VGG16, decode_predictions, preprocess_input
from tensorflow.keras.preprocessing import image
import matplotlib.pyplot as plt
import os
import glob
from tqdm import tqdm
# Import internal libraries
from analysis import analysis
from utils import process_image, rescale, get_args
def main():
# Get arguments
current_folder = os.path.abspath(os.path.normpath(os.path.join(__file__, os.path.pardir)))
in_folder, out_folder = get_args()
source = os.path.join(current_folder, in_folder)
target = os.path.join(current_folder, out_folder)
# Define the target model
modelname = "vgg"
model = VGG16(weights="imagenet")
preprocessor = preprocess_input
decoder = decode_predictions
target_layer = "block5_conv3"
n_classes = 1000
# Get all image filenames
filenames = glob.glob(source)
# Define processing images folder name
if not os.path.isdir(target): os.mkdir(target)
errorlog = []
print("Analysing all provided images.")
# Loop over all files
for filename in tqdm(filenames):
# Extract filename without extension
name = os.path.basename(filename).split(".")[0]
# Load the image
img = image.load_img(filename, target_size=(224,224))
# Preprocess the image
_,ground_truth = process_image(img, preprocess_input)
# Apply explanation methods
try:
backprop, gradcam, squadcam, data = analysis(img, model, preprocessor, decoder, target_layer, n_classes)
except ValueError:
errorlog.append("ValueError: Image {} - Expected image size (224, 224). Got {}".format(filename, img.size))
# Create folder for output if none exists
if not os.path.isdir(os.path.join(target, name)):
os.mkdir(os.path.join(target,name))
# Save explanations and ground truth
plt.imsave("{}/{}/ground_truth.jpg".format(target, name), rescale(ground_truth))
plt.imsave("{}/{}/backprop_{}_{}.jpg".format(target, name, modelname, data[1]), backprop)
plt.imsave("{}/{}/gradcam_{}_{}.jpg".format(target, name, modelname, data[1]), gradcam)
plt.imsave("{}/{}/squadcam_{}_{}.jpg".format(target, name, modelname, data[1]), squadcam)
if errorlog:
print("Errors:\n{}".format("\n".join(errorlog)))
print("Analysis completed")
if __name__ == "__main__":
main()