Showing
14 changed files
with
1439 additions
and
0 deletions
client/clinet(window).py
0 → 100644
1 | +################################################## | ||
2 | +#1. webcam에서 얼굴을 인식합니다. # | ||
3 | +#2. 얼굴일 확률이 95% 이상인 이미지를 이미지 서버로 전송합니다. # | ||
4 | +#3. 전처리 된 데이터를 verification 서버에 전송합니다. # | ||
5 | +################################################## | ||
6 | +import torch | ||
7 | +import numpy as np | ||
8 | +import cv2 | ||
9 | +import asyncio | ||
10 | +import websockets | ||
11 | +import json | ||
12 | +import os | ||
13 | +import timeit | ||
14 | +import base64 | ||
15 | + | ||
16 | +from PIL import Image | ||
17 | +from io import BytesIO | ||
18 | +import requests | ||
19 | + | ||
20 | +from models.mtcnn import MTCNN | ||
21 | + | ||
22 | +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
23 | +print('Running on device: {}'.format(device)) | ||
24 | + | ||
25 | +mtcnn = MTCNN(keep_all=True, device=device) | ||
26 | + | ||
27 | +uri = 'ws://localhost:8765' | ||
28 | + | ||
29 | +async def send_face(face_list): | ||
30 | + global uri | ||
31 | + async with websockets.connect(uri) as websocket: | ||
32 | + for face in face_list: | ||
33 | + #type: np.float32 | ||
34 | + print(face.shape) | ||
35 | + data = json.dumps({"action": "verify", "MTCNN": face.tolist()}) | ||
36 | + await websocket.send(data) | ||
37 | + print('send: verify', len(face_list), 'face(s)') | ||
38 | + code = await websocket.recv() | ||
39 | + print('code:', code) | ||
40 | + | ||
41 | +async def send_image(image_list): | ||
42 | + global uri | ||
43 | + async with websockets.connect(uri) as websocket: | ||
44 | + for image in image_list: | ||
45 | + data = json.dumps({"action": "save_image", "image": image.tolist(), "shape": image.shape}) | ||
46 | + await websocket.send(data) | ||
47 | + print('send', len(image_list), 'image(s)') | ||
48 | + code = await websocket.recv() | ||
49 | + print('code:', code) | ||
50 | + | ||
51 | +def detect_face(frame): | ||
52 | + # If required, create a face detection pipeline using MTCNN: | ||
53 | + global mtcnn | ||
54 | + results = mtcnn.detect(frame) | ||
55 | + image_list = [] | ||
56 | + if results[1][0] == None: | ||
57 | + return [] | ||
58 | + for box, prob in zip(results[0], results[1]): | ||
59 | + if prob < 0.95: | ||
60 | + continue | ||
61 | + print('face detected. prob:', prob) | ||
62 | + x1, y1, x2, y2 = box | ||
63 | + image = frame[int(y1-10):int(y2+10), int(x1-10):int(x2+10)] | ||
64 | + image_list.append(image) | ||
65 | + return image_list | ||
66 | + | ||
67 | +def make_face_list(frame): | ||
68 | + global mtcnn | ||
69 | + results, prob = mtcnn(frame, return_prob = True) | ||
70 | + face_list = [] | ||
71 | + if prob[0] == None: | ||
72 | + return [] | ||
73 | + for result, prob in zip(results, prob): | ||
74 | + if prob < 0.95: | ||
75 | + continue | ||
76 | + #np.float32 | ||
77 | + face_list.append(result.numpy()) | ||
78 | + return face_list | ||
79 | + | ||
80 | +cap = cv2.VideoCapture(0) | ||
81 | +cap.set(3, 720) | ||
82 | +cap.set(4, 480) | ||
83 | +while True: | ||
84 | + try: | ||
85 | + #start = timeit.default_timer() | ||
86 | + ret, frame = cap.read() | ||
87 | + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
88 | + face_list = make_face_list(frame) | ||
89 | + image_list = detect_face(frame) | ||
90 | + ##embedding server로 전송## | ||
91 | + if face_list: | ||
92 | + asyncio.get_event_loop().run_until_complete(send_face(face_list)) | ||
93 | + ################### | ||
94 | + ##image server로 전송## | ||
95 | + if image_list: | ||
96 | + asyncio.get_event_loop().run_until_complete(send_image(image_list)) | ||
97 | + ################### | ||
98 | + #end = timeit.default_timer() | ||
99 | + #print('delta time: ', end - start) | ||
100 | + except Exception as ex: | ||
101 | + print(ex) |
client/legacy/detection.py
0 → 100644
1 | +import torch | ||
2 | +import numpy as np | ||
3 | +import cv2 | ||
4 | +import matplotlib.pyplot as plt | ||
5 | +import os | ||
6 | + | ||
7 | +from PIL import Image, ImageDraw | ||
8 | +from IPython import display | ||
9 | + | ||
10 | +from models import mtcnn | ||
11 | +from models import inception_resnet_v1 | ||
12 | + | ||
13 | +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
14 | +print('Running on device: {}'.format(device)) | ||
15 | + | ||
16 | +def extract_face(filename, required_size=(224, 224)): | ||
17 | + # If required, create a face detection pipeline using MTCNN: | ||
18 | + mtcnn_model = mtcnn.MTCNN(keep_all=True, device=device) | ||
19 | + pixels = plt.imread(os.path.join(os.path.abspath(''), filename)) | ||
20 | + results = mtcnn_model.detect(pixels) | ||
21 | + face_array = [] | ||
22 | + for box, prob in zip(results[0], results[1]): | ||
23 | + #boxes, _ = result | ||
24 | + print('face detected. prob:', prob) | ||
25 | + x1, y1, x2, y2 = box | ||
26 | + face = pixels[int(y1):int(y2), int(x1):int(x2)] | ||
27 | + image = Image.fromarray(face) | ||
28 | + image = image.resize(required_size) | ||
29 | + face_array.append(np.asarray(image)) | ||
30 | + return face_array | ||
31 | + | ||
32 | +face_array = extract_face('image/test1.jpg') | ||
33 | +for face in face_array: | ||
34 | + plt.figure() | ||
35 | + plt.imshow(face) | ||
36 | + plt.show() | ||
37 | + | ||
38 | +face_array = extract_face('image/test2.jpg') | ||
39 | +for face in face_array: | ||
40 | + plt.figure() | ||
41 | + plt.imshow(face) | ||
42 | + plt.show() | ||
43 | + |
client/legacy/image/test1.jpg
0 → 100644
337 KB
client/legacy/image/test2.jpg
0 → 100644
167 KB
No preview for this file type
No preview for this file type
client/models/data/onet.pt
0 → 100644
No preview for this file type
client/models/data/pnet.pt
0 → 100644
No preview for this file type
client/models/data/rnet.pt
0 → 100644
No preview for this file type
client/models/mtcnn.py
0 → 100644
1 | +import torch | ||
2 | +from torch import nn | ||
3 | +import numpy as np | ||
4 | +import os | ||
5 | + | ||
6 | +from .utils.detect_face import detect_face, extract_face | ||
7 | + | ||
8 | + | ||
9 | +class PNet(nn.Module): | ||
10 | + """MTCNN PNet. | ||
11 | + | ||
12 | + Keyword Arguments: | ||
13 | + pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True}) | ||
14 | + """ | ||
15 | + | ||
16 | + def __init__(self, pretrained=True): | ||
17 | + super().__init__() | ||
18 | + | ||
19 | + self.conv1 = nn.Conv2d(3, 10, kernel_size=3) | ||
20 | + self.prelu1 = nn.PReLU(10) | ||
21 | + self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True) | ||
22 | + self.conv2 = nn.Conv2d(10, 16, kernel_size=3) | ||
23 | + self.prelu2 = nn.PReLU(16) | ||
24 | + self.conv3 = nn.Conv2d(16, 32, kernel_size=3) | ||
25 | + self.prelu3 = nn.PReLU(32) | ||
26 | + self.conv4_1 = nn.Conv2d(32, 2, kernel_size=1) | ||
27 | + self.softmax4_1 = nn.Softmax(dim=1) | ||
28 | + self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1) | ||
29 | + | ||
30 | + self.training = False | ||
31 | + | ||
32 | + if pretrained: | ||
33 | + state_dict_path = os.path.join(os.path.dirname(__file__), 'data/pnet.pt') | ||
34 | + state_dict = torch.load(state_dict_path) | ||
35 | + self.load_state_dict(state_dict) | ||
36 | + | ||
37 | + def forward(self, x): | ||
38 | + x = self.conv1(x) | ||
39 | + x = self.prelu1(x) | ||
40 | + x = self.pool1(x) | ||
41 | + x = self.conv2(x) | ||
42 | + x = self.prelu2(x) | ||
43 | + x = self.conv3(x) | ||
44 | + x = self.prelu3(x) | ||
45 | + a = self.conv4_1(x) | ||
46 | + a = self.softmax4_1(a) | ||
47 | + b = self.conv4_2(x) | ||
48 | + return b, a | ||
49 | + | ||
50 | + | ||
51 | +class RNet(nn.Module): | ||
52 | + """MTCNN RNet. | ||
53 | + | ||
54 | + Keyword Arguments: | ||
55 | + pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True}) | ||
56 | + """ | ||
57 | + | ||
58 | + def __init__(self, pretrained=True): | ||
59 | + super().__init__() | ||
60 | + | ||
61 | + self.conv1 = nn.Conv2d(3, 28, kernel_size=3) | ||
62 | + self.prelu1 = nn.PReLU(28) | ||
63 | + self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True) | ||
64 | + self.conv2 = nn.Conv2d(28, 48, kernel_size=3) | ||
65 | + self.prelu2 = nn.PReLU(48) | ||
66 | + self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True) | ||
67 | + self.conv3 = nn.Conv2d(48, 64, kernel_size=2) | ||
68 | + self.prelu3 = nn.PReLU(64) | ||
69 | + self.dense4 = nn.Linear(576, 128) | ||
70 | + self.prelu4 = nn.PReLU(128) | ||
71 | + self.dense5_1 = nn.Linear(128, 2) | ||
72 | + self.softmax5_1 = nn.Softmax(dim=1) | ||
73 | + self.dense5_2 = nn.Linear(128, 4) | ||
74 | + | ||
75 | + self.training = False | ||
76 | + | ||
77 | + if pretrained: | ||
78 | + state_dict_path = os.path.join(os.path.dirname(__file__), 'data/rnet.pt') | ||
79 | + state_dict = torch.load(state_dict_path) | ||
80 | + self.load_state_dict(state_dict) | ||
81 | + | ||
82 | + def forward(self, x): | ||
83 | + x = self.conv1(x) | ||
84 | + x = self.prelu1(x) | ||
85 | + x = self.pool1(x) | ||
86 | + x = self.conv2(x) | ||
87 | + x = self.prelu2(x) | ||
88 | + x = self.pool2(x) | ||
89 | + x = self.conv3(x) | ||
90 | + x = self.prelu3(x) | ||
91 | + x = x.permute(0, 3, 2, 1).contiguous() | ||
92 | + x = self.dense4(x.view(x.shape[0], -1)) | ||
93 | + x = self.prelu4(x) | ||
94 | + a = self.dense5_1(x) | ||
95 | + a = self.softmax5_1(a) | ||
96 | + b = self.dense5_2(x) | ||
97 | + return b, a | ||
98 | + | ||
99 | + | ||
100 | +class ONet(nn.Module): | ||
101 | + """MTCNN ONet. | ||
102 | + | ||
103 | + Keyword Arguments: | ||
104 | + pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True}) | ||
105 | + """ | ||
106 | + | ||
107 | + def __init__(self, pretrained=True): | ||
108 | + super().__init__() | ||
109 | + | ||
110 | + self.conv1 = nn.Conv2d(3, 32, kernel_size=3) | ||
111 | + self.prelu1 = nn.PReLU(32) | ||
112 | + self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True) | ||
113 | + self.conv2 = nn.Conv2d(32, 64, kernel_size=3) | ||
114 | + self.prelu2 = nn.PReLU(64) | ||
115 | + self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True) | ||
116 | + self.conv3 = nn.Conv2d(64, 64, kernel_size=3) | ||
117 | + self.prelu3 = nn.PReLU(64) | ||
118 | + self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True) | ||
119 | + self.conv4 = nn.Conv2d(64, 128, kernel_size=2) | ||
120 | + self.prelu4 = nn.PReLU(128) | ||
121 | + self.dense5 = nn.Linear(1152, 256) | ||
122 | + self.prelu5 = nn.PReLU(256) | ||
123 | + self.dense6_1 = nn.Linear(256, 2) | ||
124 | + self.softmax6_1 = nn.Softmax(dim=1) | ||
125 | + self.dense6_2 = nn.Linear(256, 4) | ||
126 | + self.dense6_3 = nn.Linear(256, 10) | ||
127 | + | ||
128 | + self.training = False | ||
129 | + | ||
130 | + if pretrained: | ||
131 | + state_dict_path = os.path.join(os.path.dirname(__file__), 'data/onet.pt') | ||
132 | + state_dict = torch.load(state_dict_path) | ||
133 | + self.load_state_dict(state_dict) | ||
134 | + | ||
135 | + def forward(self, x): | ||
136 | + x = self.conv1(x) | ||
137 | + x = self.prelu1(x) | ||
138 | + x = self.pool1(x) | ||
139 | + x = self.conv2(x) | ||
140 | + x = self.prelu2(x) | ||
141 | + x = self.pool2(x) | ||
142 | + x = self.conv3(x) | ||
143 | + x = self.prelu3(x) | ||
144 | + x = self.pool3(x) | ||
145 | + x = self.conv4(x) | ||
146 | + x = self.prelu4(x) | ||
147 | + x = x.permute(0, 3, 2, 1).contiguous() | ||
148 | + x = self.dense5(x.view(x.shape[0], -1)) | ||
149 | + x = self.prelu5(x) | ||
150 | + a = self.dense6_1(x) | ||
151 | + a = self.softmax6_1(a) | ||
152 | + b = self.dense6_2(x) | ||
153 | + c = self.dense6_3(x) | ||
154 | + return b, c, a | ||
155 | + | ||
156 | + | ||
157 | +class MTCNN(nn.Module): | ||
158 | + """MTCNN face detection module. | ||
159 | + | ||
160 | + This class loads pretrained P-, R-, and O-nets and returns images cropped to include the face | ||
161 | + only, given raw input images of one of the following types: | ||
162 | + - PIL image or list of PIL images | ||
163 | + - numpy.ndarray (uint8) representing either a single image (3D) or a batch of images (4D). | ||
164 | + Cropped faces can optionally be saved to file | ||
165 | + also. | ||
166 | + | ||
167 | + Keyword Arguments: | ||
168 | + image_size {int} -- Output image size in pixels. The image will be square. (default: {160}) | ||
169 | + margin {int} -- Margin to add to bounding box, in terms of pixels in the final image. | ||
170 | + Note that the application of the margin differs slightly from the davidsandberg/facenet | ||
171 | + repo, which applies the margin to the original image before resizing, making the margin | ||
172 | + dependent on the original image size (this is a bug in davidsandberg/facenet). | ||
173 | + (default: {0}) | ||
174 | + min_face_size {int} -- Minimum face size to search for. (default: {20}) | ||
175 | + thresholds {list} -- MTCNN face detection thresholds (default: {[0.6, 0.7, 0.7]}) | ||
176 | + factor {float} -- Factor used to create a scaling pyramid of face sizes. (default: {0.709}) | ||
177 | + post_process {bool} -- Whether or not to post process images tensors before returning. | ||
178 | + (default: {True}) | ||
179 | + select_largest {bool} -- If True, if multiple faces are detected, the largest is returned. | ||
180 | + If False, the face with the highest detection probability is returned. | ||
181 | + (default: {True}) | ||
182 | + keep_all {bool} -- If True, all detected faces are returned, in the order dictated by the | ||
183 | + select_largest parameter. If a save_path is specified, the first face is saved to that | ||
184 | + path and the remaining faces are saved to <save_path>1, <save_path>2 etc. | ||
185 | + device {torch.device} -- The device on which to run neural net passes. Image tensors and | ||
186 | + models are copied to this device before running forward passes. (default: {None}) | ||
187 | + """ | ||
188 | + | ||
189 | + def __init__( | ||
190 | + self, image_size=160, margin=0, min_face_size=20, | ||
191 | + thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, | ||
192 | + select_largest=True, keep_all=False, device=None | ||
193 | + ): | ||
194 | + super().__init__() | ||
195 | + | ||
196 | + self.image_size = image_size | ||
197 | + self.margin = margin | ||
198 | + self.min_face_size = min_face_size | ||
199 | + self.thresholds = thresholds | ||
200 | + self.factor = factor | ||
201 | + self.post_process = post_process | ||
202 | + self.select_largest = select_largest | ||
203 | + self.keep_all = keep_all | ||
204 | + | ||
205 | + self.pnet = PNet() | ||
206 | + self.rnet = RNet() | ||
207 | + self.onet = ONet() | ||
208 | + | ||
209 | + self.device = torch.device('cpu') | ||
210 | + if device is not None: | ||
211 | + self.device = device | ||
212 | + self.to(device) | ||
213 | + | ||
214 | + def forward(self, img, save_path=None, return_prob=False): | ||
215 | + """Run MTCNN face detection on a PIL image or numpy array. This method performs both | ||
216 | + detection and extraction of faces, returning tensors representing detected faces rather | ||
217 | + than the bounding boxes. To access bounding boxes, see the MTCNN.detect() method below. | ||
218 | + | ||
219 | + Arguments: | ||
220 | + img {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, or list. | ||
221 | + | ||
222 | + Keyword Arguments: | ||
223 | + save_path {str} -- An optional save path for the cropped image. Note that when | ||
224 | + self.post_process=True, although the returned tensor is post processed, the saved | ||
225 | + face image is not, so it is a true representation of the face in the input image. | ||
226 | + If `img` is a list of images, `save_path` should be a list of equal length. | ||
227 | + (default: {None}) | ||
228 | + return_prob {bool} -- Whether or not to return the detection probability. | ||
229 | + (default: {False}) | ||
230 | + | ||
231 | + Returns: | ||
232 | + Union[torch.Tensor, tuple(torch.tensor, float)] -- If detected, cropped image of a face | ||
233 | + with dimensions 3 x image_size x image_size. Optionally, the probability that a | ||
234 | + face was detected. If self.keep_all is True, n detected faces are returned in an | ||
235 | + n x 3 x image_size x image_size tensor with an optional list of detection | ||
236 | + probabilities. If `img` is a list of images, the item(s) returned have an extra | ||
237 | + dimension (batch) as the first dimension. | ||
238 | + | ||
239 | + Example: | ||
240 | + >>> from facenet_pytorch import MTCNN | ||
241 | + >>> mtcnn = MTCNN() | ||
242 | + >>> face_tensor, prob = mtcnn(img, save_path='face.png', return_prob=True) | ||
243 | + """ | ||
244 | + | ||
245 | + # Detect faces | ||
246 | + with torch.no_grad(): | ||
247 | + batch_boxes, batch_probs = self.detect(img) | ||
248 | + | ||
249 | + # Determine if a batch or single image was passed | ||
250 | + batch_mode = True | ||
251 | + if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4): | ||
252 | + img = [img] | ||
253 | + batch_boxes = [batch_boxes] | ||
254 | + batch_probs = [batch_probs] | ||
255 | + batch_mode = False | ||
256 | + | ||
257 | + # Parse save path(s) | ||
258 | + if save_path is not None: | ||
259 | + if isinstance(save_path, str): | ||
260 | + save_path = [save_path] | ||
261 | + else: | ||
262 | + save_path = [None for _ in range(len(img))] | ||
263 | + | ||
264 | + # Process all bounding boxes and probabilities | ||
265 | + faces, probs = [], [] | ||
266 | + for im, box_im, prob_im, path_im in zip(img, batch_boxes, batch_probs, save_path): | ||
267 | + if box_im is None: | ||
268 | + faces.append(None) | ||
269 | + probs.append([None] if self.keep_all else None) | ||
270 | + continue | ||
271 | + | ||
272 | + if not self.keep_all: | ||
273 | + box_im = box_im[[0]] | ||
274 | + | ||
275 | + faces_im = [] | ||
276 | + for i, box in enumerate(box_im): | ||
277 | + face_path = path_im | ||
278 | + if path_im is not None and i > 0: | ||
279 | + save_name, ext = os.path.splitext(path_im) | ||
280 | + face_path = save_name + '_' + str(i + 1) + ext | ||
281 | + | ||
282 | + face = extract_face(im, box, self.image_size, self.margin, face_path) | ||
283 | + if self.post_process: | ||
284 | + face = fixed_image_standardization(face) | ||
285 | + faces_im.append(face) | ||
286 | + | ||
287 | + if self.keep_all: | ||
288 | + faces_im = torch.stack(faces_im) | ||
289 | + else: | ||
290 | + faces_im = faces_im[0] | ||
291 | + prob_im = prob_im[0] | ||
292 | + | ||
293 | + faces.append(faces_im) | ||
294 | + probs.append(prob_im) | ||
295 | + | ||
296 | + if not batch_mode: | ||
297 | + faces = faces[0] | ||
298 | + probs = probs[0] | ||
299 | + | ||
300 | + if return_prob: | ||
301 | + return faces, probs | ||
302 | + else: | ||
303 | + return faces | ||
304 | + | ||
305 | + def detect(self, img, landmarks=False): | ||
306 | + """Detect all faces in PIL image and return bounding boxes and optional facial landmarks. | ||
307 | + | ||
308 | + This method is used by the forward method and is also useful for face detection tasks | ||
309 | + that require lower-level handling of bounding boxes and facial landmarks (e.g., face | ||
310 | + tracking). The functionality of the forward function can be emulated by using this method | ||
311 | + followed by the extract_face() function. | ||
312 | + | ||
313 | + Arguments: | ||
314 | + img {PIL.Image, np.ndarray, or list} -- A PIL image or a list of PIL images. | ||
315 | + | ||
316 | + Keyword Arguments: | ||
317 | + landmarks {bool} -- Whether to return facial landmarks in addition to bounding boxes. | ||
318 | + (default: {False}) | ||
319 | + | ||
320 | + Returns: | ||
321 | + tuple(numpy.ndarray, list) -- For N detected faces, a tuple containing an | ||
322 | + Nx4 array of bounding boxes and a length N list of detection probabilities. | ||
323 | + Returned boxes will be sorted in descending order by detection probability if | ||
324 | + self.select_largest=False, otherwise the largest face will be returned first. | ||
325 | + If `img` is a list of images, the items returned have an extra dimension | ||
326 | + (batch) as the first dimension. Optionally, a third item, the facial landmarks, | ||
327 | + are returned if `landmarks=True`. | ||
328 | + | ||
329 | + Example: | ||
330 | + >>> from PIL import Image, ImageDraw | ||
331 | + >>> from facenet_pytorch import MTCNN, extract_face | ||
332 | + >>> mtcnn = MTCNN(keep_all=True) | ||
333 | + >>> boxes, probs, points = mtcnn.detect(img, landmarks=True) | ||
334 | + >>> # Draw boxes and save faces | ||
335 | + >>> img_draw = img.copy() | ||
336 | + >>> draw = ImageDraw.Draw(img_draw) | ||
337 | + >>> for i, (box, point) in enumerate(zip(boxes, points)): | ||
338 | + ... draw.rectangle(box.tolist(), width=5) | ||
339 | + ... for p in point: | ||
340 | + ... draw.rectangle((p - 10).tolist() + (p + 10).tolist(), width=10) | ||
341 | + ... extract_face(img, box, save_path='detected_face_{}.png'.format(i)) | ||
342 | + >>> img_draw.save('annotated_faces.png') | ||
343 | + """ | ||
344 | + | ||
345 | + with torch.no_grad(): | ||
346 | + batch_boxes, batch_points = detect_face( | ||
347 | + img, self.min_face_size, | ||
348 | + self.pnet, self.rnet, self.onet, | ||
349 | + self.thresholds, self.factor, | ||
350 | + self.device | ||
351 | + ) | ||
352 | + | ||
353 | + boxes, probs, points = [], [], [] | ||
354 | + for box, point in zip(batch_boxes, batch_points): | ||
355 | + box = np.array(box) | ||
356 | + point = np.array(point) | ||
357 | + if len(box) == 0: | ||
358 | + boxes.append(None) | ||
359 | + probs.append([None]) | ||
360 | + points.append(None) | ||
361 | + elif self.select_largest: | ||
362 | + box_order = np.argsort((box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]))[::-1] | ||
363 | + box = box[box_order] | ||
364 | + point = point[box_order] | ||
365 | + boxes.append(box[:, :4]) | ||
366 | + probs.append(box[:, 4]) | ||
367 | + points.append(point) | ||
368 | + else: | ||
369 | + boxes.append(box[:, :4]) | ||
370 | + probs.append(box[:, 4]) | ||
371 | + points.append(point) | ||
372 | + boxes = np.array(boxes) | ||
373 | + probs = np.array(probs) | ||
374 | + points = np.array(points) | ||
375 | + | ||
376 | + if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4): | ||
377 | + boxes = boxes[0] | ||
378 | + probs = probs[0] | ||
379 | + points = points[0] | ||
380 | + | ||
381 | + if landmarks: | ||
382 | + return boxes, probs, points | ||
383 | + | ||
384 | + return boxes, probs | ||
385 | + | ||
386 | + | ||
387 | +def fixed_image_standardization(image_tensor): | ||
388 | + processed_tensor = (image_tensor - 127.5) / 128.0 | ||
389 | + return processed_tensor | ||
390 | + | ||
391 | +def prewhiten(x): | ||
392 | + mean = x.mean() | ||
393 | + std = x.std() | ||
394 | + std_adj = std.clamp(min=1.0/(float(x.numel())**0.5)) | ||
395 | + y = (x - mean) / std_adj | ||
396 | + return y |
No preview for this file type
client/models/utils/detect_face.py
0 → 100644
1 | +import torch | ||
2 | +from torch.nn.functional import interpolate | ||
3 | +from torchvision.transforms import functional as F | ||
4 | +from torchvision.ops.boxes import batched_nms | ||
5 | +import cv2 | ||
6 | +from PIL import Image | ||
7 | +import numpy as np | ||
8 | +import os | ||
9 | + | ||
10 | + | ||
11 | +def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device): | ||
12 | + if isinstance(imgs, (np.ndarray, torch.Tensor)): | ||
13 | + imgs = torch.as_tensor(imgs, device=device) | ||
14 | + if len(imgs.shape) == 3: | ||
15 | + imgs = imgs.unsqueeze(0) | ||
16 | + else: | ||
17 | + if not isinstance(imgs, (list, tuple)): | ||
18 | + imgs = [imgs] | ||
19 | + if any(img.size != imgs[0].size for img in imgs): | ||
20 | + raise Exception("MTCNN batch processing only compatible with equal-dimension images.") | ||
21 | + imgs = np.stack([np.uint8(img) for img in imgs]) | ||
22 | + | ||
23 | + imgs = torch.as_tensor(imgs, device=device) | ||
24 | + | ||
25 | + model_dtype = next(pnet.parameters()).dtype | ||
26 | + imgs = imgs.permute(0, 3, 1, 2).type(model_dtype) | ||
27 | + | ||
28 | + batch_size = len(imgs) | ||
29 | + h, w = imgs.shape[2:4] | ||
30 | + m = 12.0 / minsize | ||
31 | + minl = min(h, w) | ||
32 | + minl = minl * m | ||
33 | + | ||
34 | + # Create scale pyramid | ||
35 | + scale_i = m | ||
36 | + scales = [] | ||
37 | + while minl >= 12: | ||
38 | + scales.append(scale_i) | ||
39 | + scale_i = scale_i * factor | ||
40 | + minl = minl * factor | ||
41 | + | ||
42 | + # First stage | ||
43 | + boxes = [] | ||
44 | + image_inds = [] | ||
45 | + all_inds = [] | ||
46 | + all_i = 0 | ||
47 | + for scale in scales: | ||
48 | + im_data = imresample(imgs, (int(h * scale + 1), int(w * scale + 1))) | ||
49 | + im_data = (im_data - 127.5) * 0.0078125 | ||
50 | + reg, probs = pnet(im_data) | ||
51 | + | ||
52 | + boxes_scale, image_inds_scale = generateBoundingBox(reg, probs[:, 1], scale, threshold[0]) | ||
53 | + boxes.append(boxes_scale) | ||
54 | + image_inds.append(image_inds_scale) | ||
55 | + all_inds.append(all_i + image_inds_scale) | ||
56 | + all_i += batch_size | ||
57 | + | ||
58 | + boxes = torch.cat(boxes, dim=0) | ||
59 | + image_inds = torch.cat(image_inds, dim=0).cpu() | ||
60 | + all_inds = torch.cat(all_inds, dim=0) | ||
61 | + | ||
62 | + # NMS within each scale + image | ||
63 | + pick = batched_nms(boxes[:, :4], boxes[:, 4], all_inds, 0.5) | ||
64 | + boxes, image_inds = boxes[pick], image_inds[pick] | ||
65 | + | ||
66 | + # NMS within each image | ||
67 | + pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7) | ||
68 | + boxes, image_inds = boxes[pick], image_inds[pick] | ||
69 | + | ||
70 | + regw = boxes[:, 2] - boxes[:, 0] | ||
71 | + regh = boxes[:, 3] - boxes[:, 1] | ||
72 | + qq1 = boxes[:, 0] + boxes[:, 5] * regw | ||
73 | + qq2 = boxes[:, 1] + boxes[:, 6] * regh | ||
74 | + qq3 = boxes[:, 2] + boxes[:, 7] * regw | ||
75 | + qq4 = boxes[:, 3] + boxes[:, 8] * regh | ||
76 | + boxes = torch.stack([qq1, qq2, qq3, qq4, boxes[:, 4]]).permute(1, 0) | ||
77 | + boxes = rerec(boxes) | ||
78 | + y, ey, x, ex = pad(boxes, w, h) | ||
79 | + | ||
80 | + # Second stage | ||
81 | + if len(boxes) > 0: | ||
82 | + im_data = [] | ||
83 | + for k in range(len(y)): | ||
84 | + if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1): | ||
85 | + img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0) | ||
86 | + im_data.append(imresample(img_k, (24, 24))) | ||
87 | + im_data = torch.cat(im_data, dim=0) | ||
88 | + im_data = (im_data - 127.5) * 0.0078125 | ||
89 | + out = rnet(im_data) | ||
90 | + | ||
91 | + out0 = out[0].permute(1, 0) | ||
92 | + out1 = out[1].permute(1, 0) | ||
93 | + score = out1[1, :] | ||
94 | + ipass = score > threshold[1] | ||
95 | + boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1) | ||
96 | + image_inds = image_inds[ipass] | ||
97 | + mv = out0[:, ipass].permute(1, 0) | ||
98 | + | ||
99 | + # NMS within each image | ||
100 | + pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7) | ||
101 | + boxes, image_inds, mv = boxes[pick], image_inds[pick], mv[pick] | ||
102 | + boxes = bbreg(boxes, mv) | ||
103 | + boxes = rerec(boxes) | ||
104 | + | ||
105 | + # Third stage | ||
106 | + points = torch.zeros(0, 5, 2, device=device) | ||
107 | + if len(boxes) > 0: | ||
108 | + y, ey, x, ex = pad(boxes, w, h) | ||
109 | + im_data = [] | ||
110 | + for k in range(len(y)): | ||
111 | + if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1): | ||
112 | + img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0) | ||
113 | + im_data.append(imresample(img_k, (48, 48))) | ||
114 | + im_data = torch.cat(im_data, dim=0) | ||
115 | + im_data = (im_data - 127.5) * 0.0078125 | ||
116 | + out = onet(im_data) | ||
117 | + | ||
118 | + out0 = out[0].permute(1, 0) | ||
119 | + out1 = out[1].permute(1, 0) | ||
120 | + out2 = out[2].permute(1, 0) | ||
121 | + score = out2[1, :] | ||
122 | + points = out1 | ||
123 | + ipass = score > threshold[2] | ||
124 | + points = points[:, ipass] | ||
125 | + boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1) | ||
126 | + image_inds = image_inds[ipass] | ||
127 | + mv = out0[:, ipass].permute(1, 0) | ||
128 | + | ||
129 | + w_i = boxes[:, 2] - boxes[:, 0] + 1 | ||
130 | + h_i = boxes[:, 3] - boxes[:, 1] + 1 | ||
131 | + points_x = w_i.repeat(5, 1) * points[:5, :] + boxes[:, 0].repeat(5, 1) - 1 | ||
132 | + points_y = h_i.repeat(5, 1) * points[5:10, :] + boxes[:, 1].repeat(5, 1) - 1 | ||
133 | + points = torch.stack((points_x, points_y)).permute(2, 1, 0) | ||
134 | + boxes = bbreg(boxes, mv) | ||
135 | + | ||
136 | + # NMS within each image using "Min" strategy | ||
137 | + # pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7) | ||
138 | + pick = batched_nms_numpy(boxes[:, :4], boxes[:, 4], image_inds, 0.7, 'Min') | ||
139 | + boxes, image_inds, points = boxes[pick], image_inds[pick], points[pick] | ||
140 | + | ||
141 | + boxes = boxes.cpu().numpy() | ||
142 | + points = points.cpu().numpy() | ||
143 | + | ||
144 | + batch_boxes = [] | ||
145 | + batch_points = [] | ||
146 | + for b_i in range(batch_size): | ||
147 | + b_i_inds = np.where(image_inds == b_i) | ||
148 | + batch_boxes.append(boxes[b_i_inds].copy()) | ||
149 | + batch_points.append(points[b_i_inds].copy()) | ||
150 | + | ||
151 | + batch_boxes, batch_points = np.array(batch_boxes), np.array(batch_points) | ||
152 | + | ||
153 | + return batch_boxes, batch_points | ||
154 | + | ||
155 | + | ||
156 | +def bbreg(boundingbox, reg): | ||
157 | + if reg.shape[1] == 1: | ||
158 | + reg = torch.reshape(reg, (reg.shape[2], reg.shape[3])) | ||
159 | + | ||
160 | + w = boundingbox[:, 2] - boundingbox[:, 0] + 1 | ||
161 | + h = boundingbox[:, 3] - boundingbox[:, 1] + 1 | ||
162 | + b1 = boundingbox[:, 0] + reg[:, 0] * w | ||
163 | + b2 = boundingbox[:, 1] + reg[:, 1] * h | ||
164 | + b3 = boundingbox[:, 2] + reg[:, 2] * w | ||
165 | + b4 = boundingbox[:, 3] + reg[:, 3] * h | ||
166 | + boundingbox[:, :4] = torch.stack([b1, b2, b3, b4]).permute(1, 0) | ||
167 | + | ||
168 | + return boundingbox | ||
169 | + | ||
170 | + | ||
171 | +def generateBoundingBox(reg, probs, scale, thresh): | ||
172 | + stride = 2 | ||
173 | + cellsize = 12 | ||
174 | + | ||
175 | + reg = reg.permute(1, 0, 2, 3) | ||
176 | + | ||
177 | + mask = probs >= thresh | ||
178 | + mask_inds = mask.nonzero() | ||
179 | + image_inds = mask_inds[:, 0] | ||
180 | + score = probs[mask] | ||
181 | + reg = reg[:, mask].permute(1, 0) | ||
182 | + bb = mask_inds[:, 1:].type(reg.dtype).flip(1) | ||
183 | + q1 = ((stride * bb + 1) / scale).floor() | ||
184 | + q2 = ((stride * bb + cellsize - 1 + 1) / scale).floor() | ||
185 | + boundingbox = torch.cat([q1, q2, score.unsqueeze(1), reg], dim=1) | ||
186 | + return boundingbox, image_inds | ||
187 | + | ||
188 | + | ||
189 | +def nms_numpy(boxes, scores, threshold, method): | ||
190 | + if boxes.size == 0: | ||
191 | + return np.empty((0, 3)) | ||
192 | + | ||
193 | + x1 = boxes[:, 0].copy() | ||
194 | + y1 = boxes[:, 1].copy() | ||
195 | + x2 = boxes[:, 2].copy() | ||
196 | + y2 = boxes[:, 3].copy() | ||
197 | + s = scores | ||
198 | + area = (x2 - x1 + 1) * (y2 - y1 + 1) | ||
199 | + | ||
200 | + I = np.argsort(s) | ||
201 | + pick = np.zeros_like(s, dtype=np.int16) | ||
202 | + counter = 0 | ||
203 | + while I.size > 0: | ||
204 | + i = I[-1] | ||
205 | + pick[counter] = i | ||
206 | + counter += 1 | ||
207 | + idx = I[0:-1] | ||
208 | + | ||
209 | + xx1 = np.maximum(x1[i], x1[idx]).copy() | ||
210 | + yy1 = np.maximum(y1[i], y1[idx]).copy() | ||
211 | + xx2 = np.minimum(x2[i], x2[idx]).copy() | ||
212 | + yy2 = np.minimum(y2[i], y2[idx]).copy() | ||
213 | + | ||
214 | + w = np.maximum(0.0, xx2 - xx1 + 1).copy() | ||
215 | + h = np.maximum(0.0, yy2 - yy1 + 1).copy() | ||
216 | + | ||
217 | + inter = w * h | ||
218 | + if method is "Min": | ||
219 | + o = inter / np.minimum(area[i], area[idx]) | ||
220 | + else: | ||
221 | + o = inter / (area[i] + area[idx] - inter) | ||
222 | + I = I[np.where(o <= threshold)] | ||
223 | + | ||
224 | + pick = pick[:counter].copy() | ||
225 | + return pick | ||
226 | + | ||
227 | + | ||
228 | +def batched_nms_numpy(boxes, scores, idxs, threshold, method): | ||
229 | + device = boxes.device | ||
230 | + if boxes.numel() == 0: | ||
231 | + return torch.empty((0,), dtype=torch.int64, device=device) | ||
232 | + # strategy: in order to perform NMS independently per class. | ||
233 | + # we add an offset to all the boxes. The offset is dependent | ||
234 | + # only on the class idx, and is large enough so that boxes | ||
235 | + # from different classes do not overlap | ||
236 | + max_coordinate = boxes.max() | ||
237 | + offsets = idxs.to(boxes) * (max_coordinate + 1) | ||
238 | + boxes_for_nms = boxes + offsets[:, None] | ||
239 | + boxes_for_nms = boxes_for_nms.cpu().numpy() | ||
240 | + scores = scores.cpu().numpy() | ||
241 | + keep = nms_numpy(boxes_for_nms, scores, threshold, method) | ||
242 | + return torch.as_tensor(keep, dtype=torch.long, device=device) | ||
243 | + | ||
244 | + | ||
245 | +def pad(boxes, w, h): | ||
246 | + boxes = boxes.trunc().int().cpu().numpy() | ||
247 | + x = boxes[:, 0] | ||
248 | + y = boxes[:, 1] | ||
249 | + ex = boxes[:, 2] | ||
250 | + ey = boxes[:, 3] | ||
251 | + | ||
252 | + x[x < 1] = 1 | ||
253 | + y[y < 1] = 1 | ||
254 | + ex[ex > w] = w | ||
255 | + ey[ey > h] = h | ||
256 | + | ||
257 | + return y, ey, x, ex | ||
258 | + | ||
259 | + | ||
260 | +def rerec(bboxA): | ||
261 | + h = bboxA[:, 3] - bboxA[:, 1] | ||
262 | + w = bboxA[:, 2] - bboxA[:, 0] | ||
263 | + | ||
264 | + l = torch.max(w, h) | ||
265 | + bboxA[:, 0] = bboxA[:, 0] + w * 0.5 - l * 0.5 | ||
266 | + bboxA[:, 1] = bboxA[:, 1] + h * 0.5 - l * 0.5 | ||
267 | + bboxA[:, 2:4] = bboxA[:, :2] + l.repeat(2, 1).permute(1, 0) | ||
268 | + | ||
269 | + return bboxA | ||
270 | + | ||
271 | + | ||
272 | +def imresample(img, sz): | ||
273 | + im_data = interpolate(img, size=sz, mode="area") | ||
274 | + return im_data | ||
275 | + | ||
276 | + | ||
277 | +def crop_resize(img, box, image_size): | ||
278 | + if isinstance(img, np.ndarray): | ||
279 | + out = cv2.resize( | ||
280 | + img[box[1]:box[3], box[0]:box[2]], | ||
281 | + (image_size, image_size), | ||
282 | + interpolation=cv2.INTER_AREA | ||
283 | + ).copy() | ||
284 | + else: | ||
285 | + out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR) | ||
286 | + return out | ||
287 | + | ||
288 | + | ||
289 | +def save_img(img, path): | ||
290 | + if isinstance(img, np.ndarray): | ||
291 | + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) | ||
292 | + else: | ||
293 | + img.save(path) | ||
294 | + | ||
295 | + | ||
296 | +def get_size(img): | ||
297 | + if isinstance(img, np.ndarray): | ||
298 | + return img.shape[1::-1] | ||
299 | + else: | ||
300 | + return img.size | ||
301 | + | ||
302 | + | ||
303 | +def extract_face(img, box, image_size=160, margin=0, save_path=None): | ||
304 | + """Extract face + margin from PIL Image given bounding box. | ||
305 | + | ||
306 | + Arguments: | ||
307 | + img {PIL.Image} -- A PIL Image. | ||
308 | + box {numpy.ndarray} -- Four-element bounding box. | ||
309 | + image_size {int} -- Output image size in pixels. The image will be square. | ||
310 | + margin {int} -- Margin to add to bounding box, in terms of pixels in the final image. | ||
311 | + Note that the application of the margin differs slightly from the davidsandberg/facenet | ||
312 | + repo, which applies the margin to the original image before resizing, making the margin | ||
313 | + dependent on the original image size. | ||
314 | + save_path {str} -- Save path for extracted face image. (default: {None}) | ||
315 | + | ||
316 | + Returns: | ||
317 | + torch.tensor -- tensor representing the extracted face. | ||
318 | + """ | ||
319 | + margin = [ | ||
320 | + margin * (box[2] - box[0]) / (image_size - margin), | ||
321 | + margin * (box[3] - box[1]) / (image_size - margin), | ||
322 | + ] | ||
323 | + raw_image_size = get_size(img) | ||
324 | + box = [ | ||
325 | + int(max(box[0] - margin[0] / 2, 0)), | ||
326 | + int(max(box[1] - margin[1] / 2, 0)), | ||
327 | + int(min(box[2] + margin[0] / 2, raw_image_size[0])), | ||
328 | + int(min(box[3] + margin[1] / 2, raw_image_size[1])), | ||
329 | + ] | ||
330 | + | ||
331 | + face = crop_resize(img, box, image_size) | ||
332 | + | ||
333 | + if save_path is not None: | ||
334 | + os.makedirs(os.path.dirname(save_path) + "/", exist_ok=True) | ||
335 | + save_img(face, save_path) | ||
336 | + | ||
337 | + face = F.to_tensor(np.float32(face)) | ||
338 | + | ||
339 | + return face |
client/models/utils/tensorflow2pytorch.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import torch | ||
3 | +import json | ||
4 | +import os, sys | ||
5 | + | ||
6 | +from dependencies.facenet.src import facenet | ||
7 | +from dependencies.facenet.src.models import inception_resnet_v1 as tf_mdl | ||
8 | +from dependencies.facenet.src.align import detect_face | ||
9 | + | ||
10 | +from models.inception_resnet_v1 import InceptionResnetV1 | ||
11 | +from models.mtcnn import PNet, RNet, ONet | ||
12 | + | ||
13 | + | ||
14 | +def import_tf_params(tf_mdl_dir, sess): | ||
15 | + """Import tensorflow model from save directory. | ||
16 | + | ||
17 | + Arguments: | ||
18 | + tf_mdl_dir {str} -- Location of protobuf, checkpoint, meta files. | ||
19 | + sess {tensorflow.Session} -- Tensorflow session object. | ||
20 | + | ||
21 | + Returns: | ||
22 | + (list, list, list) -- Tuple of lists containing the layer names, | ||
23 | + parameter arrays as numpy ndarrays, parameter shapes. | ||
24 | + """ | ||
25 | + print('\nLoading tensorflow model\n') | ||
26 | + if callable(tf_mdl_dir): | ||
27 | + tf_mdl_dir(sess) | ||
28 | + else: | ||
29 | + facenet.load_model(tf_mdl_dir) | ||
30 | + | ||
31 | + print('\nGetting model weights\n') | ||
32 | + tf_layers = tf.trainable_variables() | ||
33 | + tf_params = sess.run(tf_layers) | ||
34 | + | ||
35 | + tf_shapes = [p.shape for p in tf_params] | ||
36 | + tf_layers = [l.name for l in tf_layers] | ||
37 | + | ||
38 | + if not callable(tf_mdl_dir): | ||
39 | + path = os.path.join(tf_mdl_dir, 'layer_description.json') | ||
40 | + else: | ||
41 | + path = 'data/layer_description.json' | ||
42 | + with open(path, 'w') as f: | ||
43 | + json.dump({l: s for l, s in zip(tf_layers, tf_shapes)}, f) | ||
44 | + | ||
45 | + return tf_layers, tf_params, tf_shapes | ||
46 | + | ||
47 | + | ||
48 | +def get_layer_indices(layer_lookup, tf_layers): | ||
49 | + """Giving a lookup of model layer attribute names and tensorflow variable names, | ||
50 | + find matching parameters. | ||
51 | + | ||
52 | + Arguments: | ||
53 | + layer_lookup {dict} -- Dictionary mapping pytorch attribute names to (partial) | ||
54 | + tensorflow variable names. Expects dict of the form {'attr': ['tf_name', ...]} | ||
55 | + where the '...'s are ignored. | ||
56 | + tf_layers {list} -- List of tensorflow variable names. | ||
57 | + | ||
58 | + Returns: | ||
59 | + list -- The input dictionary with the list of matching inds appended to each item. | ||
60 | + """ | ||
61 | + layer_inds = {} | ||
62 | + for name, value in layer_lookup.items(): | ||
63 | + layer_inds[name] = value + [[i for i, n in enumerate(tf_layers) if value[0] in n]] | ||
64 | + return layer_inds | ||
65 | + | ||
66 | + | ||
67 | +def load_tf_batchNorm(weights, layer): | ||
68 | + """Load tensorflow weights into nn.BatchNorm object. | ||
69 | + | ||
70 | + Arguments: | ||
71 | + weights {list} -- Tensorflow parameters. | ||
72 | + layer {torch.nn.Module} -- nn.BatchNorm. | ||
73 | + """ | ||
74 | + layer.bias.data = torch.tensor(weights[0]).view(layer.bias.data.shape) | ||
75 | + layer.weight.data = torch.ones_like(layer.weight.data) | ||
76 | + layer.running_mean = torch.tensor(weights[1]).view(layer.running_mean.shape) | ||
77 | + layer.running_var = torch.tensor(weights[2]).view(layer.running_var.shape) | ||
78 | + | ||
79 | + | ||
80 | +def load_tf_conv2d(weights, layer, transpose=False): | ||
81 | + """Load tensorflow weights into nn.Conv2d object. | ||
82 | + | ||
83 | + Arguments: | ||
84 | + weights {list} -- Tensorflow parameters. | ||
85 | + layer {torch.nn.Module} -- nn.Conv2d. | ||
86 | + """ | ||
87 | + if isinstance(weights, list): | ||
88 | + if len(weights) == 2: | ||
89 | + layer.bias.data = ( | ||
90 | + torch.tensor(weights[1]) | ||
91 | + .view(layer.bias.data.shape) | ||
92 | + ) | ||
93 | + weights = weights[0] | ||
94 | + | ||
95 | + if transpose: | ||
96 | + dim_order = (3, 2, 1, 0) | ||
97 | + else: | ||
98 | + dim_order = (3, 2, 0, 1) | ||
99 | + | ||
100 | + layer.weight.data = ( | ||
101 | + torch.tensor(weights) | ||
102 | + .permute(dim_order) | ||
103 | + .view(layer.weight.data.shape) | ||
104 | + ) | ||
105 | + | ||
106 | + | ||
107 | +def load_tf_conv2d_trans(weights, layer): | ||
108 | + return load_tf_conv2d(weights, layer, transpose=True) | ||
109 | + | ||
110 | + | ||
111 | +def load_tf_basicConv2d(weights, layer): | ||
112 | + """Load tensorflow weights into grouped Conv2d+BatchNorm object. | ||
113 | + | ||
114 | + Arguments: | ||
115 | + weights {list} -- Tensorflow parameters. | ||
116 | + layer {torch.nn.Module} -- Object containing Conv2d+BatchNorm. | ||
117 | + """ | ||
118 | + load_tf_conv2d(weights[0], layer.conv) | ||
119 | + load_tf_batchNorm(weights[1:], layer.bn) | ||
120 | + | ||
121 | + | ||
122 | +def load_tf_linear(weights, layer): | ||
123 | + """Load tensorflow weights into nn.Linear object. | ||
124 | + | ||
125 | + Arguments: | ||
126 | + weights {list} -- Tensorflow parameters. | ||
127 | + layer {torch.nn.Module} -- nn.Linear. | ||
128 | + """ | ||
129 | + if isinstance(weights, list): | ||
130 | + if len(weights) == 2: | ||
131 | + layer.bias.data = ( | ||
132 | + torch.tensor(weights[1]) | ||
133 | + .view(layer.bias.data.shape) | ||
134 | + ) | ||
135 | + weights = weights[0] | ||
136 | + layer.weight.data = ( | ||
137 | + torch.tensor(weights) | ||
138 | + .transpose(-1, 0) | ||
139 | + .view(layer.weight.data.shape) | ||
140 | + ) | ||
141 | + | ||
142 | + | ||
143 | +# High-level parameter-loading functions: | ||
144 | + | ||
145 | +def load_tf_block35(weights, layer): | ||
146 | + load_tf_basicConv2d(weights[:4], layer.branch0) | ||
147 | + load_tf_basicConv2d(weights[4:8], layer.branch1[0]) | ||
148 | + load_tf_basicConv2d(weights[8:12], layer.branch1[1]) | ||
149 | + load_tf_basicConv2d(weights[12:16], layer.branch2[0]) | ||
150 | + load_tf_basicConv2d(weights[16:20], layer.branch2[1]) | ||
151 | + load_tf_basicConv2d(weights[20:24], layer.branch2[2]) | ||
152 | + load_tf_conv2d(weights[24:26], layer.conv2d) | ||
153 | + | ||
154 | + | ||
155 | +def load_tf_block17_8(weights, layer): | ||
156 | + load_tf_basicConv2d(weights[:4], layer.branch0) | ||
157 | + load_tf_basicConv2d(weights[4:8], layer.branch1[0]) | ||
158 | + load_tf_basicConv2d(weights[8:12], layer.branch1[1]) | ||
159 | + load_tf_basicConv2d(weights[12:16], layer.branch1[2]) | ||
160 | + load_tf_conv2d(weights[16:18], layer.conv2d) | ||
161 | + | ||
162 | + | ||
163 | +def load_tf_mixed6a(weights, layer): | ||
164 | + if len(weights) != 16: | ||
165 | + raise ValueError(f'Number of weight arrays ({len(weights)}) not equal to 16') | ||
166 | + load_tf_basicConv2d(weights[:4], layer.branch0) | ||
167 | + load_tf_basicConv2d(weights[4:8], layer.branch1[0]) | ||
168 | + load_tf_basicConv2d(weights[8:12], layer.branch1[1]) | ||
169 | + load_tf_basicConv2d(weights[12:16], layer.branch1[2]) | ||
170 | + | ||
171 | + | ||
172 | +def load_tf_mixed7a(weights, layer): | ||
173 | + if len(weights) != 28: | ||
174 | + raise ValueError(f'Number of weight arrays ({len(weights)}) not equal to 28') | ||
175 | + load_tf_basicConv2d(weights[:4], layer.branch0[0]) | ||
176 | + load_tf_basicConv2d(weights[4:8], layer.branch0[1]) | ||
177 | + load_tf_basicConv2d(weights[8:12], layer.branch1[0]) | ||
178 | + load_tf_basicConv2d(weights[12:16], layer.branch1[1]) | ||
179 | + load_tf_basicConv2d(weights[16:20], layer.branch2[0]) | ||
180 | + load_tf_basicConv2d(weights[20:24], layer.branch2[1]) | ||
181 | + load_tf_basicConv2d(weights[24:28], layer.branch2[2]) | ||
182 | + | ||
183 | + | ||
184 | +def load_tf_repeats(weights, layer, rptlen, subfun): | ||
185 | + if len(weights) % rptlen != 0: | ||
186 | + raise ValueError(f'Number of weight arrays ({len(weights)}) not divisible by {rptlen}') | ||
187 | + weights_split = [weights[i:i+rptlen] for i in range(0, len(weights), rptlen)] | ||
188 | + for i, w in enumerate(weights_split): | ||
189 | + subfun(w, getattr(layer, str(i))) | ||
190 | + | ||
191 | + | ||
192 | +def load_tf_repeat_1(weights, layer): | ||
193 | + load_tf_repeats(weights, layer, 26, load_tf_block35) | ||
194 | + | ||
195 | + | ||
196 | +def load_tf_repeat_2(weights, layer): | ||
197 | + load_tf_repeats(weights, layer, 18, load_tf_block17_8) | ||
198 | + | ||
199 | + | ||
200 | +def load_tf_repeat_3(weights, layer): | ||
201 | + load_tf_repeats(weights, layer, 18, load_tf_block17_8) | ||
202 | + | ||
203 | + | ||
204 | +def test_loaded_params(mdl, tf_params, tf_layers): | ||
205 | + """Check each parameter in a pytorch model for an equivalent parameter | ||
206 | + in a list of tensorflow variables. | ||
207 | + | ||
208 | + Arguments: | ||
209 | + mdl {torch.nn.Module} -- Pytorch model. | ||
210 | + tf_params {list} -- List of ndarrays representing tensorflow variables. | ||
211 | + tf_layers {list} -- Corresponding list of tensorflow variable names. | ||
212 | + """ | ||
213 | + tf_means = torch.stack([torch.tensor(p).mean() for p in tf_params]) | ||
214 | + for name, param in mdl.named_parameters(): | ||
215 | + pt_mean = param.data.mean() | ||
216 | + matching_inds = ((tf_means - pt_mean).abs() < 1e-8).nonzero() | ||
217 | + print(f'{name} equivalent to {[tf_layers[i] for i in matching_inds]}') | ||
218 | + | ||
219 | + | ||
220 | +def compare_model_outputs(pt_mdl, sess, test_data): | ||
221 | + """Given some testing data, compare the output of pytorch and tensorflow models. | ||
222 | + | ||
223 | + Arguments: | ||
224 | + pt_mdl {torch.nn.Module} -- Pytorch model. | ||
225 | + sess {tensorflow.Session} -- Tensorflow session object. | ||
226 | + test_data {torch.Tensor} -- Pytorch tensor. | ||
227 | + """ | ||
228 | + print('\nPassing test data through TF model\n') | ||
229 | + if isinstance(sess, tf.Session): | ||
230 | + images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0") | ||
231 | + phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0") | ||
232 | + embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0") | ||
233 | + feed_dict = {images_placeholder: test_data.numpy(), phase_train_placeholder: False} | ||
234 | + tf_output = torch.tensor(sess.run(embeddings, feed_dict=feed_dict)) | ||
235 | + else: | ||
236 | + tf_output = sess(test_data) | ||
237 | + | ||
238 | + print(tf_output) | ||
239 | + | ||
240 | + print('\nPassing test data through PT model\n') | ||
241 | + pt_output = pt_mdl(test_data.permute(0, 3, 1, 2)) | ||
242 | + print(pt_output) | ||
243 | + | ||
244 | + distance = (tf_output - pt_output).norm() | ||
245 | + print(f'\nDistance {distance}\n') | ||
246 | + | ||
247 | + | ||
248 | +def compare_mtcnn(pt_mdl, tf_fun, sess, ind, test_data): | ||
249 | + tf_mdls = tf_fun(sess) | ||
250 | + tf_mdl = tf_mdls[ind] | ||
251 | + | ||
252 | + print('\nPassing test data through TF model\n') | ||
253 | + tf_output = tf_mdl(test_data.numpy()) | ||
254 | + tf_output = [torch.tensor(out) for out in tf_output] | ||
255 | + print('\n'.join([str(o.view(-1)[:10]) for o in tf_output])) | ||
256 | + | ||
257 | + print('\nPassing test data through PT model\n') | ||
258 | + with torch.no_grad(): | ||
259 | + pt_output = pt_mdl(test_data.permute(0, 3, 2, 1)) | ||
260 | + pt_output = [torch.tensor(out) for out in pt_output] | ||
261 | + for i in range(len(pt_output)): | ||
262 | + if len(pt_output[i].shape) == 4: | ||
263 | + pt_output[i] = pt_output[i].permute(0, 3, 2, 1).contiguous() | ||
264 | + print('\n'.join([str(o.view(-1)[:10]) for o in pt_output])) | ||
265 | + | ||
266 | + distance = [(tf_o - pt_o).norm() for tf_o, pt_o in zip(tf_output, pt_output)] | ||
267 | + print(f'\nDistance {distance}\n') | ||
268 | + | ||
269 | + | ||
270 | +def load_tf_model_weights(mdl, layer_lookup, tf_mdl_dir, is_resnet=True, arg_num=None): | ||
271 | + """Load tensorflow parameters into a pytorch model. | ||
272 | + | ||
273 | + Arguments: | ||
274 | + mdl {torch.nn.Module} -- Pytorch model. | ||
275 | + layer_lookup {[type]} -- Dictionary mapping pytorch attribute names to (partial) | ||
276 | + tensorflow variable names, and a function suitable for loading weights. | ||
277 | + Expects dict of the form {'attr': ['tf_name', function]}. | ||
278 | + tf_mdl_dir {str} -- Location of protobuf, checkpoint, meta files. | ||
279 | + """ | ||
280 | + tf.reset_default_graph() | ||
281 | + with tf.Session() as sess: | ||
282 | + tf_layers, tf_params, tf_shapes = import_tf_params(tf_mdl_dir, sess) | ||
283 | + layer_info = get_layer_indices(layer_lookup, tf_layers) | ||
284 | + | ||
285 | + for layer_name, info in layer_info.items(): | ||
286 | + print(f'Loading {info[0]}/* into {layer_name}') | ||
287 | + weights = [tf_params[i] for i in info[2]] | ||
288 | + layer = getattr(mdl, layer_name) | ||
289 | + info[1](weights, layer) | ||
290 | + | ||
291 | + test_loaded_params(mdl, tf_params, tf_layers) | ||
292 | + | ||
293 | + if is_resnet: | ||
294 | + compare_model_outputs(mdl, sess, torch.randn(5, 160, 160, 3).detach()) | ||
295 | + | ||
296 | + | ||
297 | +def tensorflow2pytorch(): | ||
298 | + lookup_inception_resnet_v1 = { | ||
299 | + 'conv2d_1a': ['InceptionResnetV1/Conv2d_1a_3x3', load_tf_basicConv2d], | ||
300 | + 'conv2d_2a': ['InceptionResnetV1/Conv2d_2a_3x3', load_tf_basicConv2d], | ||
301 | + 'conv2d_2b': ['InceptionResnetV1/Conv2d_2b_3x3', load_tf_basicConv2d], | ||
302 | + 'conv2d_3b': ['InceptionResnetV1/Conv2d_3b_1x1', load_tf_basicConv2d], | ||
303 | + 'conv2d_4a': ['InceptionResnetV1/Conv2d_4a_3x3', load_tf_basicConv2d], | ||
304 | + 'conv2d_4b': ['InceptionResnetV1/Conv2d_4b_3x3', load_tf_basicConv2d], | ||
305 | + 'repeat_1': ['InceptionResnetV1/Repeat/block35', load_tf_repeat_1], | ||
306 | + 'mixed_6a': ['InceptionResnetV1/Mixed_6a', load_tf_mixed6a], | ||
307 | + 'repeat_2': ['InceptionResnetV1/Repeat_1/block17', load_tf_repeat_2], | ||
308 | + 'mixed_7a': ['InceptionResnetV1/Mixed_7a', load_tf_mixed7a], | ||
309 | + 'repeat_3': ['InceptionResnetV1/Repeat_2/block8', load_tf_repeat_3], | ||
310 | + 'block8': ['InceptionResnetV1/Block8', load_tf_block17_8], | ||
311 | + 'last_linear': ['InceptionResnetV1/Bottleneck/weights', load_tf_linear], | ||
312 | + 'last_bn': ['InceptionResnetV1/Bottleneck/BatchNorm', load_tf_batchNorm], | ||
313 | + 'logits': ['Logits', load_tf_linear], | ||
314 | + } | ||
315 | + | ||
316 | + print('\nLoad VGGFace2-trained weights and save\n') | ||
317 | + mdl = InceptionResnetV1(num_classes=8631).eval() | ||
318 | + tf_mdl_dir = 'data/20180402-114759' | ||
319 | + data_name = 'vggface2' | ||
320 | + load_tf_model_weights(mdl, lookup_inception_resnet_v1, tf_mdl_dir) | ||
321 | + state_dict = mdl.state_dict() | ||
322 | + torch.save(state_dict, f'{tf_mdl_dir}-{data_name}.pt') | ||
323 | + torch.save( | ||
324 | + { | ||
325 | + 'logits.weight': state_dict['logits.weight'], | ||
326 | + 'logits.bias': state_dict['logits.bias'], | ||
327 | + }, | ||
328 | + f'{tf_mdl_dir}-{data_name}-logits.pt' | ||
329 | + ) | ||
330 | + state_dict.pop('logits.weight') | ||
331 | + state_dict.pop('logits.bias') | ||
332 | + torch.save(state_dict, f'{tf_mdl_dir}-{data_name}-features.pt') | ||
333 | + | ||
334 | + print('\nLoad CASIA-Webface-trained weights and save\n') | ||
335 | + mdl = InceptionResnetV1(num_classes=10575).eval() | ||
336 | + tf_mdl_dir = 'data/20180408-102900' | ||
337 | + data_name = 'casia-webface' | ||
338 | + load_tf_model_weights(mdl, lookup_inception_resnet_v1, tf_mdl_dir) | ||
339 | + state_dict = mdl.state_dict() | ||
340 | + torch.save(state_dict, f'{tf_mdl_dir}-{data_name}.pt') | ||
341 | + torch.save( | ||
342 | + { | ||
343 | + 'logits.weight': state_dict['logits.weight'], | ||
344 | + 'logits.bias': state_dict['logits.bias'], | ||
345 | + }, | ||
346 | + f'{tf_mdl_dir}-{data_name}-logits.pt' | ||
347 | + ) | ||
348 | + state_dict.pop('logits.weight') | ||
349 | + state_dict.pop('logits.bias') | ||
350 | + torch.save(state_dict, f'{tf_mdl_dir}-{data_name}-features.pt') | ||
351 | + | ||
352 | + lookup_pnet = { | ||
353 | + 'conv1': ['pnet/conv1', load_tf_conv2d_trans], | ||
354 | + 'prelu1': ['pnet/PReLU1', load_tf_linear], | ||
355 | + 'conv2': ['pnet/conv2', load_tf_conv2d_trans], | ||
356 | + 'prelu2': ['pnet/PReLU2', load_tf_linear], | ||
357 | + 'conv3': ['pnet/conv3', load_tf_conv2d_trans], | ||
358 | + 'prelu3': ['pnet/PReLU3', load_tf_linear], | ||
359 | + 'conv4_1': ['pnet/conv4-1', load_tf_conv2d_trans], | ||
360 | + 'conv4_2': ['pnet/conv4-2', load_tf_conv2d_trans], | ||
361 | + } | ||
362 | + lookup_rnet = { | ||
363 | + 'conv1': ['rnet/conv1', load_tf_conv2d_trans], | ||
364 | + 'prelu1': ['rnet/prelu1', load_tf_linear], | ||
365 | + 'conv2': ['rnet/conv2', load_tf_conv2d_trans], | ||
366 | + 'prelu2': ['rnet/prelu2', load_tf_linear], | ||
367 | + 'conv3': ['rnet/conv3', load_tf_conv2d_trans], | ||
368 | + 'prelu3': ['rnet/prelu3', load_tf_linear], | ||
369 | + 'dense4': ['rnet/conv4', load_tf_linear], | ||
370 | + 'prelu4': ['rnet/prelu4', load_tf_linear], | ||
371 | + 'dense5_1': ['rnet/conv5-1', load_tf_linear], | ||
372 | + 'dense5_2': ['rnet/conv5-2', load_tf_linear], | ||
373 | + } | ||
374 | + lookup_onet = { | ||
375 | + 'conv1': ['onet/conv1', load_tf_conv2d_trans], | ||
376 | + 'prelu1': ['onet/prelu1', load_tf_linear], | ||
377 | + 'conv2': ['onet/conv2', load_tf_conv2d_trans], | ||
378 | + 'prelu2': ['onet/prelu2', load_tf_linear], | ||
379 | + 'conv3': ['onet/conv3', load_tf_conv2d_trans], | ||
380 | + 'prelu3': ['onet/prelu3', load_tf_linear], | ||
381 | + 'conv4': ['onet/conv4', load_tf_conv2d_trans], | ||
382 | + 'prelu4': ['onet/prelu4', load_tf_linear], | ||
383 | + 'dense5': ['onet/conv5', load_tf_linear], | ||
384 | + 'prelu5': ['onet/prelu5', load_tf_linear], | ||
385 | + 'dense6_1': ['onet/conv6-1', load_tf_linear], | ||
386 | + 'dense6_2': ['onet/conv6-2', load_tf_linear], | ||
387 | + 'dense6_3': ['onet/conv6-3', load_tf_linear], | ||
388 | + } | ||
389 | + | ||
390 | + print('\nLoad PNet weights and save\n') | ||
391 | + tf_mdl_dir = lambda sess: detect_face.create_mtcnn(sess, None) | ||
392 | + mdl = PNet() | ||
393 | + data_name = 'pnet' | ||
394 | + load_tf_model_weights(mdl, lookup_pnet, tf_mdl_dir, is_resnet=False, arg_num=0) | ||
395 | + torch.save(mdl.state_dict(), f'data/{data_name}.pt') | ||
396 | + tf.reset_default_graph() | ||
397 | + with tf.Session() as sess: | ||
398 | + compare_mtcnn(mdl, tf_mdl_dir, sess, 0, torch.randn(1, 256, 256, 3).detach()) | ||
399 | + | ||
400 | + print('\nLoad RNet weights and save\n') | ||
401 | + mdl = RNet() | ||
402 | + data_name = 'rnet' | ||
403 | + load_tf_model_weights(mdl, lookup_rnet, tf_mdl_dir, is_resnet=False, arg_num=1) | ||
404 | + torch.save(mdl.state_dict(), f'data/{data_name}.pt') | ||
405 | + tf.reset_default_graph() | ||
406 | + with tf.Session() as sess: | ||
407 | + compare_mtcnn(mdl, tf_mdl_dir, sess, 1, torch.randn(1, 24, 24, 3).detach()) | ||
408 | + | ||
409 | + print('\nLoad ONet weights and save\n') | ||
410 | + mdl = ONet() | ||
411 | + data_name = 'onet' | ||
412 | + load_tf_model_weights(mdl, lookup_onet, tf_mdl_dir, is_resnet=False, arg_num=2) | ||
413 | + torch.save(mdl.state_dict(), f'data/{data_name}.pt') | ||
414 | + tf.reset_default_graph() | ||
415 | + with tf.Session() as sess: | ||
416 | + compare_mtcnn(mdl, tf_mdl_dir, sess, 2, torch.randn(1, 48, 48, 3).detach()) |
client/models/utils/training.py
0 → 100644
1 | +import torch | ||
2 | +import numpy as np | ||
3 | +import time | ||
4 | + | ||
5 | + | ||
6 | +class Logger(object): | ||
7 | + | ||
8 | + def __init__(self, mode, length, calculate_mean=False): | ||
9 | + self.mode = mode | ||
10 | + self.length = length | ||
11 | + self.calculate_mean = calculate_mean | ||
12 | + if self.calculate_mean: | ||
13 | + self.fn = lambda x, i: x / (i + 1) | ||
14 | + else: | ||
15 | + self.fn = lambda x, i: x | ||
16 | + | ||
17 | + def __call__(self, loss, metrics, i): | ||
18 | + track_str = '\r{} | {:5d}/{:<5d}| '.format(self.mode, i + 1, self.length) | ||
19 | + loss_str = 'loss: {:9.4f} | '.format(self.fn(loss, i)) | ||
20 | + metric_str = ' | '.join('{}: {:9.4f}'.format(k, self.fn(v, i)) for k, v in metrics.items()) | ||
21 | + print(track_str + loss_str + metric_str + ' ', end='') | ||
22 | + if i + 1 == self.length: | ||
23 | + print('') | ||
24 | + | ||
25 | + | ||
26 | +class BatchTimer(object): | ||
27 | + """Batch timing class. | ||
28 | + Use this class for tracking training and testing time/rate per batch or per sample. | ||
29 | + | ||
30 | + Keyword Arguments: | ||
31 | + rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds | ||
32 | + per batch or sample). (default: {True}) | ||
33 | + per_sample {bool} -- Whether to report times or rates per sample or per batch. | ||
34 | + (default: {True}) | ||
35 | + """ | ||
36 | + | ||
37 | + def __init__(self, rate=True, per_sample=True): | ||
38 | + self.start = time.time() | ||
39 | + self.end = None | ||
40 | + self.rate = rate | ||
41 | + self.per_sample = per_sample | ||
42 | + | ||
43 | + def __call__(self, y_pred, y): | ||
44 | + self.end = time.time() | ||
45 | + elapsed = self.end - self.start | ||
46 | + self.start = self.end | ||
47 | + self.end = None | ||
48 | + | ||
49 | + if self.per_sample: | ||
50 | + elapsed /= len(y_pred) | ||
51 | + if self.rate: | ||
52 | + elapsed = 1 / elapsed | ||
53 | + | ||
54 | + return torch.tensor(elapsed) | ||
55 | + | ||
56 | + | ||
57 | +def accuracy(logits, y): | ||
58 | + _, preds = torch.max(logits, 1) | ||
59 | + return (preds == y).float().mean() | ||
60 | + | ||
61 | + | ||
62 | +def pass_epoch( | ||
63 | + model, loss_fn, loader, optimizer=None, scheduler=None, | ||
64 | + batch_metrics={'time': BatchTimer()}, show_running=True, | ||
65 | + device='cpu', writer=None | ||
66 | +): | ||
67 | + """Train or evaluate over a data epoch. | ||
68 | + | ||
69 | + Arguments: | ||
70 | + model {torch.nn.Module} -- Pytorch model. | ||
71 | + loss_fn {callable} -- A function to compute (scalar) loss. | ||
72 | + loader {torch.utils.data.DataLoader} -- A pytorch data loader. | ||
73 | + | ||
74 | + Keyword Arguments: | ||
75 | + optimizer {torch.optim.Optimizer} -- A pytorch optimizer. | ||
76 | + scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None}) | ||
77 | + batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default | ||
78 | + is a simple timer. A progressive average of these metrics, along with the average | ||
79 | + loss, is printed every batch. (default: {{'time': iter_timer()}}) | ||
80 | + show_running {bool} -- Whether or not to print losses and metrics for the current batch | ||
81 | + or rolling averages. (default: {False}) | ||
82 | + device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'}) | ||
83 | + writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None}) | ||
84 | + | ||
85 | + Returns: | ||
86 | + tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average | ||
87 | + metric values across the epoch. | ||
88 | + """ | ||
89 | + | ||
90 | + mode = 'Train' if model.training else 'Valid' | ||
91 | + logger = Logger(mode, length=len(loader), calculate_mean=show_running) | ||
92 | + loss = 0 | ||
93 | + metrics = {} | ||
94 | + | ||
95 | + for i_batch, (x, y) in enumerate(loader): | ||
96 | + x = x.to(device) | ||
97 | + y = y.to(device) | ||
98 | + y_pred = model(x) | ||
99 | + loss_batch = loss_fn(y_pred, y) | ||
100 | + | ||
101 | + if model.training: | ||
102 | + loss_batch.backward() | ||
103 | + optimizer.step() | ||
104 | + optimizer.zero_grad() | ||
105 | + | ||
106 | + metrics_batch = {} | ||
107 | + for metric_name, metric_fn in batch_metrics.items(): | ||
108 | + metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu() | ||
109 | + metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name] | ||
110 | + | ||
111 | + if writer is not None and model.training: | ||
112 | + if writer.iteration % writer.interval == 0: | ||
113 | + writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration) | ||
114 | + for metric_name, metric_batch in metrics_batch.items(): | ||
115 | + writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration) | ||
116 | + writer.iteration += 1 | ||
117 | + | ||
118 | + loss_batch = loss_batch.detach().cpu() | ||
119 | + loss += loss_batch | ||
120 | + if show_running: | ||
121 | + logger(loss, metrics, i_batch) | ||
122 | + else: | ||
123 | + logger(loss_batch, metrics_batch, i_batch) | ||
124 | + | ||
125 | + if model.training and scheduler is not None: | ||
126 | + scheduler.step() | ||
127 | + | ||
128 | + loss = loss / (i_batch + 1) | ||
129 | + metrics = {k: v / (i_batch + 1) for k, v in metrics.items()} | ||
130 | + | ||
131 | + if writer is not None and not model.training: | ||
132 | + writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration) | ||
133 | + for metric_name, metric in metrics.items(): | ||
134 | + writer.add_scalars(metric_name, {mode: metric}) | ||
135 | + | ||
136 | + return loss, metrics | ||
137 | + | ||
138 | + | ||
139 | +def collate_pil(x): | ||
140 | + out_x, out_y = [], [] | ||
141 | + for xx, yy in x: | ||
142 | + out_x.append(xx) | ||
143 | + out_y.append(yy) | ||
144 | + return out_x, out_y |
-
Please register or login to post a comment