forked from longcw/faster_rcnn_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
46 lines (38 loc) · 1.59 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import cv2
import numpy as np
from faster_rcnn import network
from faster_rcnn.faster_rcnn import FasterRCNN
from faster_rcnn.utils.timer import Timer
def test():
import os
im_file = 'demo/004545.jpg'
# im_file = 'data/VOCdevkit2007/VOC2007/JPEGImages/009036.jpg'
# im_file = '/media/longc/Data/data/2DMOT2015/test/ETH-Crossing/img1/000100.jpg'
image = cv2.imread(im_file)
model_file = '/media/longc/Data/models/VGGnet_fast_rcnn_iter_70000.h5'
# model_file = '/media/longc/Data/models/faster_rcnn_pytorch3/faster_rcnn_100000.h5'
# model_file = '/media/longc/Data/models/faster_rcnn_pytorch2/faster_rcnn_2000.h5'
detector = FasterRCNN()
network.load_net(model_file, detector)
detector.cuda()
detector.eval()
print('load model successfully!')
# network.save_net(r'/media/longc/Data/models/VGGnet_fast_rcnn_iter_70000.h5', detector)
# print('save model succ')
t = Timer()
t.tic()
# image = np.zeros(shape=[600, 800, 3], dtype=np.uint8) + 255
dets, scores, classes = detector.detect(image, 0.7)
runtime = t.toc()
print('total spend: {}s'.format(runtime))
im2show = np.copy(image)
for i, det in enumerate(dets):
det = tuple(int(x) for x in det)
cv2.rectangle(im2show, det[0:2], det[2:4], (255, 205, 51), 2)
cv2.putText(im2show, '%s: %.3f' % (classes[i], scores[i]), (det[0], det[1] + 15), cv2.FONT_HERSHEY_PLAIN,
1.0, (0, 0, 255), thickness=1)
cv2.imwrite(os.path.join('demo', 'out.jpg'), im2show)
cv2.imshow('demo', im2show)
cv2.waitKey(0)
if __name__ == '__main__':
test()