Showing
17 changed files
with
569 additions
and
0 deletions
client/client(window).py
0 → 100644
1 | +################################################## | ||
2 | +#1. webcam에서 얼굴을 인식합니다. | ||
3 | +#2. 얼굴일 확률이 97% 이상이고 영역이 15000 이상인 이미지를 서버에 전송 | ||
4 | +################################################## | ||
5 | +import torch | ||
6 | +import numpy as np | ||
7 | +import cv2 | ||
8 | +import asyncio | ||
9 | +import websockets | ||
10 | +import json | ||
11 | +import os | ||
12 | +import timeit | ||
13 | +import base64 | ||
14 | +import time | ||
15 | + | ||
16 | +from PIL import Image | ||
17 | +from io import BytesIO | ||
18 | +import requests | ||
19 | + | ||
20 | +from models.mtcnn import MTCNN | ||
21 | + | ||
22 | +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
23 | +print('Running on device: {}'.format(device)) | ||
24 | + | ||
25 | +mtcnn = MTCNN(keep_all=True, post_process=True, device=device) | ||
26 | + | ||
27 | +uri = 'ws://169.56.95.131:8765' | ||
28 | + | ||
29 | +async def send_face(face_list, image_list): | ||
30 | + async with websockets.connect(uri) as websocket: | ||
31 | + for face, image in zip(face_list, image_list): | ||
32 | + #type: np.float32 | ||
33 | + send = json.dumps({'action': 'verify', 'MTCNN': face.tolist()}) | ||
34 | + await websocket.send(send) | ||
35 | + recv = await websocket.recv() | ||
36 | + data = json.loads(recv) | ||
37 | + if data['status'] == 'success': | ||
38 | + # 성공 | ||
39 | + print(data['student_id'], 'is attend') | ||
40 | + else: | ||
41 | + print('verification failed:', data['status']) | ||
42 | + if data['status'] == 'failed': | ||
43 | + send = json.dumps({'action': 'save_image', 'image': image.tolist()}) | ||
44 | + | ||
45 | +def detect_face(frame): | ||
46 | + results = mtcnn.detect(frame) | ||
47 | + faces = mtcnn(frame, return_prob = False) | ||
48 | + image_list = [] | ||
49 | + face_list = [] | ||
50 | + if results[1][0] == None: | ||
51 | + return [], [] | ||
52 | + for box, face, prob in zip(results[0], faces, results[1]): | ||
53 | + if prob < 0.97: | ||
54 | + continue | ||
55 | + print('face detected. prob:', prob) | ||
56 | + x1, y1, x2, y2 = box | ||
57 | + if (x2-x1) * (y2-y1) < 15000: | ||
58 | + # 얼굴 해상도가 너무 낮으면 무시 | ||
59 | + continue | ||
60 | + # 얼굴 주변 ±3 영역 저장 | ||
61 | + image = frame[int(y1-3):int(y2+3), int(x1-3):int(x2+3)] | ||
62 | + image_list.append(image) | ||
63 | + # MTCNN 데이터 저장 | ||
64 | + face_list.append(face.numpy()) | ||
65 | + return image_list, face_list | ||
66 | + | ||
67 | +def make_face_list(frame): | ||
68 | + results, prob = mtcnn(frame, return_prob = True) | ||
69 | + face_list = [] | ||
70 | + if prob[0] == None: | ||
71 | + return [] | ||
72 | + for result, prob in zip(results, prob): | ||
73 | + if prob < 0.97: | ||
74 | + continue | ||
75 | + #np.float32 | ||
76 | + face_list.append(result.numpy()) | ||
77 | + return face_list | ||
78 | + | ||
79 | +if __name__ == '__main__': | ||
80 | + cap = cv2.VideoCapture(0, cv2.CAP_DSHOW) | ||
81 | + cap.set(3, 720) | ||
82 | + cap.set(4, 480) | ||
83 | + #cv2.namedWindow("img", cv2.WINDOW_NORMAL) | ||
84 | + while True: | ||
85 | + try: | ||
86 | + ret, frame = cap.read() | ||
87 | + #cv2.imshow('img', frame) | ||
88 | + #cv2.waitKey(10) | ||
89 | + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
90 | + image_list, face_list = detect_face(frame) | ||
91 | + if not face_list: | ||
92 | + continue; | ||
93 | + asyncio.get_event_loop().run_until_complete(send_face(face_list, image_list)) | ||
94 | + time.sleep(1) | ||
95 | + except Exception as ex: | ||
96 | + print(ex) |
client/legacy/clinet(window)-06042035.py
0 → 100644
1 | +################################################## | ||
2 | +#1. webcam에서 얼굴을 인식합니다. # | ||
3 | +#2. 얼굴일 확률이 95% 이상인 이미지를 이미지 서버로 전송합니다. # | ||
4 | +#3. 전처리 된 데이터를 verification 서버에 전송합니다. # | ||
5 | +################################################## | ||
6 | +import torch | ||
7 | +import numpy as np | ||
8 | +import cv2 | ||
9 | +import asyncio | ||
10 | +import websockets | ||
11 | +import json | ||
12 | +import os | ||
13 | +import timeit | ||
14 | +import base64 | ||
15 | + | ||
16 | +from PIL import Image | ||
17 | +from io import BytesIO | ||
18 | +import requests | ||
19 | + | ||
20 | +from models.mtcnn import MTCNN | ||
21 | + | ||
22 | +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
23 | +print('Running on device: {}'.format(device)) | ||
24 | + | ||
25 | +mtcnn = MTCNN(keep_all=True, device=device) | ||
26 | + | ||
27 | +uri = 'ws://localhost:8765' | ||
28 | + | ||
29 | +async def send_face(face_list, image_list): | ||
30 | + global uri | ||
31 | + async with websockets.connect(uri) as websocket: | ||
32 | + for face, image in zip(face_list, image_list): | ||
33 | + #type: np.float32 | ||
34 | + send = json.dumps({"action": "verify", "MTCNN": face.tolist()}) | ||
35 | + await websocket.send(send) | ||
36 | + recv = await websocket.recv() | ||
37 | + data = json.loads(recv) | ||
38 | + if data['status'] == 'success': | ||
39 | + # 성공 | ||
40 | + print(data['id'], 'is attend') | ||
41 | + else: | ||
42 | + print('verification failed') | ||
43 | + send = json.dumps({'action': 'save_image', 'image': image.tolist(), 'shape': image.shape}) | ||
44 | + await websocket.send(send) | ||
45 | + | ||
46 | +async def send_image(image_list): | ||
47 | + global uri | ||
48 | + async with websockets.connect(uri) as websocket: | ||
49 | + for image in image_list: | ||
50 | + data = json.dumps({'action': 'save_image', 'image': image.tolist(), 'shape': image.shape}) | ||
51 | + await websocket.send(data) | ||
52 | + print('send', len(image_list), 'image(s)') | ||
53 | + code = await websocket.recv() | ||
54 | + print('code:', code) | ||
55 | + | ||
56 | +def detect_face(frame): | ||
57 | + # If required, create a face detection pipeline using MTCNN: | ||
58 | + global mtcnn | ||
59 | + results = mtcnn.detect(frame) | ||
60 | + image_list = [] | ||
61 | + if results[1][0] == None: | ||
62 | + return [] | ||
63 | + for box, prob in zip(results[0], results[1]): | ||
64 | + if prob < 0.95: | ||
65 | + continue | ||
66 | + print('face detected. prob:', prob) | ||
67 | + x1, y1, x2, y2 = box | ||
68 | + image = frame[int(y1-10):int(y2+10), int(x1-10):int(x2+10)] | ||
69 | + image_list.append(image) | ||
70 | + return image_list | ||
71 | + | ||
72 | +def make_face_list(frame): | ||
73 | + global mtcnn | ||
74 | + results, prob = mtcnn(frame, return_prob = True) | ||
75 | + face_list = [] | ||
76 | + if prob[0] == None: | ||
77 | + return [] | ||
78 | + for result, prob in zip(results, prob): | ||
79 | + if prob < 0.95: | ||
80 | + continue | ||
81 | + #np.float32 | ||
82 | + face_list.append(result.numpy()) | ||
83 | + return face_list | ||
84 | + | ||
85 | +cap = cv2.VideoCapture(0) | ||
86 | +cap.set(3, 720) | ||
87 | +cap.set(4, 480) | ||
88 | +while True: | ||
89 | + try: | ||
90 | + #start = timeit.default_timer() | ||
91 | + ret, frame = cap.read() | ||
92 | + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
93 | + face_list = make_face_list(frame) | ||
94 | + image_list = detect_face(frame) | ||
95 | + ##embedding server로 전송## | ||
96 | + if face_list: | ||
97 | + asyncio.get_event_loop().run_until_complete(send_face(face_list, image_list)) | ||
98 | + ################### | ||
99 | + ##image server로 전송## | ||
100 | + #if image_list: | ||
101 | + #asyncio.get_event_loop().run_until_complete(send_image(image_list)) | ||
102 | + ################### | ||
103 | + #end = timeit.default_timer() | ||
104 | + #print('delta time: ', end - start) | ||
105 | + except Exception as ex: | ||
106 | + print(ex) |
client/legacy/clinet(window)0605.py
0 → 100644
1 | +################################################## | ||
2 | +#1. webcam에서 얼굴을 인식합니다. # | ||
3 | +#2. 얼굴일 확률이 95% 이상인 이미지를 이미지 서버로 전송합니다. # | ||
4 | +#3. 전처리 된 데이터를 verification 서버에 전송합니다. # | ||
5 | +################################################## | ||
6 | +import torch | ||
7 | +import numpy as np | ||
8 | +import cv2 | ||
9 | +import asyncio | ||
10 | +import websockets | ||
11 | +import json | ||
12 | +import os | ||
13 | +import timeit | ||
14 | +import base64 | ||
15 | + | ||
16 | +from PIL import Image | ||
17 | +from io import BytesIO | ||
18 | +import requests | ||
19 | + | ||
20 | +from models.mtcnn import MTCNN | ||
21 | + | ||
22 | +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
23 | +print('Running on device: {}'.format(device)) | ||
24 | + | ||
25 | +mtcnn = MTCNN(keep_all=True, device=device) | ||
26 | + | ||
27 | +uri = 'ws://localhost:8765' | ||
28 | + | ||
29 | +async def send_face(face_list, image_list): | ||
30 | + global uri | ||
31 | + async with websockets.connect(uri) as websocket: | ||
32 | + for face, image in zip(face_list, image_list): | ||
33 | + #type: np.float32 | ||
34 | + send = json.dumps({"action": "verify", "MTCNN": face.tolist()}) | ||
35 | + await websocket.send(send) | ||
36 | + recv = await websocket.recv() | ||
37 | + data = json.loads(recv) | ||
38 | + if data['status'] == 'success': | ||
39 | + # 성공 | ||
40 | + print(data['student_id'], 'is attend') | ||
41 | + else: | ||
42 | + print('verification failed') | ||
43 | + send = json.dumps({'action': 'save_image', 'image': image.tolist()}) | ||
44 | + await websocket.send(send) | ||
45 | + | ||
46 | +def detect_face(frame): | ||
47 | + # If required, create a face detection pipeline using MTCNN: | ||
48 | + global mtcnn | ||
49 | + results = mtcnn.detect(frame) | ||
50 | + image_list = [] | ||
51 | + if results[1][0] == None: | ||
52 | + return [] | ||
53 | + for box, prob in zip(results[0], results[1]): | ||
54 | + if prob < 0.95: | ||
55 | + continue | ||
56 | + print('face detected. prob:', prob) | ||
57 | + x1, y1, x2, y2 = box | ||
58 | + image = frame[int(y1-10):int(y2+10), int(x1-10):int(x2+10)] | ||
59 | + image_list.append(image) | ||
60 | + return image_list | ||
61 | + | ||
62 | +def make_face_list(frame): | ||
63 | + global mtcnn | ||
64 | + results, prob = mtcnn(frame, return_prob = True) | ||
65 | + face_list = [] | ||
66 | + if prob[0] == None: | ||
67 | + return [] | ||
68 | + for result, prob in zip(results, prob): | ||
69 | + if prob < 0.95: | ||
70 | + continue | ||
71 | + #np.float32 | ||
72 | + face_list.append(result.numpy()) | ||
73 | + return face_list | ||
74 | + | ||
75 | +cap = cv2.VideoCapture(0) | ||
76 | +cap.set(3, 720) | ||
77 | +cap.set(4, 480) | ||
78 | +while True: | ||
79 | + try: | ||
80 | + #start = timeit.default_timer() | ||
81 | + ret, frame = cap.read() | ||
82 | + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
83 | + face_list = make_face_list(frame) | ||
84 | + image_list = detect_face(frame) | ||
85 | + ##embedding server로 전송## | ||
86 | + if face_list: | ||
87 | + asyncio.get_event_loop().run_until_complete(send_face(face_list, image_list)) | ||
88 | + #end = timeit.default_timer() | ||
89 | + #print('delta time: ', end - start) | ||
90 | + except Exception as ex: | ||
91 | + print(ex) |
client/legacy/clinet(window)200605.py
0 → 100644
1 | +################################################## | ||
2 | +#1. webcam에서 얼굴을 인식합니다. | ||
3 | +#2. 얼굴일 확률이 97% 이상인 이미지를 이미지 서버로 전송합니다 | ||
4 | +#3. 전처리 된 데이터를 verification 서버에 전송합니다. | ||
5 | +################################################## | ||
6 | +import torch | ||
7 | +import numpy as np | ||
8 | +import cv2 | ||
9 | +import asyncio | ||
10 | +import websockets | ||
11 | +import json | ||
12 | +import os | ||
13 | +import timeit | ||
14 | +import base64 | ||
15 | + | ||
16 | +from PIL import Image | ||
17 | +from io import BytesIO | ||
18 | +import requests | ||
19 | + | ||
20 | +from models.mtcnn import MTCNN | ||
21 | + | ||
22 | +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
23 | +print('Running on device: {}'.format(device)) | ||
24 | + | ||
25 | +mtcnn = MTCNN(keep_all=True, device=device) | ||
26 | + | ||
27 | +uri = 'ws://localhost:8765' | ||
28 | + | ||
29 | +async def send_face(face_list, image_list): | ||
30 | + async with websockets.connect(uri) as websocket: | ||
31 | + for face, image in zip(face_list, image_list): | ||
32 | + #type: np.float32 | ||
33 | + send = json.dumps({'action': 'verify', 'image': image.tolist(), 'MTCNN': face.tolist()}) | ||
34 | + await websocket.send(send) | ||
35 | + recv = await websocket.recv() | ||
36 | + data = json.loads(recv) | ||
37 | + if data['status'] == 'success': | ||
38 | + # 성공 | ||
39 | + print(data['student_id'], 'is attend') | ||
40 | + elif data['status'] == 'failed': | ||
41 | + print('verification failed:', data['status']) | ||
42 | + | ||
43 | +def detect_face(frame): | ||
44 | + results = mtcnn.detect(frame) | ||
45 | + image_list = [] | ||
46 | + if results[1][0] == None: | ||
47 | + return [] | ||
48 | + for box, prob in zip(results[0], results[1]): | ||
49 | + if prob < 0.97: | ||
50 | + continue | ||
51 | + print('face detected. prob:', prob) | ||
52 | + x1, y1, x2, y2 = box | ||
53 | + image = frame[int(y1-3):int(y2+3), int(x1-3):int(x2+3)] | ||
54 | + image_list.append(image) | ||
55 | + print(image.shape) | ||
56 | + return image_list | ||
57 | + | ||
58 | +def make_face_list(frame): | ||
59 | + results, prob = mtcnn(frame, return_prob = True) | ||
60 | + face_list = [] | ||
61 | + if prob[0] == None: | ||
62 | + return [] | ||
63 | + for result, prob in zip(results, prob): | ||
64 | + if prob < 0.97: | ||
65 | + continue | ||
66 | + #np.float32 | ||
67 | + face_list.append(result.numpy()) | ||
68 | + return face_list | ||
69 | + | ||
70 | +if __name__ == '__main__': | ||
71 | + cap = cv2.VideoCapture(0, cv2.CAP_DSHOW) | ||
72 | + cap.set(3, 720) | ||
73 | + cap.set(4, 480) | ||
74 | + cv2.namedWindow("img", cv2.WINDOW_NORMAL) | ||
75 | + while True: | ||
76 | + try: | ||
77 | + ret, frame = cap.read() | ||
78 | + cv2.imshow('img', frame) | ||
79 | + cv2.waitKey(10) | ||
80 | + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
81 | + image_list = detect_face(frame) | ||
82 | + if not image_list: | ||
83 | + continue; | ||
84 | + face_list = make_face_list(frame) | ||
85 | + ##embedding server로 전송## | ||
86 | + if face_list: | ||
87 | + asyncio.get_event_loop().run_until_complete(send_face(face_list, image_list)) | ||
88 | + except Exception as ex: | ||
89 | + print(ex) |
client/legacy/detection.py
0 → 100644
1 | +import torch | ||
2 | +import numpy as np | ||
3 | +import cv2 | ||
4 | +import matplotlib.pyplot as plt | ||
5 | +import os | ||
6 | + | ||
7 | +from PIL import Image, ImageDraw | ||
8 | +from IPython import display | ||
9 | + | ||
10 | +from models import mtcnn | ||
11 | +from models import inception_resnet_v1 | ||
12 | + | ||
13 | +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
14 | +print('Running on device: {}'.format(device)) | ||
15 | + | ||
16 | +def extract_face(filename, required_size=(224, 224)): | ||
17 | + # If required, create a face detection pipeline using MTCNN: | ||
18 | + mtcnn_model = mtcnn.MTCNN(keep_all=True, device=device) | ||
19 | + pixels = plt.imread(os.path.join(os.path.abspath(''), filename)) | ||
20 | + results = mtcnn_model.detect(pixels) | ||
21 | + face_array = [] | ||
22 | + for box, prob in zip(results[0], results[1]): | ||
23 | + #boxes, _ = result | ||
24 | + print('face detected. prob:', prob) | ||
25 | + x1, y1, x2, y2 = box | ||
26 | + face = pixels[int(y1):int(y2), int(x1):int(x2)] | ||
27 | + image = Image.fromarray(face) | ||
28 | + image = image.resize(required_size) | ||
29 | + face_array.append(np.asarray(image)) | ||
30 | + return face_array | ||
31 | + | ||
32 | +face_array = extract_face('image/test1.jpg') | ||
33 | +for face in face_array: | ||
34 | + plt.figure() | ||
35 | + plt.imshow(face) | ||
36 | + plt.show() | ||
37 | + | ||
38 | +face_array = extract_face('image/test2.jpg') | ||
39 | +for face in face_array: | ||
40 | + plt.figure() | ||
41 | + plt.imshow(face) | ||
42 | + plt.show() | ||
43 | + |
client/legacy/image/test1.jpg
0 → 100644
337 KB
client/legacy/image/test2.jpg
0 → 100644
167 KB
No preview for this file type
No preview for this file type
client/models/data/onet.pt
0 → 100644
No preview for this file type
client/models/data/pnet.pt
0 → 100644
No preview for this file type
client/models/data/rnet.pt
0 → 100644
No preview for this file type
client/models/mtcnn.py
0 → 100644
This diff is collapsed. Click to expand it.
No preview for this file type
client/models/utils/detect_face.py
0 → 100644
This diff is collapsed. Click to expand it.
client/models/utils/tensorflow2pytorch.py
0 → 100644
This diff is collapsed. Click to expand it.
client/models/utils/training.py
0 → 100644
1 | +import torch | ||
2 | +import numpy as np | ||
3 | +import time | ||
4 | + | ||
5 | + | ||
6 | +class Logger(object): | ||
7 | + | ||
8 | + def __init__(self, mode, length, calculate_mean=False): | ||
9 | + self.mode = mode | ||
10 | + self.length = length | ||
11 | + self.calculate_mean = calculate_mean | ||
12 | + if self.calculate_mean: | ||
13 | + self.fn = lambda x, i: x / (i + 1) | ||
14 | + else: | ||
15 | + self.fn = lambda x, i: x | ||
16 | + | ||
17 | + def __call__(self, loss, metrics, i): | ||
18 | + track_str = '\r{} | {:5d}/{:<5d}| '.format(self.mode, i + 1, self.length) | ||
19 | + loss_str = 'loss: {:9.4f} | '.format(self.fn(loss, i)) | ||
20 | + metric_str = ' | '.join('{}: {:9.4f}'.format(k, self.fn(v, i)) for k, v in metrics.items()) | ||
21 | + print(track_str + loss_str + metric_str + ' ', end='') | ||
22 | + if i + 1 == self.length: | ||
23 | + print('') | ||
24 | + | ||
25 | + | ||
26 | +class BatchTimer(object): | ||
27 | + """Batch timing class. | ||
28 | + Use this class for tracking training and testing time/rate per batch or per sample. | ||
29 | + | ||
30 | + Keyword Arguments: | ||
31 | + rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds | ||
32 | + per batch or sample). (default: {True}) | ||
33 | + per_sample {bool} -- Whether to report times or rates per sample or per batch. | ||
34 | + (default: {True}) | ||
35 | + """ | ||
36 | + | ||
37 | + def __init__(self, rate=True, per_sample=True): | ||
38 | + self.start = time.time() | ||
39 | + self.end = None | ||
40 | + self.rate = rate | ||
41 | + self.per_sample = per_sample | ||
42 | + | ||
43 | + def __call__(self, y_pred, y): | ||
44 | + self.end = time.time() | ||
45 | + elapsed = self.end - self.start | ||
46 | + self.start = self.end | ||
47 | + self.end = None | ||
48 | + | ||
49 | + if self.per_sample: | ||
50 | + elapsed /= len(y_pred) | ||
51 | + if self.rate: | ||
52 | + elapsed = 1 / elapsed | ||
53 | + | ||
54 | + return torch.tensor(elapsed) | ||
55 | + | ||
56 | + | ||
57 | +def accuracy(logits, y): | ||
58 | + _, preds = torch.max(logits, 1) | ||
59 | + return (preds == y).float().mean() | ||
60 | + | ||
61 | + | ||
62 | +def pass_epoch( | ||
63 | + model, loss_fn, loader, optimizer=None, scheduler=None, | ||
64 | + batch_metrics={'time': BatchTimer()}, show_running=True, | ||
65 | + device='cpu', writer=None | ||
66 | +): | ||
67 | + """Train or evaluate over a data epoch. | ||
68 | + | ||
69 | + Arguments: | ||
70 | + model {torch.nn.Module} -- Pytorch model. | ||
71 | + loss_fn {callable} -- A function to compute (scalar) loss. | ||
72 | + loader {torch.utils.data.DataLoader} -- A pytorch data loader. | ||
73 | + | ||
74 | + Keyword Arguments: | ||
75 | + optimizer {torch.optim.Optimizer} -- A pytorch optimizer. | ||
76 | + scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None}) | ||
77 | + batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default | ||
78 | + is a simple timer. A progressive average of these metrics, along with the average | ||
79 | + loss, is printed every batch. (default: {{'time': iter_timer()}}) | ||
80 | + show_running {bool} -- Whether or not to print losses and metrics for the current batch | ||
81 | + or rolling averages. (default: {False}) | ||
82 | + device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'}) | ||
83 | + writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None}) | ||
84 | + | ||
85 | + Returns: | ||
86 | + tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average | ||
87 | + metric values across the epoch. | ||
88 | + """ | ||
89 | + | ||
90 | + mode = 'Train' if model.training else 'Valid' | ||
91 | + logger = Logger(mode, length=len(loader), calculate_mean=show_running) | ||
92 | + loss = 0 | ||
93 | + metrics = {} | ||
94 | + | ||
95 | + for i_batch, (x, y) in enumerate(loader): | ||
96 | + x = x.to(device) | ||
97 | + y = y.to(device) | ||
98 | + y_pred = model(x) | ||
99 | + loss_batch = loss_fn(y_pred, y) | ||
100 | + | ||
101 | + if model.training: | ||
102 | + loss_batch.backward() | ||
103 | + optimizer.step() | ||
104 | + optimizer.zero_grad() | ||
105 | + | ||
106 | + metrics_batch = {} | ||
107 | + for metric_name, metric_fn in batch_metrics.items(): | ||
108 | + metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu() | ||
109 | + metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name] | ||
110 | + | ||
111 | + if writer is not None and model.training: | ||
112 | + if writer.iteration % writer.interval == 0: | ||
113 | + writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration) | ||
114 | + for metric_name, metric_batch in metrics_batch.items(): | ||
115 | + writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration) | ||
116 | + writer.iteration += 1 | ||
117 | + | ||
118 | + loss_batch = loss_batch.detach().cpu() | ||
119 | + loss += loss_batch | ||
120 | + if show_running: | ||
121 | + logger(loss, metrics, i_batch) | ||
122 | + else: | ||
123 | + logger(loss_batch, metrics_batch, i_batch) | ||
124 | + | ||
125 | + if model.training and scheduler is not None: | ||
126 | + scheduler.step() | ||
127 | + | ||
128 | + loss = loss / (i_batch + 1) | ||
129 | + metrics = {k: v / (i_batch + 1) for k, v in metrics.items()} | ||
130 | + | ||
131 | + if writer is not None and not model.training: | ||
132 | + writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration) | ||
133 | + for metric_name, metric in metrics.items(): | ||
134 | + writer.add_scalars(metric_name, {mode: metric}) | ||
135 | + | ||
136 | + return loss, metrics | ||
137 | + | ||
138 | + | ||
139 | +def collate_pil(x): | ||
140 | + out_x, out_y = [], [] | ||
141 | + for xx, yy in x: | ||
142 | + out_x.append(xx) | ||
143 | + out_y.append(yy) | ||
144 | + return out_x, out_y |
-
Please register or login to post a comment