-
Notifications
You must be signed in to change notification settings - Fork 0
/
VGGFace2_model.py
114 lines (102 loc) · 4.18 KB
/
VGGFace2_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Face Verification with the VGGFace2 model
import os
import cv2
from PIL import Image
from numpy import asarray
from matplotlib import pyplot
from mtcnn.mtcnn import MTCNN
from keras_vggface.vggface import VGGFace
from scipy.spatial.distance import cosine
from keras_vggface.utils import preprocess_input
"""The function load images from received folder."""
def load_images_from_folder(folder):
images = []
for filename in os.listdir(folder):
path_image = os.path.join(folder, filename)
img = cv2.imread(path_image)
if img is not None:
images.append(path_image)
return images
"""The function extract a single face from a given photograph."""
def extract_face(filename, required_size=(224, 224)):
# load image from file
pixels = pyplot.imread(filename)
# create the detector, using default weights
detector = MTCNN()
# detect faces in the image
results = detector.detect_faces(pixels)
# extract the bounding box from the first face
x1, y1, width, height = results[0]['box']
x2, y2 = x1 + width, y1 + height
# extract the face
face = pixels[y1:y2, x1:x2]
# resize pixels to the model size
image = Image.fromarray(face)
image = image.resize(required_size)
face_array = asarray(image)
return face_array
"""The function determine if a candidate face is a match for a known face."""
def is_match(image, known_embedding, candidate_embedding, thresh):
# calculate distance between embeddings
score = cosine(known_embedding, candidate_embedding)
print('\nThe face in the image: {} is:'.format(image))
if score <= thresh:
print('- a Match (score {0:.2f} <= threshold {1:.2f})'.format(score, thresh))
else:
print('- NOT a Match (score {0:.2f} > threshold {1:.2f})'.format(score, thresh))
"""The function extract faces and calculate face embeddings for a list of photo files."""
def get_embeddings(filenames, folder, foldlen, size):
figsize = 2 * size
# extract faces
faces = [extract_face(f) for f in filenames]
# create figure (fig), and array of axes (ax)
fig, ax = pyplot.subplots(nrows=size, ncols=size, figsize=(figsize, figsize))
# to turn off axes for all subplots
[axi.set_axis_off() for axi in ax.ravel()]
fig.tight_layout(pad=3.5)
# plot simple raster image on each sub-plot
for i, axi in enumerate(ax.flat):
# plot the extracted face
axi.imshow(faces[i])
axi.set_title(filenames[i][foldlen:-4])
# save the plot
pyplot.savefig('{}faces.png'.format(folder), dpi=900, transparent=True)
# show the plot
pyplot.show()
# convert into an array of samples
samples = asarray(faces, 'float32')
# prepare the face for the model, e.g. center pixels
samples = preprocess_input(samples, version=2)
# create a vggface model
model = VGGFace(model='resnet50', pooling='avg', include_top=False, input_shape=(224, 224, 3))
# perform prediction - yhat
return model.predict(samples)
"""The function checks the image index of the person for whom the verification is performed."""
def search_name(name, filenames):
for index in range(len(filenames)):
if name in filenames[index]:
return index
return None
def main():
threshold = 0.5
number_images = 4
folder = 'data'
folderlen = len(folder) + 1
verify_name = 'sharon_stone'
# define filenames
filenames = load_images_from_folder(folder)
# get embeddings file filenames
embeddings = get_embeddings(filenames, folder, folderlen, number_images)
# checks the image index of the person for whom the verification is performed
index = search_name(verify_name, filenames)
if index is None:
print("""Error! The person for whom the verification was performed is not included in the input images.""")
else:
# define sharon stone
sharon_id = embeddings[index]
print("""\nFor Positive Tests print Match, for Negative Tests print NOT a Match""")
# for loop without index 0 in embeddings
for i in range(len(filenames)):
is_match(filenames[i][folderlen:], sharon_id, embeddings[i], threshold)
if __name__ == '__main__':
main()