Graduate

make client

1 +##################################################
2 +#1. webcam에서 얼굴을 인식합니다. #
3 +#2. 얼굴일 확률이 95% 이상인 이미지를 이미지 서버로 전송합니다. #
4 +#3. 전처리 된 데이터를 verification 서버에 전송합니다. #
5 +##################################################
6 +import torch
7 +import numpy as np
8 +import cv2
9 +import asyncio
10 +import websockets
11 +import json
12 +import os
13 +import timeit
14 +import base64
15 +
16 +from PIL import Image
17 +from io import BytesIO
18 +import requests
19 +
20 +from models.mtcnn import MTCNN
21 +
22 +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
23 +print('Running on device: {}'.format(device))
24 +
25 +mtcnn = MTCNN(keep_all=True, device=device)
26 +
27 +uri = 'ws://localhost:8765'
28 +
29 +async def send_face(face_list):
30 + global uri
31 + async with websockets.connect(uri) as websocket:
32 + for face in face_list:
33 + #type: np.float32
34 + print(face.shape)
35 + data = json.dumps({"action": "verify", "MTCNN": face.tolist()})
36 + await websocket.send(data)
37 + print('send: verify', len(face_list), 'face(s)')
38 + code = await websocket.recv()
39 + print('code:', code)
40 +
41 +async def send_image(image_list):
42 + global uri
43 + async with websockets.connect(uri) as websocket:
44 + for image in image_list:
45 + data = json.dumps({"action": "save_image", "image": image.tolist(), "shape": image.shape})
46 + await websocket.send(data)
47 + print('send', len(image_list), 'image(s)')
48 + code = await websocket.recv()
49 + print('code:', code)
50 +
51 +def detect_face(frame):
52 + # If required, create a face detection pipeline using MTCNN:
53 + global mtcnn
54 + results = mtcnn.detect(frame)
55 + image_list = []
56 + if results[1][0] == None:
57 + return []
58 + for box, prob in zip(results[0], results[1]):
59 + if prob < 0.95:
60 + continue
61 + print('face detected. prob:', prob)
62 + x1, y1, x2, y2 = box
63 + image = frame[int(y1-10):int(y2+10), int(x1-10):int(x2+10)]
64 + image_list.append(image)
65 + return image_list
66 +
67 +def make_face_list(frame):
68 + global mtcnn
69 + results, prob = mtcnn(frame, return_prob = True)
70 + face_list = []
71 + if prob[0] == None:
72 + return []
73 + for result, prob in zip(results, prob):
74 + if prob < 0.95:
75 + continue
76 + #np.float32
77 + face_list.append(result.numpy())
78 + return face_list
79 +
80 +cap = cv2.VideoCapture(0)
81 +cap.set(3, 720)
82 +cap.set(4, 480)
83 +while True:
84 + try:
85 + #start = timeit.default_timer()
86 + ret, frame = cap.read()
87 + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
88 + face_list = make_face_list(frame)
89 + image_list = detect_face(frame)
90 + ##embedding server로 전송##
91 + if face_list:
92 + asyncio.get_event_loop().run_until_complete(send_face(face_list))
93 + ###################
94 + ##image server로 전송##
95 + if image_list:
96 + asyncio.get_event_loop().run_until_complete(send_image(image_list))
97 + ###################
98 + #end = timeit.default_timer()
99 + #print('delta time: ', end - start)
100 + except Exception as ex:
101 + print(ex)
1 +import torch
2 +import numpy as np
3 +import cv2
4 +import matplotlib.pyplot as plt
5 +import os
6 +
7 +from PIL import Image, ImageDraw
8 +from IPython import display
9 +
10 +from models import mtcnn
11 +from models import inception_resnet_v1
12 +
13 +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
14 +print('Running on device: {}'.format(device))
15 +
16 +def extract_face(filename, required_size=(224, 224)):
17 + # If required, create a face detection pipeline using MTCNN:
18 + mtcnn_model = mtcnn.MTCNN(keep_all=True, device=device)
19 + pixels = plt.imread(os.path.join(os.path.abspath(''), filename))
20 + results = mtcnn_model.detect(pixels)
21 + face_array = []
22 + for box, prob in zip(results[0], results[1]):
23 + #boxes, _ = result
24 + print('face detected. prob:', prob)
25 + x1, y1, x2, y2 = box
26 + face = pixels[int(y1):int(y2), int(x1):int(x2)]
27 + image = Image.fromarray(face)
28 + image = image.resize(required_size)
29 + face_array.append(np.asarray(image))
30 + return face_array
31 +
32 +face_array = extract_face('image/test1.jpg')
33 +for face in face_array:
34 + plt.figure()
35 + plt.imshow(face)
36 + plt.show()
37 +
38 +face_array = extract_face('image/test2.jpg')
39 +for face in face_array:
40 + plt.figure()
41 + plt.imshow(face)
42 + plt.show()
43 +
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
1 +import torch
2 +from torch import nn
3 +import numpy as np
4 +import os
5 +
6 +from .utils.detect_face import detect_face, extract_face
7 +
8 +
9 +class PNet(nn.Module):
10 + """MTCNN PNet.
11 +
12 + Keyword Arguments:
13 + pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True})
14 + """
15 +
16 + def __init__(self, pretrained=True):
17 + super().__init__()
18 +
19 + self.conv1 = nn.Conv2d(3, 10, kernel_size=3)
20 + self.prelu1 = nn.PReLU(10)
21 + self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
22 + self.conv2 = nn.Conv2d(10, 16, kernel_size=3)
23 + self.prelu2 = nn.PReLU(16)
24 + self.conv3 = nn.Conv2d(16, 32, kernel_size=3)
25 + self.prelu3 = nn.PReLU(32)
26 + self.conv4_1 = nn.Conv2d(32, 2, kernel_size=1)
27 + self.softmax4_1 = nn.Softmax(dim=1)
28 + self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1)
29 +
30 + self.training = False
31 +
32 + if pretrained:
33 + state_dict_path = os.path.join(os.path.dirname(__file__), 'data/pnet.pt')
34 + state_dict = torch.load(state_dict_path)
35 + self.load_state_dict(state_dict)
36 +
37 + def forward(self, x):
38 + x = self.conv1(x)
39 + x = self.prelu1(x)
40 + x = self.pool1(x)
41 + x = self.conv2(x)
42 + x = self.prelu2(x)
43 + x = self.conv3(x)
44 + x = self.prelu3(x)
45 + a = self.conv4_1(x)
46 + a = self.softmax4_1(a)
47 + b = self.conv4_2(x)
48 + return b, a
49 +
50 +
51 +class RNet(nn.Module):
52 + """MTCNN RNet.
53 +
54 + Keyword Arguments:
55 + pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True})
56 + """
57 +
58 + def __init__(self, pretrained=True):
59 + super().__init__()
60 +
61 + self.conv1 = nn.Conv2d(3, 28, kernel_size=3)
62 + self.prelu1 = nn.PReLU(28)
63 + self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True)
64 + self.conv2 = nn.Conv2d(28, 48, kernel_size=3)
65 + self.prelu2 = nn.PReLU(48)
66 + self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True)
67 + self.conv3 = nn.Conv2d(48, 64, kernel_size=2)
68 + self.prelu3 = nn.PReLU(64)
69 + self.dense4 = nn.Linear(576, 128)
70 + self.prelu4 = nn.PReLU(128)
71 + self.dense5_1 = nn.Linear(128, 2)
72 + self.softmax5_1 = nn.Softmax(dim=1)
73 + self.dense5_2 = nn.Linear(128, 4)
74 +
75 + self.training = False
76 +
77 + if pretrained:
78 + state_dict_path = os.path.join(os.path.dirname(__file__), 'data/rnet.pt')
79 + state_dict = torch.load(state_dict_path)
80 + self.load_state_dict(state_dict)
81 +
82 + def forward(self, x):
83 + x = self.conv1(x)
84 + x = self.prelu1(x)
85 + x = self.pool1(x)
86 + x = self.conv2(x)
87 + x = self.prelu2(x)
88 + x = self.pool2(x)
89 + x = self.conv3(x)
90 + x = self.prelu3(x)
91 + x = x.permute(0, 3, 2, 1).contiguous()
92 + x = self.dense4(x.view(x.shape[0], -1))
93 + x = self.prelu4(x)
94 + a = self.dense5_1(x)
95 + a = self.softmax5_1(a)
96 + b = self.dense5_2(x)
97 + return b, a
98 +
99 +
100 +class ONet(nn.Module):
101 + """MTCNN ONet.
102 +
103 + Keyword Arguments:
104 + pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True})
105 + """
106 +
107 + def __init__(self, pretrained=True):
108 + super().__init__()
109 +
110 + self.conv1 = nn.Conv2d(3, 32, kernel_size=3)
111 + self.prelu1 = nn.PReLU(32)
112 + self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True)
113 + self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
114 + self.prelu2 = nn.PReLU(64)
115 + self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True)
116 + self.conv3 = nn.Conv2d(64, 64, kernel_size=3)
117 + self.prelu3 = nn.PReLU(64)
118 + self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
119 + self.conv4 = nn.Conv2d(64, 128, kernel_size=2)
120 + self.prelu4 = nn.PReLU(128)
121 + self.dense5 = nn.Linear(1152, 256)
122 + self.prelu5 = nn.PReLU(256)
123 + self.dense6_1 = nn.Linear(256, 2)
124 + self.softmax6_1 = nn.Softmax(dim=1)
125 + self.dense6_2 = nn.Linear(256, 4)
126 + self.dense6_3 = nn.Linear(256, 10)
127 +
128 + self.training = False
129 +
130 + if pretrained:
131 + state_dict_path = os.path.join(os.path.dirname(__file__), 'data/onet.pt')
132 + state_dict = torch.load(state_dict_path)
133 + self.load_state_dict(state_dict)
134 +
135 + def forward(self, x):
136 + x = self.conv1(x)
137 + x = self.prelu1(x)
138 + x = self.pool1(x)
139 + x = self.conv2(x)
140 + x = self.prelu2(x)
141 + x = self.pool2(x)
142 + x = self.conv3(x)
143 + x = self.prelu3(x)
144 + x = self.pool3(x)
145 + x = self.conv4(x)
146 + x = self.prelu4(x)
147 + x = x.permute(0, 3, 2, 1).contiguous()
148 + x = self.dense5(x.view(x.shape[0], -1))
149 + x = self.prelu5(x)
150 + a = self.dense6_1(x)
151 + a = self.softmax6_1(a)
152 + b = self.dense6_2(x)
153 + c = self.dense6_3(x)
154 + return b, c, a
155 +
156 +
157 +class MTCNN(nn.Module):
158 + """MTCNN face detection module.
159 +
160 + This class loads pretrained P-, R-, and O-nets and returns images cropped to include the face
161 + only, given raw input images of one of the following types:
162 + - PIL image or list of PIL images
163 + - numpy.ndarray (uint8) representing either a single image (3D) or a batch of images (4D).
164 + Cropped faces can optionally be saved to file
165 + also.
166 +
167 + Keyword Arguments:
168 + image_size {int} -- Output image size in pixels. The image will be square. (default: {160})
169 + margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
170 + Note that the application of the margin differs slightly from the davidsandberg/facenet
171 + repo, which applies the margin to the original image before resizing, making the margin
172 + dependent on the original image size (this is a bug in davidsandberg/facenet).
173 + (default: {0})
174 + min_face_size {int} -- Minimum face size to search for. (default: {20})
175 + thresholds {list} -- MTCNN face detection thresholds (default: {[0.6, 0.7, 0.7]})
176 + factor {float} -- Factor used to create a scaling pyramid of face sizes. (default: {0.709})
177 + post_process {bool} -- Whether or not to post process images tensors before returning.
178 + (default: {True})
179 + select_largest {bool} -- If True, if multiple faces are detected, the largest is returned.
180 + If False, the face with the highest detection probability is returned.
181 + (default: {True})
182 + keep_all {bool} -- If True, all detected faces are returned, in the order dictated by the
183 + select_largest parameter. If a save_path is specified, the first face is saved to that
184 + path and the remaining faces are saved to <save_path>1, <save_path>2 etc.
185 + device {torch.device} -- The device on which to run neural net passes. Image tensors and
186 + models are copied to this device before running forward passes. (default: {None})
187 + """
188 +
189 + def __init__(
190 + self, image_size=160, margin=0, min_face_size=20,
191 + thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
192 + select_largest=True, keep_all=False, device=None
193 + ):
194 + super().__init__()
195 +
196 + self.image_size = image_size
197 + self.margin = margin
198 + self.min_face_size = min_face_size
199 + self.thresholds = thresholds
200 + self.factor = factor
201 + self.post_process = post_process
202 + self.select_largest = select_largest
203 + self.keep_all = keep_all
204 +
205 + self.pnet = PNet()
206 + self.rnet = RNet()
207 + self.onet = ONet()
208 +
209 + self.device = torch.device('cpu')
210 + if device is not None:
211 + self.device = device
212 + self.to(device)
213 +
214 + def forward(self, img, save_path=None, return_prob=False):
215 + """Run MTCNN face detection on a PIL image or numpy array. This method performs both
216 + detection and extraction of faces, returning tensors representing detected faces rather
217 + than the bounding boxes. To access bounding boxes, see the MTCNN.detect() method below.
218 +
219 + Arguments:
220 + img {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, or list.
221 +
222 + Keyword Arguments:
223 + save_path {str} -- An optional save path for the cropped image. Note that when
224 + self.post_process=True, although the returned tensor is post processed, the saved
225 + face image is not, so it is a true representation of the face in the input image.
226 + If `img` is a list of images, `save_path` should be a list of equal length.
227 + (default: {None})
228 + return_prob {bool} -- Whether or not to return the detection probability.
229 + (default: {False})
230 +
231 + Returns:
232 + Union[torch.Tensor, tuple(torch.tensor, float)] -- If detected, cropped image of a face
233 + with dimensions 3 x image_size x image_size. Optionally, the probability that a
234 + face was detected. If self.keep_all is True, n detected faces are returned in an
235 + n x 3 x image_size x image_size tensor with an optional list of detection
236 + probabilities. If `img` is a list of images, the item(s) returned have an extra
237 + dimension (batch) as the first dimension.
238 +
239 + Example:
240 + >>> from facenet_pytorch import MTCNN
241 + >>> mtcnn = MTCNN()
242 + >>> face_tensor, prob = mtcnn(img, save_path='face.png', return_prob=True)
243 + """
244 +
245 + # Detect faces
246 + with torch.no_grad():
247 + batch_boxes, batch_probs = self.detect(img)
248 +
249 + # Determine if a batch or single image was passed
250 + batch_mode = True
251 + if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4):
252 + img = [img]
253 + batch_boxes = [batch_boxes]
254 + batch_probs = [batch_probs]
255 + batch_mode = False
256 +
257 + # Parse save path(s)
258 + if save_path is not None:
259 + if isinstance(save_path, str):
260 + save_path = [save_path]
261 + else:
262 + save_path = [None for _ in range(len(img))]
263 +
264 + # Process all bounding boxes and probabilities
265 + faces, probs = [], []
266 + for im, box_im, prob_im, path_im in zip(img, batch_boxes, batch_probs, save_path):
267 + if box_im is None:
268 + faces.append(None)
269 + probs.append([None] if self.keep_all else None)
270 + continue
271 +
272 + if not self.keep_all:
273 + box_im = box_im[[0]]
274 +
275 + faces_im = []
276 + for i, box in enumerate(box_im):
277 + face_path = path_im
278 + if path_im is not None and i > 0:
279 + save_name, ext = os.path.splitext(path_im)
280 + face_path = save_name + '_' + str(i + 1) + ext
281 +
282 + face = extract_face(im, box, self.image_size, self.margin, face_path)
283 + if self.post_process:
284 + face = fixed_image_standardization(face)
285 + faces_im.append(face)
286 +
287 + if self.keep_all:
288 + faces_im = torch.stack(faces_im)
289 + else:
290 + faces_im = faces_im[0]
291 + prob_im = prob_im[0]
292 +
293 + faces.append(faces_im)
294 + probs.append(prob_im)
295 +
296 + if not batch_mode:
297 + faces = faces[0]
298 + probs = probs[0]
299 +
300 + if return_prob:
301 + return faces, probs
302 + else:
303 + return faces
304 +
305 + def detect(self, img, landmarks=False):
306 + """Detect all faces in PIL image and return bounding boxes and optional facial landmarks.
307 +
308 + This method is used by the forward method and is also useful for face detection tasks
309 + that require lower-level handling of bounding boxes and facial landmarks (e.g., face
310 + tracking). The functionality of the forward function can be emulated by using this method
311 + followed by the extract_face() function.
312 +
313 + Arguments:
314 + img {PIL.Image, np.ndarray, or list} -- A PIL image or a list of PIL images.
315 +
316 + Keyword Arguments:
317 + landmarks {bool} -- Whether to return facial landmarks in addition to bounding boxes.
318 + (default: {False})
319 +
320 + Returns:
321 + tuple(numpy.ndarray, list) -- For N detected faces, a tuple containing an
322 + Nx4 array of bounding boxes and a length N list of detection probabilities.
323 + Returned boxes will be sorted in descending order by detection probability if
324 + self.select_largest=False, otherwise the largest face will be returned first.
325 + If `img` is a list of images, the items returned have an extra dimension
326 + (batch) as the first dimension. Optionally, a third item, the facial landmarks,
327 + are returned if `landmarks=True`.
328 +
329 + Example:
330 + >>> from PIL import Image, ImageDraw
331 + >>> from facenet_pytorch import MTCNN, extract_face
332 + >>> mtcnn = MTCNN(keep_all=True)
333 + >>> boxes, probs, points = mtcnn.detect(img, landmarks=True)
334 + >>> # Draw boxes and save faces
335 + >>> img_draw = img.copy()
336 + >>> draw = ImageDraw.Draw(img_draw)
337 + >>> for i, (box, point) in enumerate(zip(boxes, points)):
338 + ... draw.rectangle(box.tolist(), width=5)
339 + ... for p in point:
340 + ... draw.rectangle((p - 10).tolist() + (p + 10).tolist(), width=10)
341 + ... extract_face(img, box, save_path='detected_face_{}.png'.format(i))
342 + >>> img_draw.save('annotated_faces.png')
343 + """
344 +
345 + with torch.no_grad():
346 + batch_boxes, batch_points = detect_face(
347 + img, self.min_face_size,
348 + self.pnet, self.rnet, self.onet,
349 + self.thresholds, self.factor,
350 + self.device
351 + )
352 +
353 + boxes, probs, points = [], [], []
354 + for box, point in zip(batch_boxes, batch_points):
355 + box = np.array(box)
356 + point = np.array(point)
357 + if len(box) == 0:
358 + boxes.append(None)
359 + probs.append([None])
360 + points.append(None)
361 + elif self.select_largest:
362 + box_order = np.argsort((box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]))[::-1]
363 + box = box[box_order]
364 + point = point[box_order]
365 + boxes.append(box[:, :4])
366 + probs.append(box[:, 4])
367 + points.append(point)
368 + else:
369 + boxes.append(box[:, :4])
370 + probs.append(box[:, 4])
371 + points.append(point)
372 + boxes = np.array(boxes)
373 + probs = np.array(probs)
374 + points = np.array(points)
375 +
376 + if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4):
377 + boxes = boxes[0]
378 + probs = probs[0]
379 + points = points[0]
380 +
381 + if landmarks:
382 + return boxes, probs, points
383 +
384 + return boxes, probs
385 +
386 +
387 +def fixed_image_standardization(image_tensor):
388 + processed_tensor = (image_tensor - 127.5) / 128.0
389 + return processed_tensor
390 +
391 +def prewhiten(x):
392 + mean = x.mean()
393 + std = x.std()
394 + std_adj = std.clamp(min=1.0/(float(x.numel())**0.5))
395 + y = (x - mean) / std_adj
396 + return y
1 +import torch
2 +from torch.nn.functional import interpolate
3 +from torchvision.transforms import functional as F
4 +from torchvision.ops.boxes import batched_nms
5 +import cv2
6 +from PIL import Image
7 +import numpy as np
8 +import os
9 +
10 +
11 +def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device):
12 + if isinstance(imgs, (np.ndarray, torch.Tensor)):
13 + imgs = torch.as_tensor(imgs, device=device)
14 + if len(imgs.shape) == 3:
15 + imgs = imgs.unsqueeze(0)
16 + else:
17 + if not isinstance(imgs, (list, tuple)):
18 + imgs = [imgs]
19 + if any(img.size != imgs[0].size for img in imgs):
20 + raise Exception("MTCNN batch processing only compatible with equal-dimension images.")
21 + imgs = np.stack([np.uint8(img) for img in imgs])
22 +
23 + imgs = torch.as_tensor(imgs, device=device)
24 +
25 + model_dtype = next(pnet.parameters()).dtype
26 + imgs = imgs.permute(0, 3, 1, 2).type(model_dtype)
27 +
28 + batch_size = len(imgs)
29 + h, w = imgs.shape[2:4]
30 + m = 12.0 / minsize
31 + minl = min(h, w)
32 + minl = minl * m
33 +
34 + # Create scale pyramid
35 + scale_i = m
36 + scales = []
37 + while minl >= 12:
38 + scales.append(scale_i)
39 + scale_i = scale_i * factor
40 + minl = minl * factor
41 +
42 + # First stage
43 + boxes = []
44 + image_inds = []
45 + all_inds = []
46 + all_i = 0
47 + for scale in scales:
48 + im_data = imresample(imgs, (int(h * scale + 1), int(w * scale + 1)))
49 + im_data = (im_data - 127.5) * 0.0078125
50 + reg, probs = pnet(im_data)
51 +
52 + boxes_scale, image_inds_scale = generateBoundingBox(reg, probs[:, 1], scale, threshold[0])
53 + boxes.append(boxes_scale)
54 + image_inds.append(image_inds_scale)
55 + all_inds.append(all_i + image_inds_scale)
56 + all_i += batch_size
57 +
58 + boxes = torch.cat(boxes, dim=0)
59 + image_inds = torch.cat(image_inds, dim=0).cpu()
60 + all_inds = torch.cat(all_inds, dim=0)
61 +
62 + # NMS within each scale + image
63 + pick = batched_nms(boxes[:, :4], boxes[:, 4], all_inds, 0.5)
64 + boxes, image_inds = boxes[pick], image_inds[pick]
65 +
66 + # NMS within each image
67 + pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
68 + boxes, image_inds = boxes[pick], image_inds[pick]
69 +
70 + regw = boxes[:, 2] - boxes[:, 0]
71 + regh = boxes[:, 3] - boxes[:, 1]
72 + qq1 = boxes[:, 0] + boxes[:, 5] * regw
73 + qq2 = boxes[:, 1] + boxes[:, 6] * regh
74 + qq3 = boxes[:, 2] + boxes[:, 7] * regw
75 + qq4 = boxes[:, 3] + boxes[:, 8] * regh
76 + boxes = torch.stack([qq1, qq2, qq3, qq4, boxes[:, 4]]).permute(1, 0)
77 + boxes = rerec(boxes)
78 + y, ey, x, ex = pad(boxes, w, h)
79 +
80 + # Second stage
81 + if len(boxes) > 0:
82 + im_data = []
83 + for k in range(len(y)):
84 + if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1):
85 + img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0)
86 + im_data.append(imresample(img_k, (24, 24)))
87 + im_data = torch.cat(im_data, dim=0)
88 + im_data = (im_data - 127.5) * 0.0078125
89 + out = rnet(im_data)
90 +
91 + out0 = out[0].permute(1, 0)
92 + out1 = out[1].permute(1, 0)
93 + score = out1[1, :]
94 + ipass = score > threshold[1]
95 + boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1)
96 + image_inds = image_inds[ipass]
97 + mv = out0[:, ipass].permute(1, 0)
98 +
99 + # NMS within each image
100 + pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
101 + boxes, image_inds, mv = boxes[pick], image_inds[pick], mv[pick]
102 + boxes = bbreg(boxes, mv)
103 + boxes = rerec(boxes)
104 +
105 + # Third stage
106 + points = torch.zeros(0, 5, 2, device=device)
107 + if len(boxes) > 0:
108 + y, ey, x, ex = pad(boxes, w, h)
109 + im_data = []
110 + for k in range(len(y)):
111 + if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1):
112 + img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0)
113 + im_data.append(imresample(img_k, (48, 48)))
114 + im_data = torch.cat(im_data, dim=0)
115 + im_data = (im_data - 127.5) * 0.0078125
116 + out = onet(im_data)
117 +
118 + out0 = out[0].permute(1, 0)
119 + out1 = out[1].permute(1, 0)
120 + out2 = out[2].permute(1, 0)
121 + score = out2[1, :]
122 + points = out1
123 + ipass = score > threshold[2]
124 + points = points[:, ipass]
125 + boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1)
126 + image_inds = image_inds[ipass]
127 + mv = out0[:, ipass].permute(1, 0)
128 +
129 + w_i = boxes[:, 2] - boxes[:, 0] + 1
130 + h_i = boxes[:, 3] - boxes[:, 1] + 1
131 + points_x = w_i.repeat(5, 1) * points[:5, :] + boxes[:, 0].repeat(5, 1) - 1
132 + points_y = h_i.repeat(5, 1) * points[5:10, :] + boxes[:, 1].repeat(5, 1) - 1
133 + points = torch.stack((points_x, points_y)).permute(2, 1, 0)
134 + boxes = bbreg(boxes, mv)
135 +
136 + # NMS within each image using "Min" strategy
137 + # pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7)
138 + pick = batched_nms_numpy(boxes[:, :4], boxes[:, 4], image_inds, 0.7, 'Min')
139 + boxes, image_inds, points = boxes[pick], image_inds[pick], points[pick]
140 +
141 + boxes = boxes.cpu().numpy()
142 + points = points.cpu().numpy()
143 +
144 + batch_boxes = []
145 + batch_points = []
146 + for b_i in range(batch_size):
147 + b_i_inds = np.where(image_inds == b_i)
148 + batch_boxes.append(boxes[b_i_inds].copy())
149 + batch_points.append(points[b_i_inds].copy())
150 +
151 + batch_boxes, batch_points = np.array(batch_boxes), np.array(batch_points)
152 +
153 + return batch_boxes, batch_points
154 +
155 +
156 +def bbreg(boundingbox, reg):
157 + if reg.shape[1] == 1:
158 + reg = torch.reshape(reg, (reg.shape[2], reg.shape[3]))
159 +
160 + w = boundingbox[:, 2] - boundingbox[:, 0] + 1
161 + h = boundingbox[:, 3] - boundingbox[:, 1] + 1
162 + b1 = boundingbox[:, 0] + reg[:, 0] * w
163 + b2 = boundingbox[:, 1] + reg[:, 1] * h
164 + b3 = boundingbox[:, 2] + reg[:, 2] * w
165 + b4 = boundingbox[:, 3] + reg[:, 3] * h
166 + boundingbox[:, :4] = torch.stack([b1, b2, b3, b4]).permute(1, 0)
167 +
168 + return boundingbox
169 +
170 +
171 +def generateBoundingBox(reg, probs, scale, thresh):
172 + stride = 2
173 + cellsize = 12
174 +
175 + reg = reg.permute(1, 0, 2, 3)
176 +
177 + mask = probs >= thresh
178 + mask_inds = mask.nonzero()
179 + image_inds = mask_inds[:, 0]
180 + score = probs[mask]
181 + reg = reg[:, mask].permute(1, 0)
182 + bb = mask_inds[:, 1:].type(reg.dtype).flip(1)
183 + q1 = ((stride * bb + 1) / scale).floor()
184 + q2 = ((stride * bb + cellsize - 1 + 1) / scale).floor()
185 + boundingbox = torch.cat([q1, q2, score.unsqueeze(1), reg], dim=1)
186 + return boundingbox, image_inds
187 +
188 +
189 +def nms_numpy(boxes, scores, threshold, method):
190 + if boxes.size == 0:
191 + return np.empty((0, 3))
192 +
193 + x1 = boxes[:, 0].copy()
194 + y1 = boxes[:, 1].copy()
195 + x2 = boxes[:, 2].copy()
196 + y2 = boxes[:, 3].copy()
197 + s = scores
198 + area = (x2 - x1 + 1) * (y2 - y1 + 1)
199 +
200 + I = np.argsort(s)
201 + pick = np.zeros_like(s, dtype=np.int16)
202 + counter = 0
203 + while I.size > 0:
204 + i = I[-1]
205 + pick[counter] = i
206 + counter += 1
207 + idx = I[0:-1]
208 +
209 + xx1 = np.maximum(x1[i], x1[idx]).copy()
210 + yy1 = np.maximum(y1[i], y1[idx]).copy()
211 + xx2 = np.minimum(x2[i], x2[idx]).copy()
212 + yy2 = np.minimum(y2[i], y2[idx]).copy()
213 +
214 + w = np.maximum(0.0, xx2 - xx1 + 1).copy()
215 + h = np.maximum(0.0, yy2 - yy1 + 1).copy()
216 +
217 + inter = w * h
218 + if method is "Min":
219 + o = inter / np.minimum(area[i], area[idx])
220 + else:
221 + o = inter / (area[i] + area[idx] - inter)
222 + I = I[np.where(o <= threshold)]
223 +
224 + pick = pick[:counter].copy()
225 + return pick
226 +
227 +
228 +def batched_nms_numpy(boxes, scores, idxs, threshold, method):
229 + device = boxes.device
230 + if boxes.numel() == 0:
231 + return torch.empty((0,), dtype=torch.int64, device=device)
232 + # strategy: in order to perform NMS independently per class.
233 + # we add an offset to all the boxes. The offset is dependent
234 + # only on the class idx, and is large enough so that boxes
235 + # from different classes do not overlap
236 + max_coordinate = boxes.max()
237 + offsets = idxs.to(boxes) * (max_coordinate + 1)
238 + boxes_for_nms = boxes + offsets[:, None]
239 + boxes_for_nms = boxes_for_nms.cpu().numpy()
240 + scores = scores.cpu().numpy()
241 + keep = nms_numpy(boxes_for_nms, scores, threshold, method)
242 + return torch.as_tensor(keep, dtype=torch.long, device=device)
243 +
244 +
245 +def pad(boxes, w, h):
246 + boxes = boxes.trunc().int().cpu().numpy()
247 + x = boxes[:, 0]
248 + y = boxes[:, 1]
249 + ex = boxes[:, 2]
250 + ey = boxes[:, 3]
251 +
252 + x[x < 1] = 1
253 + y[y < 1] = 1
254 + ex[ex > w] = w
255 + ey[ey > h] = h
256 +
257 + return y, ey, x, ex
258 +
259 +
260 +def rerec(bboxA):
261 + h = bboxA[:, 3] - bboxA[:, 1]
262 + w = bboxA[:, 2] - bboxA[:, 0]
263 +
264 + l = torch.max(w, h)
265 + bboxA[:, 0] = bboxA[:, 0] + w * 0.5 - l * 0.5
266 + bboxA[:, 1] = bboxA[:, 1] + h * 0.5 - l * 0.5
267 + bboxA[:, 2:4] = bboxA[:, :2] + l.repeat(2, 1).permute(1, 0)
268 +
269 + return bboxA
270 +
271 +
272 +def imresample(img, sz):
273 + im_data = interpolate(img, size=sz, mode="area")
274 + return im_data
275 +
276 +
277 +def crop_resize(img, box, image_size):
278 + if isinstance(img, np.ndarray):
279 + out = cv2.resize(
280 + img[box[1]:box[3], box[0]:box[2]],
281 + (image_size, image_size),
282 + interpolation=cv2.INTER_AREA
283 + ).copy()
284 + else:
285 + out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
286 + return out
287 +
288 +
289 +def save_img(img, path):
290 + if isinstance(img, np.ndarray):
291 + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
292 + else:
293 + img.save(path)
294 +
295 +
296 +def get_size(img):
297 + if isinstance(img, np.ndarray):
298 + return img.shape[1::-1]
299 + else:
300 + return img.size
301 +
302 +
303 +def extract_face(img, box, image_size=160, margin=0, save_path=None):
304 + """Extract face + margin from PIL Image given bounding box.
305 +
306 + Arguments:
307 + img {PIL.Image} -- A PIL Image.
308 + box {numpy.ndarray} -- Four-element bounding box.
309 + image_size {int} -- Output image size in pixels. The image will be square.
310 + margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
311 + Note that the application of the margin differs slightly from the davidsandberg/facenet
312 + repo, which applies the margin to the original image before resizing, making the margin
313 + dependent on the original image size.
314 + save_path {str} -- Save path for extracted face image. (default: {None})
315 +
316 + Returns:
317 + torch.tensor -- tensor representing the extracted face.
318 + """
319 + margin = [
320 + margin * (box[2] - box[0]) / (image_size - margin),
321 + margin * (box[3] - box[1]) / (image_size - margin),
322 + ]
323 + raw_image_size = get_size(img)
324 + box = [
325 + int(max(box[0] - margin[0] / 2, 0)),
326 + int(max(box[1] - margin[1] / 2, 0)),
327 + int(min(box[2] + margin[0] / 2, raw_image_size[0])),
328 + int(min(box[3] + margin[1] / 2, raw_image_size[1])),
329 + ]
330 +
331 + face = crop_resize(img, box, image_size)
332 +
333 + if save_path is not None:
334 + os.makedirs(os.path.dirname(save_path) + "/", exist_ok=True)
335 + save_img(face, save_path)
336 +
337 + face = F.to_tensor(np.float32(face))
338 +
339 + return face
1 +import tensorflow as tf
2 +import torch
3 +import json
4 +import os, sys
5 +
6 +from dependencies.facenet.src import facenet
7 +from dependencies.facenet.src.models import inception_resnet_v1 as tf_mdl
8 +from dependencies.facenet.src.align import detect_face
9 +
10 +from models.inception_resnet_v1 import InceptionResnetV1
11 +from models.mtcnn import PNet, RNet, ONet
12 +
13 +
14 +def import_tf_params(tf_mdl_dir, sess):
15 + """Import tensorflow model from save directory.
16 +
17 + Arguments:
18 + tf_mdl_dir {str} -- Location of protobuf, checkpoint, meta files.
19 + sess {tensorflow.Session} -- Tensorflow session object.
20 +
21 + Returns:
22 + (list, list, list) -- Tuple of lists containing the layer names,
23 + parameter arrays as numpy ndarrays, parameter shapes.
24 + """
25 + print('\nLoading tensorflow model\n')
26 + if callable(tf_mdl_dir):
27 + tf_mdl_dir(sess)
28 + else:
29 + facenet.load_model(tf_mdl_dir)
30 +
31 + print('\nGetting model weights\n')
32 + tf_layers = tf.trainable_variables()
33 + tf_params = sess.run(tf_layers)
34 +
35 + tf_shapes = [p.shape for p in tf_params]
36 + tf_layers = [l.name for l in tf_layers]
37 +
38 + if not callable(tf_mdl_dir):
39 + path = os.path.join(tf_mdl_dir, 'layer_description.json')
40 + else:
41 + path = 'data/layer_description.json'
42 + with open(path, 'w') as f:
43 + json.dump({l: s for l, s in zip(tf_layers, tf_shapes)}, f)
44 +
45 + return tf_layers, tf_params, tf_shapes
46 +
47 +
48 +def get_layer_indices(layer_lookup, tf_layers):
49 + """Giving a lookup of model layer attribute names and tensorflow variable names,
50 + find matching parameters.
51 +
52 + Arguments:
53 + layer_lookup {dict} -- Dictionary mapping pytorch attribute names to (partial)
54 + tensorflow variable names. Expects dict of the form {'attr': ['tf_name', ...]}
55 + where the '...'s are ignored.
56 + tf_layers {list} -- List of tensorflow variable names.
57 +
58 + Returns:
59 + list -- The input dictionary with the list of matching inds appended to each item.
60 + """
61 + layer_inds = {}
62 + for name, value in layer_lookup.items():
63 + layer_inds[name] = value + [[i for i, n in enumerate(tf_layers) if value[0] in n]]
64 + return layer_inds
65 +
66 +
67 +def load_tf_batchNorm(weights, layer):
68 + """Load tensorflow weights into nn.BatchNorm object.
69 +
70 + Arguments:
71 + weights {list} -- Tensorflow parameters.
72 + layer {torch.nn.Module} -- nn.BatchNorm.
73 + """
74 + layer.bias.data = torch.tensor(weights[0]).view(layer.bias.data.shape)
75 + layer.weight.data = torch.ones_like(layer.weight.data)
76 + layer.running_mean = torch.tensor(weights[1]).view(layer.running_mean.shape)
77 + layer.running_var = torch.tensor(weights[2]).view(layer.running_var.shape)
78 +
79 +
80 +def load_tf_conv2d(weights, layer, transpose=False):
81 + """Load tensorflow weights into nn.Conv2d object.
82 +
83 + Arguments:
84 + weights {list} -- Tensorflow parameters.
85 + layer {torch.nn.Module} -- nn.Conv2d.
86 + """
87 + if isinstance(weights, list):
88 + if len(weights) == 2:
89 + layer.bias.data = (
90 + torch.tensor(weights[1])
91 + .view(layer.bias.data.shape)
92 + )
93 + weights = weights[0]
94 +
95 + if transpose:
96 + dim_order = (3, 2, 1, 0)
97 + else:
98 + dim_order = (3, 2, 0, 1)
99 +
100 + layer.weight.data = (
101 + torch.tensor(weights)
102 + .permute(dim_order)
103 + .view(layer.weight.data.shape)
104 + )
105 +
106 +
107 +def load_tf_conv2d_trans(weights, layer):
108 + return load_tf_conv2d(weights, layer, transpose=True)
109 +
110 +
111 +def load_tf_basicConv2d(weights, layer):
112 + """Load tensorflow weights into grouped Conv2d+BatchNorm object.
113 +
114 + Arguments:
115 + weights {list} -- Tensorflow parameters.
116 + layer {torch.nn.Module} -- Object containing Conv2d+BatchNorm.
117 + """
118 + load_tf_conv2d(weights[0], layer.conv)
119 + load_tf_batchNorm(weights[1:], layer.bn)
120 +
121 +
122 +def load_tf_linear(weights, layer):
123 + """Load tensorflow weights into nn.Linear object.
124 +
125 + Arguments:
126 + weights {list} -- Tensorflow parameters.
127 + layer {torch.nn.Module} -- nn.Linear.
128 + """
129 + if isinstance(weights, list):
130 + if len(weights) == 2:
131 + layer.bias.data = (
132 + torch.tensor(weights[1])
133 + .view(layer.bias.data.shape)
134 + )
135 + weights = weights[0]
136 + layer.weight.data = (
137 + torch.tensor(weights)
138 + .transpose(-1, 0)
139 + .view(layer.weight.data.shape)
140 + )
141 +
142 +
143 +# High-level parameter-loading functions:
144 +
145 +def load_tf_block35(weights, layer):
146 + load_tf_basicConv2d(weights[:4], layer.branch0)
147 + load_tf_basicConv2d(weights[4:8], layer.branch1[0])
148 + load_tf_basicConv2d(weights[8:12], layer.branch1[1])
149 + load_tf_basicConv2d(weights[12:16], layer.branch2[0])
150 + load_tf_basicConv2d(weights[16:20], layer.branch2[1])
151 + load_tf_basicConv2d(weights[20:24], layer.branch2[2])
152 + load_tf_conv2d(weights[24:26], layer.conv2d)
153 +
154 +
155 +def load_tf_block17_8(weights, layer):
156 + load_tf_basicConv2d(weights[:4], layer.branch0)
157 + load_tf_basicConv2d(weights[4:8], layer.branch1[0])
158 + load_tf_basicConv2d(weights[8:12], layer.branch1[1])
159 + load_tf_basicConv2d(weights[12:16], layer.branch1[2])
160 + load_tf_conv2d(weights[16:18], layer.conv2d)
161 +
162 +
163 +def load_tf_mixed6a(weights, layer):
164 + if len(weights) != 16:
165 + raise ValueError(f'Number of weight arrays ({len(weights)}) not equal to 16')
166 + load_tf_basicConv2d(weights[:4], layer.branch0)
167 + load_tf_basicConv2d(weights[4:8], layer.branch1[0])
168 + load_tf_basicConv2d(weights[8:12], layer.branch1[1])
169 + load_tf_basicConv2d(weights[12:16], layer.branch1[2])
170 +
171 +
172 +def load_tf_mixed7a(weights, layer):
173 + if len(weights) != 28:
174 + raise ValueError(f'Number of weight arrays ({len(weights)}) not equal to 28')
175 + load_tf_basicConv2d(weights[:4], layer.branch0[0])
176 + load_tf_basicConv2d(weights[4:8], layer.branch0[1])
177 + load_tf_basicConv2d(weights[8:12], layer.branch1[0])
178 + load_tf_basicConv2d(weights[12:16], layer.branch1[1])
179 + load_tf_basicConv2d(weights[16:20], layer.branch2[0])
180 + load_tf_basicConv2d(weights[20:24], layer.branch2[1])
181 + load_tf_basicConv2d(weights[24:28], layer.branch2[2])
182 +
183 +
184 +def load_tf_repeats(weights, layer, rptlen, subfun):
185 + if len(weights) % rptlen != 0:
186 + raise ValueError(f'Number of weight arrays ({len(weights)}) not divisible by {rptlen}')
187 + weights_split = [weights[i:i+rptlen] for i in range(0, len(weights), rptlen)]
188 + for i, w in enumerate(weights_split):
189 + subfun(w, getattr(layer, str(i)))
190 +
191 +
192 +def load_tf_repeat_1(weights, layer):
193 + load_tf_repeats(weights, layer, 26, load_tf_block35)
194 +
195 +
196 +def load_tf_repeat_2(weights, layer):
197 + load_tf_repeats(weights, layer, 18, load_tf_block17_8)
198 +
199 +
200 +def load_tf_repeat_3(weights, layer):
201 + load_tf_repeats(weights, layer, 18, load_tf_block17_8)
202 +
203 +
204 +def test_loaded_params(mdl, tf_params, tf_layers):
205 + """Check each parameter in a pytorch model for an equivalent parameter
206 + in a list of tensorflow variables.
207 +
208 + Arguments:
209 + mdl {torch.nn.Module} -- Pytorch model.
210 + tf_params {list} -- List of ndarrays representing tensorflow variables.
211 + tf_layers {list} -- Corresponding list of tensorflow variable names.
212 + """
213 + tf_means = torch.stack([torch.tensor(p).mean() for p in tf_params])
214 + for name, param in mdl.named_parameters():
215 + pt_mean = param.data.mean()
216 + matching_inds = ((tf_means - pt_mean).abs() < 1e-8).nonzero()
217 + print(f'{name} equivalent to {[tf_layers[i] for i in matching_inds]}')
218 +
219 +
220 +def compare_model_outputs(pt_mdl, sess, test_data):
221 + """Given some testing data, compare the output of pytorch and tensorflow models.
222 +
223 + Arguments:
224 + pt_mdl {torch.nn.Module} -- Pytorch model.
225 + sess {tensorflow.Session} -- Tensorflow session object.
226 + test_data {torch.Tensor} -- Pytorch tensor.
227 + """
228 + print('\nPassing test data through TF model\n')
229 + if isinstance(sess, tf.Session):
230 + images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
231 + phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
232 + embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
233 + feed_dict = {images_placeholder: test_data.numpy(), phase_train_placeholder: False}
234 + tf_output = torch.tensor(sess.run(embeddings, feed_dict=feed_dict))
235 + else:
236 + tf_output = sess(test_data)
237 +
238 + print(tf_output)
239 +
240 + print('\nPassing test data through PT model\n')
241 + pt_output = pt_mdl(test_data.permute(0, 3, 1, 2))
242 + print(pt_output)
243 +
244 + distance = (tf_output - pt_output).norm()
245 + print(f'\nDistance {distance}\n')
246 +
247 +
248 +def compare_mtcnn(pt_mdl, tf_fun, sess, ind, test_data):
249 + tf_mdls = tf_fun(sess)
250 + tf_mdl = tf_mdls[ind]
251 +
252 + print('\nPassing test data through TF model\n')
253 + tf_output = tf_mdl(test_data.numpy())
254 + tf_output = [torch.tensor(out) for out in tf_output]
255 + print('\n'.join([str(o.view(-1)[:10]) for o in tf_output]))
256 +
257 + print('\nPassing test data through PT model\n')
258 + with torch.no_grad():
259 + pt_output = pt_mdl(test_data.permute(0, 3, 2, 1))
260 + pt_output = [torch.tensor(out) for out in pt_output]
261 + for i in range(len(pt_output)):
262 + if len(pt_output[i].shape) == 4:
263 + pt_output[i] = pt_output[i].permute(0, 3, 2, 1).contiguous()
264 + print('\n'.join([str(o.view(-1)[:10]) for o in pt_output]))
265 +
266 + distance = [(tf_o - pt_o).norm() for tf_o, pt_o in zip(tf_output, pt_output)]
267 + print(f'\nDistance {distance}\n')
268 +
269 +
270 +def load_tf_model_weights(mdl, layer_lookup, tf_mdl_dir, is_resnet=True, arg_num=None):
271 + """Load tensorflow parameters into a pytorch model.
272 +
273 + Arguments:
274 + mdl {torch.nn.Module} -- Pytorch model.
275 + layer_lookup {[type]} -- Dictionary mapping pytorch attribute names to (partial)
276 + tensorflow variable names, and a function suitable for loading weights.
277 + Expects dict of the form {'attr': ['tf_name', function]}.
278 + tf_mdl_dir {str} -- Location of protobuf, checkpoint, meta files.
279 + """
280 + tf.reset_default_graph()
281 + with tf.Session() as sess:
282 + tf_layers, tf_params, tf_shapes = import_tf_params(tf_mdl_dir, sess)
283 + layer_info = get_layer_indices(layer_lookup, tf_layers)
284 +
285 + for layer_name, info in layer_info.items():
286 + print(f'Loading {info[0]}/* into {layer_name}')
287 + weights = [tf_params[i] for i in info[2]]
288 + layer = getattr(mdl, layer_name)
289 + info[1](weights, layer)
290 +
291 + test_loaded_params(mdl, tf_params, tf_layers)
292 +
293 + if is_resnet:
294 + compare_model_outputs(mdl, sess, torch.randn(5, 160, 160, 3).detach())
295 +
296 +
297 +def tensorflow2pytorch():
298 + lookup_inception_resnet_v1 = {
299 + 'conv2d_1a': ['InceptionResnetV1/Conv2d_1a_3x3', load_tf_basicConv2d],
300 + 'conv2d_2a': ['InceptionResnetV1/Conv2d_2a_3x3', load_tf_basicConv2d],
301 + 'conv2d_2b': ['InceptionResnetV1/Conv2d_2b_3x3', load_tf_basicConv2d],
302 + 'conv2d_3b': ['InceptionResnetV1/Conv2d_3b_1x1', load_tf_basicConv2d],
303 + 'conv2d_4a': ['InceptionResnetV1/Conv2d_4a_3x3', load_tf_basicConv2d],
304 + 'conv2d_4b': ['InceptionResnetV1/Conv2d_4b_3x3', load_tf_basicConv2d],
305 + 'repeat_1': ['InceptionResnetV1/Repeat/block35', load_tf_repeat_1],
306 + 'mixed_6a': ['InceptionResnetV1/Mixed_6a', load_tf_mixed6a],
307 + 'repeat_2': ['InceptionResnetV1/Repeat_1/block17', load_tf_repeat_2],
308 + 'mixed_7a': ['InceptionResnetV1/Mixed_7a', load_tf_mixed7a],
309 + 'repeat_3': ['InceptionResnetV1/Repeat_2/block8', load_tf_repeat_3],
310 + 'block8': ['InceptionResnetV1/Block8', load_tf_block17_8],
311 + 'last_linear': ['InceptionResnetV1/Bottleneck/weights', load_tf_linear],
312 + 'last_bn': ['InceptionResnetV1/Bottleneck/BatchNorm', load_tf_batchNorm],
313 + 'logits': ['Logits', load_tf_linear],
314 + }
315 +
316 + print('\nLoad VGGFace2-trained weights and save\n')
317 + mdl = InceptionResnetV1(num_classes=8631).eval()
318 + tf_mdl_dir = 'data/20180402-114759'
319 + data_name = 'vggface2'
320 + load_tf_model_weights(mdl, lookup_inception_resnet_v1, tf_mdl_dir)
321 + state_dict = mdl.state_dict()
322 + torch.save(state_dict, f'{tf_mdl_dir}-{data_name}.pt')
323 + torch.save(
324 + {
325 + 'logits.weight': state_dict['logits.weight'],
326 + 'logits.bias': state_dict['logits.bias'],
327 + },
328 + f'{tf_mdl_dir}-{data_name}-logits.pt'
329 + )
330 + state_dict.pop('logits.weight')
331 + state_dict.pop('logits.bias')
332 + torch.save(state_dict, f'{tf_mdl_dir}-{data_name}-features.pt')
333 +
334 + print('\nLoad CASIA-Webface-trained weights and save\n')
335 + mdl = InceptionResnetV1(num_classes=10575).eval()
336 + tf_mdl_dir = 'data/20180408-102900'
337 + data_name = 'casia-webface'
338 + load_tf_model_weights(mdl, lookup_inception_resnet_v1, tf_mdl_dir)
339 + state_dict = mdl.state_dict()
340 + torch.save(state_dict, f'{tf_mdl_dir}-{data_name}.pt')
341 + torch.save(
342 + {
343 + 'logits.weight': state_dict['logits.weight'],
344 + 'logits.bias': state_dict['logits.bias'],
345 + },
346 + f'{tf_mdl_dir}-{data_name}-logits.pt'
347 + )
348 + state_dict.pop('logits.weight')
349 + state_dict.pop('logits.bias')
350 + torch.save(state_dict, f'{tf_mdl_dir}-{data_name}-features.pt')
351 +
352 + lookup_pnet = {
353 + 'conv1': ['pnet/conv1', load_tf_conv2d_trans],
354 + 'prelu1': ['pnet/PReLU1', load_tf_linear],
355 + 'conv2': ['pnet/conv2', load_tf_conv2d_trans],
356 + 'prelu2': ['pnet/PReLU2', load_tf_linear],
357 + 'conv3': ['pnet/conv3', load_tf_conv2d_trans],
358 + 'prelu3': ['pnet/PReLU3', load_tf_linear],
359 + 'conv4_1': ['pnet/conv4-1', load_tf_conv2d_trans],
360 + 'conv4_2': ['pnet/conv4-2', load_tf_conv2d_trans],
361 + }
362 + lookup_rnet = {
363 + 'conv1': ['rnet/conv1', load_tf_conv2d_trans],
364 + 'prelu1': ['rnet/prelu1', load_tf_linear],
365 + 'conv2': ['rnet/conv2', load_tf_conv2d_trans],
366 + 'prelu2': ['rnet/prelu2', load_tf_linear],
367 + 'conv3': ['rnet/conv3', load_tf_conv2d_trans],
368 + 'prelu3': ['rnet/prelu3', load_tf_linear],
369 + 'dense4': ['rnet/conv4', load_tf_linear],
370 + 'prelu4': ['rnet/prelu4', load_tf_linear],
371 + 'dense5_1': ['rnet/conv5-1', load_tf_linear],
372 + 'dense5_2': ['rnet/conv5-2', load_tf_linear],
373 + }
374 + lookup_onet = {
375 + 'conv1': ['onet/conv1', load_tf_conv2d_trans],
376 + 'prelu1': ['onet/prelu1', load_tf_linear],
377 + 'conv2': ['onet/conv2', load_tf_conv2d_trans],
378 + 'prelu2': ['onet/prelu2', load_tf_linear],
379 + 'conv3': ['onet/conv3', load_tf_conv2d_trans],
380 + 'prelu3': ['onet/prelu3', load_tf_linear],
381 + 'conv4': ['onet/conv4', load_tf_conv2d_trans],
382 + 'prelu4': ['onet/prelu4', load_tf_linear],
383 + 'dense5': ['onet/conv5', load_tf_linear],
384 + 'prelu5': ['onet/prelu5', load_tf_linear],
385 + 'dense6_1': ['onet/conv6-1', load_tf_linear],
386 + 'dense6_2': ['onet/conv6-2', load_tf_linear],
387 + 'dense6_3': ['onet/conv6-3', load_tf_linear],
388 + }
389 +
390 + print('\nLoad PNet weights and save\n')
391 + tf_mdl_dir = lambda sess: detect_face.create_mtcnn(sess, None)
392 + mdl = PNet()
393 + data_name = 'pnet'
394 + load_tf_model_weights(mdl, lookup_pnet, tf_mdl_dir, is_resnet=False, arg_num=0)
395 + torch.save(mdl.state_dict(), f'data/{data_name}.pt')
396 + tf.reset_default_graph()
397 + with tf.Session() as sess:
398 + compare_mtcnn(mdl, tf_mdl_dir, sess, 0, torch.randn(1, 256, 256, 3).detach())
399 +
400 + print('\nLoad RNet weights and save\n')
401 + mdl = RNet()
402 + data_name = 'rnet'
403 + load_tf_model_weights(mdl, lookup_rnet, tf_mdl_dir, is_resnet=False, arg_num=1)
404 + torch.save(mdl.state_dict(), f'data/{data_name}.pt')
405 + tf.reset_default_graph()
406 + with tf.Session() as sess:
407 + compare_mtcnn(mdl, tf_mdl_dir, sess, 1, torch.randn(1, 24, 24, 3).detach())
408 +
409 + print('\nLoad ONet weights and save\n')
410 + mdl = ONet()
411 + data_name = 'onet'
412 + load_tf_model_weights(mdl, lookup_onet, tf_mdl_dir, is_resnet=False, arg_num=2)
413 + torch.save(mdl.state_dict(), f'data/{data_name}.pt')
414 + tf.reset_default_graph()
415 + with tf.Session() as sess:
416 + compare_mtcnn(mdl, tf_mdl_dir, sess, 2, torch.randn(1, 48, 48, 3).detach())
1 +import torch
2 +import numpy as np
3 +import time
4 +
5 +
6 +class Logger(object):
7 +
8 + def __init__(self, mode, length, calculate_mean=False):
9 + self.mode = mode
10 + self.length = length
11 + self.calculate_mean = calculate_mean
12 + if self.calculate_mean:
13 + self.fn = lambda x, i: x / (i + 1)
14 + else:
15 + self.fn = lambda x, i: x
16 +
17 + def __call__(self, loss, metrics, i):
18 + track_str = '\r{} | {:5d}/{:<5d}| '.format(self.mode, i + 1, self.length)
19 + loss_str = 'loss: {:9.4f} | '.format(self.fn(loss, i))
20 + metric_str = ' | '.join('{}: {:9.4f}'.format(k, self.fn(v, i)) for k, v in metrics.items())
21 + print(track_str + loss_str + metric_str + ' ', end='')
22 + if i + 1 == self.length:
23 + print('')
24 +
25 +
26 +class BatchTimer(object):
27 + """Batch timing class.
28 + Use this class for tracking training and testing time/rate per batch or per sample.
29 +
30 + Keyword Arguments:
31 + rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds
32 + per batch or sample). (default: {True})
33 + per_sample {bool} -- Whether to report times or rates per sample or per batch.
34 + (default: {True})
35 + """
36 +
37 + def __init__(self, rate=True, per_sample=True):
38 + self.start = time.time()
39 + self.end = None
40 + self.rate = rate
41 + self.per_sample = per_sample
42 +
43 + def __call__(self, y_pred, y):
44 + self.end = time.time()
45 + elapsed = self.end - self.start
46 + self.start = self.end
47 + self.end = None
48 +
49 + if self.per_sample:
50 + elapsed /= len(y_pred)
51 + if self.rate:
52 + elapsed = 1 / elapsed
53 +
54 + return torch.tensor(elapsed)
55 +
56 +
57 +def accuracy(logits, y):
58 + _, preds = torch.max(logits, 1)
59 + return (preds == y).float().mean()
60 +
61 +
62 +def pass_epoch(
63 + model, loss_fn, loader, optimizer=None, scheduler=None,
64 + batch_metrics={'time': BatchTimer()}, show_running=True,
65 + device='cpu', writer=None
66 +):
67 + """Train or evaluate over a data epoch.
68 +
69 + Arguments:
70 + model {torch.nn.Module} -- Pytorch model.
71 + loss_fn {callable} -- A function to compute (scalar) loss.
72 + loader {torch.utils.data.DataLoader} -- A pytorch data loader.
73 +
74 + Keyword Arguments:
75 + optimizer {torch.optim.Optimizer} -- A pytorch optimizer.
76 + scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None})
77 + batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default
78 + is a simple timer. A progressive average of these metrics, along with the average
79 + loss, is printed every batch. (default: {{'time': iter_timer()}})
80 + show_running {bool} -- Whether or not to print losses and metrics for the current batch
81 + or rolling averages. (default: {False})
82 + device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'})
83 + writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None})
84 +
85 + Returns:
86 + tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average
87 + metric values across the epoch.
88 + """
89 +
90 + mode = 'Train' if model.training else 'Valid'
91 + logger = Logger(mode, length=len(loader), calculate_mean=show_running)
92 + loss = 0
93 + metrics = {}
94 +
95 + for i_batch, (x, y) in enumerate(loader):
96 + x = x.to(device)
97 + y = y.to(device)
98 + y_pred = model(x)
99 + loss_batch = loss_fn(y_pred, y)
100 +
101 + if model.training:
102 + loss_batch.backward()
103 + optimizer.step()
104 + optimizer.zero_grad()
105 +
106 + metrics_batch = {}
107 + for metric_name, metric_fn in batch_metrics.items():
108 + metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu()
109 + metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name]
110 +
111 + if writer is not None and model.training:
112 + if writer.iteration % writer.interval == 0:
113 + writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration)
114 + for metric_name, metric_batch in metrics_batch.items():
115 + writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration)
116 + writer.iteration += 1
117 +
118 + loss_batch = loss_batch.detach().cpu()
119 + loss += loss_batch
120 + if show_running:
121 + logger(loss, metrics, i_batch)
122 + else:
123 + logger(loss_batch, metrics_batch, i_batch)
124 +
125 + if model.training and scheduler is not None:
126 + scheduler.step()
127 +
128 + loss = loss / (i_batch + 1)
129 + metrics = {k: v / (i_batch + 1) for k, v in metrics.items()}
130 +
131 + if writer is not None and not model.training:
132 + writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration)
133 + for metric_name, metric in metrics.items():
134 + writer.add_scalars(metric_name, {mode: metric})
135 +
136 + return loss, metrics
137 +
138 +
139 +def collate_pil(x):
140 + out_x, out_y = [], []
141 + for xx, yy in x:
142 + out_x.append(xx)
143 + out_y.append(yy)
144 + return out_x, out_y