이민규

add lambda code

1 +#*****************************************************
2 +# *
3 +# Copyright 2018 Amazon.com, Inc. or its affiliates. *
4 +# All Rights Reserved. *
5 +# *
6 +#*****************************************************
7 +""" A sample lambda for bird detection"""
8 +from threading import Thread, Event
9 +import os
10 +import json
11 +import numpy as np
12 +import awscam
13 +import cv2
14 +import mo
15 +import greengrasssdk
16 +
17 +class LocalDisplay(Thread):
18 + """ Class for facilitating the local display of inference results
19 + (as images). The class is designed to run on its own thread. In
20 + particular the class dumps the inference results into a FIFO
21 + located in the tmp directory (which lambda has access to). The
22 + results can be rendered using mplayer by typing:
23 + mplayer -demuxer lavf -lavfdopts format=mjpeg:probesize=32 /tmp/results.mjpeg
24 + """
25 + def __init__(self, resolution):
26 + """ resolution - Desired resolution of the project stream"""
27 + super(LocalDisplay, self).__init__()
28 + # List of valid resolutions
29 + RESOLUTION = {'1080p' : (1920, 1080), '720p' : (1280, 720), '480p' : (858, 480)}
30 + if resolution not in RESOLUTION:
31 + raise Exception("Invalid resolution")
32 + self.resolution = RESOLUTION[resolution]
33 + # Initialize the default image to be a white canvas. Clients
34 + # will update the image when ready.
35 + self.frame = cv2.imencode('.jpg', 255*np.ones([640, 480, 3]))[1]
36 + self.stop_request = Event()
37 +
38 + def run(self):
39 + """ Overridden method that continually dumps images to the desired
40 + FIFO file.
41 + """
42 + # Path to the FIFO file. The lambda only has permissions to the tmp
43 + # directory. Pointing to a FIFO file in another directory
44 + # will cause the lambda to crash.
45 + result_path = '/tmp/results.mjpeg'
46 + # Create the FIFO file if it doesn't exist.
47 + if not os.path.exists(result_path):
48 + os.mkfifo(result_path)
49 + # This call will block until a consumer is available
50 + with open(result_path, 'w') as fifo_file:
51 + while not self.stop_request.isSet():
52 + try:
53 + # Write the data to the FIFO file. This call will block
54 + # meaning the code will come to a halt here until a consumer
55 + # is available.
56 + fifo_file.write(self.frame.tobytes())
57 + except IOError:
58 + continue
59 +
60 + def set_frame_data(self, frame):
61 + """ Method updates the image data. This currently encodes the
62 + numpy array to jpg but can be modified to support other encodings.
63 + frame - Numpy array containing the image data of the next frame
64 + in the project stream.
65 + """
66 + ret, jpeg = cv2.imencode('.jpg', cv2.resize(frame, self.resolution))
67 + if not ret:
68 + raise Exception('Failed to set frame data')
69 + self.frame = jpeg
70 +
71 + def join(self):
72 + self.stop_request.set()
73 +
74 +def infinite_infer_run():
75 + """ Entry point of the lambda function"""
76 + try:
77 + # is quite large so we upload them to a list to map the machine labels to human readable
78 + # labels.
79 + model_type = 'classification'
80 + with open('caltech256_labels.txt', 'r') as labels_file:
81 + output_map = [class_label.rstrip() for class_label in labels_file]
82 + # Create an IoT client for sending to messages to the cloud.
83 + client = greengrasssdk.client('iot-data')
84 + iot_topic = '$aws/things/{}/infer'.format(os.environ['AWS_IOT_THING_NAME'])
85 + # Create a local display instance that will dump the image bytes to a FIFO
86 + # file that the image can be rendered locally.
87 + local_display = LocalDisplay('480p')
88 + local_display.start()
89 + # The height and width of the training set images
90 + input_height = 224
91 + input_width = 224
92 + # The sample projects come with optimized artifacts, hence only the artifact
93 + # path is required.
94 + model_name = "image-classification"
95 + ret, model_path = mo.optimize(model_name, input_width, input_height, aux_inputs={'--epoch': 5,'--precision':'FP16'})
96 + # Load the model onto the GPU.
97 + client.publish(topic=iot_topic, payload='Loading bird detection model')
98 + model = awscam.Model(model_path, {'GPU': 1})
99 + client.publish(topic=iot_topic, payload='Bird detection loaded')
100 + # The number of top results to stream to IoT.
101 + num_top_k = 3
102 + # Define the detection region size.
103 + region_size = 1500
104 + # Define the inference display region size. This size was decided based on the longest label.
105 + label_region_width = 590
106 + label_region_height = 400
107 + # Heading for the inference display.
108 +
109 + # Do inference until the lambda is killed.
110 + while True:
111 + # Get a frame from the video stream
112 + ret, frame = awscam.getLastFrame()
113 + if not ret:
114 + raise Exception('Failed to get frame from the stream')
115 + # Crop the detection region for inference.
116 + frame_crop = frame[int(frame.shape[0]/2-region_size/2):int(frame.shape[0]/2+region_size/2), \
117 + int(frame.shape[1]/2-region_size/2):int(frame.shape[1]/2+region_size/2), :]
118 + # Resize frame to the same size as the training set.
119 + frame_resize = cv2.resize(frame_crop, (input_height, input_width))
120 + # Model was trained in RGB format but getLastFrame returns image
121 + # in BGR format so need to switch.
122 + frame_resize = cv2.cvtColor(frame_resize, cv2.COLOR_BGR2RGB)
123 + # Run the images through the inference engine and parse the results using
124 + # the parser API, note it is possible to get the output of doInference
125 + # and do the parsing manually, but since it is a classification model,
126 + # a simple API is provided.
127 + parsed_inference_results = model.parseResult(model_type,
128 + model.doInference(frame_resize))
129 + # Get top k results with highest probabilities
130 + top_k = parsed_inference_results[model_type][0:num_top_k]
131 + # Create a copy of the original frame.
132 + overlay = frame.copy()
133 + # Create the rectangle that shows the inference results.
134 + cv2.rectangle(overlay, (0, 0), \
135 + (int(label_region_width), int(label_region_height)), (211,211,211), -1)
136 + # Blend with the original frame.
137 + opacity = 0.7
138 + cv2.addWeighted(overlay, opacity, frame, 1 - opacity, 0, frame)
139 + # Add the header for the inference results.
140 +
141 + # Add the label along with the probability of the top result to the frame used by local display.
142 + # See https://docs.opencv.org/3.4.1/d6/d6e/group__imgproc__draw.html
143 + # for more information about the cv2.putText method.
144 + # Method signature: image, text, origin, font face, font scale, color, and tickness
145 + if top_k[0]['prob']*100 < 60 :
146 + cv2.putText(frame,'Take your pose',(0,50), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 0, 0), 3)
147 + else :
148 + cv2.putText(frame, output_map[top_k[0]['label']] + ' ' + str(round(top_k[0]['prob'], 3) * 100) + '%', \
149 + (0, 150), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 0, 0), 3)
150 +
151 + # Display the detection region.
152 + cv2.rectangle(frame, (int(frame.shape[1]/2-region_size/2), int(frame.shape[0]/2-region_size/2)), \
153 + (int(frame.shape[1]/2+region_size/2), int(frame.shape[0]/2+region_size/2)), (255,0,0), 5)
154 + # Set the next frame in the local display stream.
155 + local_display.set_frame_data(frame)
156 + # Send the top k results to the IoT console via MQTT
157 + cloud_output = {}
158 + for obj in top_k:
159 + cloud_output[output_map[obj['label']]] = obj['prob']
160 + client.publish(topic=iot_topic, payload=json.dumps(cloud_output))
161 + except Exception as ex:
162 + client.publish(topic=iot_topic, payload='Error in bird detection lambda: {}'.format(ex))
163 +
164 +infinite_infer_run()