-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
87 lines (71 loc) · 3.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Gintautas Plonis 1812957
EfficientDet | Focal loss | Raven, Coffee, Headphones
(Optional) REST API
"""
import os
import argparse
import torch
from torchvision import transforms
from tqdm import tqdm
from src.dataset import Resizer, Normalizer, OpenImagesDataset
from src.config import OPEN_IMAGES_COLORS, OPEN_IMAGES_CLASSES
import cv2
import shutil
def get_args():
parser = argparse.ArgumentParser("EfficientDet: Scalable and Efficient Object Detection")
parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images")
parser.add_argument("--data_path", type=str, default="data", help="The root folder of dataset")
parser.add_argument("--cls_threshold", type=float, default=0.5)
parser.add_argument("--nms_threshold", type=float, default=0.5)
parser.add_argument("--model", type=str, default="trained_models/efficientdet10-100000.pth")
parser.add_argument("--output", type=str, default="predictions")
args = parser.parse_args()
return args
def test(opt):
if torch.cuda.is_available():
model = torch.load(opt.model).module.cuda()
else:
model = torch.load(opt.model, map_location=torch.device('cpu')).module
dataset = OpenImagesDataset(root_dir=opt.data_path, set_name='val',
transform=transforms.Compose([Normalizer(), Resizer()]))
if os.path.isdir(opt.output):
shutil.rmtree(opt.output)
os.makedirs(opt.output)
for c in dataset.class_names:
os.makedirs(os.path.join(opt.output, c.lower()))
for idx in tqdm(range(len(dataset))):
data = dataset[idx]
scale = data['scale']
with torch.no_grad():
if torch.cuda.is_available():
scores, labels, boxes = model(data['img'].cuda().permute(2, 0, 1).float().unsqueeze(dim=0))
else:
scores, labels, boxes = model(data['img'].permute(2, 0, 1).float().unsqueeze(dim=0))
boxes /= scale
if boxes.shape[0] > 0:
class_name = dataset.image_to_category_name[dataset.images[idx]]
path = f'{opt.data_path}/val/{class_name}/images/{dataset.images[idx]}.jpg'
output_image = cv2.imread(path)
for box_id in range(boxes.shape[0]):
pred_prob = float(scores[box_id])
if pred_prob < opt.cls_threshold:
break
pred_label = int(labels[box_id])
x1, y1, x2, y2 = boxes[box_id, :]
color = OPEN_IMAGES_COLORS[pred_label]
start_point = (int(x1), int(y1))
end_point = (int(x2), int(y2))
cv2.rectangle(output_image, start_point, end_point, color, thickness=2)
text_size = cv2.getTextSize(f'{OPEN_IMAGES_CLASSES[pred_label]}: {pred_prob:.2f}',
cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
cv2.rectangle(output_image, start_point,
(int(x1 + text_size[0] + 3), int(y1 + text_size[1] + 4)), color, thickness=-1)
cv2.putText(
output_image, f'{OPEN_IMAGES_CLASSES[pred_label]}: {pred_prob:.2f}',
(int(x1), int(y1 + text_size[1] + 4)), cv2.FONT_HERSHEY_PLAIN, 1,
(255, 255, 255), 1)
cv2.imwrite(f"{opt.output}/{class_name}/{dataset.images[idx]}_prediction.jpg", output_image)
if __name__ == "__main__":
opt = get_args()
test(opt)