forked from JosephKJ/InsightFace_Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer_on_video.py
90 lines (78 loc) · 3.55 KB
/
infer_on_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import cv2
from PIL import Image
import argparse
from pathlib import Path
import torch
from config import get_config
from mtcnn import MTCNN
from Learner import face_learner
from utils import load_facebank, draw_box_name, prepare_facebank
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='for face verification')
parser.add_argument("-f", "--file_name", help="video file name",default='video.mp4', type=str)
parser.add_argument("-s", "--save_name", help="output file name",default='recording', type=str)
parser.add_argument('-th','--threshold',help='threshold to decide identical faces',default=1.54, type=float)
parser.add_argument("-u", "--update", help="whether perform update the facebank",action="store_true")
parser.add_argument("-tta", "--tta", help="whether test time augmentation",action="store_true")
parser.add_argument("-c", "--score", help="whether show the confidence score",action="store_true")
parser.add_argument("-b", "--begin", help="from when to start detection(in seconds)", default=0, type=int)
parser.add_argument("-d", "--duration", help="perform detection for how long(in seconds)", default=0, type=int)
args = parser.parse_args()
conf = get_config(False)
mtcnn = MTCNN()
print('mtcnn loaded')
learner = face_learner(conf, True)
learner.threshold = args.threshold
if conf.device.type == 'cpu':
learner.load_state(conf, 'cpu_final.pth', True, True)
else:
learner.load_state(conf, 'final.pth', True, True)
learner.model.eval()
print('learner loaded')
if args.update:
targets, names = prepare_facebank(conf, learner.model, mtcnn, tta = args.tta)
print('facebank updated')
else:
targets, names = load_facebank(conf)
print('facebank loaded')
cap = cv2.VideoCapture(str(conf.facebank_path/args.file_name))
cap.set(cv2.CAP_PROP_POS_MSEC, args.begin * 1000)
fps = cap.get(cv2.CAP_PROP_FPS)
video_writer = cv2.VideoWriter(str(conf.facebank_path/'{}.avi'.format(args.save_name)),
cv2.VideoWriter_fourcc(*'XVID'), int(fps), (1280,720))
if args.duration != 0:
i = 0
while cap.isOpened():
isSuccess,frame = cap.read()
if isSuccess:
# image = Image.fromarray(frame[...,::-1]) #bgr to rgb
image = Image.fromarray(frame)
try:
bboxes, faces = mtcnn.align_multi(image, conf.face_limit, 16)
except:
bboxes = []
faces = []
if len(bboxes) == 0:
print('no face')
continue
else:
bboxes = bboxes[:,:-1] #shape:[10,4],only keep 10 highest possibiity faces
bboxes = bboxes.astype(int)
bboxes = bboxes + [-1,-1,1,1] # personal choice
results, score = learner.infer(conf, faces, targets, True)
for idx,bbox in enumerate(bboxes):
if args.score:
frame = draw_box_name(bbox, names[results[idx] + 1] + '_{:.2f}'.format(score[idx]), frame)
else:
frame = draw_box_name(bbox, names[results[idx] + 1], frame)
video_writer.write(frame)
else:
break
if args.duration != 0:
i += 1
if i % 25 == 0:
print('{} second'.format(i // 25))
if i > 25 * args.duration:
break
cap.release()
video_writer.release()