Showing
2 changed files
with
1405 additions
and
0 deletions
tensorflow/retrain.py
0 → 100644
1 | +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +# ============================================================================== | ||
15 | +# NOTICE: This work was derived from tensorflow/examples/image_retraining | ||
16 | +# and modified to use TensorFlow Hub modules. | ||
17 | + | ||
18 | +# pylint: disable=line-too-long | ||
19 | +r"""Simple transfer learning with image modules from TensorFlow Hub. | ||
20 | + | ||
21 | +This example shows how to train an image classifier based on any | ||
22 | +TensorFlow Hub module that computes image feature vectors. By default, | ||
23 | +it uses the feature vectors computed by Inception V3 trained on ImageNet. | ||
24 | +For more options, search https://tfhub.dev for image feature vector modules. | ||
25 | + | ||
26 | +The top layer receives as input a 2048-dimensional vector (assuming | ||
27 | +Inception V3) for each image. We train a softmax layer on top of this | ||
28 | +representation. If the softmax layer contains N labels, this corresponds | ||
29 | +to learning N + 2048*N model parameters for the biases and weights. | ||
30 | + | ||
31 | +Here's an example, which assumes you have a folder containing class-named | ||
32 | +subfolders, each full of images for each label. The example folder flower_photos | ||
33 | +should have a structure like this: | ||
34 | + | ||
35 | +~/flower_photos/daisy/photo1.jpg | ||
36 | +~/flower_photos/daisy/photo2.jpg | ||
37 | +... | ||
38 | +~/flower_photos/rose/anotherphoto77.jpg | ||
39 | +... | ||
40 | +~/flower_photos/sunflower/somepicture.jpg | ||
41 | + | ||
42 | +The subfolder names are important, since they define what label is applied to | ||
43 | +each image, but the filenames themselves don't matter. (For a working example, | ||
44 | +download http://download.tensorflow.org/example_images/flower_photos.tgz | ||
45 | +and run tar xzf flower_photos.tgz to unpack it.) | ||
46 | + | ||
47 | +Once your images are prepared, and you have pip-installed tensorflow-hub and | ||
48 | +a sufficiently recent version of tensorflow, you can run the training with a | ||
49 | +command like this: | ||
50 | + | ||
51 | +```bash | ||
52 | +python retrain.py --image_dir ~/flower_photos | ||
53 | +``` | ||
54 | + | ||
55 | +You can replace the image_dir argument with any folder containing subfolders of | ||
56 | +images. The label for each image is taken from the name of the subfolder it's | ||
57 | +in. | ||
58 | + | ||
59 | +This produces a new model file that can be loaded and run by any TensorFlow | ||
60 | +program, for example the tensorflow/examples/label_image sample code. | ||
61 | + | ||
62 | +By default this script will use the highly accurate, but comparatively large and | ||
63 | +slow Inception V3 model architecture. It's recommended that you start with this | ||
64 | +to validate that you have gathered good training data, but if you want to deploy | ||
65 | +on resource-limited platforms, you can try the `--tfhub_module` flag with a | ||
66 | +Mobilenet model. For more information on Mobilenet, see | ||
67 | +https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html | ||
68 | + | ||
69 | +For example: | ||
70 | + | ||
71 | +Run floating-point version of Mobilenet: | ||
72 | + | ||
73 | +```bash | ||
74 | +python retrain.py --image_dir ~/flower_photos \ | ||
75 | + --tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/feature_vector/3 | ||
76 | +``` | ||
77 | + | ||
78 | +Run Mobilenet, instrumented for quantization: | ||
79 | + | ||
80 | +```bash | ||
81 | +python retrain.py --image_dir ~/flower_photos/ \ | ||
82 | + --tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/quantops/feature_vector/3 | ||
83 | +``` | ||
84 | + | ||
85 | +These instrumented models can be converted to fully quantized mobile models via | ||
86 | +TensorFlow Lite. | ||
87 | + | ||
88 | +There are different Mobilenet models to choose from, with a variety of file | ||
89 | +size and latency options. | ||
90 | + - The first number can be '100', '075', '050', or '025' to control the number | ||
91 | + of neurons (activations of hidden layers); the number of weights (and hence | ||
92 | + to some extent the file size and speed) shrinks with the square of that | ||
93 | + fraction. | ||
94 | + - The second number is the input image size. You can choose '224', '192', | ||
95 | + '160', or '128', with smaller sizes giving faster speeds. | ||
96 | + | ||
97 | +To use with TensorBoard: | ||
98 | + | ||
99 | +By default, this script will log summaries to /tmp/retrain_logs directory | ||
100 | + | ||
101 | +Visualize the summaries with this command: | ||
102 | + | ||
103 | +tensorboard --logdir /tmp/retrain_logs | ||
104 | + | ||
105 | +To use with Tensorflow Serving, run this tool with --saved_model_dir set | ||
106 | +to some increasingly numbered export location under the model base path, e.g.: | ||
107 | + | ||
108 | +```bash | ||
109 | +python retrain.py (... other args as before ...) \ | ||
110 | + --saved_model_dir=/tmp/saved_models/$(date +%s)/ | ||
111 | +tensorflow_model_server --port=9000 --model_name=my_image_classifier \ | ||
112 | + --model_base_path=/tmp/saved_models/ | ||
113 | +``` | ||
114 | +""" | ||
115 | +# pylint: enable=line-too-long | ||
116 | + | ||
117 | +from __future__ import absolute_import | ||
118 | +from __future__ import division | ||
119 | +from __future__ import print_function | ||
120 | + | ||
121 | +from absl import logging | ||
122 | + | ||
123 | +import argparse | ||
124 | +import collections | ||
125 | +from datetime import datetime | ||
126 | +import hashlib | ||
127 | +import os.path | ||
128 | +import random | ||
129 | +import re | ||
130 | +import sys | ||
131 | + | ||
132 | +import numpy as np | ||
133 | +import tensorflow as tf | ||
134 | +import tensorflow_hub as hub | ||
135 | +from tensorflow.contrib import quantize as contrib_quantize | ||
136 | + | ||
137 | +FLAGS = None | ||
138 | + | ||
139 | +MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M | ||
140 | + | ||
141 | +# A module is understood as instrumented for quantization with TF-Lite | ||
142 | +# if it contains any of these ops. | ||
143 | +FAKE_QUANT_OPS = ('FakeQuantWithMinMaxVars', | ||
144 | + 'FakeQuantWithMinMaxVarsPerChannel') | ||
145 | + | ||
146 | + | ||
147 | +def create_image_lists(image_dir, testing_percentage, validation_percentage): | ||
148 | + """Builds a list of training images from the file system. | ||
149 | + | ||
150 | + Analyzes the sub folders in the image directory, splits them into stable | ||
151 | + training, testing, and validation sets, and returns a data structure | ||
152 | + describing the lists of images for each label and their paths. | ||
153 | + | ||
154 | + Args: | ||
155 | + image_dir: String path to a folder containing subfolders of images. | ||
156 | + testing_percentage: Integer percentage of the images to reserve for tests. | ||
157 | + validation_percentage: Integer percentage of images reserved for validation. | ||
158 | + | ||
159 | + Returns: | ||
160 | + An OrderedDict containing an entry for each label subfolder, with images | ||
161 | + split into training, testing, and validation sets within each label. | ||
162 | + The order of items defines the class indices. | ||
163 | + """ | ||
164 | + if not tf.gfile.Exists(image_dir): | ||
165 | + logging.error("Image directory '" + image_dir + "' not found.") | ||
166 | + return None | ||
167 | + result = collections.OrderedDict() | ||
168 | + sub_dirs = sorted(x[0] for x in tf.gfile.Walk(image_dir)) | ||
169 | + # The root directory comes first, so skip it. | ||
170 | + is_root_dir = True | ||
171 | + for sub_dir in sub_dirs: | ||
172 | + if is_root_dir: | ||
173 | + is_root_dir = False | ||
174 | + continue | ||
175 | + extensions = sorted(set(os.path.normcase(ext) # Smash case on Windows. | ||
176 | + for ext in ['JPEG', 'JPG', 'jpeg', 'jpg', 'png'])) | ||
177 | + file_list = [] | ||
178 | + dir_name = os.path.basename( | ||
179 | + # tf.gfile.Walk() returns sub-directory with trailing '/' when it is in | ||
180 | + # Google Cloud Storage, which confuses os.path.basename(). | ||
181 | + sub_dir[:-1] if sub_dir.endswith('/') else sub_dir) | ||
182 | + | ||
183 | + if dir_name == image_dir: | ||
184 | + continue | ||
185 | + logging.info("Looking for images in '%s'", dir_name) | ||
186 | + for extension in extensions: | ||
187 | + file_glob = os.path.join(image_dir, dir_name, '*.' + extension) | ||
188 | + file_list.extend(tf.gfile.Glob(file_glob)) | ||
189 | + if not file_list: | ||
190 | + logging.warning('No files found') | ||
191 | + continue | ||
192 | + if len(file_list) < 20: | ||
193 | + logging.warning( | ||
194 | + 'WARNING: Folder has less than 20 images, which may cause issues.') | ||
195 | + elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS: | ||
196 | + logging.warning( | ||
197 | + 'WARNING: Folder %s has more than %s images. Some images will ' | ||
198 | + 'never be selected.', dir_name, MAX_NUM_IMAGES_PER_CLASS) | ||
199 | + label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower()) | ||
200 | + training_images = [] | ||
201 | + testing_images = [] | ||
202 | + validation_images = [] | ||
203 | + for file_name in file_list: | ||
204 | + base_name = os.path.basename(file_name) | ||
205 | + # We want to ignore anything after '_nohash_' in the file name when | ||
206 | + # deciding which set to put an image in, the data set creator has a way of | ||
207 | + # grouping photos that are close variations of each other. For example | ||
208 | + # this is used in the plant disease data set to group multiple pictures of | ||
209 | + # the same leaf. | ||
210 | + hash_name = re.sub(r'_nohash_.*$', '', file_name) | ||
211 | + # This looks a bit magical, but we need to decide whether this file should | ||
212 | + # go into the training, testing, or validation sets, and we want to keep | ||
213 | + # existing files in the same set even if more files are subsequently | ||
214 | + # added. | ||
215 | + # To do that, we need a stable way of deciding based on just the file name | ||
216 | + # itself, so we do a hash of that and then use that to generate a | ||
217 | + # probability value that we use to assign it. | ||
218 | + hash_name_hashed = hashlib.sha1(tf.compat.as_bytes(hash_name)).hexdigest() | ||
219 | + percentage_hash = ((int(hash_name_hashed, 16) % | ||
220 | + (MAX_NUM_IMAGES_PER_CLASS + 1)) * | ||
221 | + (100.0 / MAX_NUM_IMAGES_PER_CLASS)) | ||
222 | + if percentage_hash < validation_percentage: | ||
223 | + validation_images.append(base_name) | ||
224 | + elif percentage_hash < (testing_percentage + validation_percentage): | ||
225 | + testing_images.append(base_name) | ||
226 | + else: | ||
227 | + training_images.append(base_name) | ||
228 | + result[label_name] = { | ||
229 | + 'dir': dir_name, | ||
230 | + 'training': training_images, | ||
231 | + 'testing': testing_images, | ||
232 | + 'validation': validation_images, | ||
233 | + } | ||
234 | + return result | ||
235 | + | ||
236 | + | ||
237 | +def get_image_path(image_lists, label_name, index, image_dir, category): | ||
238 | + """Returns a path to an image for a label at the given index. | ||
239 | + | ||
240 | + Args: | ||
241 | + image_lists: OrderedDict of training images for each label. | ||
242 | + label_name: Label string we want to get an image for. | ||
243 | + index: Int offset of the image we want. This will be moduloed by the | ||
244 | + available number of images for the label, so it can be arbitrarily large. | ||
245 | + image_dir: Root folder string of the subfolders containing the training | ||
246 | + images. | ||
247 | + category: Name string of set to pull images from - training, testing, or | ||
248 | + validation. | ||
249 | + | ||
250 | + Returns: | ||
251 | + File system path string to an image that meets the requested parameters. | ||
252 | + | ||
253 | + """ | ||
254 | + if label_name not in image_lists: | ||
255 | + logging.fatal('Label does not exist %s.', label_name) | ||
256 | + label_lists = image_lists[label_name] | ||
257 | + if category not in label_lists: | ||
258 | + logging.fatal('Category does not exist %s.', category) | ||
259 | + category_list = label_lists[category] | ||
260 | + if not category_list: | ||
261 | + logging.fatal('Label %s has no images in the category %s.', | ||
262 | + label_name, category) | ||
263 | + mod_index = index % len(category_list) | ||
264 | + base_name = category_list[mod_index] | ||
265 | + sub_dir = label_lists['dir'] | ||
266 | + full_path = os.path.join(image_dir, sub_dir, base_name) | ||
267 | + return full_path | ||
268 | + | ||
269 | + | ||
270 | +def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, | ||
271 | + category, module_name): | ||
272 | + """Returns a path to a bottleneck file for a label at the given index. | ||
273 | + | ||
274 | + Args: | ||
275 | + image_lists: OrderedDict of training images for each label. | ||
276 | + label_name: Label string we want to get an image for. | ||
277 | + index: Integer offset of the image we want. This will be moduloed by the | ||
278 | + available number of images for the label, so it can be arbitrarily large. | ||
279 | + bottleneck_dir: Folder string holding cached files of bottleneck values. | ||
280 | + category: Name string of set to pull images from - training, testing, or | ||
281 | + validation. | ||
282 | + module_name: The name of the image module being used. | ||
283 | + | ||
284 | + Returns: | ||
285 | + File system path string to an image that meets the requested parameters. | ||
286 | + """ | ||
287 | + module_name = (module_name.replace('://', '~') # URL scheme. | ||
288 | + .replace('/', '~') # URL and Unix paths. | ||
289 | + .replace(':', '~').replace('\\', '~')) # Windows paths. | ||
290 | + return get_image_path(image_lists, label_name, index, bottleneck_dir, | ||
291 | + category) + '_' + module_name + '.txt' | ||
292 | + | ||
293 | + | ||
294 | +def create_module_graph(module_spec): | ||
295 | + """Creates a graph and loads Hub Module into it. | ||
296 | + | ||
297 | + Args: | ||
298 | + module_spec: the hub.ModuleSpec for the image module being used. | ||
299 | + | ||
300 | + Returns: | ||
301 | + graph: the tf.Graph that was created. | ||
302 | + bottleneck_tensor: the bottleneck values output by the module. | ||
303 | + resized_input_tensor: the input images, resized as expected by the module. | ||
304 | + wants_quantization: a boolean, whether the module has been instrumented | ||
305 | + with fake quantization ops. | ||
306 | + """ | ||
307 | + height, width = hub.get_expected_image_size(module_spec) | ||
308 | + with tf.Graph().as_default() as graph: | ||
309 | + resized_input_tensor = tf.placeholder(tf.float32, [None, height, width, 3]) | ||
310 | + m = hub.Module(module_spec) | ||
311 | + bottleneck_tensor = m(resized_input_tensor) | ||
312 | + wants_quantization = any(node.op in FAKE_QUANT_OPS | ||
313 | + for node in graph.as_graph_def().node) | ||
314 | + return graph, bottleneck_tensor, resized_input_tensor, wants_quantization | ||
315 | + | ||
316 | + | ||
317 | +def run_bottleneck_on_image(sess, image_data, image_data_tensor, | ||
318 | + decoded_image_tensor, resized_input_tensor, | ||
319 | + bottleneck_tensor): | ||
320 | + """Runs inference on an image to extract the 'bottleneck' summary layer. | ||
321 | + | ||
322 | + Args: | ||
323 | + sess: Current active TensorFlow Session. | ||
324 | + image_data: String of raw JPEG data. | ||
325 | + image_data_tensor: Input data layer in the graph. | ||
326 | + decoded_image_tensor: Output of initial image resizing and preprocessing. | ||
327 | + resized_input_tensor: The input node of the recognition graph. | ||
328 | + bottleneck_tensor: Layer before the final softmax. | ||
329 | + | ||
330 | + Returns: | ||
331 | + Numpy array of bottleneck values. | ||
332 | + """ | ||
333 | + # First decode the JPEG image, resize it, and rescale the pixel values. | ||
334 | + resized_input_values = sess.run(decoded_image_tensor, | ||
335 | + {image_data_tensor: image_data}) | ||
336 | + # Then run it through the recognition network. | ||
337 | + bottleneck_values = sess.run(bottleneck_tensor, | ||
338 | + {resized_input_tensor: resized_input_values}) | ||
339 | + bottleneck_values = np.squeeze(bottleneck_values) | ||
340 | + return bottleneck_values | ||
341 | + | ||
342 | + | ||
343 | +def ensure_dir_exists(dir_name): | ||
344 | + """Makes sure the folder exists on disk. | ||
345 | + | ||
346 | + Args: | ||
347 | + dir_name: Path string to the folder we want to create. | ||
348 | + """ | ||
349 | + if not os.path.exists(dir_name): | ||
350 | + os.makedirs(dir_name) | ||
351 | + | ||
352 | + | ||
353 | +def create_bottleneck_file(bottleneck_path, image_lists, label_name, index, | ||
354 | + image_dir, category, sess, jpeg_data_tensor, | ||
355 | + decoded_image_tensor, resized_input_tensor, | ||
356 | + bottleneck_tensor): | ||
357 | + """Create a single bottleneck file.""" | ||
358 | + logging.debug('Creating bottleneck at %s', bottleneck_path) | ||
359 | + image_path = get_image_path(image_lists, label_name, index, | ||
360 | + image_dir, category) | ||
361 | + if not tf.gfile.Exists(image_path): | ||
362 | + logging.fatal('File does not exist %s', image_path) | ||
363 | + image_data = tf.gfile.GFile(image_path, 'rb').read() | ||
364 | + try: | ||
365 | + bottleneck_values = run_bottleneck_on_image( | ||
366 | + sess, image_data, jpeg_data_tensor, decoded_image_tensor, | ||
367 | + resized_input_tensor, bottleneck_tensor) | ||
368 | + except Exception as e: | ||
369 | + raise RuntimeError('Error during processing file %s (%s)' % (image_path, | ||
370 | + str(e))) | ||
371 | + bottleneck_string = ','.join(str(x) for x in bottleneck_values) | ||
372 | + with tf.gfile.GFile(bottleneck_path, 'w') as bottleneck_file: | ||
373 | + bottleneck_file.write(bottleneck_string) | ||
374 | + | ||
375 | + | ||
376 | +def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, | ||
377 | + category, bottleneck_dir, jpeg_data_tensor, | ||
378 | + decoded_image_tensor, resized_input_tensor, | ||
379 | + bottleneck_tensor, module_name): | ||
380 | + """Retrieves or calculates bottleneck values for an image. | ||
381 | + | ||
382 | + If a cached version of the bottleneck data exists on-disk, return that, | ||
383 | + otherwise calculate the data and save it to disk for future use. | ||
384 | + | ||
385 | + Args: | ||
386 | + sess: The current active TensorFlow Session. | ||
387 | + image_lists: OrderedDict of training images for each label. | ||
388 | + label_name: Label string we want to get an image for. | ||
389 | + index: Integer offset of the image we want. This will be modulo-ed by the | ||
390 | + available number of images for the label, so it can be arbitrarily large. | ||
391 | + image_dir: Root folder string of the subfolders containing the training | ||
392 | + images. | ||
393 | + category: Name string of which set to pull images from - training, testing, | ||
394 | + or validation. | ||
395 | + bottleneck_dir: Folder string holding cached files of bottleneck values. | ||
396 | + jpeg_data_tensor: The tensor to feed loaded jpeg data into. | ||
397 | + decoded_image_tensor: The output of decoding and resizing the image. | ||
398 | + resized_input_tensor: The input node of the recognition graph. | ||
399 | + bottleneck_tensor: The output tensor for the bottleneck values. | ||
400 | + module_name: The name of the image module being used. | ||
401 | + | ||
402 | + Returns: | ||
403 | + Numpy array of values produced by the bottleneck layer for the image. | ||
404 | + """ | ||
405 | + label_lists = image_lists[label_name] | ||
406 | + sub_dir = label_lists['dir'] | ||
407 | + sub_dir_path = os.path.join(bottleneck_dir, sub_dir) | ||
408 | + ensure_dir_exists(sub_dir_path) | ||
409 | + bottleneck_path = get_bottleneck_path(image_lists, label_name, index, | ||
410 | + bottleneck_dir, category, module_name) | ||
411 | + if not os.path.exists(bottleneck_path): | ||
412 | + create_bottleneck_file(bottleneck_path, image_lists, label_name, index, | ||
413 | + image_dir, category, sess, jpeg_data_tensor, | ||
414 | + decoded_image_tensor, resized_input_tensor, | ||
415 | + bottleneck_tensor) | ||
416 | + with tf.gfile.GFile(bottleneck_path, 'r') as bottleneck_file: | ||
417 | + bottleneck_string = bottleneck_file.read() | ||
418 | + did_hit_error = False | ||
419 | + try: | ||
420 | + bottleneck_values = [float(x) for x in bottleneck_string.split(',')] | ||
421 | + except ValueError: | ||
422 | + logging.warning('Invalid float found, recreating bottleneck') | ||
423 | + did_hit_error = True | ||
424 | + if did_hit_error: | ||
425 | + create_bottleneck_file(bottleneck_path, image_lists, label_name, index, | ||
426 | + image_dir, category, sess, jpeg_data_tensor, | ||
427 | + decoded_image_tensor, resized_input_tensor, | ||
428 | + bottleneck_tensor) | ||
429 | + with tf.gfile.GFile(bottleneck_path, 'r') as bottleneck_file: | ||
430 | + bottleneck_string = bottleneck_file.read() | ||
431 | + # Allow exceptions to propagate here, since they shouldn't happen after a | ||
432 | + # fresh creation | ||
433 | + bottleneck_values = [float(x) for x in bottleneck_string.split(',')] | ||
434 | + return bottleneck_values | ||
435 | + | ||
436 | + | ||
437 | +def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir, | ||
438 | + jpeg_data_tensor, decoded_image_tensor, | ||
439 | + resized_input_tensor, bottleneck_tensor, module_name): | ||
440 | + """Ensures all the training, testing, and validation bottlenecks are cached. | ||
441 | + | ||
442 | + Because we're likely to read the same image multiple times (if there are no | ||
443 | + distortions applied during training) it can speed things up a lot if we | ||
444 | + calculate the bottleneck layer values once for each image during | ||
445 | + preprocessing, and then just read those cached values repeatedly during | ||
446 | + training. Here we go through all the images we've found, calculate those | ||
447 | + values, and save them off. | ||
448 | + | ||
449 | + Args: | ||
450 | + sess: The current active TensorFlow Session. | ||
451 | + image_lists: OrderedDict of training images for each label. | ||
452 | + image_dir: Root folder string of the subfolders containing the training | ||
453 | + images. | ||
454 | + bottleneck_dir: Folder string holding cached files of bottleneck values. | ||
455 | + jpeg_data_tensor: Input tensor for jpeg data from file. | ||
456 | + decoded_image_tensor: The output of decoding and resizing the image. | ||
457 | + resized_input_tensor: The input node of the recognition graph. | ||
458 | + bottleneck_tensor: The penultimate output layer of the graph. | ||
459 | + module_name: The name of the image module being used. | ||
460 | + | ||
461 | + Returns: | ||
462 | + Nothing. | ||
463 | + """ | ||
464 | + how_many_bottlenecks = 0 | ||
465 | + ensure_dir_exists(bottleneck_dir) | ||
466 | + for label_name, label_lists in image_lists.items(): | ||
467 | + for category in ['training', 'testing', 'validation']: | ||
468 | + category_list = label_lists[category] | ||
469 | + for index, unused_base_name in enumerate(category_list): | ||
470 | + get_or_create_bottleneck( | ||
471 | + sess, image_lists, label_name, index, image_dir, category, | ||
472 | + bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, | ||
473 | + resized_input_tensor, bottleneck_tensor, module_name) | ||
474 | + | ||
475 | + how_many_bottlenecks += 1 | ||
476 | + if how_many_bottlenecks % 100 == 0: | ||
477 | + logging.info('%s bottleneck files created.', how_many_bottlenecks) | ||
478 | + | ||
479 | + | ||
480 | +def get_random_cached_bottlenecks(sess, image_lists, how_many, category, | ||
481 | + bottleneck_dir, image_dir, jpeg_data_tensor, | ||
482 | + decoded_image_tensor, resized_input_tensor, | ||
483 | + bottleneck_tensor, module_name): | ||
484 | + """Retrieves bottleneck values for cached images. | ||
485 | + | ||
486 | + If no distortions are being applied, this function can retrieve the cached | ||
487 | + bottleneck values directly from disk for images. It picks a random set of | ||
488 | + images from the specified category. | ||
489 | + | ||
490 | + Args: | ||
491 | + sess: Current TensorFlow Session. | ||
492 | + image_lists: OrderedDict of training images for each label. | ||
493 | + how_many: If positive, a random sample of this size will be chosen. | ||
494 | + If negative, all bottlenecks will be retrieved. | ||
495 | + category: Name string of which set to pull from - training, testing, or | ||
496 | + validation. | ||
497 | + bottleneck_dir: Folder string holding cached files of bottleneck values. | ||
498 | + image_dir: Root folder string of the subfolders containing the training | ||
499 | + images. | ||
500 | + jpeg_data_tensor: The layer to feed jpeg image data into. | ||
501 | + decoded_image_tensor: The output of decoding and resizing the image. | ||
502 | + resized_input_tensor: The input node of the recognition graph. | ||
503 | + bottleneck_tensor: The bottleneck output layer of the CNN graph. | ||
504 | + module_name: The name of the image module being used. | ||
505 | + | ||
506 | + Returns: | ||
507 | + List of bottleneck arrays, their corresponding ground truths, and the | ||
508 | + relevant filenames. | ||
509 | + """ | ||
510 | + class_count = len(image_lists.keys()) | ||
511 | + bottlenecks = [] | ||
512 | + ground_truths = [] | ||
513 | + filenames = [] | ||
514 | + if how_many >= 0: | ||
515 | + # Retrieve a random sample of bottlenecks. | ||
516 | + for unused_i in range(how_many): | ||
517 | + label_index = random.randrange(class_count) | ||
518 | + label_name = list(image_lists.keys())[label_index] | ||
519 | + image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1) | ||
520 | + image_name = get_image_path(image_lists, label_name, image_index, | ||
521 | + image_dir, category) | ||
522 | + bottleneck = get_or_create_bottleneck( | ||
523 | + sess, image_lists, label_name, image_index, image_dir, category, | ||
524 | + bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, | ||
525 | + resized_input_tensor, bottleneck_tensor, module_name) | ||
526 | + bottlenecks.append(bottleneck) | ||
527 | + ground_truths.append(label_index) | ||
528 | + filenames.append(image_name) | ||
529 | + else: | ||
530 | + # Retrieve all bottlenecks. | ||
531 | + for label_index, label_name in enumerate(image_lists.keys()): | ||
532 | + for image_index, image_name in enumerate( | ||
533 | + image_lists[label_name][category]): | ||
534 | + image_name = get_image_path(image_lists, label_name, image_index, | ||
535 | + image_dir, category) | ||
536 | + bottleneck = get_or_create_bottleneck( | ||
537 | + sess, image_lists, label_name, image_index, image_dir, category, | ||
538 | + bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, | ||
539 | + resized_input_tensor, bottleneck_tensor, module_name) | ||
540 | + bottlenecks.append(bottleneck) | ||
541 | + ground_truths.append(label_index) | ||
542 | + filenames.append(image_name) | ||
543 | + return bottlenecks, ground_truths, filenames | ||
544 | + | ||
545 | + | ||
546 | +def get_random_distorted_bottlenecks( | ||
547 | + sess, image_lists, how_many, category, image_dir, input_jpeg_tensor, | ||
548 | + distorted_image, resized_input_tensor, bottleneck_tensor): | ||
549 | + """Retrieves bottleneck values for training images, after distortions. | ||
550 | + | ||
551 | + If we're training with distortions like crops, scales, or flips, we have to | ||
552 | + recalculate the full model for every image, and so we can't use cached | ||
553 | + bottleneck values. Instead we find random images for the requested category, | ||
554 | + run them through the distortion graph, and then the full graph to get the | ||
555 | + bottleneck results for each. | ||
556 | + | ||
557 | + Args: | ||
558 | + sess: Current TensorFlow Session. | ||
559 | + image_lists: OrderedDict of training images for each label. | ||
560 | + how_many: The integer number of bottleneck values to return. | ||
561 | + category: Name string of which set of images to fetch - training, testing, | ||
562 | + or validation. | ||
563 | + image_dir: Root folder string of the subfolders containing the training | ||
564 | + images. | ||
565 | + input_jpeg_tensor: The input layer we feed the image data to. | ||
566 | + distorted_image: The output node of the distortion graph. | ||
567 | + resized_input_tensor: The input node of the recognition graph. | ||
568 | + bottleneck_tensor: The bottleneck output layer of the CNN graph. | ||
569 | + | ||
570 | + Returns: | ||
571 | + List of bottleneck arrays and their corresponding ground truths. | ||
572 | + """ | ||
573 | + class_count = len(image_lists.keys()) | ||
574 | + bottlenecks = [] | ||
575 | + ground_truths = [] | ||
576 | + for unused_i in range(how_many): | ||
577 | + label_index = random.randrange(class_count) | ||
578 | + label_name = list(image_lists.keys())[label_index] | ||
579 | + image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1) | ||
580 | + image_path = get_image_path(image_lists, label_name, image_index, image_dir, | ||
581 | + category) | ||
582 | + if not tf.gfile.Exists(image_path): | ||
583 | + logging.fatal('File does not exist %s', image_path) | ||
584 | + jpeg_data = tf.gfile.GFile(image_path, 'rb').read() | ||
585 | + # Note that we materialize the distorted_image_data as a numpy array before | ||
586 | + # sending running inference on the image. This involves 2 memory copies and | ||
587 | + # might be optimized in other implementations. | ||
588 | + distorted_image_data = sess.run(distorted_image, | ||
589 | + {input_jpeg_tensor: jpeg_data}) | ||
590 | + bottleneck_values = sess.run(bottleneck_tensor, | ||
591 | + {resized_input_tensor: distorted_image_data}) | ||
592 | + bottleneck_values = np.squeeze(bottleneck_values) | ||
593 | + bottlenecks.append(bottleneck_values) | ||
594 | + ground_truths.append(label_index) | ||
595 | + return bottlenecks, ground_truths | ||
596 | + | ||
597 | + | ||
598 | +def should_distort_images(flip_left_right, random_crop, random_scale, | ||
599 | + random_brightness): | ||
600 | + """Whether any distortions are enabled, from the input flags. | ||
601 | + | ||
602 | + Args: | ||
603 | + flip_left_right: Boolean whether to randomly mirror images horizontally. | ||
604 | + random_crop: Integer percentage setting the total margin used around the | ||
605 | + crop box. | ||
606 | + random_scale: Integer percentage of how much to vary the scale by. | ||
607 | + random_brightness: Integer range to randomly multiply the pixel values by. | ||
608 | + | ||
609 | + Returns: | ||
610 | + Boolean value indicating whether any distortions should be applied. | ||
611 | + """ | ||
612 | + return (flip_left_right or (random_crop != 0) or (random_scale != 0) or | ||
613 | + (random_brightness != 0)) | ||
614 | + | ||
615 | + | ||
616 | +def add_input_distortions(flip_left_right, random_crop, random_scale, | ||
617 | + random_brightness, module_spec): | ||
618 | + """Creates the operations to apply the specified distortions. | ||
619 | + | ||
620 | + During training it can help to improve the results if we run the images | ||
621 | + through simple distortions like crops, scales, and flips. These reflect the | ||
622 | + kind of variations we expect in the real world, and so can help train the | ||
623 | + model to cope with natural data more effectively. Here we take the supplied | ||
624 | + parameters and construct a network of operations to apply them to an image. | ||
625 | + | ||
626 | + Cropping | ||
627 | + ~~~~~~~~ | ||
628 | + | ||
629 | + Cropping is done by placing a bounding box at a random position in the full | ||
630 | + image. The cropping parameter controls the size of that box relative to the | ||
631 | + input image. If it's zero, then the box is the same size as the input and no | ||
632 | + cropping is performed. If the value is 50%, then the crop box will be half the | ||
633 | + width and height of the input. In a diagram it looks like this: | ||
634 | + | ||
635 | + < width > | ||
636 | + +---------------------+ | ||
637 | + | | | ||
638 | + | width - crop% | | ||
639 | + | < > | | ||
640 | + | +------+ | | ||
641 | + | | | | | ||
642 | + | | | | | ||
643 | + | | | | | ||
644 | + | +------+ | | ||
645 | + | | | ||
646 | + | | | ||
647 | + +---------------------+ | ||
648 | + | ||
649 | + Scaling | ||
650 | + ~~~~~~~ | ||
651 | + | ||
652 | + Scaling is a lot like cropping, except that the bounding box is always | ||
653 | + centered and its size varies randomly within the given range. For example if | ||
654 | + the scale percentage is zero, then the bounding box is the same size as the | ||
655 | + input and no scaling is applied. If it's 50%, then the bounding box will be in | ||
656 | + a random range between half the width and height and full size. | ||
657 | + | ||
658 | + Args: | ||
659 | + flip_left_right: Boolean whether to randomly mirror images horizontally. | ||
660 | + random_crop: Integer percentage setting the total margin used around the | ||
661 | + crop box. | ||
662 | + random_scale: Integer percentage of how much to vary the scale by. | ||
663 | + random_brightness: Integer range to randomly multiply the pixel values by. | ||
664 | + graph. | ||
665 | + module_spec: The hub.ModuleSpec for the image module being used. | ||
666 | + | ||
667 | + Returns: | ||
668 | + The jpeg input layer and the distorted result tensor. | ||
669 | + """ | ||
670 | + input_height, input_width = hub.get_expected_image_size(module_spec) | ||
671 | + input_depth = hub.get_num_image_channels(module_spec) | ||
672 | + jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput') | ||
673 | + decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth) | ||
674 | + # Convert from full range of uint8 to range [0,1] of float32. | ||
675 | + decoded_image_as_float = tf.image.convert_image_dtype(decoded_image, | ||
676 | + tf.float32) | ||
677 | + decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0) | ||
678 | + margin_scale = 1.0 + (random_crop / 100.0) | ||
679 | + resize_scale = 1.0 + (random_scale / 100.0) | ||
680 | + margin_scale_value = tf.constant(margin_scale) | ||
681 | + resize_scale_value = tf.random_uniform(shape=[], | ||
682 | + minval=1.0, | ||
683 | + maxval=resize_scale) | ||
684 | + scale_value = tf.multiply(margin_scale_value, resize_scale_value) | ||
685 | + precrop_width = tf.multiply(scale_value, input_width) | ||
686 | + precrop_height = tf.multiply(scale_value, input_height) | ||
687 | + precrop_shape = tf.stack([precrop_height, precrop_width]) | ||
688 | + precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32) | ||
689 | + precropped_image = tf.image.resize_bilinear(decoded_image_4d, | ||
690 | + precrop_shape_as_int) | ||
691 | + precropped_image_3d = tf.squeeze(precropped_image, axis=[0]) | ||
692 | + cropped_image = tf.random_crop(precropped_image_3d, | ||
693 | + [input_height, input_width, input_depth]) | ||
694 | + if flip_left_right: | ||
695 | + flipped_image = tf.image.random_flip_left_right(cropped_image) | ||
696 | + else: | ||
697 | + flipped_image = cropped_image | ||
698 | + brightness_min = 1.0 - (random_brightness / 100.0) | ||
699 | + brightness_max = 1.0 + (random_brightness / 100.0) | ||
700 | + brightness_value = tf.random_uniform(shape=[], | ||
701 | + minval=brightness_min, | ||
702 | + maxval=brightness_max) | ||
703 | + brightened_image = tf.multiply(flipped_image, brightness_value) | ||
704 | + distort_result = tf.expand_dims(brightened_image, 0, name='DistortResult') | ||
705 | + return jpeg_data, distort_result | ||
706 | + | ||
707 | + | ||
708 | +def variable_summaries(var): | ||
709 | + """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" | ||
710 | + with tf.name_scope('summaries'): | ||
711 | + mean = tf.reduce_mean(var) | ||
712 | + tf.summary.scalar('mean', mean) | ||
713 | + with tf.name_scope('stddev'): | ||
714 | + stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) | ||
715 | + tf.summary.scalar('stddev', stddev) | ||
716 | + tf.summary.scalar('max', tf.reduce_max(var)) | ||
717 | + tf.summary.scalar('min', tf.reduce_min(var)) | ||
718 | + tf.summary.histogram('histogram', var) | ||
719 | + | ||
720 | + | ||
721 | +def add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor, | ||
722 | + quantize_layer, is_training): | ||
723 | + """Adds a new softmax and fully-connected layer for training and eval. | ||
724 | + | ||
725 | + We need to retrain the top layer to identify our new classes, so this function | ||
726 | + adds the right operations to the graph, along with some variables to hold the | ||
727 | + weights, and then sets up all the gradients for the backward pass. | ||
728 | + | ||
729 | + The set up for the softmax and fully-connected layers is based on: | ||
730 | + https://www.tensorflow.org/tutorials/mnist/beginners/index.html | ||
731 | + | ||
732 | + Args: | ||
733 | + class_count: Integer of how many categories of things we're trying to | ||
734 | + recognize. | ||
735 | + final_tensor_name: Name string for the new final node that produces results. | ||
736 | + bottleneck_tensor: The output of the main CNN graph. | ||
737 | + quantize_layer: Boolean, specifying whether the newly added layer should be | ||
738 | + instrumented for quantization with TF-Lite. | ||
739 | + is_training: Boolean, specifying whether the newly add layer is for training | ||
740 | + or eval. | ||
741 | + | ||
742 | + Returns: | ||
743 | + The tensors for the training and cross entropy results, and tensors for the | ||
744 | + bottleneck input and ground truth input. | ||
745 | + """ | ||
746 | + batch_size, bottleneck_tensor_size = bottleneck_tensor.get_shape().as_list() | ||
747 | + assert batch_size is None, 'We want to work with arbitrary batch size.' | ||
748 | + with tf.name_scope('input'): | ||
749 | + bottleneck_input = tf.placeholder_with_default( | ||
750 | + bottleneck_tensor, | ||
751 | + shape=[batch_size, bottleneck_tensor_size], | ||
752 | + name='BottleneckInputPlaceholder') | ||
753 | + | ||
754 | + ground_truth_input = tf.placeholder( | ||
755 | + tf.int64, [batch_size], name='GroundTruthInput') | ||
756 | + | ||
757 | + # Organizing the following ops so they are easier to see in TensorBoard. | ||
758 | + layer_name = 'final_retrain_ops' | ||
759 | + with tf.name_scope(layer_name): | ||
760 | + with tf.name_scope('weights'): | ||
761 | + initial_value = tf.truncated_normal( | ||
762 | + [bottleneck_tensor_size, class_count], stddev=0.001) | ||
763 | + layer_weights = tf.Variable(initial_value, name='final_weights') | ||
764 | + variable_summaries(layer_weights) | ||
765 | + | ||
766 | + with tf.name_scope('biases'): | ||
767 | + layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') | ||
768 | + variable_summaries(layer_biases) | ||
769 | + | ||
770 | + with tf.name_scope('Wx_plus_b'): | ||
771 | + logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases | ||
772 | + tf.summary.histogram('pre_activations', logits) | ||
773 | + | ||
774 | + final_tensor = tf.nn.softmax(logits, name=final_tensor_name) | ||
775 | + | ||
776 | + # The tf.contrib.quantize functions rewrite the graph in place for | ||
777 | + # quantization. The imported model graph has already been rewritten, so upon | ||
778 | + # calling these rewrites, only the newly added final layer will be | ||
779 | + # transformed. | ||
780 | + if quantize_layer: | ||
781 | + if is_training: | ||
782 | + contrib_quantize.create_training_graph() | ||
783 | + else: | ||
784 | + contrib_quantize.create_eval_graph() | ||
785 | + | ||
786 | + tf.summary.histogram('activations', final_tensor) | ||
787 | + | ||
788 | + # If this is an eval graph, we don't need to add loss ops or an optimizer. | ||
789 | + if not is_training: | ||
790 | + return None, None, bottleneck_input, ground_truth_input, final_tensor | ||
791 | + | ||
792 | + with tf.name_scope('cross_entropy'): | ||
793 | + cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( | ||
794 | + labels=ground_truth_input, logits=logits) | ||
795 | + | ||
796 | + tf.summary.scalar('cross_entropy', cross_entropy_mean) | ||
797 | + | ||
798 | + with tf.name_scope('train'): | ||
799 | + optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate) | ||
800 | + train_step = optimizer.minimize(cross_entropy_mean) | ||
801 | + | ||
802 | + return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input, | ||
803 | + final_tensor) | ||
804 | + | ||
805 | + | ||
806 | +def add_evaluation_step(result_tensor, ground_truth_tensor): | ||
807 | + """Inserts the operations we need to evaluate the accuracy of our results. | ||
808 | + | ||
809 | + Args: | ||
810 | + result_tensor: The new final node that produces results. | ||
811 | + ground_truth_tensor: The node we feed ground truth data | ||
812 | + into. | ||
813 | + | ||
814 | + Returns: | ||
815 | + Tuple of (evaluation step, prediction). | ||
816 | + """ | ||
817 | + with tf.name_scope('accuracy'): | ||
818 | + with tf.name_scope('correct_prediction'): | ||
819 | + prediction = tf.argmax(result_tensor, 1) | ||
820 | + correct_prediction = tf.equal(prediction, ground_truth_tensor) | ||
821 | + with tf.name_scope('accuracy'): | ||
822 | + evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | ||
823 | + tf.summary.scalar('accuracy', evaluation_step) | ||
824 | + return evaluation_step, prediction | ||
825 | + | ||
826 | + | ||
827 | +def run_final_eval(train_session, module_spec, class_count, image_lists, | ||
828 | + jpeg_data_tensor, decoded_image_tensor, | ||
829 | + resized_image_tensor, bottleneck_tensor): | ||
830 | + """Runs a final evaluation on an eval graph using the test data set. | ||
831 | + | ||
832 | + Args: | ||
833 | + train_session: Session for the train graph with the tensors below. | ||
834 | + module_spec: The hub.ModuleSpec for the image module being used. | ||
835 | + class_count: Number of classes | ||
836 | + image_lists: OrderedDict of training images for each label. | ||
837 | + jpeg_data_tensor: The layer to feed jpeg image data into. | ||
838 | + decoded_image_tensor: The output of decoding and resizing the image. | ||
839 | + resized_image_tensor: The input node of the recognition graph. | ||
840 | + bottleneck_tensor: The bottleneck output layer of the CNN graph. | ||
841 | + """ | ||
842 | + test_bottlenecks, test_ground_truth, test_filenames = ( | ||
843 | + get_random_cached_bottlenecks(train_session, image_lists, | ||
844 | + FLAGS.test_batch_size, | ||
845 | + 'testing', FLAGS.bottleneck_dir, | ||
846 | + FLAGS.image_dir, jpeg_data_tensor, | ||
847 | + decoded_image_tensor, resized_image_tensor, | ||
848 | + bottleneck_tensor, FLAGS.tfhub_module)) | ||
849 | + | ||
850 | + (eval_session, _, bottleneck_input, ground_truth_input, evaluation_step, | ||
851 | + prediction) = build_eval_session(module_spec, class_count) | ||
852 | + test_accuracy, predictions = eval_session.run( | ||
853 | + [evaluation_step, prediction], | ||
854 | + feed_dict={ | ||
855 | + bottleneck_input: test_bottlenecks, | ||
856 | + ground_truth_input: test_ground_truth | ||
857 | + }) | ||
858 | + logging.info('Final test accuracy = %.1f%% (N=%d)', | ||
859 | + test_accuracy * 100, len(test_bottlenecks)) | ||
860 | + | ||
861 | + if FLAGS.print_misclassified_test_images: | ||
862 | + logging.info('=== MISCLASSIFIED TEST IMAGES ===') | ||
863 | + for i, test_filename in enumerate(test_filenames): | ||
864 | + if predictions[i] != test_ground_truth[i]: | ||
865 | + logging.info('%70s %s', test_filename, | ||
866 | + list(image_lists.keys())[predictions[i]]) | ||
867 | + | ||
868 | + | ||
869 | +def build_eval_session(module_spec, class_count): | ||
870 | + """Builds an restored eval session without train operations for exporting. | ||
871 | + | ||
872 | + Args: | ||
873 | + module_spec: The hub.ModuleSpec for the image module being used. | ||
874 | + class_count: Number of classes | ||
875 | + | ||
876 | + Returns: | ||
877 | + Eval session containing the restored eval graph. | ||
878 | + The bottleneck input, ground truth, eval step, and prediction tensors. | ||
879 | + """ | ||
880 | + # If quantized, we need to create the correct eval graph for exporting. | ||
881 | + eval_graph, bottleneck_tensor, resized_input_tensor, wants_quantization = ( | ||
882 | + create_module_graph(module_spec)) | ||
883 | + | ||
884 | + eval_sess = tf.Session(graph=eval_graph) | ||
885 | + with eval_graph.as_default(): | ||
886 | + # Add the new layer for exporting. | ||
887 | + (_, _, bottleneck_input, | ||
888 | + ground_truth_input, final_tensor) = add_final_retrain_ops( | ||
889 | + class_count, FLAGS.final_tensor_name, bottleneck_tensor, | ||
890 | + wants_quantization, is_training=False) | ||
891 | + | ||
892 | + # Now we need to restore the values from the training graph to the eval | ||
893 | + # graph. | ||
894 | + tf.train.Saver().restore(eval_sess, FLAGS.checkpoint_path) | ||
895 | + | ||
896 | + evaluation_step, prediction = add_evaluation_step(final_tensor, | ||
897 | + ground_truth_input) | ||
898 | + | ||
899 | + return (eval_sess, resized_input_tensor, bottleneck_input, ground_truth_input, | ||
900 | + evaluation_step, prediction) | ||
901 | + | ||
902 | + | ||
903 | +def save_graph_to_file(graph_file_name, module_spec, class_count): | ||
904 | + """Saves an graph to file, creating a valid quantized one if necessary.""" | ||
905 | + sess, _, _, _, _, _ = build_eval_session(module_spec, class_count) | ||
906 | + graph = sess.graph | ||
907 | + | ||
908 | + output_graph_def = tf.graph_util.convert_variables_to_constants( | ||
909 | + sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) | ||
910 | + | ||
911 | + with tf.gfile.GFile(graph_file_name, 'wb') as f: | ||
912 | + f.write(output_graph_def.SerializeToString()) | ||
913 | + | ||
914 | + | ||
915 | +def prepare_file_system(): | ||
916 | + # Set up the directory we'll write summaries to for TensorBoard | ||
917 | + if tf.gfile.Exists(FLAGS.summaries_dir): | ||
918 | + tf.gfile.DeleteRecursively(FLAGS.summaries_dir) | ||
919 | + tf.gfile.MakeDirs(FLAGS.summaries_dir) | ||
920 | + if FLAGS.intermediate_store_frequency > 0: | ||
921 | + ensure_dir_exists(FLAGS.intermediate_output_graphs_dir) | ||
922 | + return | ||
923 | + | ||
924 | + | ||
925 | +def add_jpeg_decoding(module_spec): | ||
926 | + """Adds operations that perform JPEG decoding and resizing to the graph.. | ||
927 | + | ||
928 | + Args: | ||
929 | + module_spec: The hub.ModuleSpec for the image module being used. | ||
930 | + | ||
931 | + Returns: | ||
932 | + Tensors for the node to feed JPEG data into, and the output of the | ||
933 | + preprocessing steps. | ||
934 | + """ | ||
935 | + input_height, input_width = hub.get_expected_image_size(module_spec) | ||
936 | + input_depth = hub.get_num_image_channels(module_spec) | ||
937 | + jpeg_data = tf.placeholder(tf.string, name='DecodeJPGInput') | ||
938 | + decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth) | ||
939 | + # Convert from full range of uint8 to range [0,1] of float32. | ||
940 | + decoded_image_as_float = tf.image.convert_image_dtype(decoded_image, | ||
941 | + tf.float32) | ||
942 | + decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0) | ||
943 | + resize_shape = tf.stack([input_height, input_width]) | ||
944 | + resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32) | ||
945 | + resized_image = tf.image.resize_bilinear(decoded_image_4d, | ||
946 | + resize_shape_as_int) | ||
947 | + return jpeg_data, resized_image | ||
948 | + | ||
949 | + | ||
950 | +def export_model(module_spec, class_count, saved_model_dir): | ||
951 | + """Exports model for serving. | ||
952 | + | ||
953 | + Args: | ||
954 | + module_spec: The hub.ModuleSpec for the image module being used. | ||
955 | + class_count: The number of classes. | ||
956 | + saved_model_dir: Directory in which to save exported model and variables. | ||
957 | + """ | ||
958 | + # The SavedModel should hold the eval graph. | ||
959 | + sess, in_image, _, _, _, _ = build_eval_session(module_spec, class_count) | ||
960 | + with sess.graph.as_default() as graph: | ||
961 | + tf.saved_model.simple_save( | ||
962 | + sess, | ||
963 | + saved_model_dir, | ||
964 | + inputs={'image': in_image}, | ||
965 | + outputs={'prediction': graph.get_tensor_by_name('final_result:0')}, | ||
966 | + legacy_init_op=tf.group(tf.tables_initializer(), name='legacy_init_op') | ||
967 | + ) | ||
968 | + | ||
969 | + | ||
970 | +def logging_level_verbosity(logging_verbosity): | ||
971 | + """Converts logging_level into TensorFlow logging verbosity value. | ||
972 | + | ||
973 | + Args: | ||
974 | + logging_verbosity: String value representing logging level: 'DEBUG', 'INFO', | ||
975 | + 'WARN', 'ERROR', 'FATAL' | ||
976 | + """ | ||
977 | + name_to_level = { | ||
978 | + 'FATAL': logging.FATAL, | ||
979 | + 'ERROR': logging.ERROR, | ||
980 | + 'WARN': logging.WARN, | ||
981 | + 'INFO': logging.INFO, | ||
982 | + 'DEBUG': logging.DEBUG | ||
983 | + } | ||
984 | + | ||
985 | + try: | ||
986 | + return name_to_level[logging_verbosity] | ||
987 | + except Exception as e: | ||
988 | + raise RuntimeError('Not supported logs verbosity (%s). Use one of %s.' % | ||
989 | + (str(e), list(name_to_level))) | ||
990 | + | ||
991 | + | ||
992 | +def main(_): | ||
993 | + # Needed to make sure the logging output is visible. | ||
994 | + # See https://github.com/tensorflow/tensorflow/issues/3047 | ||
995 | + logging_verbosity = logging_level_verbosity(FLAGS.logging_verbosity) | ||
996 | + logging.set_verbosity(logging_verbosity) | ||
997 | + | ||
998 | + if not FLAGS.image_dir: | ||
999 | + logging.error('Must set flag --image_dir.') | ||
1000 | + return -1 | ||
1001 | + | ||
1002 | + # Prepare necessary directories that can be used during training | ||
1003 | + prepare_file_system() | ||
1004 | + | ||
1005 | + # Look at the folder structure, and create lists of all the images. | ||
1006 | + image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, | ||
1007 | + FLAGS.validation_percentage) | ||
1008 | + class_count = len(image_lists.keys()) | ||
1009 | + if class_count == 0: | ||
1010 | + logging.error('No valid folders of images found at %s', FLAGS.image_dir) | ||
1011 | + return -1 | ||
1012 | + if class_count == 1: | ||
1013 | + logging.error('Only one valid folder of images found at %s ' | ||
1014 | + ' - multiple classes are needed for classification.', | ||
1015 | + FLAGS.image_dir) | ||
1016 | + return -1 | ||
1017 | + | ||
1018 | + # See if the command-line flags mean we're applying any distortions. | ||
1019 | + do_distort_images = should_distort_images( | ||
1020 | + FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale, | ||
1021 | + FLAGS.random_brightness) | ||
1022 | + | ||
1023 | + # Set up the pre-trained graph. | ||
1024 | + module_spec = hub.load_module_spec(FLAGS.tfhub_module) | ||
1025 | + graph, bottleneck_tensor, resized_image_tensor, wants_quantization = ( | ||
1026 | + create_module_graph(module_spec)) | ||
1027 | + | ||
1028 | + # Add the new layer that we'll be training. | ||
1029 | + with graph.as_default(): | ||
1030 | + (train_step, cross_entropy, bottleneck_input, | ||
1031 | + ground_truth_input, final_tensor) = add_final_retrain_ops( | ||
1032 | + class_count, FLAGS.final_tensor_name, bottleneck_tensor, | ||
1033 | + wants_quantization, is_training=True) | ||
1034 | + | ||
1035 | + with tf.Session(graph=graph) as sess: | ||
1036 | + # Initialize all weights: for the module to their pretrained values, | ||
1037 | + # and for the newly added retraining layer to random initial values. | ||
1038 | + init = tf.global_variables_initializer() | ||
1039 | + sess.run(init) | ||
1040 | + | ||
1041 | + # Set up the image decoding sub-graph. | ||
1042 | + jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec) | ||
1043 | + | ||
1044 | + if do_distort_images: | ||
1045 | + # We will be applying distortions, so set up the operations we'll need. | ||
1046 | + (distorted_jpeg_data_tensor, | ||
1047 | + distorted_image_tensor) = add_input_distortions( | ||
1048 | + FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale, | ||
1049 | + FLAGS.random_brightness, module_spec) | ||
1050 | + else: | ||
1051 | + # We'll make sure we've calculated the 'bottleneck' image summaries and | ||
1052 | + # cached them on disk. | ||
1053 | + cache_bottlenecks(sess, image_lists, FLAGS.image_dir, | ||
1054 | + FLAGS.bottleneck_dir, jpeg_data_tensor, | ||
1055 | + decoded_image_tensor, resized_image_tensor, | ||
1056 | + bottleneck_tensor, FLAGS.tfhub_module) | ||
1057 | + | ||
1058 | + # Create the operations we need to evaluate the accuracy of our new layer. | ||
1059 | + evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input) | ||
1060 | + | ||
1061 | + # Merge all the summaries and write them out to the summaries_dir | ||
1062 | + merged = tf.summary.merge_all() | ||
1063 | + train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', | ||
1064 | + sess.graph) | ||
1065 | + | ||
1066 | + validation_writer = tf.summary.FileWriter( | ||
1067 | + FLAGS.summaries_dir + '/validation') | ||
1068 | + | ||
1069 | + # Create a train saver that is used to restore values into an eval graph | ||
1070 | + # when exporting models. | ||
1071 | + train_saver = tf.train.Saver() | ||
1072 | + | ||
1073 | + # Run the training for as many cycles as requested on the command line. | ||
1074 | + for i in range(FLAGS.how_many_training_steps): | ||
1075 | + # Get a batch of input bottleneck values, either calculated fresh every | ||
1076 | + # time with distortions applied, or from the cache stored on disk. | ||
1077 | + if do_distort_images: | ||
1078 | + (train_bottlenecks, | ||
1079 | + train_ground_truth) = get_random_distorted_bottlenecks( | ||
1080 | + sess, image_lists, FLAGS.train_batch_size, 'training', | ||
1081 | + FLAGS.image_dir, distorted_jpeg_data_tensor, | ||
1082 | + distorted_image_tensor, resized_image_tensor, bottleneck_tensor) | ||
1083 | + else: | ||
1084 | + (train_bottlenecks, | ||
1085 | + train_ground_truth, _) = get_random_cached_bottlenecks( | ||
1086 | + sess, image_lists, FLAGS.train_batch_size, 'training', | ||
1087 | + FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, | ||
1088 | + decoded_image_tensor, resized_image_tensor, bottleneck_tensor, | ||
1089 | + FLAGS.tfhub_module) | ||
1090 | + # Feed the bottlenecks and ground truth into the graph, and run a training | ||
1091 | + # step. Capture training summaries for TensorBoard with the `merged` op. | ||
1092 | + train_summary, _ = sess.run( | ||
1093 | + [merged, train_step], | ||
1094 | + feed_dict={bottleneck_input: train_bottlenecks, | ||
1095 | + ground_truth_input: train_ground_truth}) | ||
1096 | + train_writer.add_summary(train_summary, i) | ||
1097 | + | ||
1098 | + # Every so often, print out how well the graph is training. | ||
1099 | + is_last_step = (i + 1 == FLAGS.how_many_training_steps) | ||
1100 | + if (i % FLAGS.eval_step_interval) == 0 or is_last_step: | ||
1101 | + train_accuracy, cross_entropy_value = sess.run( | ||
1102 | + [evaluation_step, cross_entropy], | ||
1103 | + feed_dict={bottleneck_input: train_bottlenecks, | ||
1104 | + ground_truth_input: train_ground_truth}) | ||
1105 | + logging.info('%s: Step %d: Train accuracy = %.1f%%', | ||
1106 | + datetime.now(), i, train_accuracy * 100) | ||
1107 | + logging.info('%s: Step %d: Cross entropy = %f', | ||
1108 | + datetime.now(), i, cross_entropy_value) | ||
1109 | + # TODO: Make this use an eval graph, to avoid quantization | ||
1110 | + # moving averages being updated by the validation set, though in | ||
1111 | + # practice this makes a negligable difference. | ||
1112 | + validation_bottlenecks, validation_ground_truth, _ = ( | ||
1113 | + get_random_cached_bottlenecks( | ||
1114 | + sess, image_lists, FLAGS.validation_batch_size, 'validation', | ||
1115 | + FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, | ||
1116 | + decoded_image_tensor, resized_image_tensor, bottleneck_tensor, | ||
1117 | + FLAGS.tfhub_module)) | ||
1118 | + # Run a validation step and capture training summaries for TensorBoard | ||
1119 | + # with the `merged` op. | ||
1120 | + validation_summary, validation_accuracy = sess.run( | ||
1121 | + [merged, evaluation_step], | ||
1122 | + feed_dict={bottleneck_input: validation_bottlenecks, | ||
1123 | + ground_truth_input: validation_ground_truth}) | ||
1124 | + validation_writer.add_summary(validation_summary, i) | ||
1125 | + logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)', | ||
1126 | + datetime.now(), i, validation_accuracy * 100, | ||
1127 | + len(validation_bottlenecks)) | ||
1128 | + | ||
1129 | + # Store intermediate results | ||
1130 | + intermediate_frequency = FLAGS.intermediate_store_frequency | ||
1131 | + | ||
1132 | + if (intermediate_frequency > 0 and (i % intermediate_frequency == 0) | ||
1133 | + and i > 0): | ||
1134 | + # If we want to do an intermediate save, save a checkpoint of the train | ||
1135 | + # graph, to restore into the eval graph. | ||
1136 | + train_saver.save(sess, FLAGS.checkpoint_path) | ||
1137 | + intermediate_file_name = (FLAGS.intermediate_output_graphs_dir + | ||
1138 | + 'intermediate_' + str(i) + '.pb') | ||
1139 | + logging.info('Save intermediate result to : %s', intermediate_file_name) | ||
1140 | + save_graph_to_file(intermediate_file_name, module_spec, | ||
1141 | + class_count) | ||
1142 | + | ||
1143 | + # After training is complete, force one last save of the train checkpoint. | ||
1144 | + train_saver.save(sess, FLAGS.checkpoint_path) | ||
1145 | + | ||
1146 | + # We've completed all our training, so run a final test evaluation on | ||
1147 | + # some new images we haven't used before. | ||
1148 | + run_final_eval(sess, module_spec, class_count, image_lists, | ||
1149 | + jpeg_data_tensor, decoded_image_tensor, resized_image_tensor, | ||
1150 | + bottleneck_tensor) | ||
1151 | + | ||
1152 | + # Write out the trained graph and labels with the weights stored as | ||
1153 | + # constants. | ||
1154 | + logging.info('Save final result to : %s', FLAGS.output_graph) | ||
1155 | + if wants_quantization: | ||
1156 | + logging.info('The model is instrumented for quantization with TF-Lite') | ||
1157 | + save_graph_to_file(FLAGS.output_graph, module_spec, class_count) | ||
1158 | + with tf.gfile.GFile(FLAGS.output_labels, 'w') as f: | ||
1159 | + f.write('\n'.join(image_lists.keys()) + '\n') | ||
1160 | + | ||
1161 | + if FLAGS.saved_model_dir: | ||
1162 | + export_model(module_spec, class_count, FLAGS.saved_model_dir) | ||
1163 | + | ||
1164 | + | ||
1165 | +if __name__ == '__main__': | ||
1166 | + parser = argparse.ArgumentParser() | ||
1167 | + parser.add_argument( | ||
1168 | + '--image_dir', | ||
1169 | + type=str, | ||
1170 | + default='', | ||
1171 | + help='Path to folders of labeled images.' | ||
1172 | + ) | ||
1173 | + parser.add_argument( | ||
1174 | + '--output_graph', | ||
1175 | + type=str, | ||
1176 | + default='/tmp/output_graph.pb', | ||
1177 | + help='Where to save the trained graph.' | ||
1178 | + ) | ||
1179 | + parser.add_argument( | ||
1180 | + '--intermediate_output_graphs_dir', | ||
1181 | + type=str, | ||
1182 | + default='/tmp/intermediate_graph/', | ||
1183 | + help='Where to save the intermediate graphs.' | ||
1184 | + ) | ||
1185 | + parser.add_argument( | ||
1186 | + '--intermediate_store_frequency', | ||
1187 | + type=int, | ||
1188 | + default=0, | ||
1189 | + help="""\ | ||
1190 | + How many steps to store intermediate graph. If "0" then will not | ||
1191 | + store.\ | ||
1192 | + """ | ||
1193 | + ) | ||
1194 | + parser.add_argument( | ||
1195 | + '--output_labels', | ||
1196 | + type=str, | ||
1197 | + default='/tmp/output_labels.txt', | ||
1198 | + help='Where to save the trained graph\'s labels.' | ||
1199 | + ) | ||
1200 | + parser.add_argument( | ||
1201 | + '--summaries_dir', | ||
1202 | + type=str, | ||
1203 | + default='/tmp/retrain_logs', | ||
1204 | + help='Where to save summary logs for TensorBoard.' | ||
1205 | + ) | ||
1206 | + parser.add_argument( | ||
1207 | + '--how_many_training_steps', | ||
1208 | + type=int, | ||
1209 | + default=4000, | ||
1210 | + help='How many training steps to run before ending.' | ||
1211 | + ) | ||
1212 | + parser.add_argument( | ||
1213 | + '--learning_rate', | ||
1214 | + type=float, | ||
1215 | + default=0.01, | ||
1216 | + help='How large a learning rate to use when training.' | ||
1217 | + ) | ||
1218 | + parser.add_argument( | ||
1219 | + '--testing_percentage', | ||
1220 | + type=int, | ||
1221 | + default=10, | ||
1222 | + help='What percentage of images to use as a test set.' | ||
1223 | + ) | ||
1224 | + parser.add_argument( | ||
1225 | + '--validation_percentage', | ||
1226 | + type=int, | ||
1227 | + default=10, | ||
1228 | + help='What percentage of images to use as a validation set.' | ||
1229 | + ) | ||
1230 | + parser.add_argument( | ||
1231 | + '--eval_step_interval', | ||
1232 | + type=int, | ||
1233 | + default=10, | ||
1234 | + help='How often to evaluate the training results.' | ||
1235 | + ) | ||
1236 | + parser.add_argument( | ||
1237 | + '--train_batch_size', | ||
1238 | + type=int, | ||
1239 | + default=100, | ||
1240 | + help='How many images to train on at a time.' | ||
1241 | + ) | ||
1242 | + parser.add_argument( | ||
1243 | + '--test_batch_size', | ||
1244 | + type=int, | ||
1245 | + default=-1, | ||
1246 | + help="""\ | ||
1247 | + How many images to test on. This test set is only used once, to evaluate | ||
1248 | + the final accuracy of the model after training completes. | ||
1249 | + A value of -1 causes the entire test set to be used, which leads to more | ||
1250 | + stable results across runs.\ | ||
1251 | + """ | ||
1252 | + ) | ||
1253 | + parser.add_argument( | ||
1254 | + '--validation_batch_size', | ||
1255 | + type=int, | ||
1256 | + default=100, | ||
1257 | + help="""\ | ||
1258 | + How many images to use in an evaluation batch. This validation set is | ||
1259 | + used much more often than the test set, and is an early indicator of how | ||
1260 | + accurate the model is during training. | ||
1261 | + A value of -1 causes the entire validation set to be used, which leads to | ||
1262 | + more stable results across training iterations, but may be slower on large | ||
1263 | + training sets.\ | ||
1264 | + """ | ||
1265 | + ) | ||
1266 | + parser.add_argument( | ||
1267 | + '--print_misclassified_test_images', | ||
1268 | + default=False, | ||
1269 | + help="""\ | ||
1270 | + Whether to print out a list of all misclassified test images.\ | ||
1271 | + """, | ||
1272 | + action='store_true' | ||
1273 | + ) | ||
1274 | + parser.add_argument( | ||
1275 | + '--bottleneck_dir', | ||
1276 | + type=str, | ||
1277 | + default='/tmp/bottleneck', | ||
1278 | + help='Path to cache bottleneck layer values as files.' | ||
1279 | + ) | ||
1280 | + parser.add_argument( | ||
1281 | + '--final_tensor_name', | ||
1282 | + type=str, | ||
1283 | + default='final_result', | ||
1284 | + help="""\ | ||
1285 | + The name of the output classification layer in the retrained graph.\ | ||
1286 | + """ | ||
1287 | + ) | ||
1288 | + parser.add_argument( | ||
1289 | + '--flip_left_right', | ||
1290 | + default=False, | ||
1291 | + help="""\ | ||
1292 | + Whether to randomly flip half of the training images horizontally.\ | ||
1293 | + """, | ||
1294 | + action='store_true' | ||
1295 | + ) | ||
1296 | + parser.add_argument( | ||
1297 | + '--random_crop', | ||
1298 | + type=int, | ||
1299 | + default=0, | ||
1300 | + help="""\ | ||
1301 | + A percentage determining how much of a margin to randomly crop off the | ||
1302 | + training images.\ | ||
1303 | + """ | ||
1304 | + ) | ||
1305 | + parser.add_argument( | ||
1306 | + '--random_scale', | ||
1307 | + type=int, | ||
1308 | + default=0, | ||
1309 | + help="""\ | ||
1310 | + A percentage determining how much to randomly scale up the size of the | ||
1311 | + training images by.\ | ||
1312 | + """ | ||
1313 | + ) | ||
1314 | + parser.add_argument( | ||
1315 | + '--random_brightness', | ||
1316 | + type=int, | ||
1317 | + default=0, | ||
1318 | + help="""\ | ||
1319 | + A percentage determining how much to randomly multiply the training image | ||
1320 | + input pixels up or down by.\ | ||
1321 | + """ | ||
1322 | + ) | ||
1323 | + parser.add_argument( | ||
1324 | + '--tfhub_module', | ||
1325 | + type=str, | ||
1326 | + default=( | ||
1327 | + 'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3'), | ||
1328 | + help="""\ | ||
1329 | + Which TensorFlow Hub module to use. For more options, | ||
1330 | + search https://tfhub.dev for image feature vector modules.\ | ||
1331 | + """) | ||
1332 | + parser.add_argument( | ||
1333 | + '--saved_model_dir', | ||
1334 | + type=str, | ||
1335 | + default='', | ||
1336 | + help='Where to save the exported graph.') | ||
1337 | + parser.add_argument( | ||
1338 | + '--logging_verbosity', | ||
1339 | + type=str, | ||
1340 | + default='INFO', | ||
1341 | + choices=['DEBUG', 'INFO', 'WARN', 'ERROR', 'FATAL'], | ||
1342 | + help='How much logging output should be produced.') | ||
1343 | + parser.add_argument( | ||
1344 | + '--checkpoint_path', | ||
1345 | + type=str, | ||
1346 | + default='/tmp/_retrain_checkpoint', | ||
1347 | + help='Where to save checkpoint files.' | ||
1348 | + ) | ||
1349 | + FLAGS, unparsed = parser.parse_known_args() | ||
1350 | + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |
tensorflow/retrain_run_inference.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +"""Inception v3 architecture 모델을 retraining한 모델을 이용해서 이미지에 대한 추론(inference)을 진행하는 예제""" | ||
4 | + | ||
5 | +import numpy as np | ||
6 | +import tensorflow as tf | ||
7 | + | ||
8 | +imagePath = '/tmp/test_chartreux.jpg' # 추론을 진행할 이미지 경로 | ||
9 | +modelFullPath = '/tmp/output_graph.pb' # 읽어들일 graph 파일 경로 | ||
10 | +labelsFullPath = '/tmp/output_labels.txt' # 읽어들일 labels 파일 경로 | ||
11 | + | ||
12 | + | ||
13 | +def create_graph(): | ||
14 | + """저장된(saved) GraphDef 파일로부터 graph를 생성하고 saver를 반환한다.""" | ||
15 | + # 저장된(saved) graph_def.pb로부터 graph를 생성한다. | ||
16 | + with tf.gfile.FastGFile(modelFullPath, 'rb') as f: | ||
17 | + graph_def = tf.GraphDef() | ||
18 | + graph_def.ParseFromString(f.read()) | ||
19 | + _ = tf.import_graph_def(graph_def, name='') | ||
20 | + | ||
21 | + | ||
22 | +def run_inference_on_image(): | ||
23 | + answer = None | ||
24 | + | ||
25 | + if not tf.gfile.Exists(imagePath): | ||
26 | + tf.logging.fatal('File does not exist %s', imagePath) | ||
27 | + return answer | ||
28 | + | ||
29 | + image_data = tf.gfile.FastGFile(imagePath, 'rb').read() | ||
30 | + | ||
31 | + # 저장된(saved) GraphDef 파일로부터 graph를 생성한다. | ||
32 | + create_graph() | ||
33 | + | ||
34 | + with tf.Session() as sess: | ||
35 | + | ||
36 | + softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') | ||
37 | + predictions = sess.run(softmax_tensor, | ||
38 | + {'DecodeJpeg/contents:0': image_data}) | ||
39 | + predictions = np.squeeze(predictions) | ||
40 | + | ||
41 | + top_k = predictions.argsort()[-5:][::-1] # 가장 높은 확률을 가진 5개(top 5)의 예측값(predictions)을 얻는다. | ||
42 | + f = open(labelsFullPath, 'rb') | ||
43 | + lines = f.readlines() | ||
44 | + labels = [str(w).replace("\n", "") for w in lines] | ||
45 | + for node_id in top_k: | ||
46 | + human_string = labels[node_id] | ||
47 | + score = predictions[node_id] | ||
48 | + print('%s (score = %.5f)' % (human_string, score)) | ||
49 | + | ||
50 | + answer = labels[top_k[0]] | ||
51 | + return answer | ||
52 | + | ||
53 | + | ||
54 | +if __name__ == '__main__': | ||
55 | + run_inference_on_image() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment