Showing
2 changed files
with
163 additions
and
0 deletions
Code/bird_classficate_example.py
0 → 100644
| 1 | +#***************************************************** | ||
| 2 | +# * | ||
| 3 | +# Copyright 2018 Amazon.com, Inc. or its affiliates. * | ||
| 4 | +# All Rights Reserved. * | ||
| 5 | +# * | ||
| 6 | +#***************************************************** | ||
| 7 | +""" A sample lambda for bird detection""" | ||
| 8 | +from threading import Thread, Event | ||
| 9 | +import os | ||
| 10 | +import json | ||
| 11 | +import numpy as np | ||
| 12 | +import awscam | ||
| 13 | +import cv2 | ||
| 14 | +import mo | ||
| 15 | +import greengrasssdk | ||
| 16 | + | ||
| 17 | +class LocalDisplay(Thread): | ||
| 18 | + """ Class for facilitating the local display of inference results | ||
| 19 | + (as images). The class is designed to run on its own thread. In | ||
| 20 | + particular the class dumps the inference results into a FIFO | ||
| 21 | + located in the tmp directory (which lambda has access to). The | ||
| 22 | + results can be rendered using mplayer by typing: | ||
| 23 | + mplayer -demuxer lavf -lavfdopts format=mjpeg:probesize=32 /tmp/results.mjpeg | ||
| 24 | + """ | ||
| 25 | + def __init__(self, resolution): | ||
| 26 | + """ resolution - Desired resolution of the project stream""" | ||
| 27 | + super(LocalDisplay, self).__init__() | ||
| 28 | + # List of valid resolutions | ||
| 29 | + RESOLUTION = {'1080p' : (1920, 1080), '720p' : (1280, 720), '480p' : (858, 480)} | ||
| 30 | + if resolution not in RESOLUTION: | ||
| 31 | + raise Exception("Invalid resolution") | ||
| 32 | + self.resolution = RESOLUTION[resolution] | ||
| 33 | + # Initialize the default image to be a white canvas. Clients | ||
| 34 | + # will update the image when ready. | ||
| 35 | + self.frame = cv2.imencode('.jpg', 255*np.ones([640, 480, 3]))[1] | ||
| 36 | + self.stop_request = Event() | ||
| 37 | + | ||
| 38 | + def run(self): | ||
| 39 | + """ Overridden method that continually dumps images to the desired | ||
| 40 | + FIFO file. | ||
| 41 | + """ | ||
| 42 | + # Path to the FIFO file. The lambda only has permissions to the tmp | ||
| 43 | + # directory. Pointing to a FIFO file in another directory | ||
| 44 | + # will cause the lambda to crash. | ||
| 45 | + result_path = '/tmp/results.mjpeg' | ||
| 46 | + # Create the FIFO file if it doesn't exist. | ||
| 47 | + if not os.path.exists(result_path): | ||
| 48 | + os.mkfifo(result_path) | ||
| 49 | + # This call will block until a consumer is available | ||
| 50 | + with open(result_path, 'w') as fifo_file: | ||
| 51 | + while not self.stop_request.isSet(): | ||
| 52 | + try: | ||
| 53 | + # Write the data to the FIFO file. This call will block | ||
| 54 | + # meaning the code will come to a halt here until a consumer | ||
| 55 | + # is available. | ||
| 56 | + fifo_file.write(self.frame.tobytes()) | ||
| 57 | + except IOError: | ||
| 58 | + continue | ||
| 59 | + | ||
| 60 | + def set_frame_data(self, frame): | ||
| 61 | + """ Method updates the image data. This currently encodes the | ||
| 62 | + numpy array to jpg but can be modified to support other encodings. | ||
| 63 | + frame - Numpy array containing the image data of the next frame | ||
| 64 | + in the project stream. | ||
| 65 | + """ | ||
| 66 | + ret, jpeg = cv2.imencode('.jpg', cv2.resize(frame, self.resolution)) | ||
| 67 | + if not ret: | ||
| 68 | + raise Exception('Failed to set frame data') | ||
| 69 | + self.frame = jpeg | ||
| 70 | + | ||
| 71 | + def join(self): | ||
| 72 | + self.stop_request.set() | ||
| 73 | + | ||
| 74 | +def infinite_infer_run(): | ||
| 75 | + """ Entry point of the lambda function""" | ||
| 76 | + try: | ||
| 77 | + # This bird detection model is implemented as multi classifier. The number of labels | ||
| 78 | + # is quite large so we upload them to a list to map the machine labels to human readable | ||
| 79 | + # labels. | ||
| 80 | + model_type = 'classification' | ||
| 81 | + with open('labels.txt', 'r') as labels_file: | ||
| 82 | + output_map = [class_label.rstrip() for class_label in labels_file] | ||
| 83 | + # Create an IoT client for sending to messages to the cloud. | ||
| 84 | + client = greengrasssdk.client('iot-data') | ||
| 85 | + iot_topic = '$aws/things/{}/infer'.format(os.environ['AWS_IOT_THING_NAME']) | ||
| 86 | + # Create a local display instance that will dump the image bytes to a FIFO | ||
| 87 | + # file that the image can be rendered locally. | ||
| 88 | + local_display = LocalDisplay('480p') | ||
| 89 | + local_display.start() | ||
| 90 | + # The height and width of the training set images | ||
| 91 | + input_height = 224 | ||
| 92 | + input_width = 224 | ||
| 93 | + # The sample projects come with optimized artifacts, hence only the artifact | ||
| 94 | + # path is required. | ||
| 95 | + ret, model_path = mo.optimize('bird_classification_resnet-18', input_width, | ||
| 96 | + input_height, 'mx') | ||
| 97 | + # Load the model onto the GPU. | ||
| 98 | + client.publish(topic=iot_topic, payload='Loading bird detection model') | ||
| 99 | + model = awscam.Model(model_path, {'GPU': 1}) | ||
| 100 | + client.publish(topic=iot_topic, payload='Bird detection loaded') | ||
| 101 | + # The number of top results to stream to IoT. | ||
| 102 | + num_top_k = 5 | ||
| 103 | + # Define the detection region size. | ||
| 104 | + region_size = 800 | ||
| 105 | + # Define the inference display region size. This size was decided based on the longest label. | ||
| 106 | + label_region_width = 940 | ||
| 107 | + label_region_height = 600 | ||
| 108 | + # Heading for the inference display. | ||
| 109 | + prediction_label = 'Top 5 bird predictions' | ||
| 110 | + # Do inference until the lambda is killed. | ||
| 111 | + while True: | ||
| 112 | + # Get a frame from the video stream | ||
| 113 | + ret, frame = awscam.getLastFrame() | ||
| 114 | + if not ret: | ||
| 115 | + raise Exception('Failed to get frame from the stream') | ||
| 116 | + # Crop the detection region for inference. | ||
| 117 | + frame_crop = frame[int(frame.shape[0]/2-region_size/2):int(frame.shape[0]/2+region_size/2), \ | ||
| 118 | + int(frame.shape[1]/2-region_size/2):int(frame.shape[1]/2+region_size/2), :] | ||
| 119 | + # Resize frame to the same size as the training set. | ||
| 120 | + frame_resize = cv2.resize(frame_crop, (input_height, input_width)) | ||
| 121 | + # Model was trained in RGB format but getLastFrame returns image | ||
| 122 | + # in BGR format so need to switch. | ||
| 123 | + frame_resize = cv2.cvtColor(frame_resize, cv2.COLOR_BGR2RGB) | ||
| 124 | + # Run the images through the inference engine and parse the results using | ||
| 125 | + # the parser API, note it is possible to get the output of doInference | ||
| 126 | + # and do the parsing manually, but since it is a classification model, | ||
| 127 | + # a simple API is provided. | ||
| 128 | + parsed_inference_results = model.parseResult(model_type, | ||
| 129 | + model.doInference(frame_resize)) | ||
| 130 | + # Get top k results with highest probabilities | ||
| 131 | + top_k = parsed_inference_results[model_type][0:num_top_k] | ||
| 132 | + # Create a copy of the original frame. | ||
| 133 | + overlay = frame.copy() | ||
| 134 | + # Create the rectangle that shows the inference results. | ||
| 135 | + cv2.rectangle(overlay, (0, 0), \ | ||
| 136 | + (int(label_region_width), int(label_region_height)), (211,211,211), -1) | ||
| 137 | + # Blend with the original frame. | ||
| 138 | + opacity = 0.7 | ||
| 139 | + cv2.addWeighted(overlay, opacity, frame, 1 - opacity, 0, frame) | ||
| 140 | + # Add the header for the inference results. | ||
| 141 | + cv2.putText(frame, prediction_label, (0, 50), | ||
| 142 | + cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 4) | ||
| 143 | + # Add the label along with the probability of the top result to the frame used by local display. | ||
| 144 | + # See https://docs.opencv.org/3.4.1/d6/d6e/group__imgproc__draw.html | ||
| 145 | + # for more information about the cv2.putText method. | ||
| 146 | + # Method signature: image, text, origin, font face, font scale, color, and tickness | ||
| 147 | + for i in range(num_top_k): | ||
| 148 | + cv2.putText(frame, output_map[top_k[i]['label']] + ' ' + str(round(top_k[i]['prob'], 3) * 100) + '%', \ | ||
| 149 | + (0, 100*i+150), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 0, 0), 3) | ||
| 150 | + # Display the detection region. | ||
| 151 | + cv2.rectangle(frame, (int(frame.shape[1]/2-region_size/2), int(frame.shape[0]/2-region_size/2)), \ | ||
| 152 | + (int(frame.shape[1]/2+region_size/2), int(frame.shape[0]/2+region_size/2)), (255,0,0), 5) | ||
| 153 | + # Set the next frame in the local display stream. | ||
| 154 | + local_display.set_frame_data(frame) | ||
| 155 | + # Send the top k results to the IoT console via MQTT | ||
| 156 | + cloud_output = {} | ||
| 157 | + for obj in top_k: | ||
| 158 | + cloud_output[output_map[obj['label']]] = obj['prob'] | ||
| 159 | + client.publish(topic=iot_topic, payload=json.dumps(cloud_output)) | ||
| 160 | + except Exception as ex: | ||
| 161 | + client.publish(topic=iot_topic, payload='Error in bird detection lambda: {}'.format(ex)) | ||
| 162 | + | ||
| 163 | +infinite_infer_run() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
Report/wk12 주간보고서.docx
0 → 100644
No preview for this file type
-
Please register or login to post a comment