Graduate

Commit for merge

No preview for this file type
No preview for this file type
No preview for this file type
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
1 +import torch
2 +import numpy as np
3 +import time
4 +
5 +
6 +class Logger(object):
7 +
8 + def __init__(self, mode, length, calculate_mean=False):
9 + self.mode = mode
10 + self.length = length
11 + self.calculate_mean = calculate_mean
12 + if self.calculate_mean:
13 + self.fn = lambda x, i: x / (i + 1)
14 + else:
15 + self.fn = lambda x, i: x
16 +
17 + def __call__(self, loss, metrics, i):
18 + track_str = '\r{} | {:5d}/{:<5d}| '.format(self.mode, i + 1, self.length)
19 + loss_str = 'loss: {:9.4f} | '.format(self.fn(loss, i))
20 + metric_str = ' | '.join('{}: {:9.4f}'.format(k, self.fn(v, i)) for k, v in metrics.items())
21 + print(track_str + loss_str + metric_str + ' ', end='')
22 + if i + 1 == self.length:
23 + print('')
24 +
25 +
26 +class BatchTimer(object):
27 + """Batch timing class.
28 + Use this class for tracking training and testing time/rate per batch or per sample.
29 +
30 + Keyword Arguments:
31 + rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds
32 + per batch or sample). (default: {True})
33 + per_sample {bool} -- Whether to report times or rates per sample or per batch.
34 + (default: {True})
35 + """
36 +
37 + def __init__(self, rate=True, per_sample=True):
38 + self.start = time.time()
39 + self.end = None
40 + self.rate = rate
41 + self.per_sample = per_sample
42 +
43 + def __call__(self, y_pred, y):
44 + self.end = time.time()
45 + elapsed = self.end - self.start
46 + self.start = self.end
47 + self.end = None
48 +
49 + if self.per_sample:
50 + elapsed /= len(y_pred)
51 + if self.rate:
52 + elapsed = 1 / elapsed
53 +
54 + return torch.tensor(elapsed)
55 +
56 +
57 +def accuracy(logits, y):
58 + _, preds = torch.max(logits, 1)
59 + return (preds == y).float().mean()
60 +
61 +
62 +def pass_epoch(
63 + model, loss_fn, loader, optimizer=None, scheduler=None,
64 + batch_metrics={'time': BatchTimer()}, show_running=True,
65 + device='cpu', writer=None
66 +):
67 + """Train or evaluate over a data epoch.
68 +
69 + Arguments:
70 + model {torch.nn.Module} -- Pytorch model.
71 + loss_fn {callable} -- A function to compute (scalar) loss.
72 + loader {torch.utils.data.DataLoader} -- A pytorch data loader.
73 +
74 + Keyword Arguments:
75 + optimizer {torch.optim.Optimizer} -- A pytorch optimizer.
76 + scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None})
77 + batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default
78 + is a simple timer. A progressive average of these metrics, along with the average
79 + loss, is printed every batch. (default: {{'time': iter_timer()}})
80 + show_running {bool} -- Whether or not to print losses and metrics for the current batch
81 + or rolling averages. (default: {False})
82 + device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'})
83 + writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None})
84 +
85 + Returns:
86 + tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average
87 + metric values across the epoch.
88 + """
89 +
90 + mode = 'Train' if model.training else 'Valid'
91 + logger = Logger(mode, length=len(loader), calculate_mean=show_running)
92 + loss = 0
93 + metrics = {}
94 +
95 + for i_batch, (x, y) in enumerate(loader):
96 + x = x.to(device)
97 + y = y.to(device)
98 + y_pred = model(x)
99 + loss_batch = loss_fn(y_pred, y)
100 +
101 + if model.training:
102 + loss_batch.backward()
103 + optimizer.step()
104 + optimizer.zero_grad()
105 +
106 + metrics_batch = {}
107 + for metric_name, metric_fn in batch_metrics.items():
108 + metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu()
109 + metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name]
110 +
111 + if writer is not None and model.training:
112 + if writer.iteration % writer.interval == 0:
113 + writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration)
114 + for metric_name, metric_batch in metrics_batch.items():
115 + writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration)
116 + writer.iteration += 1
117 +
118 + loss_batch = loss_batch.detach().cpu()
119 + loss += loss_batch
120 + if show_running:
121 + logger(loss, metrics, i_batch)
122 + else:
123 + logger(loss_batch, metrics_batch, i_batch)
124 +
125 + if model.training and scheduler is not None:
126 + scheduler.step()
127 +
128 + loss = loss / (i_batch + 1)
129 + metrics = {k: v / (i_batch + 1) for k, v in metrics.items()}
130 +
131 + if writer is not None and not model.training:
132 + writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration)
133 + for metric_name, metric in metrics.items():
134 + writer.add_scalars(metric_name, {mode: metric})
135 +
136 + return loss, metrics
137 +
138 +
139 +def collate_pil(x):
140 + out_x, out_y = [], []
141 + for xx, yy in x:
142 + out_x.append(xx)
143 + out_y.append(yy)
144 + return out_x, out_y
1 +##################################################
2 +#1. webcam에서 얼굴을 인식합니다
3 +#2. 인식한 얼굴을 등록합니다
4 +##################################################
5 +import torch
6 +import numpy as np
7 +import cv2
8 +import asyncio
9 +import websockets
10 +import json
11 +import os
12 +import timeit
13 +import base64
14 +
15 +from PIL import Image
16 +from io import BytesIO
17 +import requests
18 +
19 +from models.mtcnn import MTCNN
20 +
21 +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
22 +print('Running on device: {}'.format(device))
23 +
24 +mtcnn = MTCNN(keep_all=True, device=device)
25 +
26 +uri = 'ws://169.56.95.131:8765'
27 +
28 +async def send_face(face_list, image_list):
29 + global uri
30 + async with websockets.connect(uri) as websocket:
31 + for face, image in zip(face_list, image_list):
32 + #type: np.float32
33 + send = json.dumps({'action': 'register', 'student_id':'2014104149', 'student_name':'정해갑', 'MTCNN': face.tolist()})
34 + await websocket.send(send)
35 + recv = await websocket.recv()
36 + data = json.loads(recv)
37 + if data['status'] == 'success':
38 + # 성공
39 + print(data['student_id'], 'is registered')
40 +
41 +def detect_face(frame):
42 + # If required, create a face detection pipeline using MTCNN:
43 + global mtcnn
44 + results = mtcnn.detect(frame)
45 + image_list = []
46 + if results[1][0] == None:
47 + return []
48 + for box, prob in zip(results[0], results[1]):
49 + if prob < 0.95:
50 + continue
51 + print('face detected. prob:', prob)
52 + x1, y1, x2, y2 = box
53 + image = frame[int(y1-10):int(y2+10), int(x1-10):int(x2+10)]
54 + image_list.append(image)
55 + return image_list
56 +
57 +def detect_face(frame):
58 + results = mtcnn.detect(frame)
59 + faces = mtcnn(frame, return_prob = False)
60 + image_list = []
61 + face_list = []
62 + if results[1][0] == None:
63 + return [], []
64 + for box, face, prob in zip(results[0], faces, results[1]):
65 + if prob < 0.97:
66 + continue
67 + print('face detected. prob:', prob)
68 + x1, y1, x2, y2 = box
69 + if (x2-x1) * (y2-y1) < 15000:
70 + # 얼굴 해상도가 너무 낮으면 무시
71 + continue
72 + # 얼굴 주변 ±3 영역 저장
73 + image = frame[int(y1-3):int(y2+3), int(x1-3):int(x2+3)]
74 + image_list.append(image)
75 + # MTCNN 데이터 저장
76 + face_list.append(face.numpy())
77 + return image_list, face_list
78 +
79 +cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
80 +cap.set(3, 720)
81 +cap.set(4, 480)
82 +ret, frame = cap.read()
83 +frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
84 +image_list, face_list = detect_face(frame)
85 +if face_list:
86 + asyncio.get_event_loop().run_until_complete(send_face(face_list, image_list))
...\ No newline at end of file ...\ No newline at end of file