-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
266 lines (242 loc) · 13.7 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
from tqdm import tqdm
from ensemble_scheme import apply_voting_to_ensemble_predictions
from inf_helpers import *
from time import time
from datagen import DataGen
import argparse
import pickle
from focal_loss import BinaryFocalLoss
# Prediction
def make_prediction(_model, img_path, test_label, img_size, model_type, visual=False, ensemble=False):
"""
This method is used to generate a prediction given a raw image.
:param _model: A model to use for the prediction - it must be able to accept images with size = 'img_size' argument
:param img_path: The path to the image which mask will be generated by the model
:param test_label: The path to the corresponding ground-truth annotated mask
:param img_size: The size images have to be in to fit the model's representation
:param model_type: The network architecture type - 'unet' or 'mob_net'
:param visual: Boolean that determines whether we want to actually generate PIL.Images with overlays and etc. If it
is False, then we are probably performing inference time testing, because we do not need visuals.
:param ensemble: This indicates whether we are in ensemble mode and have to return the raw testing image, too
:return: The name of the image that was used and a 3-tuple containing the RGB true and predicted mask and an overlay
Note: In case of ensemble, the returned elements are 4 - the name of the image, the raw test img,
the 3-tuple containing the RGB true and predicted masks and an overlay and the ground-truth 1-channel label
"""
test_image = Image.open(img_path)
test_arr = prepare_for_prediction(test_image, img_size)
prediction = _model.predict(test_arr)
if model_type == 'unet':
resulting_img = interpret_prediction(prediction, test_image.size)
else:
resulting_img = create_mask(prediction).resize(test_image.size)
if visual:
test_lbl = Image.open(img_path).convert('L')
img_file_name = img_path.rpartition('/')[-1]
to_score = generate_true_and_predicted_masked_img_and_overlay(test_image=test_image, pred_mask=resulting_img,
mdl_version=model_type)
if not ensemble:
show_img(to_score[-1], title=f'Prediction made for image: {img_file_name}')
""" Other visualisations
show_img(test_image, title=f'Original raw image for image - {img_file_name}') # REAL IMAGE ONLY
show_img(resulting_img, title='Predicted segmentation mask') # PREDICTED GRAYSCALE SEGMENTATION MASK ONLY
"""
if not ensemble:
return img_path.rpartition('/')[-1], to_score
else:
return img_path.rpartition('/')[-1], test_image, to_score,test_lbl
if __name__ == '__main__':
# This argument parser is responsible for handling commandline arguments
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'-c', '--config_filepath',
help='Path to a pickled dictionary object (include the .pickle extension, too) containing all comfiguration '
'params.\n The config.py file can be used to generate such a file if it is not present in current '
'directory. \nThis file is required because its config params will be used to create a DataGen object for '
'testing data generation.',
type=str, default='default_config.pickle'
)
parser.add_argument(
'-s', '--saved_models_dir',
help='The path where the pretrained models are saved.\nThis will be used to load models into memory for '
'evaluation.',
type=str, default='models/'
)
parser.add_argument(
'-e', '--ensemble',
help='Enabling this flag option will set the evaluation in ensemble mode. \n'
'This will require several models to be loaded (specified during model selection process).\n',
action='store_true'
)
parser.add_argument(
'-i', '--intermediate_path',
help='This option allows you to decide whether to save the intermediate \n'
'outputs during the ensemble voting stage.\n'
'They are all saved to directory named \'intermediate_outputs\'\n.',
action='store_true'
)
# Get values of arguments - if none are given, the defaults are gonna be used
args = parser.parse_args()
# Try to open the pickled config file. If an error arises abort execution and notify user
try:
with open(args.config_filepath, 'rb') as pickled_dict:
params = pickle.load(pickled_dict)
except FileNotFoundError as err:
print(f'The following error occured during execution:\n\t-> {err}\n'
f'Please define a valid path to a pickled config dict. It can be generated by exicuting:\n\t'
f'>>> python config.py | for help run python config.py -h')
exit()
# Check if path for saved models exists and if it is a directory. If not abort and notify user
if os.path.exists(args.saved_models_dir):
if not os.path.isdir(args.saved_models_dir):
print(f'The given model path {args.saved_models_dir} is not a directory.'
f'Ensure that the correct name of the dir where all pretrained models reside is\n'
f'given as an argument to this execution call. Aborting...')
exit()
else:
print(f'The given model path {args.saved_models_dir} is does not exist.\n'
f'Ensure that the correct name of the dir where all pretrained models reside is\n'
f'given as an argument to this execution call. Aborting...')
exit()
# If ensemble scheme has been requested
if args.ensemble:
models_paths = get_saved_models_and_choice(save_dir=args.saved_models_dir, ensemble=args.ensemble)
if models_paths == -1:
print(f'The given directory for saved models [{args.saved_models_dir}] is empty. '
f'\nEither there are no models available or'
' their folders\' names do not start with the word \'model\'.')
exit()
print('\nLoading all models...')
loaded_models = [tf.keras.models.load_model(model) for model in tqdm(models_paths)]
print('All models have been loaded succesfully.\n')
images_used_to_train = []
else:
# Get user to select a pretrained model from all available ones
pretrained_model_path = get_saved_models_and_choice(save_dir=args.saved_models_dir)
if pretrained_model_path == -1:
print(f'The given directory for saved models [{args.saved_models_dir}] is empty. '
f'\nEither there are no models available or'
' their folders\' names do not start with the word \'model\'.')
exit()
# Load it using Keras API
loaded_model = tf.keras.models.load_model(pretrained_model_path)
print("\nModel loaded successfully. Printing model summary...\n")
# Print summary of model
loaded_model.summary()
# Get text file containing info about all images used for training so that during
# inference model is not given already seen examples
with open(os.path.join(pretrained_model_path, 'used_images.txt'), 'r') as f:
images_used_to_train = [name.strip() for name in f.readlines()] # we remove the '\n' newline character
# Choose whether to time the model execution or to generate visuals
try:
time_it = bool(int(input(
"To time the execution of the predictor, type an integer. "
"All other inputs will be regarded as a 'NO'\n\t-> ")))
except ValueError:
time_it = False
# Create a data generator object from variables in the config.py file
data_gen = DataGen(
data_home=params['data_home'],
train_dir=params['train_dir'],
label_dir=params['label_dir'],
test_dir=params['test_dir'],
batch_size=params['BATCH_SIZE'],
img_size=params['IMG_SIZE'],
model_type=params['VERSION'],
val_split=params['VAL_SPLIT'],
augmentation_data=params['AUGMENTATION_DATA'],
partial_sampling=params['PARTIAL_SAMPLING']
)
# Create a generator for testing samples - pairs of img and label
test_gen = data_gen.get_testing_data(images_used_to_train)
# Count all available original testing images (excluding the augmented ones)
total_test_imgs = len(list(data_gen.get_testing_data(images_used_to_train)))
# Get user to select how many samples they want to predict
try:
count = min(int(input(f"How many testing samples would you like to predict (Total {total_test_imgs}): \n\t")),
total_test_imgs)
except ValueError:
count = total_test_imgs
times = [] # List that holds all timing scores (will be averaged later)
masked_predictions = [] # This list holds tuples containing (img_name, true_mask, predicted_mask, overlay)
# which will later be saved to storage if user requests it.
outs = ()
for _ in tqdm(range(count)):
start = time() # get time (can be used for debugging even if not requested by user)
try: # Iterate over all testing images
path_to_imgs = next(test_gen) # Get path to test img and label
except StopIteration: # When we exhaust all testing samples, stop
print('\nReseting testing data generator as all images were exhausted')
test_gen = data_gen.get_testing_data(images_used_to_train)
path_to_imgs = next(test_gen)
# Predict and generate outputs
if not args.ensemble:
outs = make_prediction(
_model=loaded_model,
img_path=path_to_imgs[0],
test_label=path_to_imgs[1],
img_size=params['IMG_SIZE'],
model_type=params['VERSION'],
visual=not time_it
)
else:
ensemble_predictions = []
test_img, test_lbl_ = None, None
for loaded_model in loaded_models:
outs = make_prediction(
_model=loaded_model,
img_path=path_to_imgs[0],
test_label=path_to_imgs[1],
img_size=params['IMG_SIZE'],
model_type=params['VERSION'],
visual=not time_it,
ensemble=args.ensemble
)
ensemble_predictions.append(outs[2][1])
test_img = outs[1]
test_lbl_ = outs[3]
if args.intermediate_path:
img_name = outs[0].rpartition('\\')[-1]
save_path = f'intermediate_outputs/{img_name}/'
if not os.path.exists(save_path):
os.makedirs(save_path)
else:
save_path = None
# Get average voting result from all predictions
averaged_pixelwise_voting = apply_voting_to_ensemble_predictions(
predictions=ensemble_predictions, savepath=save_path
)
# Very important to multiply by 255 to allow rescaling to full colour gamma
averaged_pixelwise_voting = averaged_pixelwise_voting * 255
averaged_voted_image_arr = generate_true_and_predicted_masked_img_and_overlay(
test_image=test_img,
pred_mask=Image.fromarray(averaged_pixelwise_voting[:, :, 2]),
# real_mask=test_lbl_,
mdl_version=params['VERSION']
)[-1]
outs[2][-1] = averaged_voted_image_arr
outs = (outs[0], outs[2])
times.append(time() - start) # Append elapsed time for prediction (and maybe visuals generation)
# show_img(outs[1], title='Predicted image by the two ensembles')
if outs: # If we requested visuals, then save them to a list
masked_predictions.append(outs)
if masked_predictions: # If there are visuals saved to the list, ask user whether to save them to storage
save_ = input('\nWould you like to save the predicted images in a directory? (Y\\n) \n'
'Typing anything else will abort this functionality. \n')
if save_ in 'yY':
# Get a folder name for saved visuals
save_dir_ = input('\nNow, input a valid directory name to save the images in. \n')
# Create appropriate subdirectories
# os.makedirs(os.path.join(save_dir_, 'ground_truth_masks'), exist_ok=True)
os.makedirs(os.path.join(save_dir_, 'predicted_masks'), exist_ok=True)
print(f'Saving all images to "{save_dir_}" now...')
for data in tqdm(masked_predictions):
name, gt_and_pred_masks_and_overlay = data
name = name.rpartition('\\')[-1]
# Image.fromarray(gt_and_pred_masks_and_overlay[0]).save(f'{save_dir_}/ground_truth_masks/test_{name}')
Image.fromarray(gt_and_pred_masks_and_overlay[0]).save(f'{save_dir_}/predicted_masks/test_{name}')
Image.fromarray(gt_and_pred_masks_and_overlay[1]).save(f'{save_dir_}/test_{name}')
print(f"\nAverage time taken for prediction in seconds: {np.mean(times):.2f}")
if not time_it:
print('The given time includes various transforms and plotting operations.')
print('If an accurate result is sought that measures only the inference time')
print('and its corresponding image pre-processing, consider requesting timing at the begining.')