Graduate

Connect DB

1 +import os
1 import torch 2 import torch
2 -import torch.multiprocessing as mp
3 import numpy as np 3 import numpy as np
4 -import os
5 import asyncio 4 import asyncio
6 import json 5 import json
7 import base64 6 import base64
8 import websockets 7 import websockets
9 from io import BytesIO 8 from io import BytesIO
10 9
10 +import pymysql
11 +import datetime
12 +
11 from PIL import Image, ImageDraw 13 from PIL import Image, ImageDraw
12 from IPython import display 14 from IPython import display
13 15
...@@ -18,6 +20,13 @@ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') ...@@ -18,6 +20,13 @@ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
18 print('Running on device: {}'.format(device)) 20 print('Running on device: {}'.format(device))
19 21
20 model = InceptionResnetV1().eval().to(device) 22 model = InceptionResnetV1().eval().to(device)
23 +attendance_db = pymysql.connect(
24 + user='root',
25 + passwd='5978',
26 + host='localhost',
27 + db='attendance',
28 + charset='utf8'
29 +)
21 30
22 lock = asyncio.Lock() 31 lock = asyncio.Lock()
23 clients = set() 32 clients = set()
...@@ -29,8 +38,8 @@ async def get_embeddings(face_list): ...@@ -29,8 +38,8 @@ async def get_embeddings(face_list):
29 yhat = model(x) 38 yhat = model(x)
30 return yhat 39 return yhat
31 40
32 -def get_distance(someone, database): 41 +async def get_distance(arr1, arr2):
33 - distance = [(someone - data).norm().item() for data in database] 42 + distance = (arr1 - arr2).norm().item()
34 return distance 43 return distance
35 44
36 def get_argmin(someone, database): 45 def get_argmin(someone, database):
...@@ -64,29 +73,83 @@ async def thread(websocket, path): ...@@ -64,29 +73,83 @@ async def thread(websocket, path):
64 # await websocket.send(state_event()) 73 # await websocket.send(state_event())
65 async for message in websocket: 74 async for message in websocket:
66 data = json.loads(message) 75 data = json.loads(message)
67 - if data['action'] == 'register':
68 - # data['id']
69 - face = np.asarray(data['MTCNN'], dtype = np.float32)
70 - face = face.reshape((1,3,160,160))
71 remote_ip = websocket.remote_address[0] 76 remote_ip = websocket.remote_address[0]
77 + if data['action'] == 'register':
78 + # log
72 msg='[{ip}] register face'.format(ip=remote_ip) 79 msg='[{ip}] register face'.format(ip=remote_ip)
73 print(msg) 80 print(msg)
74 - embedding = await get_embeddings(face) 81 +
75 - await websocket.send('registered') 82 + # load json
76 - #await notify_state() 83 + student_id = data['student_id']
77 - elif data['action'] == "verify": 84 + student_name = data['student_name']
78 face = np.asarray(data['MTCNN'], dtype = np.float32) 85 face = np.asarray(data['MTCNN'], dtype = np.float32)
79 - print(face.shape)
80 face = face.reshape((1,3,160,160)) 86 face = face.reshape((1,3,160,160))
81 - remote_ip = websocket.remote_address[0] 87 +
88 + # DB에 연결
89 + cursor = attendance_db.cursor(pymysql.cursors.DictCursor)
90 +
91 + # 학생을 찾음
92 + sql = "SELECT student_id FROM student WHERE student_id = %s;"
93 + cursor.execute(sql, (student_id))
94 +
95 + # DB에 학생이 없으면 등록
96 + if not cursor.fetchone():
97 + sql = "insert into student(student_id, student_name) values (%s, %s)"
98 + cursor.execute(sql, (student_id, student_name))
99 + attendance_db.commit()
100 +
101 + # student_embedding Table에 등록
102 + embedding = await get_embeddings(face)
103 + embedding = embedding.detach().numpy().tobytes()
104 + embedding_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
105 + sql = "insert into student_embedding(student_id, embedding_date, embedding) values (%s, %s, %s)"
106 + cursor.execute(sql, (student_id, embedding_date, embedding))
107 + attendance_db.commit()
108 + await websocket.send('{id} registered'.format(id=student_id))
109 + elif data['action'] == "verify":
110 + # log
82 msg='[{ip}] verify face'.format(ip=remote_ip) 111 msg='[{ip}] verify face'.format(ip=remote_ip)
83 print(msg) 112 print(msg)
113 + ###############
114 +
115 + # load json
116 + face = np.asarray(data['MTCNN'], dtype = np.float32)
117 + face = face.reshape((1,3,160,160))
118 +
119 + # embedding 구하기
84 embedding = await get_embeddings(face) 120 embedding = await get_embeddings(face)
85 - # Todo: 아래 embedding.numpy()를 데이터베이스에 저장해야함. 121 + embedding = embedding.detach().numpy()
86 # embedding.numpy() 122 # embedding.numpy()
87 # [1, 512] numpy()임 123 # [1, 512] numpy()임
88 - # np.bytes() 명령으로 바꾼 뒤 np.frombuffer()로 불러오는 것이 좋을 듯. 124 + # np.frombuffer()로 불러오는 것이 좋을 듯.
89 - await websocket.send('정해갑') 125 + # DB에 연결
126 + cursor = attendance_db.cursor(pymysql.cursors.DictCursor)
127 +
128 + # 학생을 찾음
129 + sql = "SELECT student_id, embedding FROM student_embedding;"
130 + cursor.execute(sql)
131 + result = cursor.fetchall()
132 + verified_id = '0000000000'
133 + distance_min = 1
134 + for row_data in result:
135 + db_embedding = np.frombuffer(row_data['embedding'], dtype=np.float32)
136 + db_embedding = db_embedding.reshape((1,512))
137 + distance = get_distance(embedding, db_embedding)
138 + if (distance < distance_min):
139 + verified_id = row_data['student_id']
140 + distance_min = distance
141 +
142 + # 출석 데이터 전송
143 + data = ''
144 + if distance_min >= 0.6:
145 + # 해당하는 사람 DB에 없음
146 + print('verification failed: not in DB')
147 + data = json.dumps({'state': 'fail'})
148 + else:
149 + # 해당하는 사람 DB에 있음
150 + print('verification success:', verified_id)
151 + data = json.dumps({'state': 'success', 'id': verified_id})
152 + await websocket.send(data)
90 elif data['action'] == "save_image": 153 elif data['action'] == "save_image":
91 # 출석이 제대로 이뤄지지 않으면 이미지를 저장하여 154 # 출석이 제대로 이뤄지지 않으면 이미지를 저장하여
92 # 나중에 교강사가 출석을 확인할 수 있도록 한다 155 # 나중에 교강사가 출석을 확인할 수 있도록 한다
......