verification server.py
1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import numpy as np
import os
import asyncio
import json
import websockets
from io import BytesIO
from PIL import Image, ImageDraw
from IPython import display
from models.mtcnn import MTCNN
from models.inception_resnet_v1 import InceptionResnetV1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))
model = InceptionResnetV1().eval().to(device)
async def get_embeddings(face_list):
global model
x = torch.Tensor(face_list).to(device)
yhat = model(x)
return yhat
def get_distance(someone, database):
distance = [(someone - data).norm().item() for data in database]
return distance
def get_argmin(someone, database):
distance = get_distance(someone, database)
for i in range(len(distance)):
return np.argmin(distance)
return -1
async def recv_face(websocket, path):
buf = await websocket.recv()
face = np.frombuffer(buf, dtype = np.float32)
face = face.reshape((1,3,160,160))
remote_ip = websocket.remote_address[0]
msg='[{ip}] receive face properly, numpy shape={shape}'.format(ip=remote_ip, shape=face.shape)
print(msg)
embedding = await get_embeddings(face)
await websocket.send('100')
##embedding DB서버에 넘기기##
print('run verification server')
start_server = websockets.serve(recv_face, '0.0.0.0', 8765)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()