Showing
4 changed files
with
80 additions
and
17 deletions
No preview for this file type
No preview for this file type
No preview for this file type
1 | +import os | ||
1 | import torch | 2 | import torch |
2 | -import torch.multiprocessing as mp | ||
3 | import numpy as np | 3 | import numpy as np |
4 | -import os | ||
5 | import asyncio | 4 | import asyncio |
6 | import json | 5 | import json |
7 | import base64 | 6 | import base64 |
8 | import websockets | 7 | import websockets |
9 | from io import BytesIO | 8 | from io import BytesIO |
10 | 9 | ||
10 | +import pymysql | ||
11 | +import datetime | ||
12 | + | ||
11 | from PIL import Image, ImageDraw | 13 | from PIL import Image, ImageDraw |
12 | from IPython import display | 14 | from IPython import display |
13 | 15 | ||
... | @@ -18,6 +20,13 @@ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ... | @@ -18,6 +20,13 @@ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
18 | print('Running on device: {}'.format(device)) | 20 | print('Running on device: {}'.format(device)) |
19 | 21 | ||
20 | model = InceptionResnetV1().eval().to(device) | 22 | model = InceptionResnetV1().eval().to(device) |
23 | +attendance_db = pymysql.connect( | ||
24 | + user='root', | ||
25 | + passwd='5978', | ||
26 | + host='localhost', | ||
27 | + db='attendance', | ||
28 | + charset='utf8' | ||
29 | +) | ||
21 | 30 | ||
22 | lock = asyncio.Lock() | 31 | lock = asyncio.Lock() |
23 | clients = set() | 32 | clients = set() |
... | @@ -29,8 +38,8 @@ async def get_embeddings(face_list): | ... | @@ -29,8 +38,8 @@ async def get_embeddings(face_list): |
29 | yhat = model(x) | 38 | yhat = model(x) |
30 | return yhat | 39 | return yhat |
31 | 40 | ||
32 | -def get_distance(someone, database): | 41 | +async def get_distance(arr1, arr2): |
33 | - distance = [(someone - data).norm().item() for data in database] | 42 | + distance = (arr1 - arr2).norm().item() |
34 | return distance | 43 | return distance |
35 | 44 | ||
36 | def get_argmin(someone, database): | 45 | def get_argmin(someone, database): |
... | @@ -64,29 +73,83 @@ async def thread(websocket, path): | ... | @@ -64,29 +73,83 @@ async def thread(websocket, path): |
64 | # await websocket.send(state_event()) | 73 | # await websocket.send(state_event()) |
65 | async for message in websocket: | 74 | async for message in websocket: |
66 | data = json.loads(message) | 75 | data = json.loads(message) |
67 | - if data['action'] == 'register': | ||
68 | - # data['id'] | ||
69 | - face = np.asarray(data['MTCNN'], dtype = np.float32) | ||
70 | - face = face.reshape((1,3,160,160)) | ||
71 | remote_ip = websocket.remote_address[0] | 76 | remote_ip = websocket.remote_address[0] |
77 | + if data['action'] == 'register': | ||
78 | + # log | ||
72 | msg='[{ip}] register face'.format(ip=remote_ip) | 79 | msg='[{ip}] register face'.format(ip=remote_ip) |
73 | print(msg) | 80 | print(msg) |
74 | - embedding = await get_embeddings(face) | 81 | + |
75 | - await websocket.send('registered') | 82 | + # load json |
76 | - #await notify_state() | 83 | + student_id = data['student_id'] |
77 | - elif data['action'] == "verify": | 84 | + student_name = data['student_name'] |
78 | face = np.asarray(data['MTCNN'], dtype = np.float32) | 85 | face = np.asarray(data['MTCNN'], dtype = np.float32) |
79 | - print(face.shape) | ||
80 | face = face.reshape((1,3,160,160)) | 86 | face = face.reshape((1,3,160,160)) |
81 | - remote_ip = websocket.remote_address[0] | 87 | + |
88 | + # DB에 연결 | ||
89 | + cursor = attendance_db.cursor(pymysql.cursors.DictCursor) | ||
90 | + | ||
91 | + # 학생을 찾음 | ||
92 | + sql = "SELECT student_id FROM student WHERE student_id = %s;" | ||
93 | + cursor.execute(sql, (student_id)) | ||
94 | + | ||
95 | + # DB에 학생이 없으면 등록 | ||
96 | + if not cursor.fetchone(): | ||
97 | + sql = "insert into student(student_id, student_name) values (%s, %s)" | ||
98 | + cursor.execute(sql, (student_id, student_name)) | ||
99 | + attendance_db.commit() | ||
100 | + | ||
101 | + # student_embedding Table에 등록 | ||
102 | + embedding = await get_embeddings(face) | ||
103 | + embedding = embedding.detach().numpy().tobytes() | ||
104 | + embedding_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S') | ||
105 | + sql = "insert into student_embedding(student_id, embedding_date, embedding) values (%s, %s, %s)" | ||
106 | + cursor.execute(sql, (student_id, embedding_date, embedding)) | ||
107 | + attendance_db.commit() | ||
108 | + await websocket.send('{id} registered'.format(id=student_id)) | ||
109 | + elif data['action'] == "verify": | ||
110 | + # log | ||
82 | msg='[{ip}] verify face'.format(ip=remote_ip) | 111 | msg='[{ip}] verify face'.format(ip=remote_ip) |
83 | print(msg) | 112 | print(msg) |
113 | + ############### | ||
114 | + | ||
115 | + # load json | ||
116 | + face = np.asarray(data['MTCNN'], dtype = np.float32) | ||
117 | + face = face.reshape((1,3,160,160)) | ||
118 | + | ||
119 | + # embedding 구하기 | ||
84 | embedding = await get_embeddings(face) | 120 | embedding = await get_embeddings(face) |
85 | - # Todo: 아래 embedding.numpy()를 데이터베이스에 저장해야함. | 121 | + embedding = embedding.detach().numpy() |
86 | # embedding.numpy() | 122 | # embedding.numpy() |
87 | # [1, 512] numpy()임 | 123 | # [1, 512] numpy()임 |
88 | - # np.bytes() 명령으로 바꾼 뒤 np.frombuffer()로 불러오는 것이 좋을 듯. | 124 | + # np.frombuffer()로 불러오는 것이 좋을 듯. |
89 | - await websocket.send('정해갑') | 125 | + # DB에 연결 |
126 | + cursor = attendance_db.cursor(pymysql.cursors.DictCursor) | ||
127 | + | ||
128 | + # 학생을 찾음 | ||
129 | + sql = "SELECT student_id, embedding FROM student_embedding;" | ||
130 | + cursor.execute(sql) | ||
131 | + result = cursor.fetchall() | ||
132 | + verified_id = '0000000000' | ||
133 | + distance_min = 1 | ||
134 | + for row_data in result: | ||
135 | + db_embedding = np.frombuffer(row_data['embedding'], dtype=np.float32) | ||
136 | + db_embedding = db_embedding.reshape((1,512)) | ||
137 | + distance = get_distance(embedding, db_embedding) | ||
138 | + if (distance < distance_min): | ||
139 | + verified_id = row_data['student_id'] | ||
140 | + distance_min = distance | ||
141 | + | ||
142 | + # 출석 데이터 전송 | ||
143 | + data = '' | ||
144 | + if distance_min >= 0.6: | ||
145 | + # 해당하는 사람 DB에 없음 | ||
146 | + print('verification failed: not in DB') | ||
147 | + data = json.dumps({'state': 'fail'}) | ||
148 | + else: | ||
149 | + # 해당하는 사람 DB에 있음 | ||
150 | + print('verification success:', verified_id) | ||
151 | + data = json.dumps({'state': 'success', 'id': verified_id}) | ||
152 | + await websocket.send(data) | ||
90 | elif data['action'] == "save_image": | 153 | elif data['action'] == "save_image": |
91 | # 출석이 제대로 이뤄지지 않으면 이미지를 저장하여 | 154 | # 출석이 제대로 이뤄지지 않으면 이미지를 저장하여 |
92 | # 나중에 교강사가 출석을 확인할 수 있도록 한다 | 155 | # 나중에 교강사가 출석을 확인할 수 있도록 한다 | ... | ... |
-
Please register or login to post a comment