최예리

tensorflow-inception file upload

1 +# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +# ==============================================================================
15 +# NOTICE: This work was derived from tensorflow/examples/image_retraining
16 +# and modified to use TensorFlow Hub modules.
17 +
18 +# pylint: disable=line-too-long
19 +r"""Simple transfer learning with image modules from TensorFlow Hub.
20 +
21 +This example shows how to train an image classifier based on any
22 +TensorFlow Hub module that computes image feature vectors. By default,
23 +it uses the feature vectors computed by Inception V3 trained on ImageNet.
24 +For more options, search https://tfhub.dev for image feature vector modules.
25 +
26 +The top layer receives as input a 2048-dimensional vector (assuming
27 +Inception V3) for each image. We train a softmax layer on top of this
28 +representation. If the softmax layer contains N labels, this corresponds
29 +to learning N + 2048*N model parameters for the biases and weights.
30 +
31 +Here's an example, which assumes you have a folder containing class-named
32 +subfolders, each full of images for each label. The example folder flower_photos
33 +should have a structure like this:
34 +
35 +~/flower_photos/daisy/photo1.jpg
36 +~/flower_photos/daisy/photo2.jpg
37 +...
38 +~/flower_photos/rose/anotherphoto77.jpg
39 +...
40 +~/flower_photos/sunflower/somepicture.jpg
41 +
42 +The subfolder names are important, since they define what label is applied to
43 +each image, but the filenames themselves don't matter. (For a working example,
44 +download http://download.tensorflow.org/example_images/flower_photos.tgz
45 +and run tar xzf flower_photos.tgz to unpack it.)
46 +
47 +Once your images are prepared, and you have pip-installed tensorflow-hub and
48 +a sufficiently recent version of tensorflow, you can run the training with a
49 +command like this:
50 +
51 +```bash
52 +python retrain.py --image_dir ~/flower_photos
53 +```
54 +
55 +You can replace the image_dir argument with any folder containing subfolders of
56 +images. The label for each image is taken from the name of the subfolder it's
57 +in.
58 +
59 +This produces a new model file that can be loaded and run by any TensorFlow
60 +program, for example the tensorflow/examples/label_image sample code.
61 +
62 +By default this script will use the highly accurate, but comparatively large and
63 +slow Inception V3 model architecture. It's recommended that you start with this
64 +to validate that you have gathered good training data, but if you want to deploy
65 +on resource-limited platforms, you can try the `--tfhub_module` flag with a
66 +Mobilenet model. For more information on Mobilenet, see
67 +https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
68 +
69 +For example:
70 +
71 +Run floating-point version of Mobilenet:
72 +
73 +```bash
74 +python retrain.py --image_dir ~/flower_photos \
75 + --tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/feature_vector/3
76 +```
77 +
78 +Run Mobilenet, instrumented for quantization:
79 +
80 +```bash
81 +python retrain.py --image_dir ~/flower_photos/ \
82 + --tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/quantops/feature_vector/3
83 +```
84 +
85 +These instrumented models can be converted to fully quantized mobile models via
86 +TensorFlow Lite.
87 +
88 +There are different Mobilenet models to choose from, with a variety of file
89 +size and latency options.
90 + - The first number can be '100', '075', '050', or '025' to control the number
91 + of neurons (activations of hidden layers); the number of weights (and hence
92 + to some extent the file size and speed) shrinks with the square of that
93 + fraction.
94 + - The second number is the input image size. You can choose '224', '192',
95 + '160', or '128', with smaller sizes giving faster speeds.
96 +
97 +To use with TensorBoard:
98 +
99 +By default, this script will log summaries to /tmp/retrain_logs directory
100 +
101 +Visualize the summaries with this command:
102 +
103 +tensorboard --logdir /tmp/retrain_logs
104 +
105 +To use with Tensorflow Serving, run this tool with --saved_model_dir set
106 +to some increasingly numbered export location under the model base path, e.g.:
107 +
108 +```bash
109 +python retrain.py (... other args as before ...) \
110 + --saved_model_dir=/tmp/saved_models/$(date +%s)/
111 +tensorflow_model_server --port=9000 --model_name=my_image_classifier \
112 + --model_base_path=/tmp/saved_models/
113 +```
114 +"""
115 +# pylint: enable=line-too-long
116 +
117 +from __future__ import absolute_import
118 +from __future__ import division
119 +from __future__ import print_function
120 +
121 +from absl import logging
122 +
123 +import argparse
124 +import collections
125 +from datetime import datetime
126 +import hashlib
127 +import os.path
128 +import random
129 +import re
130 +import sys
131 +
132 +import numpy as np
133 +import tensorflow as tf
134 +import tensorflow_hub as hub
135 +from tensorflow.contrib import quantize as contrib_quantize
136 +
137 +FLAGS = None
138 +
139 +MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M
140 +
141 +# A module is understood as instrumented for quantization with TF-Lite
142 +# if it contains any of these ops.
143 +FAKE_QUANT_OPS = ('FakeQuantWithMinMaxVars',
144 + 'FakeQuantWithMinMaxVarsPerChannel')
145 +
146 +
147 +def create_image_lists(image_dir, testing_percentage, validation_percentage):
148 + """Builds a list of training images from the file system.
149 +
150 + Analyzes the sub folders in the image directory, splits them into stable
151 + training, testing, and validation sets, and returns a data structure
152 + describing the lists of images for each label and their paths.
153 +
154 + Args:
155 + image_dir: String path to a folder containing subfolders of images.
156 + testing_percentage: Integer percentage of the images to reserve for tests.
157 + validation_percentage: Integer percentage of images reserved for validation.
158 +
159 + Returns:
160 + An OrderedDict containing an entry for each label subfolder, with images
161 + split into training, testing, and validation sets within each label.
162 + The order of items defines the class indices.
163 + """
164 + if not tf.gfile.Exists(image_dir):
165 + logging.error("Image directory '" + image_dir + "' not found.")
166 + return None
167 + result = collections.OrderedDict()
168 + sub_dirs = sorted(x[0] for x in tf.gfile.Walk(image_dir))
169 + # The root directory comes first, so skip it.
170 + is_root_dir = True
171 + for sub_dir in sub_dirs:
172 + if is_root_dir:
173 + is_root_dir = False
174 + continue
175 + extensions = sorted(set(os.path.normcase(ext) # Smash case on Windows.
176 + for ext in ['JPEG', 'JPG', 'jpeg', 'jpg', 'png']))
177 + file_list = []
178 + dir_name = os.path.basename(
179 + # tf.gfile.Walk() returns sub-directory with trailing '/' when it is in
180 + # Google Cloud Storage, which confuses os.path.basename().
181 + sub_dir[:-1] if sub_dir.endswith('/') else sub_dir)
182 +
183 + if dir_name == image_dir:
184 + continue
185 + logging.info("Looking for images in '%s'", dir_name)
186 + for extension in extensions:
187 + file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
188 + file_list.extend(tf.gfile.Glob(file_glob))
189 + if not file_list:
190 + logging.warning('No files found')
191 + continue
192 + if len(file_list) < 20:
193 + logging.warning(
194 + 'WARNING: Folder has less than 20 images, which may cause issues.')
195 + elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
196 + logging.warning(
197 + 'WARNING: Folder %s has more than %s images. Some images will '
198 + 'never be selected.', dir_name, MAX_NUM_IMAGES_PER_CLASS)
199 + label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
200 + training_images = []
201 + testing_images = []
202 + validation_images = []
203 + for file_name in file_list:
204 + base_name = os.path.basename(file_name)
205 + # We want to ignore anything after '_nohash_' in the file name when
206 + # deciding which set to put an image in, the data set creator has a way of
207 + # grouping photos that are close variations of each other. For example
208 + # this is used in the plant disease data set to group multiple pictures of
209 + # the same leaf.
210 + hash_name = re.sub(r'_nohash_.*$', '', file_name)
211 + # This looks a bit magical, but we need to decide whether this file should
212 + # go into the training, testing, or validation sets, and we want to keep
213 + # existing files in the same set even if more files are subsequently
214 + # added.
215 + # To do that, we need a stable way of deciding based on just the file name
216 + # itself, so we do a hash of that and then use that to generate a
217 + # probability value that we use to assign it.
218 + hash_name_hashed = hashlib.sha1(tf.compat.as_bytes(hash_name)).hexdigest()
219 + percentage_hash = ((int(hash_name_hashed, 16) %
220 + (MAX_NUM_IMAGES_PER_CLASS + 1)) *
221 + (100.0 / MAX_NUM_IMAGES_PER_CLASS))
222 + if percentage_hash < validation_percentage:
223 + validation_images.append(base_name)
224 + elif percentage_hash < (testing_percentage + validation_percentage):
225 + testing_images.append(base_name)
226 + else:
227 + training_images.append(base_name)
228 + result[label_name] = {
229 + 'dir': dir_name,
230 + 'training': training_images,
231 + 'testing': testing_images,
232 + 'validation': validation_images,
233 + }
234 + return result
235 +
236 +
237 +def get_image_path(image_lists, label_name, index, image_dir, category):
238 + """Returns a path to an image for a label at the given index.
239 +
240 + Args:
241 + image_lists: OrderedDict of training images for each label.
242 + label_name: Label string we want to get an image for.
243 + index: Int offset of the image we want. This will be moduloed by the
244 + available number of images for the label, so it can be arbitrarily large.
245 + image_dir: Root folder string of the subfolders containing the training
246 + images.
247 + category: Name string of set to pull images from - training, testing, or
248 + validation.
249 +
250 + Returns:
251 + File system path string to an image that meets the requested parameters.
252 +
253 + """
254 + if label_name not in image_lists:
255 + logging.fatal('Label does not exist %s.', label_name)
256 + label_lists = image_lists[label_name]
257 + if category not in label_lists:
258 + logging.fatal('Category does not exist %s.', category)
259 + category_list = label_lists[category]
260 + if not category_list:
261 + logging.fatal('Label %s has no images in the category %s.',
262 + label_name, category)
263 + mod_index = index % len(category_list)
264 + base_name = category_list[mod_index]
265 + sub_dir = label_lists['dir']
266 + full_path = os.path.join(image_dir, sub_dir, base_name)
267 + return full_path
268 +
269 +
270 +def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir,
271 + category, module_name):
272 + """Returns a path to a bottleneck file for a label at the given index.
273 +
274 + Args:
275 + image_lists: OrderedDict of training images for each label.
276 + label_name: Label string we want to get an image for.
277 + index: Integer offset of the image we want. This will be moduloed by the
278 + available number of images for the label, so it can be arbitrarily large.
279 + bottleneck_dir: Folder string holding cached files of bottleneck values.
280 + category: Name string of set to pull images from - training, testing, or
281 + validation.
282 + module_name: The name of the image module being used.
283 +
284 + Returns:
285 + File system path string to an image that meets the requested parameters.
286 + """
287 + module_name = (module_name.replace('://', '~') # URL scheme.
288 + .replace('/', '~') # URL and Unix paths.
289 + .replace(':', '~').replace('\\', '~')) # Windows paths.
290 + return get_image_path(image_lists, label_name, index, bottleneck_dir,
291 + category) + '_' + module_name + '.txt'
292 +
293 +
294 +def create_module_graph(module_spec):
295 + """Creates a graph and loads Hub Module into it.
296 +
297 + Args:
298 + module_spec: the hub.ModuleSpec for the image module being used.
299 +
300 + Returns:
301 + graph: the tf.Graph that was created.
302 + bottleneck_tensor: the bottleneck values output by the module.
303 + resized_input_tensor: the input images, resized as expected by the module.
304 + wants_quantization: a boolean, whether the module has been instrumented
305 + with fake quantization ops.
306 + """
307 + height, width = hub.get_expected_image_size(module_spec)
308 + with tf.Graph().as_default() as graph:
309 + resized_input_tensor = tf.placeholder(tf.float32, [None, height, width, 3])
310 + m = hub.Module(module_spec)
311 + bottleneck_tensor = m(resized_input_tensor)
312 + wants_quantization = any(node.op in FAKE_QUANT_OPS
313 + for node in graph.as_graph_def().node)
314 + return graph, bottleneck_tensor, resized_input_tensor, wants_quantization
315 +
316 +
317 +def run_bottleneck_on_image(sess, image_data, image_data_tensor,
318 + decoded_image_tensor, resized_input_tensor,
319 + bottleneck_tensor):
320 + """Runs inference on an image to extract the 'bottleneck' summary layer.
321 +
322 + Args:
323 + sess: Current active TensorFlow Session.
324 + image_data: String of raw JPEG data.
325 + image_data_tensor: Input data layer in the graph.
326 + decoded_image_tensor: Output of initial image resizing and preprocessing.
327 + resized_input_tensor: The input node of the recognition graph.
328 + bottleneck_tensor: Layer before the final softmax.
329 +
330 + Returns:
331 + Numpy array of bottleneck values.
332 + """
333 + # First decode the JPEG image, resize it, and rescale the pixel values.
334 + resized_input_values = sess.run(decoded_image_tensor,
335 + {image_data_tensor: image_data})
336 + # Then run it through the recognition network.
337 + bottleneck_values = sess.run(bottleneck_tensor,
338 + {resized_input_tensor: resized_input_values})
339 + bottleneck_values = np.squeeze(bottleneck_values)
340 + return bottleneck_values
341 +
342 +
343 +def ensure_dir_exists(dir_name):
344 + """Makes sure the folder exists on disk.
345 +
346 + Args:
347 + dir_name: Path string to the folder we want to create.
348 + """
349 + if not os.path.exists(dir_name):
350 + os.makedirs(dir_name)
351 +
352 +
353 +def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
354 + image_dir, category, sess, jpeg_data_tensor,
355 + decoded_image_tensor, resized_input_tensor,
356 + bottleneck_tensor):
357 + """Create a single bottleneck file."""
358 + logging.debug('Creating bottleneck at %s', bottleneck_path)
359 + image_path = get_image_path(image_lists, label_name, index,
360 + image_dir, category)
361 + if not tf.gfile.Exists(image_path):
362 + logging.fatal('File does not exist %s', image_path)
363 + image_data = tf.gfile.GFile(image_path, 'rb').read()
364 + try:
365 + bottleneck_values = run_bottleneck_on_image(
366 + sess, image_data, jpeg_data_tensor, decoded_image_tensor,
367 + resized_input_tensor, bottleneck_tensor)
368 + except Exception as e:
369 + raise RuntimeError('Error during processing file %s (%s)' % (image_path,
370 + str(e)))
371 + bottleneck_string = ','.join(str(x) for x in bottleneck_values)
372 + with tf.gfile.GFile(bottleneck_path, 'w') as bottleneck_file:
373 + bottleneck_file.write(bottleneck_string)
374 +
375 +
376 +def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
377 + category, bottleneck_dir, jpeg_data_tensor,
378 + decoded_image_tensor, resized_input_tensor,
379 + bottleneck_tensor, module_name):
380 + """Retrieves or calculates bottleneck values for an image.
381 +
382 + If a cached version of the bottleneck data exists on-disk, return that,
383 + otherwise calculate the data and save it to disk for future use.
384 +
385 + Args:
386 + sess: The current active TensorFlow Session.
387 + image_lists: OrderedDict of training images for each label.
388 + label_name: Label string we want to get an image for.
389 + index: Integer offset of the image we want. This will be modulo-ed by the
390 + available number of images for the label, so it can be arbitrarily large.
391 + image_dir: Root folder string of the subfolders containing the training
392 + images.
393 + category: Name string of which set to pull images from - training, testing,
394 + or validation.
395 + bottleneck_dir: Folder string holding cached files of bottleneck values.
396 + jpeg_data_tensor: The tensor to feed loaded jpeg data into.
397 + decoded_image_tensor: The output of decoding and resizing the image.
398 + resized_input_tensor: The input node of the recognition graph.
399 + bottleneck_tensor: The output tensor for the bottleneck values.
400 + module_name: The name of the image module being used.
401 +
402 + Returns:
403 + Numpy array of values produced by the bottleneck layer for the image.
404 + """
405 + label_lists = image_lists[label_name]
406 + sub_dir = label_lists['dir']
407 + sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
408 + ensure_dir_exists(sub_dir_path)
409 + bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
410 + bottleneck_dir, category, module_name)
411 + if not os.path.exists(bottleneck_path):
412 + create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
413 + image_dir, category, sess, jpeg_data_tensor,
414 + decoded_image_tensor, resized_input_tensor,
415 + bottleneck_tensor)
416 + with tf.gfile.GFile(bottleneck_path, 'r') as bottleneck_file:
417 + bottleneck_string = bottleneck_file.read()
418 + did_hit_error = False
419 + try:
420 + bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
421 + except ValueError:
422 + logging.warning('Invalid float found, recreating bottleneck')
423 + did_hit_error = True
424 + if did_hit_error:
425 + create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
426 + image_dir, category, sess, jpeg_data_tensor,
427 + decoded_image_tensor, resized_input_tensor,
428 + bottleneck_tensor)
429 + with tf.gfile.GFile(bottleneck_path, 'r') as bottleneck_file:
430 + bottleneck_string = bottleneck_file.read()
431 + # Allow exceptions to propagate here, since they shouldn't happen after a
432 + # fresh creation
433 + bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
434 + return bottleneck_values
435 +
436 +
437 +def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
438 + jpeg_data_tensor, decoded_image_tensor,
439 + resized_input_tensor, bottleneck_tensor, module_name):
440 + """Ensures all the training, testing, and validation bottlenecks are cached.
441 +
442 + Because we're likely to read the same image multiple times (if there are no
443 + distortions applied during training) it can speed things up a lot if we
444 + calculate the bottleneck layer values once for each image during
445 + preprocessing, and then just read those cached values repeatedly during
446 + training. Here we go through all the images we've found, calculate those
447 + values, and save them off.
448 +
449 + Args:
450 + sess: The current active TensorFlow Session.
451 + image_lists: OrderedDict of training images for each label.
452 + image_dir: Root folder string of the subfolders containing the training
453 + images.
454 + bottleneck_dir: Folder string holding cached files of bottleneck values.
455 + jpeg_data_tensor: Input tensor for jpeg data from file.
456 + decoded_image_tensor: The output of decoding and resizing the image.
457 + resized_input_tensor: The input node of the recognition graph.
458 + bottleneck_tensor: The penultimate output layer of the graph.
459 + module_name: The name of the image module being used.
460 +
461 + Returns:
462 + Nothing.
463 + """
464 + how_many_bottlenecks = 0
465 + ensure_dir_exists(bottleneck_dir)
466 + for label_name, label_lists in image_lists.items():
467 + for category in ['training', 'testing', 'validation']:
468 + category_list = label_lists[category]
469 + for index, unused_base_name in enumerate(category_list):
470 + get_or_create_bottleneck(
471 + sess, image_lists, label_name, index, image_dir, category,
472 + bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
473 + resized_input_tensor, bottleneck_tensor, module_name)
474 +
475 + how_many_bottlenecks += 1
476 + if how_many_bottlenecks % 100 == 0:
477 + logging.info('%s bottleneck files created.', how_many_bottlenecks)
478 +
479 +
480 +def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
481 + bottleneck_dir, image_dir, jpeg_data_tensor,
482 + decoded_image_tensor, resized_input_tensor,
483 + bottleneck_tensor, module_name):
484 + """Retrieves bottleneck values for cached images.
485 +
486 + If no distortions are being applied, this function can retrieve the cached
487 + bottleneck values directly from disk for images. It picks a random set of
488 + images from the specified category.
489 +
490 + Args:
491 + sess: Current TensorFlow Session.
492 + image_lists: OrderedDict of training images for each label.
493 + how_many: If positive, a random sample of this size will be chosen.
494 + If negative, all bottlenecks will be retrieved.
495 + category: Name string of which set to pull from - training, testing, or
496 + validation.
497 + bottleneck_dir: Folder string holding cached files of bottleneck values.
498 + image_dir: Root folder string of the subfolders containing the training
499 + images.
500 + jpeg_data_tensor: The layer to feed jpeg image data into.
501 + decoded_image_tensor: The output of decoding and resizing the image.
502 + resized_input_tensor: The input node of the recognition graph.
503 + bottleneck_tensor: The bottleneck output layer of the CNN graph.
504 + module_name: The name of the image module being used.
505 +
506 + Returns:
507 + List of bottleneck arrays, their corresponding ground truths, and the
508 + relevant filenames.
509 + """
510 + class_count = len(image_lists.keys())
511 + bottlenecks = []
512 + ground_truths = []
513 + filenames = []
514 + if how_many >= 0:
515 + # Retrieve a random sample of bottlenecks.
516 + for unused_i in range(how_many):
517 + label_index = random.randrange(class_count)
518 + label_name = list(image_lists.keys())[label_index]
519 + image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
520 + image_name = get_image_path(image_lists, label_name, image_index,
521 + image_dir, category)
522 + bottleneck = get_or_create_bottleneck(
523 + sess, image_lists, label_name, image_index, image_dir, category,
524 + bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
525 + resized_input_tensor, bottleneck_tensor, module_name)
526 + bottlenecks.append(bottleneck)
527 + ground_truths.append(label_index)
528 + filenames.append(image_name)
529 + else:
530 + # Retrieve all bottlenecks.
531 + for label_index, label_name in enumerate(image_lists.keys()):
532 + for image_index, image_name in enumerate(
533 + image_lists[label_name][category]):
534 + image_name = get_image_path(image_lists, label_name, image_index,
535 + image_dir, category)
536 + bottleneck = get_or_create_bottleneck(
537 + sess, image_lists, label_name, image_index, image_dir, category,
538 + bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
539 + resized_input_tensor, bottleneck_tensor, module_name)
540 + bottlenecks.append(bottleneck)
541 + ground_truths.append(label_index)
542 + filenames.append(image_name)
543 + return bottlenecks, ground_truths, filenames
544 +
545 +
546 +def get_random_distorted_bottlenecks(
547 + sess, image_lists, how_many, category, image_dir, input_jpeg_tensor,
548 + distorted_image, resized_input_tensor, bottleneck_tensor):
549 + """Retrieves bottleneck values for training images, after distortions.
550 +
551 + If we're training with distortions like crops, scales, or flips, we have to
552 + recalculate the full model for every image, and so we can't use cached
553 + bottleneck values. Instead we find random images for the requested category,
554 + run them through the distortion graph, and then the full graph to get the
555 + bottleneck results for each.
556 +
557 + Args:
558 + sess: Current TensorFlow Session.
559 + image_lists: OrderedDict of training images for each label.
560 + how_many: The integer number of bottleneck values to return.
561 + category: Name string of which set of images to fetch - training, testing,
562 + or validation.
563 + image_dir: Root folder string of the subfolders containing the training
564 + images.
565 + input_jpeg_tensor: The input layer we feed the image data to.
566 + distorted_image: The output node of the distortion graph.
567 + resized_input_tensor: The input node of the recognition graph.
568 + bottleneck_tensor: The bottleneck output layer of the CNN graph.
569 +
570 + Returns:
571 + List of bottleneck arrays and their corresponding ground truths.
572 + """
573 + class_count = len(image_lists.keys())
574 + bottlenecks = []
575 + ground_truths = []
576 + for unused_i in range(how_many):
577 + label_index = random.randrange(class_count)
578 + label_name = list(image_lists.keys())[label_index]
579 + image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
580 + image_path = get_image_path(image_lists, label_name, image_index, image_dir,
581 + category)
582 + if not tf.gfile.Exists(image_path):
583 + logging.fatal('File does not exist %s', image_path)
584 + jpeg_data = tf.gfile.GFile(image_path, 'rb').read()
585 + # Note that we materialize the distorted_image_data as a numpy array before
586 + # sending running inference on the image. This involves 2 memory copies and
587 + # might be optimized in other implementations.
588 + distorted_image_data = sess.run(distorted_image,
589 + {input_jpeg_tensor: jpeg_data})
590 + bottleneck_values = sess.run(bottleneck_tensor,
591 + {resized_input_tensor: distorted_image_data})
592 + bottleneck_values = np.squeeze(bottleneck_values)
593 + bottlenecks.append(bottleneck_values)
594 + ground_truths.append(label_index)
595 + return bottlenecks, ground_truths
596 +
597 +
598 +def should_distort_images(flip_left_right, random_crop, random_scale,
599 + random_brightness):
600 + """Whether any distortions are enabled, from the input flags.
601 +
602 + Args:
603 + flip_left_right: Boolean whether to randomly mirror images horizontally.
604 + random_crop: Integer percentage setting the total margin used around the
605 + crop box.
606 + random_scale: Integer percentage of how much to vary the scale by.
607 + random_brightness: Integer range to randomly multiply the pixel values by.
608 +
609 + Returns:
610 + Boolean value indicating whether any distortions should be applied.
611 + """
612 + return (flip_left_right or (random_crop != 0) or (random_scale != 0) or
613 + (random_brightness != 0))
614 +
615 +
616 +def add_input_distortions(flip_left_right, random_crop, random_scale,
617 + random_brightness, module_spec):
618 + """Creates the operations to apply the specified distortions.
619 +
620 + During training it can help to improve the results if we run the images
621 + through simple distortions like crops, scales, and flips. These reflect the
622 + kind of variations we expect in the real world, and so can help train the
623 + model to cope with natural data more effectively. Here we take the supplied
624 + parameters and construct a network of operations to apply them to an image.
625 +
626 + Cropping
627 + ~~~~~~~~
628 +
629 + Cropping is done by placing a bounding box at a random position in the full
630 + image. The cropping parameter controls the size of that box relative to the
631 + input image. If it's zero, then the box is the same size as the input and no
632 + cropping is performed. If the value is 50%, then the crop box will be half the
633 + width and height of the input. In a diagram it looks like this:
634 +
635 + < width >
636 + +---------------------+
637 + | |
638 + | width - crop% |
639 + | < > |
640 + | +------+ |
641 + | | | |
642 + | | | |
643 + | | | |
644 + | +------+ |
645 + | |
646 + | |
647 + +---------------------+
648 +
649 + Scaling
650 + ~~~~~~~
651 +
652 + Scaling is a lot like cropping, except that the bounding box is always
653 + centered and its size varies randomly within the given range. For example if
654 + the scale percentage is zero, then the bounding box is the same size as the
655 + input and no scaling is applied. If it's 50%, then the bounding box will be in
656 + a random range between half the width and height and full size.
657 +
658 + Args:
659 + flip_left_right: Boolean whether to randomly mirror images horizontally.
660 + random_crop: Integer percentage setting the total margin used around the
661 + crop box.
662 + random_scale: Integer percentage of how much to vary the scale by.
663 + random_brightness: Integer range to randomly multiply the pixel values by.
664 + graph.
665 + module_spec: The hub.ModuleSpec for the image module being used.
666 +
667 + Returns:
668 + The jpeg input layer and the distorted result tensor.
669 + """
670 + input_height, input_width = hub.get_expected_image_size(module_spec)
671 + input_depth = hub.get_num_image_channels(module_spec)
672 + jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput')
673 + decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth)
674 + # Convert from full range of uint8 to range [0,1] of float32.
675 + decoded_image_as_float = tf.image.convert_image_dtype(decoded_image,
676 + tf.float32)
677 + decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
678 + margin_scale = 1.0 + (random_crop / 100.0)
679 + resize_scale = 1.0 + (random_scale / 100.0)
680 + margin_scale_value = tf.constant(margin_scale)
681 + resize_scale_value = tf.random_uniform(shape=[],
682 + minval=1.0,
683 + maxval=resize_scale)
684 + scale_value = tf.multiply(margin_scale_value, resize_scale_value)
685 + precrop_width = tf.multiply(scale_value, input_width)
686 + precrop_height = tf.multiply(scale_value, input_height)
687 + precrop_shape = tf.stack([precrop_height, precrop_width])
688 + precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32)
689 + precropped_image = tf.image.resize_bilinear(decoded_image_4d,
690 + precrop_shape_as_int)
691 + precropped_image_3d = tf.squeeze(precropped_image, axis=[0])
692 + cropped_image = tf.random_crop(precropped_image_3d,
693 + [input_height, input_width, input_depth])
694 + if flip_left_right:
695 + flipped_image = tf.image.random_flip_left_right(cropped_image)
696 + else:
697 + flipped_image = cropped_image
698 + brightness_min = 1.0 - (random_brightness / 100.0)
699 + brightness_max = 1.0 + (random_brightness / 100.0)
700 + brightness_value = tf.random_uniform(shape=[],
701 + minval=brightness_min,
702 + maxval=brightness_max)
703 + brightened_image = tf.multiply(flipped_image, brightness_value)
704 + distort_result = tf.expand_dims(brightened_image, 0, name='DistortResult')
705 + return jpeg_data, distort_result
706 +
707 +
708 +def variable_summaries(var):
709 + """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
710 + with tf.name_scope('summaries'):
711 + mean = tf.reduce_mean(var)
712 + tf.summary.scalar('mean', mean)
713 + with tf.name_scope('stddev'):
714 + stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
715 + tf.summary.scalar('stddev', stddev)
716 + tf.summary.scalar('max', tf.reduce_max(var))
717 + tf.summary.scalar('min', tf.reduce_min(var))
718 + tf.summary.histogram('histogram', var)
719 +
720 +
721 +def add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor,
722 + quantize_layer, is_training):
723 + """Adds a new softmax and fully-connected layer for training and eval.
724 +
725 + We need to retrain the top layer to identify our new classes, so this function
726 + adds the right operations to the graph, along with some variables to hold the
727 + weights, and then sets up all the gradients for the backward pass.
728 +
729 + The set up for the softmax and fully-connected layers is based on:
730 + https://www.tensorflow.org/tutorials/mnist/beginners/index.html
731 +
732 + Args:
733 + class_count: Integer of how many categories of things we're trying to
734 + recognize.
735 + final_tensor_name: Name string for the new final node that produces results.
736 + bottleneck_tensor: The output of the main CNN graph.
737 + quantize_layer: Boolean, specifying whether the newly added layer should be
738 + instrumented for quantization with TF-Lite.
739 + is_training: Boolean, specifying whether the newly add layer is for training
740 + or eval.
741 +
742 + Returns:
743 + The tensors for the training and cross entropy results, and tensors for the
744 + bottleneck input and ground truth input.
745 + """
746 + batch_size, bottleneck_tensor_size = bottleneck_tensor.get_shape().as_list()
747 + assert batch_size is None, 'We want to work with arbitrary batch size.'
748 + with tf.name_scope('input'):
749 + bottleneck_input = tf.placeholder_with_default(
750 + bottleneck_tensor,
751 + shape=[batch_size, bottleneck_tensor_size],
752 + name='BottleneckInputPlaceholder')
753 +
754 + ground_truth_input = tf.placeholder(
755 + tf.int64, [batch_size], name='GroundTruthInput')
756 +
757 + # Organizing the following ops so they are easier to see in TensorBoard.
758 + layer_name = 'final_retrain_ops'
759 + with tf.name_scope(layer_name):
760 + with tf.name_scope('weights'):
761 + initial_value = tf.truncated_normal(
762 + [bottleneck_tensor_size, class_count], stddev=0.001)
763 + layer_weights = tf.Variable(initial_value, name='final_weights')
764 + variable_summaries(layer_weights)
765 +
766 + with tf.name_scope('biases'):
767 + layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
768 + variable_summaries(layer_biases)
769 +
770 + with tf.name_scope('Wx_plus_b'):
771 + logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
772 + tf.summary.histogram('pre_activations', logits)
773 +
774 + final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
775 +
776 + # The tf.contrib.quantize functions rewrite the graph in place for
777 + # quantization. The imported model graph has already been rewritten, so upon
778 + # calling these rewrites, only the newly added final layer will be
779 + # transformed.
780 + if quantize_layer:
781 + if is_training:
782 + contrib_quantize.create_training_graph()
783 + else:
784 + contrib_quantize.create_eval_graph()
785 +
786 + tf.summary.histogram('activations', final_tensor)
787 +
788 + # If this is an eval graph, we don't need to add loss ops or an optimizer.
789 + if not is_training:
790 + return None, None, bottleneck_input, ground_truth_input, final_tensor
791 +
792 + with tf.name_scope('cross_entropy'):
793 + cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
794 + labels=ground_truth_input, logits=logits)
795 +
796 + tf.summary.scalar('cross_entropy', cross_entropy_mean)
797 +
798 + with tf.name_scope('train'):
799 + optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
800 + train_step = optimizer.minimize(cross_entropy_mean)
801 +
802 + return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
803 + final_tensor)
804 +
805 +
806 +def add_evaluation_step(result_tensor, ground_truth_tensor):
807 + """Inserts the operations we need to evaluate the accuracy of our results.
808 +
809 + Args:
810 + result_tensor: The new final node that produces results.
811 + ground_truth_tensor: The node we feed ground truth data
812 + into.
813 +
814 + Returns:
815 + Tuple of (evaluation step, prediction).
816 + """
817 + with tf.name_scope('accuracy'):
818 + with tf.name_scope('correct_prediction'):
819 + prediction = tf.argmax(result_tensor, 1)
820 + correct_prediction = tf.equal(prediction, ground_truth_tensor)
821 + with tf.name_scope('accuracy'):
822 + evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
823 + tf.summary.scalar('accuracy', evaluation_step)
824 + return evaluation_step, prediction
825 +
826 +
827 +def run_final_eval(train_session, module_spec, class_count, image_lists,
828 + jpeg_data_tensor, decoded_image_tensor,
829 + resized_image_tensor, bottleneck_tensor):
830 + """Runs a final evaluation on an eval graph using the test data set.
831 +
832 + Args:
833 + train_session: Session for the train graph with the tensors below.
834 + module_spec: The hub.ModuleSpec for the image module being used.
835 + class_count: Number of classes
836 + image_lists: OrderedDict of training images for each label.
837 + jpeg_data_tensor: The layer to feed jpeg image data into.
838 + decoded_image_tensor: The output of decoding and resizing the image.
839 + resized_image_tensor: The input node of the recognition graph.
840 + bottleneck_tensor: The bottleneck output layer of the CNN graph.
841 + """
842 + test_bottlenecks, test_ground_truth, test_filenames = (
843 + get_random_cached_bottlenecks(train_session, image_lists,
844 + FLAGS.test_batch_size,
845 + 'testing', FLAGS.bottleneck_dir,
846 + FLAGS.image_dir, jpeg_data_tensor,
847 + decoded_image_tensor, resized_image_tensor,
848 + bottleneck_tensor, FLAGS.tfhub_module))
849 +
850 + (eval_session, _, bottleneck_input, ground_truth_input, evaluation_step,
851 + prediction) = build_eval_session(module_spec, class_count)
852 + test_accuracy, predictions = eval_session.run(
853 + [evaluation_step, prediction],
854 + feed_dict={
855 + bottleneck_input: test_bottlenecks,
856 + ground_truth_input: test_ground_truth
857 + })
858 + logging.info('Final test accuracy = %.1f%% (N=%d)',
859 + test_accuracy * 100, len(test_bottlenecks))
860 +
861 + if FLAGS.print_misclassified_test_images:
862 + logging.info('=== MISCLASSIFIED TEST IMAGES ===')
863 + for i, test_filename in enumerate(test_filenames):
864 + if predictions[i] != test_ground_truth[i]:
865 + logging.info('%70s %s', test_filename,
866 + list(image_lists.keys())[predictions[i]])
867 +
868 +
869 +def build_eval_session(module_spec, class_count):
870 + """Builds an restored eval session without train operations for exporting.
871 +
872 + Args:
873 + module_spec: The hub.ModuleSpec for the image module being used.
874 + class_count: Number of classes
875 +
876 + Returns:
877 + Eval session containing the restored eval graph.
878 + The bottleneck input, ground truth, eval step, and prediction tensors.
879 + """
880 + # If quantized, we need to create the correct eval graph for exporting.
881 + eval_graph, bottleneck_tensor, resized_input_tensor, wants_quantization = (
882 + create_module_graph(module_spec))
883 +
884 + eval_sess = tf.Session(graph=eval_graph)
885 + with eval_graph.as_default():
886 + # Add the new layer for exporting.
887 + (_, _, bottleneck_input,
888 + ground_truth_input, final_tensor) = add_final_retrain_ops(
889 + class_count, FLAGS.final_tensor_name, bottleneck_tensor,
890 + wants_quantization, is_training=False)
891 +
892 + # Now we need to restore the values from the training graph to the eval
893 + # graph.
894 + tf.train.Saver().restore(eval_sess, FLAGS.checkpoint_path)
895 +
896 + evaluation_step, prediction = add_evaluation_step(final_tensor,
897 + ground_truth_input)
898 +
899 + return (eval_sess, resized_input_tensor, bottleneck_input, ground_truth_input,
900 + evaluation_step, prediction)
901 +
902 +
903 +def save_graph_to_file(graph_file_name, module_spec, class_count):
904 + """Saves an graph to file, creating a valid quantized one if necessary."""
905 + sess, _, _, _, _, _ = build_eval_session(module_spec, class_count)
906 + graph = sess.graph
907 +
908 + output_graph_def = tf.graph_util.convert_variables_to_constants(
909 + sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
910 +
911 + with tf.gfile.GFile(graph_file_name, 'wb') as f:
912 + f.write(output_graph_def.SerializeToString())
913 +
914 +
915 +def prepare_file_system():
916 + # Set up the directory we'll write summaries to for TensorBoard
917 + if tf.gfile.Exists(FLAGS.summaries_dir):
918 + tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
919 + tf.gfile.MakeDirs(FLAGS.summaries_dir)
920 + if FLAGS.intermediate_store_frequency > 0:
921 + ensure_dir_exists(FLAGS.intermediate_output_graphs_dir)
922 + return
923 +
924 +
925 +def add_jpeg_decoding(module_spec):
926 + """Adds operations that perform JPEG decoding and resizing to the graph..
927 +
928 + Args:
929 + module_spec: The hub.ModuleSpec for the image module being used.
930 +
931 + Returns:
932 + Tensors for the node to feed JPEG data into, and the output of the
933 + preprocessing steps.
934 + """
935 + input_height, input_width = hub.get_expected_image_size(module_spec)
936 + input_depth = hub.get_num_image_channels(module_spec)
937 + jpeg_data = tf.placeholder(tf.string, name='DecodeJPGInput')
938 + decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth)
939 + # Convert from full range of uint8 to range [0,1] of float32.
940 + decoded_image_as_float = tf.image.convert_image_dtype(decoded_image,
941 + tf.float32)
942 + decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
943 + resize_shape = tf.stack([input_height, input_width])
944 + resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
945 + resized_image = tf.image.resize_bilinear(decoded_image_4d,
946 + resize_shape_as_int)
947 + return jpeg_data, resized_image
948 +
949 +
950 +def export_model(module_spec, class_count, saved_model_dir):
951 + """Exports model for serving.
952 +
953 + Args:
954 + module_spec: The hub.ModuleSpec for the image module being used.
955 + class_count: The number of classes.
956 + saved_model_dir: Directory in which to save exported model and variables.
957 + """
958 + # The SavedModel should hold the eval graph.
959 + sess, in_image, _, _, _, _ = build_eval_session(module_spec, class_count)
960 + with sess.graph.as_default() as graph:
961 + tf.saved_model.simple_save(
962 + sess,
963 + saved_model_dir,
964 + inputs={'image': in_image},
965 + outputs={'prediction': graph.get_tensor_by_name('final_result:0')},
966 + legacy_init_op=tf.group(tf.tables_initializer(), name='legacy_init_op')
967 + )
968 +
969 +
970 +def logging_level_verbosity(logging_verbosity):
971 + """Converts logging_level into TensorFlow logging verbosity value.
972 +
973 + Args:
974 + logging_verbosity: String value representing logging level: 'DEBUG', 'INFO',
975 + 'WARN', 'ERROR', 'FATAL'
976 + """
977 + name_to_level = {
978 + 'FATAL': logging.FATAL,
979 + 'ERROR': logging.ERROR,
980 + 'WARN': logging.WARN,
981 + 'INFO': logging.INFO,
982 + 'DEBUG': logging.DEBUG
983 + }
984 +
985 + try:
986 + return name_to_level[logging_verbosity]
987 + except Exception as e:
988 + raise RuntimeError('Not supported logs verbosity (%s). Use one of %s.' %
989 + (str(e), list(name_to_level)))
990 +
991 +
992 +def main(_):
993 + # Needed to make sure the logging output is visible.
994 + # See https://github.com/tensorflow/tensorflow/issues/3047
995 + logging_verbosity = logging_level_verbosity(FLAGS.logging_verbosity)
996 + logging.set_verbosity(logging_verbosity)
997 +
998 + if not FLAGS.image_dir:
999 + logging.error('Must set flag --image_dir.')
1000 + return -1
1001 +
1002 + # Prepare necessary directories that can be used during training
1003 + prepare_file_system()
1004 +
1005 + # Look at the folder structure, and create lists of all the images.
1006 + image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
1007 + FLAGS.validation_percentage)
1008 + class_count = len(image_lists.keys())
1009 + if class_count == 0:
1010 + logging.error('No valid folders of images found at %s', FLAGS.image_dir)
1011 + return -1
1012 + if class_count == 1:
1013 + logging.error('Only one valid folder of images found at %s '
1014 + ' - multiple classes are needed for classification.',
1015 + FLAGS.image_dir)
1016 + return -1
1017 +
1018 + # See if the command-line flags mean we're applying any distortions.
1019 + do_distort_images = should_distort_images(
1020 + FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
1021 + FLAGS.random_brightness)
1022 +
1023 + # Set up the pre-trained graph.
1024 + module_spec = hub.load_module_spec(FLAGS.tfhub_module)
1025 + graph, bottleneck_tensor, resized_image_tensor, wants_quantization = (
1026 + create_module_graph(module_spec))
1027 +
1028 + # Add the new layer that we'll be training.
1029 + with graph.as_default():
1030 + (train_step, cross_entropy, bottleneck_input,
1031 + ground_truth_input, final_tensor) = add_final_retrain_ops(
1032 + class_count, FLAGS.final_tensor_name, bottleneck_tensor,
1033 + wants_quantization, is_training=True)
1034 +
1035 + with tf.Session(graph=graph) as sess:
1036 + # Initialize all weights: for the module to their pretrained values,
1037 + # and for the newly added retraining layer to random initial values.
1038 + init = tf.global_variables_initializer()
1039 + sess.run(init)
1040 +
1041 + # Set up the image decoding sub-graph.
1042 + jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec)
1043 +
1044 + if do_distort_images:
1045 + # We will be applying distortions, so set up the operations we'll need.
1046 + (distorted_jpeg_data_tensor,
1047 + distorted_image_tensor) = add_input_distortions(
1048 + FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
1049 + FLAGS.random_brightness, module_spec)
1050 + else:
1051 + # We'll make sure we've calculated the 'bottleneck' image summaries and
1052 + # cached them on disk.
1053 + cache_bottlenecks(sess, image_lists, FLAGS.image_dir,
1054 + FLAGS.bottleneck_dir, jpeg_data_tensor,
1055 + decoded_image_tensor, resized_image_tensor,
1056 + bottleneck_tensor, FLAGS.tfhub_module)
1057 +
1058 + # Create the operations we need to evaluate the accuracy of our new layer.
1059 + evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input)
1060 +
1061 + # Merge all the summaries and write them out to the summaries_dir
1062 + merged = tf.summary.merge_all()
1063 + train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
1064 + sess.graph)
1065 +
1066 + validation_writer = tf.summary.FileWriter(
1067 + FLAGS.summaries_dir + '/validation')
1068 +
1069 + # Create a train saver that is used to restore values into an eval graph
1070 + # when exporting models.
1071 + train_saver = tf.train.Saver()
1072 +
1073 + # Run the training for as many cycles as requested on the command line.
1074 + for i in range(FLAGS.how_many_training_steps):
1075 + # Get a batch of input bottleneck values, either calculated fresh every
1076 + # time with distortions applied, or from the cache stored on disk.
1077 + if do_distort_images:
1078 + (train_bottlenecks,
1079 + train_ground_truth) = get_random_distorted_bottlenecks(
1080 + sess, image_lists, FLAGS.train_batch_size, 'training',
1081 + FLAGS.image_dir, distorted_jpeg_data_tensor,
1082 + distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
1083 + else:
1084 + (train_bottlenecks,
1085 + train_ground_truth, _) = get_random_cached_bottlenecks(
1086 + sess, image_lists, FLAGS.train_batch_size, 'training',
1087 + FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
1088 + decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
1089 + FLAGS.tfhub_module)
1090 + # Feed the bottlenecks and ground truth into the graph, and run a training
1091 + # step. Capture training summaries for TensorBoard with the `merged` op.
1092 + train_summary, _ = sess.run(
1093 + [merged, train_step],
1094 + feed_dict={bottleneck_input: train_bottlenecks,
1095 + ground_truth_input: train_ground_truth})
1096 + train_writer.add_summary(train_summary, i)
1097 +
1098 + # Every so often, print out how well the graph is training.
1099 + is_last_step = (i + 1 == FLAGS.how_many_training_steps)
1100 + if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
1101 + train_accuracy, cross_entropy_value = sess.run(
1102 + [evaluation_step, cross_entropy],
1103 + feed_dict={bottleneck_input: train_bottlenecks,
1104 + ground_truth_input: train_ground_truth})
1105 + logging.info('%s: Step %d: Train accuracy = %.1f%%',
1106 + datetime.now(), i, train_accuracy * 100)
1107 + logging.info('%s: Step %d: Cross entropy = %f',
1108 + datetime.now(), i, cross_entropy_value)
1109 + # TODO: Make this use an eval graph, to avoid quantization
1110 + # moving averages being updated by the validation set, though in
1111 + # practice this makes a negligable difference.
1112 + validation_bottlenecks, validation_ground_truth, _ = (
1113 + get_random_cached_bottlenecks(
1114 + sess, image_lists, FLAGS.validation_batch_size, 'validation',
1115 + FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
1116 + decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
1117 + FLAGS.tfhub_module))
1118 + # Run a validation step and capture training summaries for TensorBoard
1119 + # with the `merged` op.
1120 + validation_summary, validation_accuracy = sess.run(
1121 + [merged, evaluation_step],
1122 + feed_dict={bottleneck_input: validation_bottlenecks,
1123 + ground_truth_input: validation_ground_truth})
1124 + validation_writer.add_summary(validation_summary, i)
1125 + logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)',
1126 + datetime.now(), i, validation_accuracy * 100,
1127 + len(validation_bottlenecks))
1128 +
1129 + # Store intermediate results
1130 + intermediate_frequency = FLAGS.intermediate_store_frequency
1131 +
1132 + if (intermediate_frequency > 0 and (i % intermediate_frequency == 0)
1133 + and i > 0):
1134 + # If we want to do an intermediate save, save a checkpoint of the train
1135 + # graph, to restore into the eval graph.
1136 + train_saver.save(sess, FLAGS.checkpoint_path)
1137 + intermediate_file_name = (FLAGS.intermediate_output_graphs_dir +
1138 + 'intermediate_' + str(i) + '.pb')
1139 + logging.info('Save intermediate result to : %s', intermediate_file_name)
1140 + save_graph_to_file(intermediate_file_name, module_spec,
1141 + class_count)
1142 +
1143 + # After training is complete, force one last save of the train checkpoint.
1144 + train_saver.save(sess, FLAGS.checkpoint_path)
1145 +
1146 + # We've completed all our training, so run a final test evaluation on
1147 + # some new images we haven't used before.
1148 + run_final_eval(sess, module_spec, class_count, image_lists,
1149 + jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
1150 + bottleneck_tensor)
1151 +
1152 + # Write out the trained graph and labels with the weights stored as
1153 + # constants.
1154 + logging.info('Save final result to : %s', FLAGS.output_graph)
1155 + if wants_quantization:
1156 + logging.info('The model is instrumented for quantization with TF-Lite')
1157 + save_graph_to_file(FLAGS.output_graph, module_spec, class_count)
1158 + with tf.gfile.GFile(FLAGS.output_labels, 'w') as f:
1159 + f.write('\n'.join(image_lists.keys()) + '\n')
1160 +
1161 + if FLAGS.saved_model_dir:
1162 + export_model(module_spec, class_count, FLAGS.saved_model_dir)
1163 +
1164 +
1165 +if __name__ == '__main__':
1166 + parser = argparse.ArgumentParser()
1167 + parser.add_argument(
1168 + '--image_dir',
1169 + type=str,
1170 + default='',
1171 + help='Path to folders of labeled images.'
1172 + )
1173 + parser.add_argument(
1174 + '--output_graph',
1175 + type=str,
1176 + default='/tmp/output_graph.pb',
1177 + help='Where to save the trained graph.'
1178 + )
1179 + parser.add_argument(
1180 + '--intermediate_output_graphs_dir',
1181 + type=str,
1182 + default='/tmp/intermediate_graph/',
1183 + help='Where to save the intermediate graphs.'
1184 + )
1185 + parser.add_argument(
1186 + '--intermediate_store_frequency',
1187 + type=int,
1188 + default=0,
1189 + help="""\
1190 + How many steps to store intermediate graph. If "0" then will not
1191 + store.\
1192 + """
1193 + )
1194 + parser.add_argument(
1195 + '--output_labels',
1196 + type=str,
1197 + default='/tmp/output_labels.txt',
1198 + help='Where to save the trained graph\'s labels.'
1199 + )
1200 + parser.add_argument(
1201 + '--summaries_dir',
1202 + type=str,
1203 + default='/tmp/retrain_logs',
1204 + help='Where to save summary logs for TensorBoard.'
1205 + )
1206 + parser.add_argument(
1207 + '--how_many_training_steps',
1208 + type=int,
1209 + default=4000,
1210 + help='How many training steps to run before ending.'
1211 + )
1212 + parser.add_argument(
1213 + '--learning_rate',
1214 + type=float,
1215 + default=0.01,
1216 + help='How large a learning rate to use when training.'
1217 + )
1218 + parser.add_argument(
1219 + '--testing_percentage',
1220 + type=int,
1221 + default=10,
1222 + help='What percentage of images to use as a test set.'
1223 + )
1224 + parser.add_argument(
1225 + '--validation_percentage',
1226 + type=int,
1227 + default=10,
1228 + help='What percentage of images to use as a validation set.'
1229 + )
1230 + parser.add_argument(
1231 + '--eval_step_interval',
1232 + type=int,
1233 + default=10,
1234 + help='How often to evaluate the training results.'
1235 + )
1236 + parser.add_argument(
1237 + '--train_batch_size',
1238 + type=int,
1239 + default=100,
1240 + help='How many images to train on at a time.'
1241 + )
1242 + parser.add_argument(
1243 + '--test_batch_size',
1244 + type=int,
1245 + default=-1,
1246 + help="""\
1247 + How many images to test on. This test set is only used once, to evaluate
1248 + the final accuracy of the model after training completes.
1249 + A value of -1 causes the entire test set to be used, which leads to more
1250 + stable results across runs.\
1251 + """
1252 + )
1253 + parser.add_argument(
1254 + '--validation_batch_size',
1255 + type=int,
1256 + default=100,
1257 + help="""\
1258 + How many images to use in an evaluation batch. This validation set is
1259 + used much more often than the test set, and is an early indicator of how
1260 + accurate the model is during training.
1261 + A value of -1 causes the entire validation set to be used, which leads to
1262 + more stable results across training iterations, but may be slower on large
1263 + training sets.\
1264 + """
1265 + )
1266 + parser.add_argument(
1267 + '--print_misclassified_test_images',
1268 + default=False,
1269 + help="""\
1270 + Whether to print out a list of all misclassified test images.\
1271 + """,
1272 + action='store_true'
1273 + )
1274 + parser.add_argument(
1275 + '--bottleneck_dir',
1276 + type=str,
1277 + default='/tmp/bottleneck',
1278 + help='Path to cache bottleneck layer values as files.'
1279 + )
1280 + parser.add_argument(
1281 + '--final_tensor_name',
1282 + type=str,
1283 + default='final_result',
1284 + help="""\
1285 + The name of the output classification layer in the retrained graph.\
1286 + """
1287 + )
1288 + parser.add_argument(
1289 + '--flip_left_right',
1290 + default=False,
1291 + help="""\
1292 + Whether to randomly flip half of the training images horizontally.\
1293 + """,
1294 + action='store_true'
1295 + )
1296 + parser.add_argument(
1297 + '--random_crop',
1298 + type=int,
1299 + default=0,
1300 + help="""\
1301 + A percentage determining how much of a margin to randomly crop off the
1302 + training images.\
1303 + """
1304 + )
1305 + parser.add_argument(
1306 + '--random_scale',
1307 + type=int,
1308 + default=0,
1309 + help="""\
1310 + A percentage determining how much to randomly scale up the size of the
1311 + training images by.\
1312 + """
1313 + )
1314 + parser.add_argument(
1315 + '--random_brightness',
1316 + type=int,
1317 + default=0,
1318 + help="""\
1319 + A percentage determining how much to randomly multiply the training image
1320 + input pixels up or down by.\
1321 + """
1322 + )
1323 + parser.add_argument(
1324 + '--tfhub_module',
1325 + type=str,
1326 + default=(
1327 + 'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3'),
1328 + help="""\
1329 + Which TensorFlow Hub module to use. For more options,
1330 + search https://tfhub.dev for image feature vector modules.\
1331 + """)
1332 + parser.add_argument(
1333 + '--saved_model_dir',
1334 + type=str,
1335 + default='',
1336 + help='Where to save the exported graph.')
1337 + parser.add_argument(
1338 + '--logging_verbosity',
1339 + type=str,
1340 + default='INFO',
1341 + choices=['DEBUG', 'INFO', 'WARN', 'ERROR', 'FATAL'],
1342 + help='How much logging output should be produced.')
1343 + parser.add_argument(
1344 + '--checkpoint_path',
1345 + type=str,
1346 + default='/tmp/_retrain_checkpoint',
1347 + help='Where to save checkpoint files.'
1348 + )
1349 + FLAGS, unparsed = parser.parse_known_args()
1350 + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
1 +# -*- coding: utf-8 -*-
2 +
3 +"""Inception v3 architecture 모델을 retraining한 모델을 이용해서 이미지에 대한 추론(inference)을 진행하는 예제"""
4 +
5 +import numpy as np
6 +import tensorflow as tf
7 +
8 +imagePath = '/tmp/test_chartreux.jpg' # 추론을 진행할 이미지 경로
9 +modelFullPath = '/tmp/output_graph.pb' # 읽어들일 graph 파일 경로
10 +labelsFullPath = '/tmp/output_labels.txt' # 읽어들일 labels 파일 경로
11 +
12 +
13 +def create_graph():
14 + """저장된(saved) GraphDef 파일로부터 graph를 생성하고 saver를 반환한다."""
15 + # 저장된(saved) graph_def.pb로부터 graph를 생성한다.
16 + with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
17 + graph_def = tf.GraphDef()
18 + graph_def.ParseFromString(f.read())
19 + _ = tf.import_graph_def(graph_def, name='')
20 +
21 +
22 +def run_inference_on_image():
23 + answer = None
24 +
25 + if not tf.gfile.Exists(imagePath):
26 + tf.logging.fatal('File does not exist %s', imagePath)
27 + return answer
28 +
29 + image_data = tf.gfile.FastGFile(imagePath, 'rb').read()
30 +
31 + # 저장된(saved) GraphDef 파일로부터 graph를 생성한다.
32 + create_graph()
33 +
34 + with tf.Session() as sess:
35 +
36 + softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
37 + predictions = sess.run(softmax_tensor,
38 + {'DecodeJpeg/contents:0': image_data})
39 + predictions = np.squeeze(predictions)
40 +
41 + top_k = predictions.argsort()[-5:][::-1] # 가장 높은 확률을 가진 5개(top 5)의 예측값(predictions)을 얻는다.
42 + f = open(labelsFullPath, 'rb')
43 + lines = f.readlines()
44 + labels = [str(w).replace("\n", "") for w in lines]
45 + for node_id in top_k:
46 + human_string = labels[node_id]
47 + score = predictions[node_id]
48 + print('%s (score = %.5f)' % (human_string, score))
49 +
50 + answer = labels[top_k[0]]
51 + return answer
52 +
53 +
54 +if __name__ == '__main__':
55 + run_inference_on_image()
...\ No newline at end of file ...\ No newline at end of file