client.py
4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
import os
import requests
import subprocess
import time
import cv2
import torch
from models.experimental import attempt_load
from utils.datasets import LoadImages
from utils.general import check_img_size, non_max_suppression, set_logging, scale_coords
from utils.torch_utils import select_device, time_synchronized
SERVER_CHECK_ENDPOINT = 'http://mosaic.khunet.net'
WEIGHT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weight.pt')
INPUT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'input.mp4')
OUTPUT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'output.ts')
def download(url, file_path):
with open(file_path, 'wb') as file:
res = requests.get(url)
file.write(res.content)
def mosaic(src, ratio=0.07):
small = cv2.resize(src, None, fx=ratio, fy=ratio)
return cv2.resize(small, src.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
@torch.no_grad()
def detect(weight_path, input_path, output_path):
command = ['ffmpeg',
'-loglevel', 'panic',
'-y',
'-f', 'rawvideo',
'-pixel_format', 'bgr24',
'-video_size', "{}x{}".format(1280, 720),
'-framerate', str(30),
'-i', '-',
'-i', input_path,
'-c:a', 'copy',
'-map', '0:v:0',
'-map', '1:a:0',
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
'-preset', 'ultrafast',
output_path]
writer = subprocess.Popen(command, stdin=subprocess.PIPE)
source, weights, imgsz = input_path, weight_path, 640
# Initialize
set_logging()
device = select_device('')
# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
stride = int(model.stride.max()) # model stride
imgsz = check_img_size(imgsz, s=stride) # check img_size
names = model.module.names if hasattr(model, 'module') else model.names # get class names
# Set Dataloader
dataset = LoadImages(source, img_size=imgsz, stride=stride)
# Run inference
if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
t0 = time.time()
for path, img, im0s, vid_cap in dataset:
img = torch.from_numpy(img).to(device)
img = img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
t1 = time_synchronized()
pred = model(img, augment=False)[0]
# Apply NMS
pred = non_max_suppression(pred, max_det=1000)
t2 = time_synchronized()
# Process detections
for i, det in enumerate(pred): # detections per image
p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)
s += '%gx%g ' % img.shape[2:] # print string
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
# Write results
for *xyxy, conf, cls in reversed(det):
x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
src = im0[y1:y2, x1:x2]
dst = im0.copy()
dst[y1:y2, x1:x2] = mosaic(src)
im0 = dst
# Print time (inference + NMS)
# print(f'{s}Done. ({t2 - t1:.3f}s)')
# Save results (image with detections)
writer.stdin.write(im0.tobytes())
writer.stdin.close()
writer.wait()
print(f'Done. ({time.time() - t0:.3f}s)')
if __name__ == '__main__':
while True:
try:
response = requests.get(SERVER_CHECK_ENDPOINT + '/check')
data = json.loads(response.text)
if data['data'] is None:
continue
download(SERVER_CHECK_ENDPOINT + '/origin/' + data['data'], INPUT_PATH)
detect(WEIGHT_PATH, INPUT_PATH, OUTPUT_PATH)
response = requests.post(SERVER_CHECK_ENDPOINT + '/upload', files={'file': (data['data'], open(OUTPUT_PATH, 'rb'))})
print(data['data'] + ' : ' + response.text)
except:
print('Error!')
time.sleep(0.5)