-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_recognition.py
More file actions
49 lines (36 loc) · 1.64 KB
/
inference_recognition.py
File metadata and controls
49 lines (36 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Warning: We havn't tested the code on original arcface model
import argparse
import math
import cv2
import numpy as np
import torch
from facexlib.utils import build_model
from facexlib.recognition import calculate_sim
def main(args):
det_net = build_model(args.det_model_name)
recog_net = build_model(args.recog_model_name)
img1 = cv2.imread(args.img_path1)
img2 = cv2.imread(args.img_path2)
with torch.no_grad():
bbox1 = det_net.detect_faces(img1, 0.97)[0]
bbox2 = det_net.detect_faces(img2, 0.97)[0]
if args.recog_model_name == 'facenet512':
output = recog_net.get([img1, img2], [bbox1[:5], bbox2[:5]], [bbox1[5:], bbox2[5:]])
else:
output = recog_net.get([img1, img2], [bbox1[5:], bbox2[5:]])
output = output.data.cpu().numpy()
dist = calculate_sim(output[0], output[1])
dist = np.arccos(dist) / math.pi * 180
if dist < 10:
print(f'Theses two images are almost identical (distance: {dist:.2f} degrees).')
else:
print(f'Theses two images are not identical (distance: {dist:.2f} degrees).')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img_path1', type=str, default='assets/test.jpg')
parser.add_argument('--img_path2', type=str, default='assets/test2.jpg')
parser.add_argument(
'--det_model_name', type=str, default='retinaface_resnet50', help='retinaface_resnet50 | retinaface_mobile0.25')
parser.add_argument('--recog_model_name', type=str, default='antelopev2', help='arcface | antelopev2 | buffalo_l | facenet512 | mtlface')
args = parser.parse_args()
main(args)