이현규

Convert submodule to files

youtube-8m @ e6f6bf68
1 -Subproject commit e6f6bf682d20bb21904ea9c081c15e070809d914
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Calculate or keep track of the interpolated average precision.
15 +
16 +It provides an interface for calculating interpolated average precision for an
17 +entire list or the top-n ranked items. For the definition of the
18 +(non-)interpolated average precision:
19 +http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf
20 +
21 +Example usages:
22 +1) Use it as a static function call to directly calculate average precision for
23 +a short ranked list in the memory.
24 +
25 +```
26 +import random
27 +
28 +p = np.array([random.random() for _ in xrange(10)])
29 +a = np.array([random.choice([0, 1]) for _ in xrange(10)])
30 +
31 +ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a)
32 +```
33 +
34 +2) Use it as an object for long ranked list that cannot be stored in memory or
35 +the case where partial predictions can be observed at a time (Tensorflow
36 +predictions). In this case, we first call the function accumulate many times
37 +to process parts of the ranked list. After processing all the parts, we call
38 +peek_interpolated_ap_at_n.
39 +```
40 +p1 = np.array([random.random() for _ in xrange(5)])
41 +a1 = np.array([random.choice([0, 1]) for _ in xrange(5)])
42 +p2 = np.array([random.random() for _ in xrange(5)])
43 +a2 = np.array([random.choice([0, 1]) for _ in xrange(5)])
44 +
45 +# interpolated average precision at 10 using 1000 break points
46 +calculator = average_precision_calculator.AveragePrecisionCalculator(10)
47 +calculator.accumulate(p1, a1)
48 +calculator.accumulate(p2, a2)
49 +ap3 = calculator.peek_ap_at_n()
50 +```
51 +"""
52 +
53 +import heapq
54 +import random
55 +import numbers
56 +
57 +import numpy
58 +
59 +
60 +class AveragePrecisionCalculator(object):
61 + """Calculate the average precision and average precision at n."""
62 +
63 + def __init__(self, top_n=None):
64 + """Construct an AveragePrecisionCalculator to calculate average precision.
65 +
66 + This class is used to calculate the average precision for a single label.
67 +
68 + Args:
69 + top_n: A positive Integer specifying the average precision at n, or None
70 + to use all provided data points.
71 +
72 + Raises:
73 + ValueError: An error occurred when the top_n is not a positive integer.
74 + """
75 + if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None):
76 + raise ValueError("top_n must be a positive integer or None.")
77 +
78 + self._top_n = top_n # average precision at n
79 + self._total_positives = 0 # total number of positives have seen
80 + self._heap = [] # max heap of (prediction, actual)
81 +
82 + @property
83 + def heap_size(self):
84 + """Gets the heap size maintained in the class."""
85 + return len(self._heap)
86 +
87 + @property
88 + def num_accumulated_positives(self):
89 + """Gets the number of positive samples that have been accumulated."""
90 + return self._total_positives
91 +
92 + def accumulate(self, predictions, actuals, num_positives=None):
93 + """Accumulate the predictions and their ground truth labels.
94 +
95 + After the function call, we may call peek_ap_at_n to actually calculate
96 + the average precision.
97 + Note predictions and actuals must have the same shape.
98 +
99 + Args:
100 + predictions: a list storing the prediction scores.
101 + actuals: a list storing the ground truth labels. Any value larger than 0
102 + will be treated as positives, otherwise as negatives. num_positives = If
103 + the 'predictions' and 'actuals' inputs aren't complete, then it's
104 + possible some true positives were missed in them. In that case, you can
105 + provide 'num_positives' in order to accurately track recall.
106 +
107 + Raises:
108 + ValueError: An error occurred when the format of the input is not the
109 + numpy 1-D array or the shape of predictions and actuals does not match.
110 + """
111 + if len(predictions) != len(actuals):
112 + raise ValueError("the shape of predictions and actuals does not match.")
113 +
114 + if num_positives is not None:
115 + if not isinstance(num_positives, numbers.Number) or num_positives < 0:
116 + raise ValueError(
117 + "'num_positives' was provided but it was a negative number.")
118 +
119 + if num_positives is not None:
120 + self._total_positives += num_positives
121 + else:
122 + self._total_positives += numpy.size(
123 + numpy.where(numpy.array(actuals) > 1e-5))
124 + topk = self._top_n
125 + heap = self._heap
126 +
127 + for i in range(numpy.size(predictions)):
128 + if topk is None or len(heap) < topk:
129 + heapq.heappush(heap, (predictions[i], actuals[i]))
130 + else:
131 + if predictions[i] > heap[0][0]: # heap[0] is the smallest
132 + heapq.heappop(heap)
133 + heapq.heappush(heap, (predictions[i], actuals[i]))
134 +
135 + def clear(self):
136 + """Clear the accumulated predictions."""
137 + self._heap = []
138 + self._total_positives = 0
139 +
140 + def peek_ap_at_n(self):
141 + """Peek the non-interpolated average precision at n.
142 +
143 + Returns:
144 + The non-interpolated average precision at n (default 0).
145 + If n is larger than the length of the ranked list,
146 + the average precision will be returned.
147 + """
148 + if self.heap_size <= 0:
149 + return 0
150 + predlists = numpy.array(list(zip(*self._heap)))
151 +
152 + ap = self.ap_at_n(predlists[0],
153 + predlists[1],
154 + n=self._top_n,
155 + total_num_positives=self._total_positives)
156 + return ap
157 +
158 + @staticmethod
159 + def ap(predictions, actuals):
160 + """Calculate the non-interpolated average precision.
161 +
162 + Args:
163 + predictions: a numpy 1-D array storing the sparse prediction scores.
164 + actuals: a numpy 1-D array storing the ground truth labels. Any value
165 + larger than 0 will be treated as positives, otherwise as negatives.
166 +
167 + Returns:
168 + The non-interpolated average precision at n.
169 + If n is larger than the length of the ranked list,
170 + the average precision will be returned.
171 +
172 + Raises:
173 + ValueError: An error occurred when the format of the input is not the
174 + numpy 1-D array or the shape of predictions and actuals does not match.
175 + """
176 + return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None)
177 +
178 + @staticmethod
179 + def ap_at_n(predictions, actuals, n=20, total_num_positives=None):
180 + """Calculate the non-interpolated average precision.
181 +
182 + Args:
183 + predictions: a numpy 1-D array storing the sparse prediction scores.
184 + actuals: a numpy 1-D array storing the ground truth labels. Any value
185 + larger than 0 will be treated as positives, otherwise as negatives.
186 + n: the top n items to be considered in ap@n.
187 + total_num_positives : (optionally) you can specify the number of total
188 + positive in the list. If specified, it will be used in calculation.
189 +
190 + Returns:
191 + The non-interpolated average precision at n.
192 + If n is larger than the length of the ranked list,
193 + the average precision will be returned.
194 +
195 + Raises:
196 + ValueError: An error occurred when
197 + 1) the format of the input is not the numpy 1-D array;
198 + 2) the shape of predictions and actuals does not match;
199 + 3) the input n is not a positive integer.
200 + """
201 + if len(predictions) != len(actuals):
202 + raise ValueError("the shape of predictions and actuals does not match.")
203 +
204 + if n is not None:
205 + if not isinstance(n, int) or n <= 0:
206 + raise ValueError("n must be 'None' or a positive integer."
207 + " It was '%s'." % n)
208 +
209 + ap = 0.0
210 +
211 + predictions = numpy.array(predictions)
212 + actuals = numpy.array(actuals)
213 +
214 + # add a shuffler to avoid overestimating the ap
215 + predictions, actuals = AveragePrecisionCalculator._shuffle(
216 + predictions, actuals)
217 + sortidx = sorted(range(len(predictions)),
218 + key=lambda k: predictions[k],
219 + reverse=True)
220 +
221 + if total_num_positives is None:
222 + numpos = numpy.size(numpy.where(actuals > 0))
223 + else:
224 + numpos = total_num_positives
225 +
226 + if numpos == 0:
227 + return 0
228 +
229 + if n is not None:
230 + numpos = min(numpos, n)
231 + delta_recall = 1.0 / numpos
232 + poscount = 0.0
233 +
234 + # calculate the ap
235 + r = len(sortidx)
236 + if n is not None:
237 + r = min(r, n)
238 + for i in range(r):
239 + if actuals[sortidx[i]] > 0:
240 + poscount += 1
241 + ap += poscount / (i + 1) * delta_recall
242 + return ap
243 +
244 + @staticmethod
245 + def _shuffle(predictions, actuals):
246 + random.seed(0)
247 + suffidx = random.sample(range(len(predictions)), len(predictions))
248 + predictions = predictions[suffidx]
249 + actuals = actuals[suffidx]
250 + return predictions, actuals
251 +
252 + @staticmethod
253 + def _zero_one_normalize(predictions, epsilon=1e-7):
254 + """Normalize the predictions to the range between 0.0 and 1.0.
255 +
256 + For some predictions like SVM predictions, we need to normalize them before
257 + calculate the interpolated average precision. The normalization will not
258 + change the rank in the original list and thus won't change the average
259 + precision.
260 +
261 + Args:
262 + predictions: a numpy 1-D array storing the sparse prediction scores.
263 + epsilon: a small constant to avoid denominator being zero.
264 +
265 + Returns:
266 + The normalized prediction.
267 + """
268 + denominator = numpy.max(predictions) - numpy.min(predictions)
269 + ret = (predictions - numpy.min(predictions)) / numpy.max(
270 + denominator, epsilon)
271 + return ret
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Utility to convert the output of batch prediction into a CSV submission.
15 +
16 +It converts the JSON files created by the command
17 +'gcloud beta ml jobs submit prediction' into a CSV file ready for submission.
18 +"""
19 +
20 +import json
21 +import tensorflow as tf
22 +
23 +from builtins import range
24 +from tensorflow import app
25 +from tensorflow import flags
26 +from tensorflow import gfile
27 +from tensorflow import logging
28 +
29 +FLAGS = flags.FLAGS
30 +
31 +if __name__ == "__main__":
32 +
33 + flags.DEFINE_string(
34 + "json_prediction_files_pattern", None,
35 + "Pattern specifying the list of JSON files that the command "
36 + "'gcloud beta ml jobs submit prediction' outputs. These files are "
37 + "located in the output path of the prediction command and are prefixed "
38 + "with 'prediction.results'.")
39 + flags.DEFINE_string(
40 + "csv_output_file", None,
41 + "The file to save the predictions converted to the CSV format.")
42 +
43 +
44 +def get_csv_header():
45 + return "VideoId,LabelConfidencePairs\n"
46 +
47 +
48 +def to_csv_row(json_data):
49 +
50 + video_id = json_data["video_id"]
51 +
52 + class_indexes = json_data["class_indexes"]
53 + predictions = json_data["predictions"]
54 +
55 + if isinstance(video_id, list):
56 + video_id = video_id[0]
57 + class_indexes = class_indexes[0]
58 + predictions = predictions[0]
59 +
60 + if len(class_indexes) != len(predictions):
61 + raise ValueError(
62 + "The number of indexes (%s) and predictions (%s) must be equal." %
63 + (len(class_indexes), len(predictions)))
64 +
65 + return (video_id.decode("utf-8") + "," +
66 + " ".join("%i %f" % (class_indexes[i], predictions[i])
67 + for i in range(len(class_indexes))) + "\n")
68 +
69 +
70 +def main(unused_argv):
71 + logging.set_verbosity(tf.logging.INFO)
72 +
73 + if not FLAGS.json_prediction_files_pattern:
74 + raise ValueError(
75 + "The flag --json_prediction_files_pattern must be specified.")
76 +
77 + if not FLAGS.csv_output_file:
78 + raise ValueError("The flag --csv_output_file must be specified.")
79 +
80 + logging.info("Looking for prediction files with pattern: %s",
81 + FLAGS.json_prediction_files_pattern)
82 +
83 + file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern)
84 + logging.info("Found files: %s", file_paths)
85 +
86 + logging.info("Writing submission file to: %s", FLAGS.csv_output_file)
87 + with gfile.Open(FLAGS.csv_output_file, "w+") as output_file:
88 + output_file.write(get_csv_header())
89 +
90 + for file_path in file_paths:
91 + logging.info("processing file: %s", file_path)
92 +
93 + with gfile.Open(file_path) as input_file:
94 +
95 + for line in input_file:
96 + json_data = json.loads(line)
97 + output_file.write(to_csv_row(json_data))
98 +
99 + output_file.flush()
100 + logging.info("done")
101 +
102 +
103 +if __name__ == "__main__":
104 + app.run()
No preview for this file type
1 +import numpy as np
2 +import tensorflow as tf
3 +from tensorflow import logging
4 +from tensorflow import gfile
5 +import esot3ria.pbutil as pbutil
6 +
7 +
8 +def get_segments(batch_video_mtx, batch_num_frames, segment_size):
9 + """Get segment-level inputs from frame-level features."""
10 + video_batch_size = batch_video_mtx.shape[0]
11 + max_frame = batch_video_mtx.shape[1]
12 + feature_dim = batch_video_mtx.shape[-1]
13 + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size
14 + padded_segment_sizes *= segment_size
15 + segment_mask = (
16 + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame)))
17 +
18 + # Segment bags.
19 + frame_bags = batch_video_mtx.reshape((-1, feature_dim))
20 + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape(
21 + (-1, segment_size, feature_dim))
22 +
23 + # Segment num frames.
24 + segment_start_times = np.arange(0, max_frame, segment_size)
25 + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times
26 + num_segment_bags = num_segments.reshape((-1))
27 + valid_segment_mask = num_segment_bags > 0
28 + segment_num_frames = num_segment_bags[valid_segment_mask]
29 + segment_num_frames[segment_num_frames > segment_size] = segment_size
30 +
31 + max_segment_num = (max_frame + segment_size - 1) // segment_size
32 + video_idxs = np.tile(
33 + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num])
34 + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1])
35 + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2))
36 + video_segment_ids = idx_bags[valid_segment_mask]
37 +
38 + return {
39 + "video_batch": segment_frames,
40 + "num_frames_batch": segment_num_frames,
41 + "video_segment_ids": video_segment_ids
42 + }
43 +
44 +
45 +def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None):
46 + batch_size = len(video_ids)
47 + for video_index in range(batch_size):
48 + video_prediction = predictions[video_index]
49 + if whitelisted_cls_mask is not None:
50 + # Whitelist classes.
51 + video_prediction *= whitelisted_cls_mask
52 + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:]
53 + line = [(class_index, predictions[video_index][class_index])
54 + for class_index in top_indices]
55 + line = sorted(line, key=lambda p: -p[1])
56 + return (video_ids[video_index] + "," +
57 + " ".join("%i %g" % (label, score) for (label, score) in line) +
58 + "\n").encode("utf8")
59 +
60 +
61 +def inference_pb(filename):
62 + with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
63 +
64 + # 200527 Esot3riA
65 + # 0. Import SequenceExample type target from pb.
66 + target_video = pbutil.convert_pb(filename)
67 +
68 + # 1. Load video features from pb.
69 + video_id_batch_val = np.array([b'video'])
70 + n_frames = len(target_video.feature_lists.feature_list['rgb'].feature)
71 + # Restrict frame size to 300
72 + if n_frames > 300:
73 + n_frames = 300
74 + video_batch_val = np.zeros((300, 1152))
75 + for i in range(n_frames):
76 + video_batch_rgb_raw = target_video.feature_lists.feature_list['rgb'].feature[i].bytes_list.value[0]
77 + video_batch_rgb = np.array(tf.cast(tf.decode_raw(video_batch_rgb_raw, tf.float32), tf.float32).eval())
78 + video_batch_audio_raw = target_video.feature_lists.feature_list['audio'].feature[i].bytes_list.value[0]
79 + video_batch_audio = np.array(tf.cast(tf.decode_raw(video_batch_audio_raw, tf.float32), tf.float32).eval())
80 + video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0)
81 + video_batch_val = np.array([video_batch_val])
82 + num_frames_batch_val = np.array([n_frames])
83 + # 200527 Esot3riA End
84 +
85 + # Restore checkpoint and meta-graph file
86 + checkpoint_file = '/Users/esot3ria/PycharmProjects/yt8m/models/frame' \
87 + '/sample_model/inference_model/segment_inference_model'
88 + if not gfile.Exists(checkpoint_file + ".meta"):
89 + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file)
90 + meta_graph_location = checkpoint_file + ".meta"
91 + logging.info("loading meta-graph: " + meta_graph_location)
92 +
93 + with tf.device("/cpu:0"):
94 + saver = tf.train.import_meta_graph(meta_graph_location,
95 + clear_devices=True)
96 + logging.info("restoring variables from " + checkpoint_file)
97 + saver.restore(sess, checkpoint_file)
98 + input_tensor = tf.get_collection("input_batch_raw")[0]
99 + num_frames_tensor = tf.get_collection("num_frames")[0]
100 + predictions_tensor = tf.get_collection("predictions")[0]
101 +
102 + # Workaround for num_epochs issue.
103 + def set_up_init_ops(variables):
104 + init_op_list = []
105 + for variable in list(variables):
106 + if "train_input" in variable.name:
107 + init_op_list.append(tf.assign(variable, 1))
108 + variables.remove(variable)
109 + init_op_list.append(tf.variables_initializer(variables))
110 + return init_op_list
111 +
112 + sess.run(
113 + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)))
114 +
115 + coord = tf.train.Coordinator()
116 + threads = tf.train.start_queue_runners(sess=sess, coord=coord)
117 + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
118 + dtype=np.float32)
119 + segment_label_ids_file = '../segment_label_ids.csv'
120 + with tf.io.gfile.GFile(segment_label_ids_file) as fobj:
121 + for line in fobj:
122 + try:
123 + cls_id = int(line)
124 + whitelisted_cls_mask[cls_id] = 1.
125 + except ValueError:
126 + # Simply skip the non-integer line.
127 + continue
128 +
129 + # 200527 Esot3riA
130 + # 2. Make segment features.
131 + results = get_segments(video_batch_val, num_frames_batch_val, 5)
132 + video_segment_ids = results["video_segment_ids"]
133 + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]]
134 + video_id_batch_val = np.array([
135 + "%s:%d" % (x.decode("utf8"), y)
136 + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1])
137 + ])
138 + video_batch_val = results["video_batch"]
139 + num_frames_batch_val = results["num_frames_batch"]
140 + if input_tensor.get_shape()[1] != video_batch_val.shape[1]:
141 + raise ValueError("max_frames mismatch. Please re-run the eval.py "
142 + "with correct segment_labels settings.")
143 +
144 + predictions_val, = sess.run([predictions_tensor],
145 + feed_dict={
146 + input_tensor: video_batch_val,
147 + num_frames_tensor: num_frames_batch_val
148 + })
149 + logging.info(predictions_val)
150 + logging.info("profit :D")
151 +
152 + # result = format_prediction(video_id_batch_val, predictions_val, 10, whitelisted_cls_mask)
153 +
154 +
155 +if __name__ == '__main__':
156 + logging.set_verbosity(tf.logging.INFO)
157 +
158 + filename = 'features.pb'
159 + inference_pb(filename)
1 +import tensorflow as tf
2 +import numpy
3 +
4 +
5 +def _make_bytes(int_array):
6 + if bytes == str: # Python2
7 + return ''.join(map(chr, int_array))
8 + else:
9 + return bytes(int_array)
10 +
11 +
12 +def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0):
13 + """Quantizes float32 `features` into string."""
14 + assert features.dtype == 'float32'
15 + assert len(features.shape) == 1 # 1-D array
16 + features = numpy.clip(features, min_quantized_value, max_quantized_value)
17 + quantize_range = max_quantized_value - min_quantized_value
18 + features = (features - min_quantized_value) * (255.0 / quantize_range)
19 + features = [int(round(f)) for f in features]
20 +
21 + return _make_bytes(features)
22 +
23 +
24 +# for parse feature.pb
25 +
26 +contexts = {
27 + 'AUDIO/feature/dimensions': tf.io.FixedLenFeature([], tf.int64),
28 + 'AUDIO/feature/rate': tf.io.FixedLenFeature([], tf.float32),
29 + 'RGB/feature/dimensions': tf.io.FixedLenFeature([], tf.int64),
30 + 'RGB/feature/rate': tf.io.FixedLenFeature([], tf.float32),
31 + 'clip/data_path': tf.io.FixedLenFeature([], tf.string),
32 + 'clip/end/timestamp': tf.io.FixedLenFeature([], tf.int64),
33 + 'clip/start/timestamp': tf.io.FixedLenFeature([], tf.int64)
34 +}
35 +
36 +features = {
37 + 'AUDIO/feature/floats': tf.io.VarLenFeature(dtype=tf.float32),
38 + 'AUDIO/feature/timestamp': tf.io.VarLenFeature(tf.int64),
39 + 'RGB/feature/floats': tf.io.VarLenFeature(dtype=tf.float32),
40 + 'RGB/feature/timestamp': tf.io.VarLenFeature(tf.int64)
41 +
42 +}
43 +
44 +
45 +def parse_exmp(serial_exmp):
46 + _, sequence_parsed = tf.io.parse_single_sequence_example(
47 + serialized=serial_exmp,
48 + context_features=contexts,
49 + sequence_features=features)
50 +
51 + sequence_parsed = tf.contrib.learn.run_n(sequence_parsed)[0]
52 +
53 + audio = sequence_parsed['AUDIO/feature/floats'].values
54 + rgb = sequence_parsed['RGB/feature/floats'].values
55 +
56 + # print(audio.values)
57 + # print(type(audio.values))
58 +
59 + # audio is 128 8bit, rgb is 1024 8bit for every second
60 + audio_slices = [audio[128 * i: 128 * (i + 1)] for i in range(len(audio) // 128)]
61 + rgb_slices = [rgb[1024 * i: 1024 * (i + 1)] for i in range(len(rgb) // 1024)]
62 +
63 + byte_audio = []
64 + byte_rgb = []
65 +
66 + for seg in audio_slices:
67 + # audio_seg = quantize(seg)
68 + audio_seg = _make_bytes(seg)
69 + byte_audio.append(audio_seg)
70 +
71 + for seg in rgb_slices:
72 + # rgb_seg = quantize(seg)
73 + rgb_seg = _make_bytes(seg)
74 + byte_rgb.append(rgb_seg)
75 +
76 + return byte_audio, byte_rgb
77 +
78 +
79 +def make_exmp(id, audio, rgb):
80 + audio_features = []
81 + rgb_features = []
82 +
83 + for embedding in audio:
84 + embedding_feature = tf.train.Feature(
85 + bytes_list=tf.train.BytesList(value=[embedding]))
86 + audio_features.append(embedding_feature)
87 +
88 + for embedding in rgb:
89 + embedding_feature = tf.train.Feature(
90 + bytes_list=tf.train.BytesList(value=[embedding]))
91 + rgb_features.append(embedding_feature)
92 +
93 + # for construct yt8m data
94 + seq_exmp = tf.train.SequenceExample(
95 + context=tf.train.Features(
96 + feature={
97 + 'id': tf.train.Feature(bytes_list=tf.train.BytesList(
98 + value=[id.encode('utf-8')]))
99 + }),
100 + feature_lists=tf.train.FeatureLists(
101 + feature_list={
102 + 'audio': tf.train.FeatureList(
103 + feature=audio_features
104 + ),
105 + 'rgb': tf.train.FeatureList(
106 + feature=rgb_features
107 + )
108 + })
109 + )
110 + serialized = seq_exmp.SerializeToString()
111 + return serialized
112 +
113 +
114 +def convert_pb(filename):
115 + sequence_example = open(filename, 'rb').read()
116 +
117 + audio, rgb = parse_exmp(sequence_example)
118 + tmp_example = make_exmp('video', audio, rgb)
119 +
120 + decoded = tf.train.SequenceExample.FromString(tmp_example)
121 + return decoded
1 +import tensorflow as tf
2 +import numpy as np
3 +
4 +frame_lvl_record = "test0000.tfrecord"
5 +
6 +feat_rgb = []
7 +feat_audio = []
8 +
9 +for example in tf.python_io.tf_record_iterator(frame_lvl_record):
10 + tf_seq_example = tf.train.SequenceExample.FromString(example)
11 + test = tf_seq_example.SerializeToString()
12 + n_frames = len(tf_seq_example.feature_lists.feature_list['audio'].feature)
13 + sess = tf.InteractiveSession()
14 + rgb_frame = []
15 + audio_frame = []
16 + # iterate through frames
17 + for i in range(n_frames):
18 + rgb_frame.append(tf.cast(tf.decode_raw(
19 + tf_seq_example.feature_lists.feature_list['rgb']
20 + .feature[i].bytes_list.value[0], tf.uint8)
21 + , tf.float32).eval())
22 + audio_frame.append(tf.cast(tf.decode_raw(
23 + tf_seq_example.feature_lists.feature_list['audio']
24 + .feature[i].bytes_list.value[0], tf.uint8)
25 + , tf.float32).eval())
26 +
27 + sess.close()
28 +
29 + feat_audio.append(audio_frame)
30 + feat_rgb.append(rgb_frame)
31 + break
32 +
33 +print('The first video has %d frames' %len(feat_rgb[0]))
...\ No newline at end of file ...\ No newline at end of file
No preview for this file type
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Binary for evaluating Tensorflow models on the YouTube-8M dataset."""
15 +
16 +import json
17 +import os
18 +import time
19 +
20 +from absl import logging
21 +import eval_util
22 +import frame_level_models
23 +import losses
24 +import readers
25 +import tensorflow as tf
26 +from tensorflow import flags
27 +from tensorflow.python.lib.io import file_io
28 +import utils
29 +import video_level_models
30 +
31 +FLAGS = flags.FLAGS
32 +
33 +if __name__ == "__main__":
34 + # Dataset flags.
35 + flags.DEFINE_string(
36 + "train_dir", "/tmp/yt8m_model/",
37 + "The directory to load the model files from. "
38 + "The tensorboard metrics files are also saved to this "
39 + "directory.")
40 + flags.DEFINE_string(
41 + "eval_data_pattern", "",
42 + "File glob defining the evaluation dataset in tensorflow.SequenceExample "
43 + "format. The SequenceExamples are expected to have an 'rgb' byte array "
44 + "sequence feature as well as a 'labels' int64 context feature.")
45 + flags.DEFINE_bool(
46 + "segment_labels", False,
47 + "If set, then --eval_data_pattern must be frame-level features (but with"
48 + " segment_labels). Otherwise, --eval_data_pattern must be aggregated "
49 + "video-level features. The model must also be set appropriately (i.e. to "
50 + "read 3D batches VS 4D batches.")
51 +
52 + # Other flags.
53 + flags.DEFINE_integer("batch_size", 1024,
54 + "How many examples to process per batch.")
55 + flags.DEFINE_integer("num_readers", 8,
56 + "How many threads to use for reading input files.")
57 + flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.")
58 + flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.")
59 +
60 +
61 +def find_class_by_name(name, modules):
62 + """Searches the provided modules for the named class and returns it."""
63 + modules = [getattr(module, name, None) for module in modules]
64 + return next(a for a in modules if a)
65 +
66 +
67 +def get_input_evaluation_tensors(reader,
68 + data_pattern,
69 + batch_size=1024,
70 + num_readers=1):
71 + """Creates the section of the graph which reads the evaluation data.
72 +
73 + Args:
74 + reader: A class which parses the training data.
75 + data_pattern: A 'glob' style path to the data files.
76 + batch_size: How many examples to process at a time.
77 + num_readers: How many I/O threads to use.
78 +
79 + Returns:
80 + A tuple containing the features tensor, labels tensor, and optionally a
81 + tensor containing the number of frames per video. The exact dimensions
82 + depend on the reader being used.
83 +
84 + Raises:
85 + IOError: If no files matching the given pattern were found.
86 + """
87 + logging.info("Using batch size of %d for evaluation.", batch_size)
88 + with tf.name_scope("eval_input"):
89 + files = tf.io.gfile.glob(data_pattern)
90 + if not files:
91 + raise IOError("Unable to find the evaluation files.")
92 + logging.info("number of evaluation files: %d", len(files))
93 + filename_queue = tf.train.string_input_producer(files,
94 + shuffle=False,
95 + num_epochs=1)
96 + eval_data = [
97 + reader.prepare_reader(filename_queue) for _ in range(num_readers)
98 + ]
99 + return tf.train.batch_join(eval_data,
100 + batch_size=batch_size,
101 + capacity=3 * batch_size,
102 + allow_smaller_final_batch=True,
103 + enqueue_many=True)
104 +
105 +
106 +def build_graph(reader,
107 + model,
108 + eval_data_pattern,
109 + label_loss_fn,
110 + batch_size=1024,
111 + num_readers=1):
112 + """Creates the Tensorflow graph for evaluation.
113 +
114 + Args:
115 + reader: The data file reader. It should inherit from BaseReader.
116 + model: The core model (e.g. logistic or neural net). It should inherit from
117 + BaseModel.
118 + eval_data_pattern: glob path to the evaluation data files.
119 + label_loss_fn: What kind of loss to apply to the model. It should inherit
120 + from BaseLoss.
121 + batch_size: How many examples to process at a time.
122 + num_readers: How many threads to use for I/O operations.
123 + """
124 +
125 + global_step = tf.Variable(0, trainable=False, name="global_step")
126 + input_data_dict = get_input_evaluation_tensors(reader,
127 + eval_data_pattern,
128 + batch_size=batch_size,
129 + num_readers=num_readers)
130 + video_id_batch = input_data_dict["video_ids"]
131 + model_input_raw = input_data_dict["video_matrix"]
132 + labels_batch = input_data_dict["labels"]
133 + num_frames = input_data_dict["num_frames"]
134 + tf.compat.v1.summary.histogram("model_input_raw", model_input_raw)
135 +
136 + feature_dim = len(model_input_raw.get_shape()) - 1
137 +
138 + # Normalize input features.
139 + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
140 +
141 + with tf.compat.v1.variable_scope("tower"):
142 + result = model.create_model(model_input,
143 + num_frames=num_frames,
144 + vocab_size=reader.num_classes,
145 + labels=labels_batch,
146 + is_training=False)
147 +
148 + predictions = result["predictions"]
149 + tf.compat.v1.summary.histogram("model_activations", predictions)
150 + if "loss" in result.keys():
151 + label_loss = result["loss"]
152 + else:
153 + label_loss = label_loss_fn.calculate_loss(predictions, labels_batch)
154 +
155 + tf.compat.v1.add_to_collection("global_step", global_step)
156 + tf.compat.v1.add_to_collection("loss", label_loss)
157 + tf.compat.v1.add_to_collection("predictions", predictions)
158 + tf.compat.v1.add_to_collection("input_batch", model_input)
159 + tf.compat.v1.add_to_collection("input_batch_raw", model_input_raw)
160 + tf.compat.v1.add_to_collection("video_id_batch", video_id_batch)
161 + tf.compat.v1.add_to_collection("num_frames", num_frames)
162 + tf.compat.v1.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
163 + if FLAGS.segment_labels:
164 + tf.compat.v1.add_to_collection("label_weights",
165 + input_data_dict["label_weights"])
166 + tf.compat.v1.add_to_collection("summary_op", tf.compat.v1.summary.merge_all())
167 +
168 +
169 +def evaluation_loop(fetches, saver, summary_writer, evl_metrics,
170 + last_global_step_val):
171 + """Run the evaluation loop once.
172 +
173 + Args:
174 + fetches: a dict of tensors to be run within Session.
175 + saver: a tensorflow saver to restore the model.
176 + summary_writer: a tensorflow summary_writer
177 + evl_metrics: an EvaluationMetrics object.
178 + last_global_step_val: the global step used in the previous evaluation.
179 +
180 + Returns:
181 + The global_step used in the latest model.
182 + """
183 +
184 + global_step_val = -1
185 + with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
186 + allow_growth=True))) as sess:
187 + latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
188 + if latest_checkpoint:
189 + logging.info("Loading checkpoint for eval: %s", latest_checkpoint)
190 + # Restores from checkpoint
191 + saver.restore(sess, latest_checkpoint)
192 + # Assuming model_checkpoint_path looks something like:
193 + # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
194 + global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]
195 +
196 + # Save model
197 + if FLAGS.segment_labels:
198 + inference_model_name = "segment_inference_model"
199 + else:
200 + inference_model_name = "inference_model"
201 + saver.save(
202 + sess,
203 + os.path.join(FLAGS.train_dir, "inference_model",
204 + inference_model_name))
205 + else:
206 + logging.info("No checkpoint file found.")
207 + return global_step_val
208 +
209 + if global_step_val == last_global_step_val:
210 + logging.info(
211 + "skip this checkpoint global_step_val=%s "
212 + "(same as the previous one).", global_step_val)
213 + return global_step_val
214 +
215 + sess.run([tf.local_variables_initializer()])
216 +
217 + # Start the queue runners.
218 + coord = tf.train.Coordinator()
219 + try:
220 + threads = []
221 + for qr in tf.compat.v1.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
222 + threads.extend(
223 + qr.create_threads(sess, coord=coord, daemon=True, start=True))
224 + logging.info("enter eval_once loop global_step_val = %s. ",
225 + global_step_val)
226 +
227 + evl_metrics.clear()
228 +
229 + examples_processed = 0
230 + while not coord.should_stop():
231 + batch_start_time = time.time()
232 + output_data_dict = sess.run(fetches)
233 + seconds_per_batch = time.time() - batch_start_time
234 + labels_val = output_data_dict["labels"]
235 + summary_val = output_data_dict["summary"]
236 + example_per_second = labels_val.shape[0] / seconds_per_batch
237 + examples_processed += labels_val.shape[0]
238 +
239 + predictions = output_data_dict["predictions"]
240 + if FLAGS.segment_labels:
241 + # This is a workaround to ignore the unrated labels.
242 + predictions *= output_data_dict["label_weights"]
243 + iteration_info_dict = evl_metrics.accumulate(predictions, labels_val,
244 + output_data_dict["loss"])
245 + iteration_info_dict["examples_per_second"] = example_per_second
246 +
247 + iterinfo = utils.AddGlobalStepSummary(
248 + summary_writer,
249 + global_step_val,
250 + iteration_info_dict,
251 + summary_scope="SegEval" if FLAGS.segment_labels else "Eval")
252 + logging.info("examples_processed: %d | %s", examples_processed,
253 + iterinfo)
254 +
255 + except tf.errors.OutOfRangeError as e:
256 + logging.info(
257 + "Done with batched inference. Now calculating global performance "
258 + "metrics.")
259 + # calculate the metrics for the entire epoch
260 + epoch_info_dict = evl_metrics.get()
261 + epoch_info_dict["epoch_id"] = global_step_val
262 +
263 + summary_writer.add_summary(summary_val, global_step_val)
264 + epochinfo = utils.AddEpochSummary(
265 + summary_writer,
266 + global_step_val,
267 + epoch_info_dict,
268 + summary_scope="SegEval" if FLAGS.segment_labels else "Eval")
269 + logging.info(epochinfo)
270 + evl_metrics.clear()
271 + except Exception as e: # pylint: disable=broad-except
272 + logging.info("Unexpected exception: %s", str(e))
273 + coord.request_stop(e)
274 +
275 + coord.request_stop()
276 + coord.join(threads, stop_grace_period_secs=10)
277 + logging.info("Total: examples_processed: %d", examples_processed)
278 +
279 + return global_step_val
280 +
281 +
282 +def evaluate():
283 + """Starts main evaluation loop."""
284 + tf.compat.v1.set_random_seed(0) # for reproducibility
285 +
286 + # Write json of flags
287 + model_flags_path = os.path.join(FLAGS.train_dir, "model_flags.json")
288 + if not file_io.file_exists(model_flags_path):
289 + raise IOError(("Cannot find file %s. Did you run train.py on the same "
290 + "--train_dir?") % model_flags_path)
291 + flags_dict = json.loads(file_io.FileIO(model_flags_path, mode="r").read())
292 +
293 + with tf.Graph().as_default():
294 + # convert feature_names and feature_sizes to lists of values
295 + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
296 + flags_dict["feature_names"], flags_dict["feature_sizes"])
297 +
298 + if flags_dict["frame_features"]:
299 + reader = readers.YT8MFrameFeatureReader(
300 + feature_names=feature_names,
301 + feature_sizes=feature_sizes,
302 + segment_labels=FLAGS.segment_labels)
303 + else:
304 + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
305 + feature_sizes=feature_sizes)
306 +
307 + model = find_class_by_name(flags_dict["model"],
308 + [frame_level_models, video_level_models])()
309 + label_loss_fn = find_class_by_name(flags_dict["label_loss"], [losses])()
310 +
311 + if not FLAGS.eval_data_pattern:
312 + raise IOError("'eval_data_pattern' was not specified. Nothing to "
313 + "evaluate.")
314 +
315 + build_graph(reader=reader,
316 + model=model,
317 + eval_data_pattern=FLAGS.eval_data_pattern,
318 + label_loss_fn=label_loss_fn,
319 + num_readers=FLAGS.num_readers,
320 + batch_size=FLAGS.batch_size)
321 + logging.info("built evaluation graph")
322 +
323 + # A dict of tensors to be run in Session.
324 + fetches = {
325 + "video_id": tf.compat.v1.get_collection("video_id_batch")[0],
326 + "predictions": tf.compat.v1.get_collection("predictions")[0],
327 + "labels": tf.compat.v1.get_collection("labels")[0],
328 + "loss": tf.compat.v1.get_collection("loss")[0],
329 + "summary": tf.compat.v1.get_collection("summary_op")[0]
330 + }
331 + if FLAGS.segment_labels:
332 + fetches["label_weights"] = tf.compat.v1.get_collection("label_weights")[0]
333 +
334 + saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
335 + summary_writer = tf.compat.v1.summary.FileWriter(
336 + os.path.join(FLAGS.train_dir, "eval"),
337 + graph=tf.compat.v1.get_default_graph())
338 +
339 + evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k,
340 + None)
341 +
342 + last_global_step_val = -1
343 + while True:
344 + last_global_step_val = evaluation_loop(fetches, saver, summary_writer,
345 + evl_metrics, last_global_step_val)
346 + if FLAGS.run_once:
347 + break
348 +
349 +
350 +def main(unused_argv):
351 + logging.set_verbosity(logging.INFO)
352 + logging.info("tensorflow version: %s", tf.__version__)
353 + evaluate()
354 +
355 +
356 +if __name__ == "__main__":
357 + tf.compat.v1.app.run()
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Provides functions to help with evaluating models."""
15 +import average_precision_calculator as ap_calculator
16 +import mean_average_precision_calculator as map_calculator
17 +import numpy
18 +from tensorflow.python.platform import gfile
19 +
20 +
21 +def flatten(l):
22 + """Merges a list of lists into a single list. """
23 + return [item for sublist in l for item in sublist]
24 +
25 +
26 +def calculate_hit_at_one(predictions, actuals):
27 + """Performs a local (numpy) calculation of the hit at one.
28 +
29 + Args:
30 + predictions: Matrix containing the outputs of the model. Dimensions are
31 + 'batch' x 'num_classes'.
32 + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
33 + 'num_classes'.
34 +
35 + Returns:
36 + float: The average hit at one across the entire batch.
37 + """
38 + top_prediction = numpy.argmax(predictions, 1)
39 + hits = actuals[numpy.arange(actuals.shape[0]), top_prediction]
40 + return numpy.average(hits)
41 +
42 +
43 +def calculate_precision_at_equal_recall_rate(predictions, actuals):
44 + """Performs a local (numpy) calculation of the PERR.
45 +
46 + Args:
47 + predictions: Matrix containing the outputs of the model. Dimensions are
48 + 'batch' x 'num_classes'.
49 + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
50 + 'num_classes'.
51 +
52 + Returns:
53 + float: The average precision at equal recall rate across the entire batch.
54 + """
55 + aggregated_precision = 0.0
56 + num_videos = actuals.shape[0]
57 + for row in numpy.arange(num_videos):
58 + num_labels = int(numpy.sum(actuals[row]))
59 + top_indices = numpy.argpartition(predictions[row],
60 + -num_labels)[-num_labels:]
61 + item_precision = 0.0
62 + for label_index in top_indices:
63 + if predictions[row][label_index] > 0:
64 + item_precision += actuals[row][label_index]
65 + item_precision /= top_indices.size
66 + aggregated_precision += item_precision
67 + aggregated_precision /= num_videos
68 + return aggregated_precision
69 +
70 +
71 +def calculate_gap(predictions, actuals, top_k=20):
72 + """Performs a local (numpy) calculation of the global average precision.
73 +
74 + Only the top_k predictions are taken for each of the videos.
75 +
76 + Args:
77 + predictions: Matrix containing the outputs of the model. Dimensions are
78 + 'batch' x 'num_classes'.
79 + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
80 + 'num_classes'.
81 + top_k: How many predictions to use per video.
82 +
83 + Returns:
84 + float: The global average precision.
85 + """
86 + gap_calculator = ap_calculator.AveragePrecisionCalculator()
87 + sparse_predictions, sparse_labels, num_positives = top_k_by_class(
88 + predictions, actuals, top_k)
89 + gap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels),
90 + sum(num_positives))
91 + return gap_calculator.peek_ap_at_n()
92 +
93 +
94 +def top_k_by_class(predictions, labels, k=20):
95 + """Extracts the top k predictions for each video, sorted by class.
96 +
97 + Args:
98 + predictions: A numpy matrix containing the outputs of the model. Dimensions
99 + are 'batch' x 'num_classes'.
100 + k: the top k non-zero entries to preserve in each prediction.
101 +
102 + Returns:
103 + A tuple (predictions,labels, true_positives). 'predictions' and 'labels'
104 + are lists of lists of floats. 'true_positives' is a list of scalars. The
105 + length of the lists are equal to the number of classes. The entries in the
106 + predictions variable are probability predictions, and
107 + the corresponding entries in the labels variable are the ground truth for
108 + those predictions. The entries in 'true_positives' are the number of true
109 + positives for each class in the ground truth.
110 +
111 + Raises:
112 + ValueError: An error occurred when the k is not a positive integer.
113 + """
114 + if k <= 0:
115 + raise ValueError("k must be a positive integer.")
116 + k = min(k, predictions.shape[1])
117 + num_classes = predictions.shape[1]
118 + prediction_triplets = []
119 + for video_index in range(predictions.shape[0]):
120 + prediction_triplets.extend(
121 + top_k_triplets(predictions[video_index], labels[video_index], k))
122 + out_predictions = [[] for _ in range(num_classes)]
123 + out_labels = [[] for _ in range(num_classes)]
124 + for triplet in prediction_triplets:
125 + out_predictions[triplet[0]].append(triplet[1])
126 + out_labels[triplet[0]].append(triplet[2])
127 + out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)]
128 +
129 + return out_predictions, out_labels, out_true_positives
130 +
131 +
132 +def top_k_triplets(predictions, labels, k=20):
133 + """Get the top_k for a 1-d numpy array.
134 +
135 + Returns a sparse list of tuples in
136 + (prediction, class) format
137 + """
138 + m = len(predictions)
139 + k = min(k, m)
140 + indices = numpy.argpartition(predictions, -k)[-k:]
141 + return [(index, predictions[index], labels[index]) for index in indices]
142 +
143 +
144 +class EvaluationMetrics(object):
145 + """A class to store the evaluation metrics."""
146 +
147 + def __init__(self, num_class, top_k, top_n):
148 + """Construct an EvaluationMetrics object to store the evaluation metrics.
149 +
150 + Args:
151 + num_class: A positive integer specifying the number of classes.
152 + top_k: A positive integer specifying how many predictions are considered
153 + per video.
154 + top_n: A positive Integer specifying the average precision at n, or None
155 + to use all provided data points.
156 +
157 + Raises:
158 + ValueError: An error occurred when MeanAveragePrecisionCalculator cannot
159 + not be constructed.
160 + """
161 + self.sum_hit_at_one = 0.0
162 + self.sum_perr = 0.0
163 + self.sum_loss = 0.0
164 + self.map_calculator = map_calculator.MeanAveragePrecisionCalculator(
165 + num_class, top_n=top_n)
166 + self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator()
167 + self.top_k = top_k
168 + self.num_examples = 0
169 +
170 + def accumulate(self, predictions, labels, loss):
171 + """Accumulate the metrics calculated locally for this mini-batch.
172 +
173 + Args:
174 + predictions: A numpy matrix containing the outputs of the model.
175 + Dimensions are 'batch' x 'num_classes'.
176 + labels: A numpy matrix containing the ground truth labels. Dimensions are
177 + 'batch' x 'num_classes'.
178 + loss: A numpy array containing the loss for each sample.
179 +
180 + Returns:
181 + dictionary: A dictionary storing the metrics for the mini-batch.
182 +
183 + Raises:
184 + ValueError: An error occurred when the shape of predictions and actuals
185 + does not match.
186 + """
187 + batch_size = labels.shape[0]
188 + mean_hit_at_one = calculate_hit_at_one(predictions, labels)
189 + mean_perr = calculate_precision_at_equal_recall_rate(predictions, labels)
190 + mean_loss = numpy.mean(loss)
191 +
192 + # Take the top 20 predictions.
193 + sparse_predictions, sparse_labels, num_positives = top_k_by_class(
194 + predictions, labels, self.top_k)
195 + self.map_calculator.accumulate(sparse_predictions, sparse_labels,
196 + num_positives)
197 + self.global_ap_calculator.accumulate(flatten(sparse_predictions),
198 + flatten(sparse_labels),
199 + sum(num_positives))
200 +
201 + self.num_examples += batch_size
202 + self.sum_hit_at_one += mean_hit_at_one * batch_size
203 + self.sum_perr += mean_perr * batch_size
204 + self.sum_loss += mean_loss * batch_size
205 +
206 + return {"hit_at_one": mean_hit_at_one, "perr": mean_perr, "loss": mean_loss}
207 +
208 + def get(self):
209 + """Calculate the evaluation metrics for the whole epoch.
210 +
211 + Raises:
212 + ValueError: If no examples were accumulated.
213 +
214 + Returns:
215 + dictionary: a dictionary storing the evaluation metrics for the epoch. The
216 + dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and
217 + aps (default nan).
218 + """
219 + if self.num_examples <= 0:
220 + raise ValueError("total_sample must be positive.")
221 + avg_hit_at_one = self.sum_hit_at_one / self.num_examples
222 + avg_perr = self.sum_perr / self.num_examples
223 + avg_loss = self.sum_loss / self.num_examples
224 +
225 + aps = self.map_calculator.peek_map_at_n()
226 + gap = self.global_ap_calculator.peek_ap_at_n()
227 +
228 + epoch_info_dict = {
229 + "avg_hit_at_one": avg_hit_at_one,
230 + "avg_perr": avg_perr,
231 + "avg_loss": avg_loss,
232 + "aps": aps,
233 + "gap": gap
234 + }
235 + return epoch_info_dict
236 +
237 + def clear(self):
238 + """Clear the evaluation metrics and reset the EvaluationMetrics object."""
239 + self.sum_hit_at_one = 0.0
240 + self.sum_perr = 0.0
241 + self.sum_loss = 0.0
242 + self.map_calculator.clear()
243 + self.global_ap_calculator.clear()
244 + self.num_examples = 0
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Utilities to export a model for batch prediction."""
15 +
16 +import tensorflow as tf
17 +import tensorflow.contrib.slim as slim
18 +
19 +from tensorflow.python.saved_model import builder as saved_model_builder
20 +from tensorflow.python.saved_model import signature_constants
21 +from tensorflow.python.saved_model import signature_def_utils
22 +from tensorflow.python.saved_model import tag_constants
23 +from tensorflow.python.saved_model import utils as saved_model_utils
24 +
25 +_TOP_PREDICTIONS_IN_OUTPUT = 20
26 +
27 +
28 +class ModelExporter(object):
29 +
30 + def __init__(self, frame_features, model, reader):
31 + self.frame_features = frame_features
32 + self.model = model
33 + self.reader = reader
34 +
35 + with tf.Graph().as_default() as graph:
36 + self.inputs, self.outputs = self.build_inputs_and_outputs()
37 + self.graph = graph
38 + self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True)
39 +
40 + def export_model(self, model_dir, global_step_val, last_checkpoint):
41 + """Exports the model so that it can used for batch predictions."""
42 +
43 + with self.graph.as_default():
44 + with tf.Session() as session:
45 + session.run(tf.global_variables_initializer())
46 + self.saver.restore(session, last_checkpoint)
47 +
48 + signature = signature_def_utils.build_signature_def(
49 + inputs=self.inputs,
50 + outputs=self.outputs,
51 + method_name=signature_constants.PREDICT_METHOD_NAME)
52 +
53 + signature_map = {
54 + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
55 + }
56 +
57 + model_builder = saved_model_builder.SavedModelBuilder(model_dir)
58 + model_builder.add_meta_graph_and_variables(
59 + session,
60 + tags=[tag_constants.SERVING],
61 + signature_def_map=signature_map,
62 + clear_devices=True)
63 + model_builder.save()
64 +
65 + def build_inputs_and_outputs(self):
66 + if self.frame_features:
67 + serialized_examples = tf.placeholder(tf.string, shape=(None,))
68 +
69 + fn = lambda x: self.build_prediction_graph(x)
70 + video_id_output, top_indices_output, top_predictions_output = (tf.map_fn(
71 + fn, serialized_examples, dtype=(tf.string, tf.int32, tf.float32)))
72 +
73 + else:
74 + serialized_examples = tf.placeholder(tf.string, shape=(None,))
75 +
76 + video_id_output, top_indices_output, top_predictions_output = (
77 + self.build_prediction_graph(serialized_examples))
78 +
79 + inputs = {
80 + "example_bytes":
81 + saved_model_utils.build_tensor_info(serialized_examples)
82 + }
83 +
84 + outputs = {
85 + "video_id":
86 + saved_model_utils.build_tensor_info(video_id_output),
87 + "class_indexes":
88 + saved_model_utils.build_tensor_info(top_indices_output),
89 + "predictions":
90 + saved_model_utils.build_tensor_info(top_predictions_output)
91 + }
92 +
93 + return inputs, outputs
94 +
95 + def build_prediction_graph(self, serialized_examples):
96 + input_data_dict = (
97 + self.reader.prepare_serialized_examples(serialized_examples))
98 + video_id = input_data_dict["video_ids"]
99 + model_input_raw = input_data_dict["video_matrix"]
100 + labels_batch = input_data_dict["labels"]
101 + num_frames = input_data_dict["num_frames"]
102 +
103 + feature_dim = len(model_input_raw.get_shape()) - 1
104 + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
105 +
106 + with tf.variable_scope("tower"):
107 + result = self.model.create_model(model_input,
108 + num_frames=num_frames,
109 + vocab_size=self.reader.num_classes,
110 + labels=labels_batch,
111 + is_training=False)
112 +
113 + for variable in slim.get_model_variables():
114 + tf.summary.histogram(variable.op.name, variable)
115 +
116 + predictions = result["predictions"]
117 +
118 + top_predictions, top_indices = tf.nn.top_k(predictions,
119 + _TOP_PREDICTIONS_IN_OUTPUT)
120 + return video_id, top_indices, top_predictions
1 +# Lint as: python3
2 +import numpy as np
3 +import tensorflow as tf
4 +from tensorflow import app
5 +from tensorflow import flags
6 +
7 +FLAGS = flags.FLAGS
8 +
9 +
10 +def main(unused_argv):
11 + # Get the input tensor names to be replaced.
12 + tf.reset_default_graph()
13 + meta_graph_location = FLAGS.checkpoint_file + ".meta"
14 + tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
15 +
16 + input_tensor_name = tf.get_collection("input_batch_raw")[0].name
17 + num_frames_tensor_name = tf.get_collection("num_frames")[0].name
18 +
19 + # Create output graph.
20 + saver = tf.train.Saver()
21 + tf.reset_default_graph()
22 +
23 + input_feature_placeholder = tf.placeholder(
24 + tf.float32, shape=(None, None, 1152))
25 + num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1))
26 +
27 + saver = tf.train.import_meta_graph(
28 + meta_graph_location,
29 + input_map={
30 + input_tensor_name: input_feature_placeholder,
31 + num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1)
32 + },
33 + clear_devices=True)
34 + predictions_tensor = tf.get_collection("predictions")[0]
35 +
36 + with tf.Session() as sess:
37 + print("restoring variables from " + FLAGS.checkpoint_file)
38 + saver.restore(sess, FLAGS.checkpoint_file)
39 + tf.saved_model.simple_save(
40 + sess,
41 + FLAGS.output_dir,
42 + inputs={'rgb_and_audio': input_feature_placeholder,
43 + 'num_frames': num_frames_placeholder},
44 + outputs={'predictions': predictions_tensor})
45 +
46 + # Try running inference.
47 + predictions = sess.run(
48 + [predictions_tensor],
49 + feed_dict={
50 + input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32),
51 + num_frames_placeholder: np.array([[7]], dtype=np.int32)})
52 + print('Test inference:', predictions)
53 +
54 + print('Model saved to ', FLAGS.output_dir)
55 +
56 +
57 +if __name__ == '__main__':
58 + flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.')
59 + flags.DEFINE_string('output_dir', None, 'SavedModel output directory.')
60 + app.run(main)
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains a collection of models which operate on variable-length sequences."""
15 +import math
16 +
17 +import model_utils as utils
18 +import models
19 +import tensorflow as tf
20 +from tensorflow import flags
21 +import tensorflow.contrib.slim as slim
22 +import video_level_models
23 +
24 +FLAGS = flags.FLAGS
25 +flags.DEFINE_integer("iterations", 30, "Number of frames per batch for DBoF.")
26 +flags.DEFINE_bool("dbof_add_batch_norm", True,
27 + "Adds batch normalization to the DBoF model.")
28 +flags.DEFINE_bool(
29 + "sample_random_frames", True,
30 + "If true samples random frames (for frame level models). If false, a random"
31 + "sequence of frames is sampled instead.")
32 +flags.DEFINE_integer("dbof_cluster_size", 8192,
33 + "Number of units in the DBoF cluster layer.")
34 +flags.DEFINE_integer("dbof_hidden_size", 1024,
35 + "Number of units in the DBoF hidden layer.")
36 +flags.DEFINE_string(
37 + "dbof_pooling_method", "max",
38 + "The pooling method used in the DBoF cluster layer. "
39 + "Choices are 'average' and 'max'.")
40 +flags.DEFINE_string(
41 + "dbof_activation", "sigmoid",
42 + "The nonlinear activation method for cluster and hidden dense layer, e.g., "
43 + "sigmoid, relu6, etc.")
44 +flags.DEFINE_string(
45 + "video_level_classifier_model", "MoeModel",
46 + "Some Frame-Level models can be decomposed into a "
47 + "generalized pooling operation followed by a "
48 + "classifier layer")
49 +flags.DEFINE_integer("lstm_cells", 1024, "Number of LSTM cells.")
50 +flags.DEFINE_integer("lstm_layers", 2, "Number of LSTM layers.")
51 +
52 +
53 +class FrameLevelLogisticModel(models.BaseModel):
54 + """Creates a logistic classifier over the aggregated frame-level features."""
55 +
56 + def create_model(self, model_input, vocab_size, num_frames, **unused_params):
57 + """See base class.
58 +
59 + This class is intended to be an example for implementors of frame level
60 + models. If you want to train a model over averaged features it is more
61 + efficient to average them beforehand rather than on the fly.
62 +
63 + Args:
64 + model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
65 + input features.
66 + vocab_size: The number of classes in the dataset.
67 + num_frames: A vector of length 'batch' which indicates the number of
68 + frames for each video (before padding).
69 +
70 + Returns:
71 + A dictionary with a tensor containing the probability predictions of the
72 + model in the 'predictions' key. The dimensions of the tensor are
73 + 'batch_size' x 'num_classes'.
74 + """
75 + num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32)
76 + feature_size = model_input.get_shape().as_list()[2]
77 +
78 + denominators = tf.reshape(tf.tile(num_frames, [1, feature_size]),
79 + [-1, feature_size])
80 + avg_pooled = tf.reduce_sum(model_input, axis=[1]) / denominators
81 +
82 + output = slim.fully_connected(avg_pooled,
83 + vocab_size,
84 + activation_fn=tf.nn.sigmoid,
85 + weights_regularizer=slim.l2_regularizer(1e-8))
86 + return {"predictions": output}
87 +
88 +
89 +class DbofModel(models.BaseModel):
90 + """Creates a Deep Bag of Frames model.
91 +
92 + The model projects the features for each frame into a higher dimensional
93 + 'clustering' space, pools across frames in that space, and then
94 + uses a configurable video-level model to classify the now aggregated features.
95 +
96 + The model will randomly sample either frames or sequences of frames during
97 + training to speed up convergence.
98 + """
99 +
100 + ACT_FN_MAP = {
101 + "sigmoid": tf.nn.sigmoid,
102 + "relu6": tf.nn.relu6,
103 + }
104 +
105 + def create_model(self,
106 + model_input,
107 + vocab_size,
108 + num_frames,
109 + iterations=None,
110 + add_batch_norm=None,
111 + sample_random_frames=None,
112 + cluster_size=None,
113 + hidden_size=None,
114 + is_training=True,
115 + **unused_params):
116 + """See base class.
117 +
118 + Args:
119 + model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
120 + input features.
121 + vocab_size: The number of classes in the dataset.
122 + num_frames: A vector of length 'batch' which indicates the number of
123 + frames for each video (before padding).
124 + iterations: the number of frames to be sampled.
125 + add_batch_norm: whether to add batch norm during training.
126 + sample_random_frames: whether to sample random frames or random sequences.
127 + cluster_size: the output neuron number of the cluster layer.
128 + hidden_size: the output neuron number of the hidden layer.
129 + is_training: whether to build the graph in training mode.
130 +
131 + Returns:
132 + A dictionary with a tensor containing the probability predictions of the
133 + model in the 'predictions' key. The dimensions of the tensor are
134 + 'batch_size' x 'num_classes'.
135 + """
136 + iterations = iterations or FLAGS.iterations
137 + add_batch_norm = add_batch_norm or FLAGS.dbof_add_batch_norm
138 + random_frames = sample_random_frames or FLAGS.sample_random_frames
139 + cluster_size = cluster_size or FLAGS.dbof_cluster_size
140 + hidden1_size = hidden_size or FLAGS.dbof_hidden_size
141 + act_fn = self.ACT_FN_MAP.get(FLAGS.dbof_activation)
142 + assert act_fn is not None, ("dbof_activation is not valid: %s." %
143 + FLAGS.dbof_activation)
144 +
145 + num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32)
146 + if random_frames:
147 + model_input = utils.SampleRandomFrames(model_input, num_frames,
148 + iterations)
149 + else:
150 + model_input = utils.SampleRandomSequence(model_input, num_frames,
151 + iterations)
152 + max_frames = model_input.get_shape().as_list()[1]
153 + feature_size = model_input.get_shape().as_list()[2]
154 + reshaped_input = tf.reshape(model_input, [-1, feature_size])
155 + tf.compat.v1.summary.histogram("input_hist", reshaped_input)
156 +
157 + if add_batch_norm:
158 + reshaped_input = slim.batch_norm(reshaped_input,
159 + center=True,
160 + scale=True,
161 + is_training=is_training,
162 + scope="input_bn")
163 +
164 + cluster_weights = tf.compat.v1.get_variable(
165 + "cluster_weights", [feature_size, cluster_size],
166 + initializer=tf.random_normal_initializer(stddev=1 /
167 + math.sqrt(feature_size)))
168 + tf.compat.v1.summary.histogram("cluster_weights", cluster_weights)
169 + activation = tf.matmul(reshaped_input, cluster_weights)
170 + if add_batch_norm:
171 + activation = slim.batch_norm(activation,
172 + center=True,
173 + scale=True,
174 + is_training=is_training,
175 + scope="cluster_bn")
176 + else:
177 + cluster_biases = tf.compat.v1.get_variable(
178 + "cluster_biases", [cluster_size],
179 + initializer=tf.random_normal_initializer(stddev=1 /
180 + math.sqrt(feature_size)))
181 + tf.compat.v1.summary.histogram("cluster_biases", cluster_biases)
182 + activation += cluster_biases
183 + activation = act_fn(activation)
184 + tf.compat.v1.summary.histogram("cluster_output", activation)
185 +
186 + activation = tf.reshape(activation, [-1, max_frames, cluster_size])
187 + activation = utils.FramePooling(activation, FLAGS.dbof_pooling_method)
188 +
189 + hidden1_weights = tf.compat.v1.get_variable(
190 + "hidden1_weights", [cluster_size, hidden1_size],
191 + initializer=tf.random_normal_initializer(stddev=1 /
192 + math.sqrt(cluster_size)))
193 + tf.compat.v1.summary.histogram("hidden1_weights", hidden1_weights)
194 + activation = tf.matmul(activation, hidden1_weights)
195 + if add_batch_norm:
196 + activation = slim.batch_norm(activation,
197 + center=True,
198 + scale=True,
199 + is_training=is_training,
200 + scope="hidden1_bn")
201 + else:
202 + hidden1_biases = tf.compat.v1.get_variable(
203 + "hidden1_biases", [hidden1_size],
204 + initializer=tf.random_normal_initializer(stddev=0.01))
205 + tf.compat.v1.summary.histogram("hidden1_biases", hidden1_biases)
206 + activation += hidden1_biases
207 + activation = act_fn(activation)
208 + tf.compat.v1.summary.histogram("hidden1_output", activation)
209 +
210 + aggregated_model = getattr(video_level_models,
211 + FLAGS.video_level_classifier_model)
212 + return aggregated_model().create_model(model_input=activation,
213 + vocab_size=vocab_size,
214 + **unused_params)
215 +
216 +
217 +class LstmModel(models.BaseModel):
218 + """Creates a model which uses a stack of LSTMs to represent the video."""
219 +
220 + def create_model(self, model_input, vocab_size, num_frames, **unused_params):
221 + """See base class.
222 +
223 + Args:
224 + model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
225 + input features.
226 + vocab_size: The number of classes in the dataset.
227 + num_frames: A vector of length 'batch' which indicates the number of
228 + frames for each video (before padding).
229 +
230 + Returns:
231 + A dictionary with a tensor containing the probability predictions of the
232 + model in the 'predictions' key. The dimensions of the tensor are
233 + 'batch_size' x 'num_classes'.
234 + """
235 + lstm_size = FLAGS.lstm_cells
236 + number_of_layers = FLAGS.lstm_layers
237 +
238 + stacked_lstm = tf.contrib.rnn.MultiRNNCell([
239 + tf.contrib.rnn.BasicLSTMCell(lstm_size, forget_bias=1.0)
240 + for _ in range(number_of_layers)
241 + ])
242 +
243 + _, state = tf.nn.dynamic_rnn(stacked_lstm,
244 + model_input,
245 + sequence_length=num_frames,
246 + dtype=tf.float32)
247 +
248 + aggregated_model = getattr(video_level_models,
249 + FLAGS.video_level_classifier_model)
250 +
251 + return aggregated_model().create_model(model_input=state[-1].h,
252 + vocab_size=vocab_size,
253 + **unused_params)
1 +# Copyright 2017 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Binary for generating predictions over a set of videos."""
15 +
16 +from __future__ import print_function
17 +
18 +import glob
19 +import heapq
20 +import json
21 +import os
22 +import tarfile
23 +import tempfile
24 +import time
25 +import numpy as np
26 +
27 +import readers
28 +from six.moves import urllib
29 +import tensorflow as tf
30 +from tensorflow import app
31 +from tensorflow import flags
32 +from tensorflow import gfile
33 +from tensorflow import logging
34 +from tensorflow.python.lib.io import file_io
35 +import utils
36 +
37 +FLAGS = flags.FLAGS
38 +
39 +if __name__ == "__main__":
40 + # Input
41 + flags.DEFINE_string(
42 + "train_dir", "", "The directory to load the model files from. We assume "
43 + "that you have already run eval.py onto this, such that "
44 + "inference_model.* files already exist.")
45 + flags.DEFINE_string(
46 + "input_data_pattern", "",
47 + "File glob defining the evaluation dataset in tensorflow.SequenceExample "
48 + "format. The SequenceExamples are expected to have an 'rgb' byte array "
49 + "sequence feature as well as a 'labels' int64 context feature.")
50 + flags.DEFINE_string(
51 + "input_model_tgz", "",
52 + "If given, must be path to a .tgz file that was written "
53 + "by this binary using flag --output_model_tgz. In this "
54 + "case, the .tgz file will be untarred to "
55 + "--untar_model_dir and the model will be used for "
56 + "inference.")
57 + flags.DEFINE_string(
58 + "untar_model_dir", "/tmp/yt8m-model",
59 + "If --input_model_tgz is given, then this directory will "
60 + "be created and the contents of the .tgz file will be "
61 + "untarred here.")
62 + flags.DEFINE_bool(
63 + "segment_labels", False,
64 + "If set, then --input_data_pattern must be frame-level features (but with"
65 + " segment_labels). Otherwise, --input_data_pattern must be aggregated "
66 + "video-level features. The model must also be set appropriately (i.e. to "
67 + "read 3D batches VS 4D batches.")
68 + flags.DEFINE_integer("segment_max_pred", 100000,
69 + "Limit total number of segment outputs per entity.")
70 + flags.DEFINE_string(
71 + "segment_label_ids_file",
72 + "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv",
73 + "The file that contains the segment label ids.")
74 +
75 + # Output
76 + flags.DEFINE_string("output_file", "", "The file to save the predictions to.")
77 + flags.DEFINE_string(
78 + "output_model_tgz", "",
79 + "If given, should be a filename with a .tgz extension, "
80 + "the model graph and checkpoint will be bundled in this "
81 + "gzip tar. This file can be uploaded to Kaggle for the "
82 + "top 10 participants.")
83 + flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.")
84 +
85 + # Other flags.
86 + flags.DEFINE_integer("batch_size", 512,
87 + "How many examples to process per batch.")
88 + flags.DEFINE_integer("num_readers", 1,
89 + "How many threads to use for reading input files.")
90 +
91 +
92 +def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None):
93 + """Create an information line the submission file."""
94 + batch_size = len(video_ids)
95 + for video_index in range(batch_size):
96 + video_prediction = predictions[video_index]
97 + if whitelisted_cls_mask is not None:
98 + # Whitelist classes.
99 + video_prediction *= whitelisted_cls_mask
100 + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:]
101 + line = [(class_index, predictions[video_index][class_index])
102 + for class_index in top_indices]
103 + line = sorted(line, key=lambda p: -p[1])
104 + yield (video_ids[video_index] + "," +
105 + " ".join("%i %g" % (label, score) for (label, score) in line) +
106 + "\n").encode("utf8")
107 +
108 +
109 +def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1):
110 + """Creates the section of the graph which reads the input data.
111 +
112 + Args:
113 + reader: A class which parses the input data.
114 + data_pattern: A 'glob' style path to the data files.
115 + batch_size: How many examples to process at a time.
116 + num_readers: How many I/O threads to use.
117 +
118 + Returns:
119 + A tuple containing the features tensor, labels tensor, and optionally a
120 + tensor containing the number of frames per video. The exact dimensions
121 + depend on the reader being used.
122 +
123 + Raises:
124 + IOError: If no files matching the given pattern were found.
125 + """
126 + with tf.name_scope("input"):
127 + files = gfile.Glob(data_pattern)
128 + if not files:
129 + raise IOError("Unable to find input files. data_pattern='" +
130 + data_pattern + "'")
131 + logging.info("number of input files: " + str(len(files)))
132 + filename_queue = tf.train.string_input_producer(files,
133 + num_epochs=1,
134 + shuffle=False)
135 + examples_and_labels = [
136 + reader.prepare_reader(filename_queue) for _ in range(num_readers)
137 + ]
138 +
139 + input_data_dict = (tf.train.batch_join(examples_and_labels,
140 + batch_size=batch_size,
141 + allow_smaller_final_batch=True,
142 + enqueue_many=True))
143 + video_id_batch = input_data_dict["video_ids"]
144 + video_batch = input_data_dict["video_matrix"]
145 + num_frames_batch = input_data_dict["num_frames"]
146 + return video_id_batch, video_batch, num_frames_batch
147 +
148 +
149 +def get_segments(batch_video_mtx, batch_num_frames, segment_size):
150 + """Get segment-level inputs from frame-level features."""
151 + video_batch_size = batch_video_mtx.shape[0]
152 + max_frame = batch_video_mtx.shape[1]
153 + feature_dim = batch_video_mtx.shape[-1]
154 + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size
155 + padded_segment_sizes *= segment_size
156 + segment_mask = (
157 + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame)))
158 +
159 + # Segment bags.
160 + frame_bags = batch_video_mtx.reshape((-1, feature_dim))
161 + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape(
162 + (-1, segment_size, feature_dim))
163 +
164 + # Segment num frames.
165 + segment_start_times = np.arange(0, max_frame, segment_size)
166 + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times
167 + num_segment_bags = num_segments.reshape((-1))
168 + valid_segment_mask = num_segment_bags > 0
169 + segment_num_frames = num_segment_bags[valid_segment_mask]
170 + segment_num_frames[segment_num_frames > segment_size] = segment_size
171 +
172 + max_segment_num = (max_frame + segment_size - 1) // segment_size
173 + video_idxs = np.tile(
174 + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num])
175 + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1])
176 + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2))
177 + video_segment_ids = idx_bags[valid_segment_mask]
178 +
179 + return {
180 + "video_batch": segment_frames,
181 + "num_frames_batch": segment_num_frames,
182 + "video_segment_ids": video_segment_ids
183 + }
184 +
185 +
186 +def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
187 + top_k):
188 + """Inference function."""
189 + with tf.Session(config=tf.ConfigProto(
190 + allow_soft_placement=True)) as sess, gfile.Open(out_file_location,
191 + "w+") as out_file:
192 + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(
193 + reader, data_pattern, batch_size)
194 + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model"
195 + checkpoint_file = os.path.join(train_dir, "inference_model",
196 + inference_model_name)
197 + if not gfile.Exists(checkpoint_file + ".meta"):
198 + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file)
199 + meta_graph_location = checkpoint_file + ".meta"
200 + logging.info("loading meta-graph: " + meta_graph_location)
201 +
202 + if FLAGS.output_model_tgz:
203 + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar:
204 + for model_file in glob.glob(checkpoint_file + ".*"):
205 + tar.add(model_file, arcname=os.path.basename(model_file))
206 + tar.add(os.path.join(train_dir, "model_flags.json"),
207 + arcname="model_flags.json")
208 + print("Tarred model onto " + FLAGS.output_model_tgz)
209 + with tf.device("/cpu:0"):
210 + saver = tf.train.import_meta_graph(meta_graph_location,
211 + clear_devices=True)
212 + logging.info("restoring variables from " + checkpoint_file)
213 + saver.restore(sess, checkpoint_file)
214 + input_tensor = tf.get_collection("input_batch_raw")[0]
215 + num_frames_tensor = tf.get_collection("num_frames")[0]
216 + predictions_tensor = tf.get_collection("predictions")[0]
217 +
218 + # Workaround for num_epochs issue.
219 + def set_up_init_ops(variables):
220 + init_op_list = []
221 + for variable in list(variables):
222 + if "train_input" in variable.name:
223 + init_op_list.append(tf.assign(variable, 1))
224 + variables.remove(variable)
225 + init_op_list.append(tf.variables_initializer(variables))
226 + return init_op_list
227 +
228 + sess.run(
229 + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)))
230 +
231 + coord = tf.train.Coordinator()
232 + threads = tf.train.start_queue_runners(sess=sess, coord=coord)
233 + num_examples_processed = 0
234 + start_time = time.time()
235 + whitelisted_cls_mask = None
236 + if FLAGS.segment_labels:
237 + final_out_file = out_file
238 + out_file = tempfile.NamedTemporaryFile()
239 + logging.info(
240 + "Segment temp prediction output will be written to temp file: %s",
241 + out_file.name)
242 + if FLAGS.segment_label_ids_file:
243 + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
244 + dtype=np.float32)
245 + segment_label_ids_file = FLAGS.segment_label_ids_file
246 + if segment_label_ids_file.startswith("http"):
247 + logging.info("Retrieving segment ID whitelist files from %s...",
248 + segment_label_ids_file)
249 + segment_label_ids_file, _ = urllib.request.urlretrieve(
250 + segment_label_ids_file)
251 + with tf.io.gfile.GFile(segment_label_ids_file) as fobj:
252 + for line in fobj:
253 + try:
254 + cls_id = int(line)
255 + whitelisted_cls_mask[cls_id] = 1.
256 + except ValueError:
257 + # Simply skip the non-integer line.
258 + continue
259 +
260 + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8"))
261 +
262 + try:
263 + while not coord.should_stop():
264 + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run(
265 + [video_id_batch, video_batch, num_frames_batch])
266 + if FLAGS.segment_labels:
267 + results = get_segments(video_batch_val, num_frames_batch_val, 5)
268 + video_segment_ids = results["video_segment_ids"]
269 + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]]
270 + video_id_batch_val = np.array([
271 + "%s:%d" % (x.decode("utf8"), y)
272 + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1])
273 + ])
274 + video_batch_val = results["video_batch"]
275 + num_frames_batch_val = results["num_frames_batch"]
276 + if input_tensor.get_shape()[1] != video_batch_val.shape[1]:
277 + raise ValueError("max_frames mismatch. Please re-run the eval.py "
278 + "with correct segment_labels settings.")
279 +
280 + predictions_val, = sess.run([predictions_tensor],
281 + feed_dict={
282 + input_tensor: video_batch_val,
283 + num_frames_tensor: num_frames_batch_val
284 + })
285 + now = time.time()
286 + num_examples_processed += len(video_batch_val)
287 + elapsed_time = now - start_time
288 + logging.info("num examples processed: " + str(num_examples_processed) +
289 + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) +
290 + " examples/sec: %.2f" %
291 + (num_examples_processed / elapsed_time))
292 + for line in format_lines(video_id_batch_val, predictions_val, top_k,
293 + whitelisted_cls_mask):
294 + out_file.write(line)
295 + out_file.flush()
296 +
297 + except tf.errors.OutOfRangeError:
298 + logging.info("Done with inference. The output file was written to " +
299 + out_file.name)
300 + finally:
301 + coord.request_stop()
302 +
303 + if FLAGS.segment_labels:
304 + # Re-read the file and do heap sort.
305 + # Create multiple heaps.
306 + logging.info("Post-processing segment predictions...")
307 + heaps = {}
308 + out_file.seek(0, 0)
309 + for line in out_file:
310 + segment_id, preds = line.decode("utf8").split(",")
311 + if segment_id == "VideoId":
312 + # Skip the headline.
313 + continue
314 + preds = preds.split(" ")
315 + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)]
316 + pred_cls_scores = [
317 + float(preds[idx]) for idx in range(1, len(preds), 2)
318 + ]
319 + for cls, score in zip(pred_cls_ids, pred_cls_scores):
320 + if not whitelisted_cls_mask[cls]:
321 + # Skip non-whitelisted classes.
322 + continue
323 + if cls not in heaps:
324 + heaps[cls] = []
325 + if len(heaps[cls]) >= FLAGS.segment_max_pred:
326 + heapq.heappushpop(heaps[cls], (score, segment_id))
327 + else:
328 + heapq.heappush(heaps[cls], (score, segment_id))
329 + logging.info("Writing sorted segment predictions to: %s",
330 + final_out_file.name)
331 + final_out_file.write("Class,Segments\n")
332 + for cls, cls_heap in heaps.items():
333 + cls_heap.sort(key=lambda x: x[0], reverse=True)
334 + final_out_file.write("%d,%s\n" %
335 + (cls, " ".join([x[1] for x in cls_heap])))
336 + final_out_file.close()
337 +
338 + out_file.close()
339 +
340 + coord.join(threads)
341 + sess.close()
342 +
343 +
344 +def main(unused_argv):
345 + logging.set_verbosity(tf.logging.INFO)
346 + if FLAGS.input_model_tgz:
347 + if FLAGS.train_dir:
348 + raise ValueError("You cannot supply --train_dir if supplying "
349 + "--input_model_tgz")
350 + # Untar.
351 + if not os.path.exists(FLAGS.untar_model_dir):
352 + os.makedirs(FLAGS.untar_model_dir)
353 + tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir)
354 + FLAGS.train_dir = FLAGS.untar_model_dir
355 +
356 + flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json")
357 + if not file_io.file_exists(flags_dict_file):
358 + raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file)
359 + flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read())
360 +
361 + # convert feature_names and feature_sizes to lists of values
362 + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
363 + flags_dict["feature_names"], flags_dict["feature_sizes"])
364 +
365 + if flags_dict["frame_features"]:
366 + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
367 + feature_sizes=feature_sizes)
368 + else:
369 + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
370 + feature_sizes=feature_sizes)
371 +
372 + if not FLAGS.output_file:
373 + raise ValueError("'output_file' was not specified. "
374 + "Unable to continue with inference.")
375 +
376 + if not FLAGS.input_data_pattern:
377 + raise ValueError("'input_data_pattern' was not specified. "
378 + "Unable to continue with inference.")
379 +
380 + inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern,
381 + FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k)
382 +
383 +
384 +if __name__ == "__main__":
385 + app.run()
1 +# Copyright 2017 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Binary for generating predictions over a set of videos."""
15 +
16 +from __future__ import print_function
17 +
18 +import glob
19 +import heapq
20 +import json
21 +import os
22 +import tarfile
23 +import tempfile
24 +import time
25 +import numpy as np
26 +
27 +import readers
28 +from six.moves import urllib
29 +import tensorflow as tf
30 +from tensorflow import app
31 +from tensorflow import flags
32 +from tensorflow import gfile
33 +from tensorflow import logging
34 +from tensorflow.python.lib.io import file_io
35 +import utils
36 +from collections import Counter
37 +
38 +FLAGS = flags.FLAGS
39 +
40 +if __name__ == "__main__":
41 + # Input
42 + flags.DEFINE_string(
43 + "train_dir", "", "The directory to load the model files from. We assume "
44 + "that you have already run eval.py onto this, such that "
45 + "inference_model.* files already exist.")
46 + flags.DEFINE_string(
47 + "input_data_pattern", "",
48 + "File glob defining the evaluation dataset in tensorflow.SequenceExample "
49 + "format. The SequenceExamples are expected to have an 'rgb' byte array "
50 + "sequence feature as well as a 'labels' int64 context feature.")
51 + flags.DEFINE_string(
52 + "input_model_tgz", "",
53 + "If given, must be path to a .tgz file that was written "
54 + "by this binary using flag --output_model_tgz. In this "
55 + "case, the .tgz file will be untarred to "
56 + "--untar_model_dir and the model will be used for "
57 + "inference.")
58 + flags.DEFINE_string(
59 + "untar_model_dir", "/tmp/yt8m-model",
60 + "If --input_model_tgz is given, then this directory will "
61 + "be created and the contents of the .tgz file will be "
62 + "untarred here.")
63 + flags.DEFINE_bool(
64 + "segment_labels", False,
65 + "If set, then --input_data_pattern must be frame-level features (but with"
66 + " segment_labels). Otherwise, --input_data_pattern must be aggregated "
67 + "video-level features. The model must also be set appropriately (i.e. to "
68 + "read 3D batches VS 4D batches.")
69 + flags.DEFINE_integer("segment_max_pred", 100000,
70 + "Limit total number of segment outputs per entity.")
71 + flags.DEFINE_string(
72 + "segment_label_ids_file",
73 + "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv",
74 + "The file that contains the segment label ids.")
75 +
76 + # Output
77 + flags.DEFINE_string("output_file", "", "The file to save the predictions to.")
78 + flags.DEFINE_string(
79 + "output_model_tgz", "",
80 + "If given, should be a filename with a .tgz extension, "
81 + "the model graph and checkpoint will be bundled in this "
82 + "gzip tar. This file can be uploaded to Kaggle for the "
83 + "top 10 participants.")
84 + flags.DEFINE_integer("top_k", 1, "How many predictions to output per video.")
85 +
86 + # Other flags.
87 + flags.DEFINE_integer("batch_size", 512,
88 + "How many examples to process per batch.")
89 + flags.DEFINE_integer("num_readers", 1,
90 + "How many threads to use for reading input files.")
91 +
92 +
93 +def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None):
94 + """Create an information line the submission file."""
95 + batch_size = len(video_ids)
96 + for video_index in range(batch_size):
97 + video_prediction = predictions[video_index]
98 + if whitelisted_cls_mask is not None:
99 + # Whitelist classes.
100 + video_prediction *= whitelisted_cls_mask
101 + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:]
102 + line = [(class_index, predictions[video_index][class_index])
103 + for class_index in top_indices]
104 + line = sorted(line, key=lambda p: -p[1])
105 + yield (video_ids[video_index] + "," +
106 + " ".join("%i %g" % (label, score) for (label, score) in line) +
107 + "\n").encode("utf8")
108 +
109 +
110 +def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1):
111 + """Creates the section of the graph which reads the input data.
112 +
113 + Args:
114 + reader: A class which parses the input data.
115 + data_pattern: A 'glob' style path to the data files.
116 + batch_size: How many examples to process at a time.
117 + num_readers: How many I/O threads to use.
118 +
119 + Returns:
120 + A tuple containing the features tensor, labels tensor, and optionally a
121 + tensor containing the number of frames per video. The exact dimensions
122 + depend on the reader being used.
123 +
124 + Raises:
125 + IOError: If no files matching the given pattern were found.
126 + """
127 + with tf.name_scope("input"):
128 + files = gfile.Glob(data_pattern)
129 + if not files:
130 + raise IOError("Unable to find input files. data_pattern='" +
131 + data_pattern + "'")
132 + logging.info("number of input files: " + str(len(files)))
133 + filename_queue = tf.train.string_input_producer(files,
134 + num_epochs=1,
135 + shuffle=False)
136 + examples_and_labels = [
137 + reader.prepare_reader(filename_queue) for _ in range(num_readers)
138 + ]
139 +
140 + input_data_dict = (tf.train.batch_join(examples_and_labels,
141 + batch_size=batch_size,
142 + allow_smaller_final_batch=True,
143 + enqueue_many=True))
144 + video_id_batch = input_data_dict["video_ids"]
145 + video_batch = input_data_dict["video_matrix"]
146 + num_frames_batch = input_data_dict["num_frames"]
147 + return video_id_batch, video_batch, num_frames_batch
148 +
149 +
150 +def get_segments(batch_video_mtx, batch_num_frames, segment_size):
151 + """Get segment-level inputs from frame-level features."""
152 + video_batch_size = batch_video_mtx.shape[0]
153 + max_frame = batch_video_mtx.shape[1]
154 + feature_dim = batch_video_mtx.shape[-1]
155 + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size
156 + padded_segment_sizes *= segment_size
157 + segment_mask = (
158 + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame)))
159 +
160 + # Segment bags.
161 + frame_bags = batch_video_mtx.reshape((-1, feature_dim))
162 + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape(
163 + (-1, segment_size, feature_dim))
164 +
165 + # Segment num frames.
166 + segment_start_times = np.arange(0, max_frame, segment_size)
167 + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times
168 + num_segment_bags = num_segments.reshape((-1))
169 + valid_segment_mask = num_segment_bags > 0
170 + segment_num_frames = num_segment_bags[valid_segment_mask]
171 + segment_num_frames[segment_num_frames > segment_size] = segment_size
172 +
173 + max_segment_num = (max_frame + segment_size - 1) // segment_size
174 + video_idxs = np.tile(
175 + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num])
176 + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1])
177 + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2))
178 + video_segment_ids = idx_bags[valid_segment_mask]
179 +
180 + return {
181 + "video_batch": segment_frames,
182 + "num_frames_batch": segment_num_frames,
183 + "video_segment_ids": video_segment_ids
184 + }
185 +
186 +
187 +def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
188 + top_k):
189 + """Inference function."""
190 + with tf.Session(config=tf.ConfigProto(
191 + allow_soft_placement=True)) as sess, gfile.Open(out_file_location,
192 + "w+") as out_file:
193 + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(
194 + reader, data_pattern, batch_size)
195 + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model"
196 + checkpoint_file = os.path.join(train_dir, "inference_model",
197 + inference_model_name)
198 + if not gfile.Exists(checkpoint_file + ".meta"):
199 + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file)
200 + meta_graph_location = checkpoint_file + ".meta"
201 + logging.info("loading meta-graph: " + meta_graph_location)
202 +
203 + if FLAGS.output_model_tgz:
204 + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar:
205 + for model_file in glob.glob(checkpoint_file + ".*"):
206 + tar.add(model_file, arcname=os.path.basename(model_file))
207 + tar.add(os.path.join(train_dir, "model_flags.json"),
208 + arcname="model_flags.json")
209 + print("Tarred model onto " + FLAGS.output_model_tgz)
210 + with tf.device("/cpu:0"):
211 + saver = tf.train.import_meta_graph(meta_graph_location,
212 + clear_devices=True)
213 + logging.info("restoring variables from " + checkpoint_file)
214 + saver.restore(sess, checkpoint_file)
215 + input_tensor = tf.get_collection("input_batch_raw")[0]
216 + num_frames_tensor = tf.get_collection("num_frames")[0]
217 + predictions_tensor = tf.get_collection("predictions")[0]
218 +
219 + # Workaround for num_epochs issue.
220 + def set_up_init_ops(variables):
221 + init_op_list = []
222 + for variable in list(variables):
223 + if "train_input" in variable.name:
224 + init_op_list.append(tf.assign(variable, 1))
225 + variables.remove(variable)
226 + init_op_list.append(tf.variables_initializer(variables))
227 + return init_op_list
228 +
229 + sess.run(
230 + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)))
231 +
232 + coord = tf.train.Coordinator()
233 + threads = tf.train.start_queue_runners(sess=sess, coord=coord)
234 + num_examples_processed = 0
235 + start_time = time.time()
236 + whitelisted_cls_mask = None
237 + if FLAGS.segment_labels:
238 + final_out_file = out_file
239 + out_file = tempfile.NamedTemporaryFile()
240 + logging.info(
241 + "Segment temp prediction output will be written to temp file: %s",
242 + out_file.name)
243 + if FLAGS.segment_label_ids_file:
244 + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],),
245 + dtype=np.float32)
246 + segment_label_ids_file = FLAGS.segment_label_ids_file
247 + if segment_label_ids_file.startswith("http"):
248 + logging.info("Retrieving segment ID whitelist files from %s...",
249 + segment_label_ids_file)
250 + segment_label_ids_file, _ = urllib.request.urlretrieve(
251 + segment_label_ids_file)
252 + with tf.io.gfile.GFile(segment_label_ids_file) as fobj:
253 + for line in fobj:
254 + try:
255 + cls_id = int(line)
256 + whitelisted_cls_mask[cls_id] = 1.
257 + except ValueError:
258 + # Simply skip the non-integer line.
259 + continue
260 +
261 + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8"))
262 +
263 + try:
264 + while not coord.should_stop():
265 + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run(
266 + [video_id_batch, video_batch, num_frames_batch])
267 + if FLAGS.segment_labels:
268 + results = get_segments(video_batch_val, num_frames_batch_val, 5)
269 + video_segment_ids = results["video_segment_ids"]
270 + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]]
271 + video_id_batch_val = np.array([
272 + "%s:%d" % (x.decode("utf8"), y)
273 + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1])
274 + ])
275 + video_batch_val = results["video_batch"]
276 + num_frames_batch_val = results["num_frames_batch"]
277 + if input_tensor.get_shape()[1] != video_batch_val.shape[1]:
278 + raise ValueError("max_frames mismatch. Please re-run the eval.py "
279 + "with correct segment_labels settings.")
280 +
281 + predictions_val, = sess.run([predictions_tensor],
282 + feed_dict={
283 + input_tensor: video_batch_val,
284 + num_frames_tensor: num_frames_batch_val
285 + })
286 + now = time.time()
287 + num_examples_processed += len(video_batch_val)
288 + elapsed_time = now - start_time
289 + logging.info("num examples processed: " + str(num_examples_processed) +
290 + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) +
291 + " examples/sec: %.2f" %
292 + (num_examples_processed / elapsed_time))
293 + for line in format_lines(video_id_batch_val, predictions_val, top_k,
294 + whitelisted_cls_mask):
295 + out_file.write(line)
296 + out_file.flush()
297 +
298 + except tf.errors.OutOfRangeError:
299 + logging.info("Done with inference. The output file was written to " +
300 + out_file.name)
301 + finally:
302 + coord.request_stop()
303 +
304 + if FLAGS.segment_labels:
305 + # Re-read the file and do heap sort.
306 + # Create multiple heaps.
307 + logging.info("Post-processing segment predictions...")
308 + segment_id_list = []
309 + segment_classes = []
310 + cls_result_arr = []
311 + out_file.seek(0, 0)
312 + for line in out_file:
313 + segment_id, preds = line.decode("utf8").split(",")
314 + if segment_id == "VideoId":
315 + # Skip the headline.
316 + continue
317 +
318 + preds = preds.split(" ")
319 + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)]
320 + # =======================================
321 + segment_id = str(segment_id.split(":")[0])
322 + if segment_id not in segment_id_list:
323 + segment_id_list.append(str(segment_id))
324 + segment_classes.append("")
325 +
326 + index = segment_id_list.index(segment_id)
327 + for classes in pred_cls_ids:
328 + segment_classes[index] = str(segment_classes[index]) + str(
329 + classes) + " " # append classes from new segment
330 +
331 + for segs, item in zip(segment_id_list, segment_classes):
332 + print('====== R E C O R D ======')
333 + cls_arr = item.split(" ")[:-1]
334 +
335 + cls_arr = list(map(int, cls_arr))
336 + cls_arr = sorted(cls_arr)
337 +
338 + result_string = ""
339 +
340 + temp = Counter(cls_arr)
341 + for item in temp:
342 + result_string = result_string + str(item) + ":" + str(temp[item]) + ","
343 +
344 + cls_result_arr.append(result_string[:-1])
345 + logging.info(segs + " : " + result_string[:-1])
346 + # =======================================
347 + final_out_file.write("vid_id,seg_classes\n")
348 + for seg_id, class_indcies in zip(segment_id_list, cls_result_arr):
349 + final_out_file.write("%s,%s\n" % (seg_id, str(class_indcies)))
350 + final_out_file.close()
351 +
352 + out_file.close()
353 +
354 + coord.join(threads)
355 + sess.close()
356 +
357 +
358 +def main(unused_argv):
359 + logging.set_verbosity(tf.logging.INFO)
360 + if FLAGS.input_model_tgz:
361 + if FLAGS.train_dir:
362 + raise ValueError("You cannot supply --train_dir if supplying "
363 + "--input_model_tgz")
364 + # Untar.
365 + if not os.path.exists(FLAGS.untar_model_dir):
366 + os.makedirs(FLAGS.untar_model_dir)
367 + tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir)
368 + FLAGS.train_dir = FLAGS.untar_model_dir
369 +
370 + flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json")
371 + if not file_io.file_exists(flags_dict_file):
372 + raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file)
373 + flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read())
374 +
375 + # convert feature_names and feature_sizes to lists of values
376 + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
377 + flags_dict["feature_names"], flags_dict["feature_sizes"])
378 +
379 + if flags_dict["frame_features"]:
380 + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
381 + feature_sizes=feature_sizes)
382 + else:
383 + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
384 + feature_sizes=feature_sizes)
385 +
386 + if not FLAGS.output_file:
387 + raise ValueError("'output_file' was not specified. "
388 + "Unable to continue with inference.")
389 +
390 + if not FLAGS.input_data_pattern:
391 + raise ValueError("'input_data_pattern' was not specified. "
392 + "Unable to continue with inference.")
393 +
394 + inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern,
395 + FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k)
396 +
397 +
398 +if __name__ == "__main__":
399 + app.run()
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Provides definitions for non-regularized training or test losses."""
15 +
16 +import tensorflow as tf
17 +
18 +
19 +class BaseLoss(object):
20 + """Inherit from this class when implementing new losses."""
21 +
22 + def calculate_loss(self, unused_predictions, unused_labels, **unused_params):
23 + """Calculates the average loss of the examples in a mini-batch.
24 +
25 + Args:
26 + unused_predictions: a 2-d tensor storing the prediction scores, in which
27 + each row represents a sample in the mini-batch and each column
28 + represents a class.
29 + unused_labels: a 2-d tensor storing the labels, which has the same shape
30 + as the unused_predictions. The labels must be in the range of 0 and 1.
31 + unused_params: loss specific parameters.
32 +
33 + Returns:
34 + A scalar loss tensor.
35 + """
36 + raise NotImplementedError()
37 +
38 +
39 +class CrossEntropyLoss(BaseLoss):
40 + """Calculate the cross entropy loss between the predictions and labels."""
41 +
42 + def calculate_loss(self,
43 + predictions,
44 + labels,
45 + label_weights=None,
46 + **unused_params):
47 + with tf.name_scope("loss_xent"):
48 + epsilon = 1e-5
49 + float_labels = tf.cast(labels, tf.float32)
50 + cross_entropy_loss = float_labels * tf.math.log(predictions + epsilon) + (
51 + 1 - float_labels) * tf.math.log(1 - predictions + epsilon)
52 + cross_entropy_loss = tf.negative(cross_entropy_loss)
53 + if label_weights is not None:
54 + cross_entropy_loss *= label_weights
55 + return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1))
56 +
57 +
58 +class HingeLoss(BaseLoss):
59 + """Calculate the hinge loss between the predictions and labels.
60 +
61 + Note the subgradient is used in the backpropagation, and thus the optimization
62 + may converge slower. The predictions trained by the hinge loss are between -1
63 + and +1.
64 + """
65 +
66 + def calculate_loss(self, predictions, labels, b=1.0, **unused_params):
67 + with tf.name_scope("loss_hinge"):
68 + float_labels = tf.cast(labels, tf.float32)
69 + all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32)
70 + all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32)
71 + sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones)
72 + hinge_loss = tf.maximum(
73 + all_zeros,
74 + tf.scalar_mul(b, all_ones) - sign_labels * predictions)
75 + return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1))
76 +
77 +
78 +class SoftmaxLoss(BaseLoss):
79 + """Calculate the softmax loss between the predictions and labels.
80 +
81 + The function calculates the loss in the following way: first we feed the
82 + predictions to the softmax activation function and then we calculate
83 + the minus linear dot product between the logged softmax activations and the
84 + normalized ground truth label.
85 +
86 + It is an extension to the one-hot label. It allows for more than one positive
87 + labels for each sample.
88 + """
89 +
90 + def calculate_loss(self, predictions, labels, **unused_params):
91 + with tf.name_scope("loss_softmax"):
92 + epsilon = 10e-8
93 + float_labels = tf.cast(labels, tf.float32)
94 + # l1 normalization (labels are no less than 0)
95 + label_rowsum = tf.maximum(tf.reduce_sum(float_labels, 1, keep_dims=True),
96 + epsilon)
97 + norm_float_labels = tf.div(float_labels, label_rowsum)
98 + softmax_outputs = tf.nn.softmax(predictions)
99 + softmax_loss = tf.negative(
100 + tf.reduce_sum(tf.multiply(norm_float_labels, tf.log(softmax_outputs)),
101 + 1))
102 + return tf.reduce_mean(softmax_loss)
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Calculate the mean average precision.
15 +
16 +It provides an interface for calculating mean average precision
17 +for an entire list or the top-n ranked items.
18 +
19 +Example usages:
20 +We first call the function accumulate many times to process parts of the ranked
21 +list. After processing all the parts, we call peek_map_at_n
22 +to calculate the mean average precision.
23 +
24 +```
25 +import random
26 +
27 +p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)])
28 +a = np.array([[random.choice([0, 1]) for _ in xrange(50)]
29 + for _ in xrange(1000)])
30 +
31 +# mean average precision for 50 classes.
32 +calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator(
33 + num_class=50)
34 +calculator.accumulate(p, a)
35 +aps = calculator.peek_map_at_n()
36 +```
37 +"""
38 +
39 +import average_precision_calculator
40 +
41 +
42 +class MeanAveragePrecisionCalculator(object):
43 + """This class is to calculate mean average precision."""
44 +
45 + def __init__(self, num_class, filter_empty_classes=True, top_n=None):
46 + """Construct a calculator to calculate the (macro) average precision.
47 +
48 + Args:
49 + num_class: A positive Integer specifying the number of classes.
50 + filter_empty_classes: whether to filter classes without any positives.
51 + top_n: A positive Integer specifying the average precision at n, or None
52 + to use all provided data points.
53 +
54 + Raises:
55 + ValueError: An error occurred when num_class is not a positive integer;
56 + or the top_n_array is not a list of positive integers.
57 + """
58 + if not isinstance(num_class, int) or num_class <= 1:
59 + raise ValueError("num_class must be a positive integer.")
60 +
61 + self._ap_calculators = [] # member of AveragePrecisionCalculator
62 + self._num_class = num_class # total number of classes
63 + self._filter_empty_classes = filter_empty_classes
64 + for _ in range(num_class):
65 + self._ap_calculators.append(
66 + average_precision_calculator.AveragePrecisionCalculator(top_n=top_n))
67 +
68 + def accumulate(self, predictions, actuals, num_positives=None):
69 + """Accumulate the predictions and their ground truth labels.
70 +
71 + Args:
72 + predictions: A list of lists storing the prediction scores. The outer
73 + dimension corresponds to classes.
74 + actuals: A list of lists storing the ground truth labels. The dimensions
75 + should correspond to the predictions input. Any value larger than 0 will
76 + be treated as positives, otherwise as negatives.
77 + num_positives: If provided, it is a list of numbers representing the
78 + number of true positives for each class. If not provided, the number of
79 + true positives will be inferred from the 'actuals' array.
80 +
81 + Raises:
82 + ValueError: An error occurred when the shape of predictions and actuals
83 + does not match.
84 + """
85 + if not num_positives:
86 + num_positives = [None for i in range(self._num_class)]
87 +
88 + calculators = self._ap_calculators
89 + for i in range(self._num_class):
90 + calculators[i].accumulate(predictions[i], actuals[i], num_positives[i])
91 +
92 + def clear(self):
93 + for calculator in self._ap_calculators:
94 + calculator.clear()
95 +
96 + def is_empty(self):
97 + return ([calculator.heap_size for calculator in self._ap_calculators
98 + ] == [0 for _ in range(self._num_class)])
99 +
100 + def peek_map_at_n(self):
101 + """Peek the non-interpolated mean average precision at n.
102 +
103 + Returns:
104 + An array of non-interpolated average precision at n (default 0) for each
105 + class.
106 + """
107 + aps = []
108 + for i in range(self._num_class):
109 + if (not self._filter_empty_classes or
110 + self._ap_calculators[i].num_accumulated_positives > 0):
111 + ap = self._ap_calculators[i].peek_ap_at_n()
112 + aps.append(ap)
113 + return aps
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains a collection of util functions for model construction."""
15 +import numpy
16 +import tensorflow as tf
17 +from tensorflow import logging
18 +from tensorflow import flags
19 +import tensorflow.contrib.slim as slim
20 +
21 +
22 +def SampleRandomSequence(model_input, num_frames, num_samples):
23 + """Samples a random sequence of frames of size num_samples.
24 +
25 + Args:
26 + model_input: A tensor of size batch_size x max_frames x feature_size
27 + num_frames: A tensor of size batch_size x 1
28 + num_samples: A scalar
29 +
30 + Returns:
31 + `model_input`: A tensor of size batch_size x num_samples x feature_size
32 + """
33 +
34 + batch_size = tf.shape(model_input)[0]
35 + frame_index_offset = tf.tile(tf.expand_dims(tf.range(num_samples), 0),
36 + [batch_size, 1])
37 + max_start_frame_index = tf.maximum(num_frames - num_samples, 0)
38 + start_frame_index = tf.cast(
39 + tf.multiply(tf.random_uniform([batch_size, 1]),
40 + tf.cast(max_start_frame_index + 1, tf.float32)), tf.int32)
41 + frame_index = tf.minimum(start_frame_index + frame_index_offset,
42 + tf.cast(num_frames - 1, tf.int32))
43 + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1),
44 + [1, num_samples])
45 + index = tf.stack([batch_index, frame_index], 2)
46 + return tf.gather_nd(model_input, index)
47 +
48 +
49 +def SampleRandomFrames(model_input, num_frames, num_samples):
50 + """Samples a random set of frames of size num_samples.
51 +
52 + Args:
53 + model_input: A tensor of size batch_size x max_frames x feature_size
54 + num_frames: A tensor of size batch_size x 1
55 + num_samples: A scalar
56 +
57 + Returns:
58 + `model_input`: A tensor of size batch_size x num_samples x feature_size
59 + """
60 + batch_size = tf.shape(model_input)[0]
61 + frame_index = tf.cast(
62 + tf.multiply(tf.random_uniform([batch_size, num_samples]),
63 + tf.tile(tf.cast(num_frames, tf.float32), [1, num_samples])),
64 + tf.int32)
65 + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1),
66 + [1, num_samples])
67 + index = tf.stack([batch_index, frame_index], 2)
68 + return tf.gather_nd(model_input, index)
69 +
70 +
71 +def FramePooling(frames, method, **unused_params):
72 + """Pools over the frames of a video.
73 +
74 + Args:
75 + frames: A tensor with shape [batch_size, num_frames, feature_size].
76 + method: "average", "max", "attention", or "none".
77 +
78 + Returns:
79 + A tensor with shape [batch_size, feature_size] for average, max, or
80 + attention pooling. A tensor with shape [batch_size*num_frames, feature_size]
81 + for none pooling.
82 +
83 + Raises:
84 + ValueError: if method is other than "average", "max", "attention", or
85 + "none".
86 + """
87 + if method == "average":
88 + return tf.reduce_mean(frames, 1)
89 + elif method == "max":
90 + return tf.reduce_max(frames, 1)
91 + elif method == "none":
92 + feature_size = frames.shape_as_list()[2]
93 + return tf.reshape(frames, [-1, feature_size])
94 + else:
95 + raise ValueError("Unrecognized pooling method: %s" % method)
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains the base class for models."""
15 +
16 +
17 +class BaseModel(object):
18 + """Inherit from this class when implementing new models."""
19 +
20 + def create_model(self, unused_model_input, **unused_params):
21 + raise NotImplementedError()
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Provides readers configured for different datasets."""
15 +
16 +import tensorflow as tf
17 +import utils
18 +
19 +
20 +def resize_axis(tensor, axis, new_size, fill_value=0):
21 + """Truncates or pads a tensor to new_size on on a given axis.
22 +
23 + Truncate or extend tensor such that tensor.shape[axis] == new_size. If the
24 + size increases, the padding will be performed at the end, using fill_value.
25 +
26 + Args:
27 + tensor: The tensor to be resized.
28 + axis: An integer representing the dimension to be sliced.
29 + new_size: An integer or 0d tensor representing the new value for
30 + tensor.shape[axis].
31 + fill_value: Value to use to fill any new entries in the tensor. Will be cast
32 + to the type of tensor.
33 +
34 + Returns:
35 + The resized tensor.
36 + """
37 + tensor = tf.convert_to_tensor(tensor)
38 + shape = tf.unstack(tf.shape(tensor))
39 +
40 + pad_shape = shape[:]
41 + pad_shape[axis] = tf.maximum(0, new_size - shape[axis])
42 +
43 + shape[axis] = tf.minimum(shape[axis], new_size)
44 + shape = tf.stack(shape)
45 +
46 + resized = tf.concat([
47 + tf.slice(tensor, tf.zeros_like(shape), shape),
48 + tf.fill(tf.stack(pad_shape), tf.cast(fill_value, tensor.dtype))
49 + ], axis)
50 +
51 + # Update shape.
52 + new_shape = tensor.get_shape().as_list() # A copy is being made.
53 + new_shape[axis] = new_size
54 + resized.set_shape(new_shape)
55 + return resized
56 +
57 +
58 +class BaseReader(object):
59 + """Inherit from this class when implementing new readers."""
60 +
61 + def prepare_reader(self, unused_filename_queue):
62 + """Create a thread for generating prediction and label tensors."""
63 + raise NotImplementedError()
64 +
65 +
66 +class YT8MAggregatedFeatureReader(BaseReader):
67 + """Reads TFRecords of pre-aggregated Examples.
68 +
69 + The TFRecords must contain Examples with a sparse int64 'labels' feature and
70 + a fixed length float32 feature, obtained from the features in 'feature_name'.
71 + The float features are assumed to be an average of dequantized values.
72 + """
73 +
74 + def __init__( # pylint: disable=dangerous-default-value
75 + self,
76 + num_classes=3862,
77 + feature_sizes=[1024, 128],
78 + feature_names=["mean_rgb", "mean_audio"]):
79 + """Construct a YT8MAggregatedFeatureReader.
80 +
81 + Args:
82 + num_classes: a positive integer for the number of classes.
83 + feature_sizes: positive integer(s) for the feature dimensions as a list.
84 + feature_names: the feature name(s) in the tensorflow record as a list.
85 + """
86 +
87 + assert len(feature_names) == len(feature_sizes), (
88 + "length of feature_names (={}) != length of feature_sizes (={})".format(
89 + len(feature_names), len(feature_sizes)))
90 +
91 + self.num_classes = num_classes
92 + self.feature_sizes = feature_sizes
93 + self.feature_names = feature_names
94 +
95 + def prepare_reader(self, filename_queue, batch_size=1024):
96 + """Creates a single reader thread for pre-aggregated YouTube 8M Examples.
97 +
98 + Args:
99 + filename_queue: A tensorflow queue of filename locations.
100 + batch_size: batch size used for feature output.
101 +
102 + Returns:
103 + A dict of video indexes, features, labels, and frame counts.
104 + """
105 + reader = tf.TFRecordReader()
106 + _, serialized_examples = reader.read_up_to(filename_queue, batch_size)
107 +
108 + tf.add_to_collection("serialized_examples", serialized_examples)
109 + return self.prepare_serialized_examples(serialized_examples)
110 +
111 + def prepare_serialized_examples(self, serialized_examples):
112 + """Parse a single video-level TF Example."""
113 + # set the mapping from the fields to data types in the proto
114 + num_features = len(self.feature_names)
115 + assert num_features > 0, "self.feature_names is empty!"
116 + assert len(self.feature_names) == len(self.feature_sizes), \
117 + "length of feature_names (={}) != length of feature_sizes (={})".format(
118 + len(self.feature_names), len(self.feature_sizes))
119 +
120 + feature_map = {
121 + "id": tf.io.FixedLenFeature([], tf.string),
122 + "labels": tf.io.VarLenFeature(tf.int64)
123 + }
124 + for feature_index in range(num_features):
125 + feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature(
126 + [self.feature_sizes[feature_index]], tf.float32)
127 +
128 + features = tf.parse_example(serialized_examples, features=feature_map)
129 + labels = tf.sparse_to_indicator(features["labels"], self.num_classes)
130 + labels.set_shape([None, self.num_classes])
131 + concatenated_features = tf.concat(
132 + [features[feature_name] for feature_name in self.feature_names], 1)
133 +
134 + output_dict = {
135 + "video_ids": features["id"],
136 + "video_matrix": concatenated_features,
137 + "labels": labels,
138 + "num_frames": tf.ones([tf.shape(serialized_examples)[0]])
139 + }
140 +
141 + return output_dict
142 +
143 +
144 +class YT8MFrameFeatureReader(BaseReader):
145 + """Reads TFRecords of SequenceExamples.
146 +
147 + The TFRecords must contain SequenceExamples with the sparse in64 'labels'
148 + context feature and a fixed length byte-quantized feature vector, obtained
149 + from the features in 'feature_names'. The quantized features will be mapped
150 + back into a range between min_quantized_value and max_quantized_value.
151 + """
152 +
153 + def __init__( # pylint: disable=dangerous-default-value
154 + self,
155 + num_classes=3862,
156 + feature_sizes=[1024, 128],
157 + feature_names=["rgb", "audio"],
158 + max_frames=300,
159 + segment_labels=False,
160 + segment_size=5):
161 + """Construct a YT8MFrameFeatureReader.
162 +
163 + Args:
164 + num_classes: a positive integer for the number of classes.
165 + feature_sizes: positive integer(s) for the feature dimensions as a list.
166 + feature_names: the feature name(s) in the tensorflow record as a list.
167 + max_frames: the maximum number of frames to process.
168 + segment_labels: if we read segment labels instead.
169 + segment_size: the segment_size used for reading segments.
170 + """
171 +
172 + assert len(feature_names) == len(feature_sizes), (
173 + "length of feature_names (={}) != length of feature_sizes (={})".format(
174 + len(feature_names), len(feature_sizes)))
175 +
176 + self.num_classes = num_classes
177 + self.feature_sizes = feature_sizes
178 + self.feature_names = feature_names
179 + self.max_frames = max_frames
180 + self.segment_labels = segment_labels
181 + self.segment_size = segment_size
182 +
183 + def get_video_matrix(self, features, feature_size, max_frames,
184 + max_quantized_value, min_quantized_value):
185 + """Decodes features from an input string and quantizes it.
186 +
187 + Args:
188 + features: raw feature values
189 + feature_size: length of each frame feature vector
190 + max_frames: number of frames (rows) in the output feature_matrix
191 + max_quantized_value: the maximum of the quantized value.
192 + min_quantized_value: the minimum of the quantized value.
193 +
194 + Returns:
195 + feature_matrix: matrix of all frame-features
196 + num_frames: number of frames in the sequence
197 + """
198 + decoded_features = tf.reshape(
199 + tf.cast(tf.decode_raw(features, tf.uint8), tf.float32),
200 + [-1, feature_size])
201 +
202 + num_frames = tf.minimum(tf.shape(decoded_features)[0], max_frames)
203 + feature_matrix = utils.Dequantize(decoded_features, max_quantized_value,
204 + min_quantized_value)
205 + feature_matrix = resize_axis(feature_matrix, 0, max_frames)
206 + return feature_matrix, num_frames
207 +
208 + def prepare_reader(self,
209 + filename_queue,
210 + max_quantized_value=2,
211 + min_quantized_value=-2):
212 + """Creates a single reader thread for YouTube8M SequenceExamples.
213 +
214 + Args:
215 + filename_queue: A tensorflow queue of filename locations.
216 + max_quantized_value: the maximum of the quantized value.
217 + min_quantized_value: the minimum of the quantized value.
218 +
219 + Returns:
220 + A dict of video indexes, video features, labels, and frame counts.
221 + """
222 + reader = tf.TFRecordReader()
223 + _, serialized_example = reader.read(filename_queue)
224 +
225 + return self.prepare_serialized_examples(serialized_example,
226 + max_quantized_value,
227 + min_quantized_value)
228 +
229 + def prepare_serialized_examples(self,
230 + serialized_example,
231 + max_quantized_value=2,
232 + min_quantized_value=-2):
233 + """Parse single serialized SequenceExample from the TFRecords."""
234 +
235 + # Read/parse frame/segment-level labels.
236 + context_features = {
237 + "id": tf.io.FixedLenFeature([], tf.string),
238 + }
239 + if self.segment_labels:
240 + context_features.update({
241 + # There is no need to read end-time given we always assume the segment
242 + # has the same size.
243 + "segment_labels": tf.io.VarLenFeature(tf.int64),
244 + "segment_start_times": tf.io.VarLenFeature(tf.int64),
245 + "segment_scores": tf.io.VarLenFeature(tf.float32)
246 + })
247 + else:
248 + context_features.update({"labels": tf.io.VarLenFeature(tf.int64)})
249 + sequence_features = {
250 + feature_name: tf.io.FixedLenSequenceFeature([], dtype=tf.string)
251 + for feature_name in self.feature_names
252 + }
253 + contexts, features = tf.io.parse_single_sequence_example(
254 + serialized_example,
255 + context_features=context_features,
256 + sequence_features=sequence_features)
257 +
258 + # loads (potentially) different types of features and concatenates them
259 + num_features = len(self.feature_names)
260 + assert num_features > 0, "No feature selected: feature_names is empty!"
261 +
262 + assert len(self.feature_names) == len(self.feature_sizes), (
263 + "length of feature_names (={}) != length of feature_sizes (={})".format(
264 + len(self.feature_names), len(self.feature_sizes)))
265 +
266 + num_frames = -1 # the number of frames in the video
267 + feature_matrices = [None] * num_features # an array of different features
268 + for feature_index in range(num_features):
269 + feature_matrix, num_frames_in_this_feature = self.get_video_matrix(
270 + features[self.feature_names[feature_index]],
271 + self.feature_sizes[feature_index], self.max_frames,
272 + max_quantized_value, min_quantized_value)
273 + if num_frames == -1:
274 + num_frames = num_frames_in_this_feature
275 +
276 + feature_matrices[feature_index] = feature_matrix
277 +
278 + # cap the number of frames at self.max_frames
279 + num_frames = tf.minimum(num_frames, self.max_frames)
280 +
281 + # concatenate different features
282 + video_matrix = tf.concat(feature_matrices, 1)
283 +
284 + # Partition frame-level feature matrix to segment-level feature matrix.
285 + if self.segment_labels:
286 + start_times = contexts["segment_start_times"].values
287 + # Here we assume all the segments that started at the same start time has
288 + # the same segment_size.
289 + uniq_start_times, seg_idxs = tf.unique(start_times,
290 + out_idx=tf.dtypes.int64)
291 + # TODO(zhengxu): Ensure the segment_sizes are all same.
292 + segment_size = self.segment_size
293 + # Range gather matrix, e.g., [[0,1,2],[1,2,3]] for segment_size == 3.
294 + range_mtx = tf.expand_dims(uniq_start_times, axis=-1) + tf.expand_dims(
295 + tf.range(0, segment_size, dtype=tf.int64), axis=0)
296 + # Shape: [num_segment, segment_size, feature_dim].
297 + batch_video_matrix = tf.gather_nd(video_matrix,
298 + tf.expand_dims(range_mtx, axis=-1))
299 + num_segment = tf.shape(batch_video_matrix)[0]
300 + batch_video_ids = tf.reshape(tf.tile([contexts["id"]], [num_segment]),
301 + (num_segment,))
302 + batch_frames = tf.reshape(tf.tile([segment_size], [num_segment]),
303 + (num_segment,))
304 +
305 + # For segment labels, all labels are not exhausively rated. So we only
306 + # evaluate the rated labels.
307 +
308 + # Label indices for each segment, shape: [num_segment, 2].
309 + label_indices = tf.stack([seg_idxs, contexts["segment_labels"].values],
310 + axis=-1)
311 + label_values = contexts["segment_scores"].values
312 + sparse_labels = tf.sparse.SparseTensor(label_indices, label_values,
313 + (num_segment, self.num_classes))
314 + batch_labels = tf.sparse.to_dense(sparse_labels, validate_indices=False)
315 +
316 + sparse_label_weights = tf.sparse.SparseTensor(
317 + label_indices, tf.ones_like(label_values, dtype=tf.float32),
318 + (num_segment, self.num_classes))
319 + batch_label_weights = tf.sparse.to_dense(sparse_label_weights,
320 + validate_indices=False)
321 + else:
322 + # Process video-level labels.
323 + label_indices = contexts["labels"].values
324 + sparse_labels = tf.sparse.SparseTensor(
325 + tf.expand_dims(label_indices, axis=-1),
326 + tf.ones_like(contexts["labels"].values, dtype=tf.bool),
327 + (self.num_classes,))
328 + labels = tf.sparse.to_dense(sparse_labels,
329 + default_value=False,
330 + validate_indices=False)
331 + # convert to batch format.
332 + batch_video_ids = tf.expand_dims(contexts["id"], 0)
333 + batch_video_matrix = tf.expand_dims(video_matrix, 0)
334 + batch_labels = tf.expand_dims(labels, 0)
335 + batch_frames = tf.expand_dims(num_frames, 0)
336 + batch_label_weights = None
337 +
338 + output_dict = {
339 + "video_ids": batch_video_ids,
340 + "video_matrix": batch_video_matrix,
341 + "labels": batch_labels,
342 + "num_frames": batch_frames,
343 + }
344 + if batch_label_weights is not None:
345 + output_dict["label_weights"] = batch_label_weights
346 +
347 + return output_dict
1 +"""Eval mAP@N metric from inference file."""
2 +
3 +from __future__ import absolute_import
4 +from __future__ import division
5 +from __future__ import print_function
6 +
7 +from absl import app
8 +from absl import flags
9 +
10 +import mean_average_precision_calculator as map_calculator
11 +import numpy as np
12 +import tensorflow as tf
13 +
14 +flags.DEFINE_string(
15 + "eval_data_pattern", "",
16 + "File glob defining the evaluation dataset in tensorflow.SequenceExample "
17 + "format. The SequenceExamples are expected to have an 'rgb' byte array "
18 + "sequence feature as well as a 'labels' int64 context feature.")
19 +flags.DEFINE_string(
20 + "label_cache", "",
21 + "The path for the label cache file. Leave blank for not to cache.")
22 +flags.DEFINE_string("submission_file", "",
23 + "The segment submission file generated by inference.py.")
24 +flags.DEFINE_integer(
25 + "top_n", 0,
26 + "The cap per-class predictions by a maximum of N. Use 0 for not capping.")
27 +
28 +FLAGS = flags.FLAGS
29 +
30 +
31 +class Labels(object):
32 + """Contains the class to hold label objects.
33 +
34 + This class can serialize and de-serialize the groundtruths.
35 + The ground truth is in a mapping from (segment_id, class_id) -> label_score.
36 + """
37 +
38 + def __init__(self, labels):
39 + """__init__ method."""
40 + self._labels = labels
41 +
42 + @property
43 + def labels(self):
44 + """Return the ground truth mapping. See class docstring for details."""
45 + return self._labels
46 +
47 + def to_file(self, file_name):
48 + """Materialize the GT mapping to file."""
49 + with tf.gfile.Open(file_name, "w") as fobj:
50 + for k, v in self._labels.items():
51 + seg_id, label = k
52 + line = "%s,%s,%s\n" % (seg_id, label, v)
53 + fobj.write(line)
54 +
55 + @classmethod
56 + def from_file(cls, file_name):
57 + """Read the GT mapping from cached file."""
58 + labels = {}
59 + with tf.gfile.Open(file_name) as fobj:
60 + for line in fobj:
61 + line = line.strip().strip("\n")
62 + seg_id, label, score = line.split(",")
63 + labels[(seg_id, int(label))] = float(score)
64 + return cls(labels)
65 +
66 +
67 +def read_labels(data_pattern, cache_path=""):
68 + """Read labels from TFRecords.
69 +
70 + Args:
71 + data_pattern: the data pattern to the TFRecords.
72 + cache_path: the cache path for the label file.
73 +
74 + Returns:
75 + a Labels object.
76 + """
77 + if cache_path:
78 + if tf.gfile.Exists(cache_path):
79 + tf.logging.info("Reading cached labels from %s..." % cache_path)
80 + return Labels.from_file(cache_path)
81 + tf.enable_eager_execution()
82 + data_paths = tf.gfile.Glob(data_pattern)
83 + ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50)
84 + context_features = {
85 + "id": tf.FixedLenFeature([], tf.string),
86 + "segment_labels": tf.VarLenFeature(tf.int64),
87 + "segment_start_times": tf.VarLenFeature(tf.int64),
88 + "segment_scores": tf.VarLenFeature(tf.float32)
89 + }
90 +
91 + def _parse_se_func(sequence_example):
92 + return tf.parse_single_sequence_example(sequence_example,
93 + context_features=context_features)
94 +
95 + ds = ds.map(_parse_se_func)
96 + rated_labels = {}
97 + tf.logging.info("Reading labels from TFRecords...")
98 + last_batch = 0
99 + batch_size = 5000
100 + for cxt_feature_val, _ in ds:
101 + video_id = cxt_feature_val["id"].numpy()
102 + segment_labels = cxt_feature_val["segment_labels"].values.numpy()
103 + segment_start_times = cxt_feature_val["segment_start_times"].values.numpy()
104 + segment_scores = cxt_feature_val["segment_scores"].values.numpy()
105 + for label, start_time, score in zip(segment_labels, segment_start_times,
106 + segment_scores):
107 + rated_labels[("%s:%d" % (video_id, start_time), label)] = score
108 + batch_id = len(rated_labels) // batch_size
109 + if batch_id != last_batch:
110 + tf.logging.info("%d examples processed.", len(rated_labels))
111 + last_batch = batch_id
112 + tf.logging.info("Finish reading labels from TFRecords...")
113 + labels_obj = Labels(rated_labels)
114 + if cache_path:
115 + tf.logging.info("Caching labels to %s..." % cache_path)
116 + labels_obj.to_file(cache_path)
117 + return labels_obj
118 +
119 +
120 +def read_segment_predictions(file_path, labels, top_n=None):
121 + """Read segement predictions.
122 +
123 + Args:
124 + file_path: the submission file path.
125 + labels: a Labels object containing the eval labels.
126 + top_n: the per-class class capping.
127 +
128 + Returns:
129 + a segment prediction list for each classes.
130 + """
131 + cls_preds = {} # A label_id to pred list mapping.
132 + with tf.gfile.Open(file_path) as fobj:
133 + tf.logging.info("Reading predictions from %s..." % file_path)
134 + for line in fobj:
135 + label_id, pred_ids_val = line.split(",")
136 + pred_ids = pred_ids_val.split(" ")
137 + if top_n:
138 + pred_ids = pred_ids[:top_n]
139 + pred_ids = [
140 + pred_id for pred_id in pred_ids
141 + if (pred_id, int(label_id)) in labels.labels
142 + ]
143 + cls_preds[int(label_id)] = pred_ids
144 + if len(cls_preds) % 50 == 0:
145 + tf.logging.info("Processed %d classes..." % len(cls_preds))
146 + tf.logging.info("Finish reading predictions.")
147 + return cls_preds
148 +
149 +
150 +def main(unused_argv):
151 + """Entry function of the script."""
152 + if not FLAGS.submission_file:
153 + raise ValueError("You must input submission file.")
154 + eval_labels = read_labels(FLAGS.eval_data_pattern,
155 + cache_path=FLAGS.label_cache)
156 + tf.logging.info("Total rated segments: %d." % len(eval_labels.labels))
157 + positive_counter = {}
158 + for k, v in eval_labels.labels.items():
159 + _, label_id = k
160 + if v > 0:
161 + positive_counter[label_id] = positive_counter.get(label_id, 0) + 1
162 +
163 + seg_preds = read_segment_predictions(FLAGS.submission_file,
164 + eval_labels,
165 + top_n=FLAGS.top_n)
166 + map_cal = map_calculator.MeanAveragePrecisionCalculator(len(seg_preds))
167 + seg_labels = []
168 + seg_scored_preds = []
169 + num_positives = []
170 + for label_id in sorted(seg_preds):
171 + class_preds = seg_preds[label_id]
172 + seg_label = [eval_labels.labels[(pred, label_id)] for pred in class_preds]
173 + seg_labels.append(seg_label)
174 + seg_scored_pred = []
175 + if class_preds:
176 + seg_scored_pred = [
177 + float(x) / len(class_preds) for x in range(len(class_preds), 0, -1)
178 + ]
179 + seg_scored_preds.append(seg_scored_pred)
180 + num_positives.append(positive_counter[label_id])
181 + map_cal.accumulate(seg_scored_preds, seg_labels, num_positives)
182 + map_at_n = np.mean(map_cal.peek_map_at_n())
183 + tf.logging.info("Num classes: %d | mAP@%d: %.6f" %
184 + (len(seg_preds), FLAGS.top_n, map_at_n))
185 +
186 +
187 +if __name__ == "__main__":
188 + app.run(main)
1 +Index
2 +3
3 +7
4 +8
5 +11
6 +12
7 +17
8 +18
9 +19
10 +21
11 +22
12 +23
13 +28
14 +31
15 +30
16 +32
17 +33
18 +34
19 +41
20 +43
21 +45
22 +46
23 +48
24 +53
25 +54
26 +52
27 +55
28 +58
29 +59
30 +60
31 +61
32 +65
33 +68
34 +73
35 +71
36 +74
37 +75
38 +76
39 +77
40 +80
41 +83
42 +90
43 +88
44 +89
45 +92
46 +95
47 +100
48 +101
49 +99
50 +104
51 +105
52 +109
53 +113
54 +112
55 +115
56 +116
57 +118
58 +120
59 +121
60 +123
61 +125
62 +127
63 +131
64 +128
65 +129
66 +130
67 +137
68 +141
69 +143
70 +145
71 +148
72 +152
73 +151
74 +156
75 +155
76 +158
77 +160
78 +164
79 +163
80 +169
81 +170
82 +172
83 +171
84 +173
85 +174
86 +175
87 +176
88 +178
89 +182
90 +184
91 +186
92 +188
93 +187
94 +192
95 +191
96 +190
97 +194
98 +197
99 +196
100 +198
101 +201
102 +202
103 +200
104 +199
105 +205
106 +204
107 +209
108 +207
109 +206
110 +210
111 +213
112 +214
113 +220
114 +218
115 +217
116 +226
117 +227
118 +231
119 +232
120 +229
121 +233
122 +235
123 +237
124 +244
125 +240
126 +249
127 +246
128 +248
129 +239
130 +250
131 +245
132 +255
133 +253
134 +256
135 +261
136 +259
137 +263
138 +262
139 +266
140 +267
141 +268
142 +269
143 +271
144 +276
145 +273
146 +277
147 +274
148 +278
149 +279
150 +280
151 +288
152 +291
153 +295
154 +294
155 +293
156 +297
157 +296
158 +300
159 +299
160 +303
161 +302
162 +304
163 +305
164 +313
165 +307
166 +311
167 +310
168 +312
169 +316
170 +318
171 +321
172 +322
173 +331
174 +333
175 +329
176 +330
177 +334
178 +343
179 +349
180 +340
181 +344
182 +348
183 +358
184 +347
185 +359
186 +355
187 +361
188 +360
189 +364
190 +365
191 +368
192 +369
193 +366
194 +370
195 +374
196 +380
197 +373
198 +385
199 +384
200 +388
201 +389
202 +382
203 +393
204 +381
205 +390
206 +394
207 +399
208 +397
209 +396
210 +402
211 +400
212 +398
213 +401
214 +405
215 +406
216 +410
217 +408
218 +416
219 +415
220 +419
221 +422
222 +414
223 +421
224 +424
225 +429
226 +418
227 +427
228 +434
229 +428
230 +435
231 +430
232 +441
233 +439
234 +437
235 +443
236 +440
237 +442
238 +445
239 +446
240 +448
241 +454
242 +444
243 +453
244 +455
245 +451
246 +452
247 +458
248 +460
249 +465
250 +457
251 +463
252 +462
253 +461
254 +464
255 +469
256 +468
257 +472
258 +473
259 +471
260 +475
261 +474
262 +477
263 +485
264 +491
265 +488
266 +482
267 +490
268 +496
269 +494
270 +483
271 +495
272 +493
273 +507
274 +501
275 +499
276 +503
277 +498
278 +514
279 +504
280 +502
281 +506
282 +508
283 +511
284 +527
285 +526
286 +532
287 +513
288 +519
289 +525
290 +518
291 +528
292 +522
293 +523
294 +535
295 +539
296 +540
297 +533
298 +521
299 +541
300 +547
301 +550
302 +544
303 +549
304 +551
305 +554
306 +543
307 +548
308 +557
309 +560
310 +552
311 +559
312 +563
313 +565
314 +567
315 +555
316 +576
317 +568
318 +564
319 +573
320 +581
321 +580
322 +572
323 +571
324 +584
325 +590
326 +585
327 +587
328 +588
329 +592
330 +598
331 +597
332 +599
333 +603
334 +600
335 +604
336 +605
337 +614
338 +602
339 +610
340 +608
341 +611
342 +612
343 +613
344 +617
345 +620
346 +607
347 +624
348 +627
349 +625
350 +631
351 +629
352 +638
353 +632
354 +634
355 +644
356 +641
357 +642
358 +646
359 +652
360 +647
361 +637
362 +661
363 +635
364 +658
365 +648
366 +663
367 +668
368 +664
369 +656
370 +666
371 +671
372 +683
373 +675
374 +669
375 +676
376 +667
377 +691
378 +685
379 +673
380 +688
381 +702
382 +684
383 +679
384 +694
385 +686
386 +689
387 +680
388 +693
389 +703
390 +697
391 +698
392 +692
393 +705
394 +706
395 +712
396 +711
397 +709
398 +710
399 +726
400 +713
401 +721
402 +720
403 +715
404 +717
405 +730
406 +728
407 +723
408 +716
409 +722
410 +718
411 +732
412 +724
413 +736
414 +725
415 +742
416 +727
417 +735
418 +740
419 +748
420 +738
421 +746
422 +751
423 +749
424 +752
425 +754
426 +760
427 +763
428 +756
429 +758
430 +766
431 +764
432 +757
433 +780
434 +767
435 +769
436 +771
437 +786
438 +785
439 +781
440 +787
441 +778
442 +783
443 +792
444 +791
445 +795
446 +788
447 +805
448 +802
449 +801
450 +793
451 +796
452 +804
453 +803
454 +797
455 +814
456 +813
457 +789
458 +808
459 +818
460 +816
461 +817
462 +811
463 +820
464 +826
465 +829
466 +824
467 +821
468 +825
469 +822
470 +835
471 +833
472 +843
473 +823
474 +827
475 +830
476 +832
477 +837
478 +852
479 +844
480 +841
481 +812
482 +847
483 +862
484 +869
485 +860
486 +838
487 +870
488 +846
489 +858
490 +854
491 +880
492 +876
493 +857
494 +859
495 +877
496 +871
497 +855
498 +875
499 +861
500 +867
501 +892
502 +898
503 +888
504 +884
505 +887
506 +891
507 +906
508 +900
509 +878
510 +885
511 +883
512 +901
513 +903
514 +907
515 +930
516 +897
517 +914
518 +917
519 +910
520 +905
521 +909
522 +933
523 +932
524 +922
525 +913
526 +923
527 +931
528 +911
529 +937
530 +918
531 +955
532 +915
533 +944
534 +952
535 +945
536 +948
537 +946
538 +970
539 +974
540 +958
541 +925
542 +979
543 +942
544 +965
545 +975
546 +950
547 +982
548 +940
549 +973
550 +962
551 +972
552 +957
553 +984
554 +983
555 +964
556 +1007
557 +971
558 +981
559 +954
560 +993
561 +991
562 +996
563 +1005
564 +1015
565 +1009
566 +995
567 +986
568 +1000
569 +985
570 +980
571 +1016
572 +1011
573 +999
574 +1002
575 +994
576 +1013
577 +1010
578 +992
579 +1008
580 +1036
581 +1025
582 +1012
583 +990
584 +1037
585 +1040
586 +1031
587 +1019
588 +1052
589 +1001
590 +1055
591 +1032
592 +1069
593 +1058
594 +1014
595 +1023
596 +1030
597 +1061
598 +1035
599 +1034
600 +1053
601 +1045
602 +1046
603 +1067
604 +1060
605 +1049
606 +1056
607 +1074
608 +1066
609 +1044
610 +1038
611 +1073
612 +1077
613 +1068
614 +1057
615 +1072
616 +1104
617 +1083
618 +1089
619 +1087
620 +1099
621 +1076
622 +1086
623 +1098
624 +1094
625 +1095
626 +1096
627 +1101
628 +1107
629 +1105
630 +1117
631 +1093
632 +1106
633 +1122
634 +1119
635 +1103
636 +1128
637 +1120
638 +1126
639 +1102
640 +1115
641 +1124
642 +1123
643 +1131
644 +1136
645 +1144
646 +1121
647 +1137
648 +1132
649 +1133
650 +1157
651 +1134
652 +1143
653 +1159
654 +1164
655 +1155
656 +1142
657 +1150
658 +1148
659 +1161
660 +1165
661 +1147
662 +1162
663 +1152
664 +1174
665 +1160
666 +1166
667 +1190
668 +1175
669 +1167
670 +1156
671 +1180
672 +1171
673 +1179
674 +1172
675 +1186
676 +1188
677 +1201
678 +1177
679 +1208
680 +1183
681 +1189
682 +1192
683 +1209
684 +1214
685 +1197
686 +1168
687 +1202
688 +1205
689 +1203
690 +1199
691 +1219
692 +1217
693 +1187
694 +1206
695 +1210
696 +1241
697 +1221
698 +1218
699 +1223
700 +1236
701 +1212
702 +1237
703 +1195
704 +1216
705 +1247
706 +1234
707 +1240
708 +1257
709 +1224
710 +1243
711 +1259
712 +1242
713 +1282
714 +1222
715 +1254
716 +1227
717 +1235
718 +1269
719 +1258
720 +1290
721 +1275
722 +1262
723 +1252
724 +1248
725 +1272
726 +1246
727 +1225
728 +1245
729 +1277
730 +1298
731 +1288
732 +1271
733 +1265
734 +1286
735 +1260
736 +1266
737 +1296
738 +1280
739 +1285
740 +1293
741 +1276
742 +1287
743 +1289
744 +1261
745 +1264
746 +1295
747 +1291
748 +1283
749 +1311
750 +1303
751 +1330
752 +1315
753 +1300
754 +1333
755 +1307
756 +1325
757 +1334
758 +1316
759 +1314
760 +1317
761 +1310
762 +1329
763 +1324
764 +1339
765 +1346
766 +1342
767 +1352
768 +1321
769 +1376
770 +1366
771 +1308
772 +1345
773 +1348
774 +1386
775 +1383
776 +1372
777 +1367
778 +1400
779 +1382
780 +1375
781 +1392
782 +1380
783 +1371
784 +1393
785 +1389
786 +1353
787 +1387
788 +1374
789 +1379
790 +1381
791 +1359
792 +1360
793 +1396
794 +1399
795 +1365
796 +1424
797 +1373
798 +1411
799 +1401
800 +1397
801 +1395
802 +1412
803 +1394
804 +1368
805 +1423
806 +1391
807 +1435
808 +1409
809 +1443
810 +1402
811 +1425
812 +1415
813 +1421
814 +1426
815 +1433
816 +1420
817 +1452
818 +1436
819 +1430
820 +1408
821 +1458
822 +1429
823 +1453
824 +1454
825 +1447
826 +1472
827 +1486
828 +1468
829 +1461
830 +1467
831 +1484
832 +1457
833 +1444
834 +1450
835 +1451
836 +1459
837 +1462
838 +1449
839 +1476
840 +1470
841 +1471
842 +1498
843 +1488
844 +1442
845 +1480
846 +1456
847 +1466
848 +1505
849 +1517
850 +1464
851 +1503
852 +1490
853 +1519
854 +1481
855 +1493
856 +1463
857 +1532
858 +1487
859 +1501
860 +1500
861 +1495
862 +1509
863 +1535
864 +1506
865 +1521
866 +1580
867 +1540
868 +1502
869 +1520
870 +1496
871 +1569
872 +1515
873 +1489
874 +1507
875 +1527
876 +1545
877 +1560
878 +1510
879 +1514
880 +1526
881 +1594
882 +1511
883 +1572
884 +1548
885 +1584
886 +1556
887 +1588
888 +1628
889 +1555
890 +1568
891 +1550
892 +1622
893 +1563
894 +1603
895 +1616
896 +1576
897 +1549
898 +1537
899 +1593
900 +1618
901 +1645
902 +1624
903 +1617
904 +1634
905 +1595
906 +1597
907 +1590
908 +1632
909 +1575
910 +1559
911 +1625
912 +1615
913 +1591
914 +1630
915 +1608
916 +1621
917 +1589
918 +1646
919 +1643
920 +1652
921 +1627
922 +1611
923 +1626
924 +1613
925 +1639
926 +1655
927 +1620
928 +1602
929 +1651
930 +1653
931 +1669
932 +1638
933 +1696
934 +1649
935 +1675
936 +1660
937 +1683
938 +1666
939 +1671
940 +1703
941 +1716
942 +1637
943 +1672
944 +1676
945 +1692
946 +1711
947 +1680
948 +1641
949 +1688
950 +1708
951 +1704
952 +1690
953 +1674
954 +1718
955 +1699
956 +1723
957 +1756
958 +1700
959 +1662
960 +1715
961 +1657
962 +1733
963 +1728
964 +1670
965 +1712
966 +1685
967 +1724
968 +1735
969 +1714
970 +1730
971 +1747
972 +1656
973 +1737
974 +1705
975 +1693
976 +1713
977 +1689
978 +1753
979 +1739
980 +1721
981 +1725
982 +1749
983 +1732
984 +1743
985 +1731
986 +1767
987 +1738
988 +1831
989 +1771
990 +1726
991 +1746
992 +1776
993 +1775
994 +1799
995 +1774
996 +1780
997 +1781
998 +1769
999 +1805
1000 +1788
1001 +1801
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Binary for training Tensorflow models on the YouTube-8M dataset."""
15 +
16 +import json
17 +import os
18 +import time
19 +
20 +import eval_util
21 +import export_model
22 +import losses
23 +import frame_level_models
24 +import video_level_models
25 +import readers
26 +import tensorflow as tf
27 +import tensorflow.contrib.slim as slim
28 +from tensorflow.python.lib.io import file_io
29 +from tensorflow import app
30 +from tensorflow import flags
31 +from tensorflow import gfile
32 +from tensorflow import logging
33 +from tensorflow.python.client import device_lib
34 +import utils
35 +
36 +FLAGS = flags.FLAGS
37 +
38 +if __name__ == "__main__":
39 + # Dataset flags.
40 + flags.DEFINE_string("train_dir", "/tmp/yt8m_model/",
41 + "The directory to save the model files in.")
42 + flags.DEFINE_string(
43 + "train_data_pattern", "",
44 + "File glob for the training dataset. If the files refer to Frame Level "
45 + "features (i.e. tensorflow.SequenceExample), then set --reader_type "
46 + "format. The (Sequence)Examples are expected to have 'rgb' byte array "
47 + "sequence feature as well as a 'labels' int64 context feature.")
48 + flags.DEFINE_string("feature_names", "mean_rgb", "Name of the feature "
49 + "to use for training.")
50 + flags.DEFINE_string("feature_sizes", "1024", "Length of the feature vectors.")
51 +
52 + # Model flags.
53 + flags.DEFINE_bool(
54 + "frame_features", False,
55 + "If set, then --train_data_pattern must be frame-level features. "
56 + "Otherwise, --train_data_pattern must be aggregated video-level "
57 + "features. The model must also be set appropriately (i.e. to read 3D "
58 + "batches VS 4D batches.")
59 + flags.DEFINE_bool(
60 + "segment_labels", False,
61 + "If set, then --train_data_pattern must be frame-level features (but with"
62 + " segment_labels). Otherwise, --train_data_pattern must be aggregated "
63 + "video-level features. The model must also be set appropriately (i.e. to "
64 + "read 3D batches VS 4D batches.")
65 + flags.DEFINE_string(
66 + "model", "LogisticModel",
67 + "Which architecture to use for the model. Models are defined "
68 + "in models.py.")
69 + flags.DEFINE_bool(
70 + "start_new_model", False,
71 + "If set, this will not resume from a checkpoint and will instead create a"
72 + " new model instance.")
73 +
74 + # Training flags.
75 + flags.DEFINE_integer(
76 + "num_gpu", 1, "The maximum number of GPU devices to use for training. "
77 + "Flag only applies if GPUs are installed")
78 + flags.DEFINE_integer("batch_size", 1024,
79 + "How many examples to process per batch for training.")
80 + flags.DEFINE_string("label_loss", "CrossEntropyLoss",
81 + "Which loss function to use for training the model.")
82 + flags.DEFINE_float(
83 + "regularization_penalty", 1.0,
84 + "How much weight to give to the regularization loss (the label loss has "
85 + "a weight of 1).")
86 + flags.DEFINE_float("base_learning_rate", 0.01,
87 + "Which learning rate to start with.")
88 + flags.DEFINE_float(
89 + "learning_rate_decay", 0.95,
90 + "Learning rate decay factor to be applied every "
91 + "learning_rate_decay_examples.")
92 + flags.DEFINE_float(
93 + "learning_rate_decay_examples", 4000000,
94 + "Multiply current learning rate by learning_rate_decay "
95 + "every learning_rate_decay_examples.")
96 + flags.DEFINE_integer(
97 + "num_epochs", 5, "How many passes to make over the dataset before "
98 + "halting training.")
99 + flags.DEFINE_integer(
100 + "max_steps", None,
101 + "The maximum number of iterations of the training loop.")
102 + flags.DEFINE_integer(
103 + "export_model_steps", 1000,
104 + "The period, in number of steps, with which the model "
105 + "is exported for batch prediction.")
106 +
107 + # Other flags.
108 + flags.DEFINE_integer("num_readers", 8,
109 + "How many threads to use for reading input files.")
110 + flags.DEFINE_string("optimizer", "AdamOptimizer",
111 + "What optimizer class to use.")
112 + flags.DEFINE_float("clip_gradient_norm", 1.0, "Norm to clip gradients to.")
113 + flags.DEFINE_bool(
114 + "log_device_placement", False,
115 + "Whether to write the device on which every op will run into the "
116 + "logs on startup.")
117 +
118 +
119 +def validate_class_name(flag_value, category, modules, expected_superclass):
120 + """Checks that the given string matches a class of the expected type.
121 +
122 + Args:
123 + flag_value: A string naming the class to instantiate.
124 + category: A string used further describe the class in error messages (e.g.
125 + 'model', 'reader', 'loss').
126 + modules: A list of modules to search for the given class.
127 + expected_superclass: A class that the given class should inherit from.
128 +
129 + Raises:
130 + FlagsError: If the given class could not be found or if the first class
131 + found with that name doesn't inherit from the expected superclass.
132 +
133 + Returns:
134 + True if a class was found that matches the given constraints.
135 + """
136 + candidates = [getattr(module, flag_value, None) for module in modules]
137 + for candidate in candidates:
138 + if not candidate:
139 + continue
140 + if not issubclass(candidate, expected_superclass):
141 + raise flags.FlagsError(
142 + "%s '%s' doesn't inherit from %s." %
143 + (category, flag_value, expected_superclass.__name__))
144 + return True
145 + raise flags.FlagsError("Unable to find %s '%s'." % (category, flag_value))
146 +
147 +
148 +def get_input_data_tensors(reader,
149 + data_pattern,
150 + batch_size=1000,
151 + num_epochs=None,
152 + num_readers=1):
153 + """Creates the section of the graph which reads the training data.
154 +
155 + Args:
156 + reader: A class which parses the training data.
157 + data_pattern: A 'glob' style path to the data files.
158 + batch_size: How many examples to process at a time.
159 + num_epochs: How many passes to make over the training data. Set to 'None' to
160 + run indefinitely.
161 + num_readers: How many I/O threads to use.
162 +
163 + Returns:
164 + A tuple containing the features tensor, labels tensor, and optionally a
165 + tensor containing the number of frames per video. The exact dimensions
166 + depend on the reader being used.
167 +
168 + Raises:
169 + IOError: If no files matching the given pattern were found.
170 + """
171 + logging.info("Using batch size of " + str(batch_size) + " for training.")
172 + with tf.name_scope("train_input"):
173 + files = gfile.Glob(data_pattern)
174 + if not files:
175 + raise IOError("Unable to find training files. data_pattern='" +
176 + data_pattern + "'.")
177 + logging.info("Number of training files: %s.", str(len(files)))
178 + filename_queue = tf.train.string_input_producer(files,
179 + num_epochs=num_epochs,
180 + shuffle=True)
181 + training_data = [
182 + reader.prepare_reader(filename_queue) for _ in range(num_readers)
183 + ]
184 +
185 + return tf.train.shuffle_batch_join(training_data,
186 + batch_size=batch_size,
187 + capacity=batch_size * 5,
188 + min_after_dequeue=batch_size,
189 + allow_smaller_final_batch=True,
190 + enqueue_many=True)
191 +
192 +
193 +def find_class_by_name(name, modules):
194 + """Searches the provided modules for the named class and returns it."""
195 + modules = [getattr(module, name, None) for module in modules]
196 + return next(a for a in modules if a)
197 +
198 +
199 +def build_graph(reader,
200 + model,
201 + train_data_pattern,
202 + label_loss_fn=losses.CrossEntropyLoss(),
203 + batch_size=1000,
204 + base_learning_rate=0.01,
205 + learning_rate_decay_examples=1000000,
206 + learning_rate_decay=0.95,
207 + optimizer_class=tf.train.AdamOptimizer,
208 + clip_gradient_norm=1.0,
209 + regularization_penalty=1,
210 + num_readers=1,
211 + num_epochs=None):
212 + """Creates the Tensorflow graph.
213 +
214 + This will only be called once in the life of
215 + a training model, because after the graph is created the model will be
216 + restored from a meta graph file rather than being recreated.
217 +
218 + Args:
219 + reader: The data file reader. It should inherit from BaseReader.
220 + model: The core model (e.g. logistic or neural net). It should inherit from
221 + BaseModel.
222 + train_data_pattern: glob path to the training data files.
223 + label_loss_fn: What kind of loss to apply to the model. It should inherit
224 + from BaseLoss.
225 + batch_size: How many examples to process at a time.
226 + base_learning_rate: What learning rate to initialize the optimizer with.
227 + optimizer_class: Which optimization algorithm to use.
228 + clip_gradient_norm: Magnitude of the gradient to clip to.
229 + regularization_penalty: How much weight to give the regularization loss
230 + compared to the label loss.
231 + num_readers: How many threads to use for I/O operations.
232 + num_epochs: How many passes to make over the data. 'None' means an unlimited
233 + number of passes.
234 + """
235 +
236 + global_step = tf.Variable(0, trainable=False, name="global_step")
237 +
238 + local_device_protos = device_lib.list_local_devices()
239 + gpus = [x.name for x in local_device_protos if x.device_type == "GPU"]
240 + gpus = gpus[:FLAGS.num_gpu]
241 + num_gpus = len(gpus)
242 +
243 + if num_gpus > 0:
244 + logging.info("Using the following GPUs to train: " + str(gpus))
245 + num_towers = num_gpus
246 + device_string = "/gpu:%d"
247 + else:
248 + logging.info("No GPUs found. Training on CPU.")
249 + num_towers = 1
250 + device_string = "/cpu:%d"
251 +
252 + learning_rate = tf.train.exponential_decay(base_learning_rate,
253 + global_step * batch_size *
254 + num_towers,
255 + learning_rate_decay_examples,
256 + learning_rate_decay,
257 + staircase=True)
258 + tf.summary.scalar("learning_rate", learning_rate)
259 +
260 + optimizer = optimizer_class(learning_rate)
261 + input_data_dict = (get_input_data_tensors(reader,
262 + train_data_pattern,
263 + batch_size=batch_size * num_towers,
264 + num_readers=num_readers,
265 + num_epochs=num_epochs))
266 + model_input_raw = input_data_dict["video_matrix"]
267 + labels_batch = input_data_dict["labels"]
268 + num_frames = input_data_dict["num_frames"]
269 + print("model_input_shape, ", model_input_raw.shape)
270 + tf.summary.histogram("model/input_raw", model_input_raw)
271 +
272 + feature_dim = len(model_input_raw.get_shape()) - 1
273 +
274 + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
275 +
276 + tower_inputs = tf.split(model_input, num_towers)
277 + tower_labels = tf.split(labels_batch, num_towers)
278 + tower_num_frames = tf.split(num_frames, num_towers)
279 + tower_gradients = []
280 + tower_predictions = []
281 + tower_label_losses = []
282 + tower_reg_losses = []
283 + for i in range(num_towers):
284 + # For some reason these 'with' statements can't be combined onto the same
285 + # line. They have to be nested.
286 + with tf.device(device_string % i):
287 + with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)):
288 + with (slim.arg_scope([slim.model_variable, slim.variable],
289 + device="/cpu:0" if num_gpus != 1 else "/gpu:0")):
290 + result = model.create_model(tower_inputs[i],
291 + num_frames=tower_num_frames[i],
292 + vocab_size=reader.num_classes,
293 + labels=tower_labels[i])
294 + for variable in slim.get_model_variables():
295 + tf.summary.histogram(variable.op.name, variable)
296 +
297 + predictions = result["predictions"]
298 + tower_predictions.append(predictions)
299 +
300 + if "loss" in result.keys():
301 + label_loss = result["loss"]
302 + else:
303 + label_loss = label_loss_fn.calculate_loss(predictions,
304 + tower_labels[i])
305 +
306 + if "regularization_loss" in result.keys():
307 + reg_loss = result["regularization_loss"]
308 + else:
309 + reg_loss = tf.constant(0.0)
310 +
311 + reg_losses = tf.losses.get_regularization_losses()
312 + if reg_losses:
313 + reg_loss += tf.add_n(reg_losses)
314 +
315 + tower_reg_losses.append(reg_loss)
316 +
317 + # Adds update_ops (e.g., moving average updates in batch normalization) as
318 + # a dependency to the train_op.
319 + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
320 + if "update_ops" in result.keys():
321 + update_ops += result["update_ops"]
322 + if update_ops:
323 + with tf.control_dependencies(update_ops):
324 + barrier = tf.no_op(name="gradient_barrier")
325 + with tf.control_dependencies([barrier]):
326 + label_loss = tf.identity(label_loss)
327 +
328 + tower_label_losses.append(label_loss)
329 +
330 + # Incorporate the L2 weight penalties etc.
331 + final_loss = regularization_penalty * reg_loss + label_loss
332 + gradients = optimizer.compute_gradients(
333 + final_loss, colocate_gradients_with_ops=False)
334 + tower_gradients.append(gradients)
335 + label_loss = tf.reduce_mean(tf.stack(tower_label_losses))
336 + tf.summary.scalar("label_loss", label_loss)
337 + if regularization_penalty != 0:
338 + reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses))
339 + tf.summary.scalar("reg_loss", reg_loss)
340 + merged_gradients = utils.combine_gradients(tower_gradients)
341 +
342 + if clip_gradient_norm > 0:
343 + with tf.name_scope("clip_grads"):
344 + merged_gradients = utils.clip_gradient_norms(merged_gradients,
345 + clip_gradient_norm)
346 +
347 + train_op = optimizer.apply_gradients(merged_gradients,
348 + global_step=global_step)
349 +
350 + tf.add_to_collection("global_step", global_step)
351 + tf.add_to_collection("loss", label_loss)
352 + tf.add_to_collection("predictions", tf.concat(tower_predictions, 0))
353 + tf.add_to_collection("input_batch_raw", model_input_raw)
354 + tf.add_to_collection("input_batch", model_input)
355 + tf.add_to_collection("num_frames", num_frames)
356 + tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
357 + tf.add_to_collection("train_op", train_op)
358 +
359 +
360 +class Trainer(object):
361 + """A Trainer to train a Tensorflow graph."""
362 +
363 + def __init__(self,
364 + cluster,
365 + task,
366 + train_dir,
367 + model,
368 + reader,
369 + model_exporter,
370 + log_device_placement=True,
371 + max_steps=None,
372 + export_model_steps=1000):
373 + """"Creates a Trainer.
374 +
375 + Args:
376 + cluster: A tf.train.ClusterSpec if the execution is distributed. None
377 + otherwise.
378 + task: A TaskSpec describing the job type and the task index.
379 + """
380 +
381 + self.cluster = cluster
382 + self.task = task
383 + self.is_master = (task.type == "master" and task.index == 0)
384 + self.train_dir = train_dir
385 + self.config = tf.ConfigProto(allow_soft_placement=True,
386 + log_device_placement=log_device_placement)
387 + self.config.gpu_options.allow_growth = True
388 + self.model = model
389 + self.reader = reader
390 + self.model_exporter = model_exporter
391 + self.max_steps = max_steps
392 + self.max_steps_reached = False
393 + self.export_model_steps = export_model_steps
394 + self.last_model_export_step = 0
395 +
396 +
397 +# if self.is_master and self.task.index > 0:
398 +# raise StandardError("%s: Only one replica of master expected",
399 +# task_as_string(self.task))
400 +
401 + def run(self, start_new_model=False):
402 + """Performs training on the currently defined Tensorflow graph.
403 +
404 + Returns:
405 + A tuple of the training Hit@1 and the training PERR.
406 + """
407 + if self.is_master and start_new_model:
408 + self.remove_training_directory(self.train_dir)
409 +
410 + if not os.path.exists(self.train_dir):
411 + os.makedirs(self.train_dir)
412 +
413 + model_flags_dict = {
414 + "model": FLAGS.model,
415 + "feature_sizes": FLAGS.feature_sizes,
416 + "feature_names": FLAGS.feature_names,
417 + "frame_features": FLAGS.frame_features,
418 + "label_loss": FLAGS.label_loss,
419 + }
420 + flags_json_path = os.path.join(FLAGS.train_dir, "model_flags.json")
421 + if file_io.file_exists(flags_json_path):
422 + existing_flags = json.load(file_io.FileIO(flags_json_path, mode="r"))
423 + if existing_flags != model_flags_dict:
424 + logging.error(
425 + "Model flags do not match existing file %s. Please "
426 + "delete the file, change --train_dir, or pass flag "
427 + "--start_new_model", flags_json_path)
428 + logging.error("Ran model with flags: %s", str(model_flags_dict))
429 + logging.error("Previously ran with flags: %s", str(existing_flags))
430 + exit(1)
431 + else:
432 + # Write the file.
433 + with file_io.FileIO(flags_json_path, mode="w") as fout:
434 + fout.write(json.dumps(model_flags_dict))
435 +
436 + target, device_fn = self.start_server_if_distributed()
437 +
438 + meta_filename = self.get_meta_filename(start_new_model, self.train_dir)
439 +
440 + with tf.Graph().as_default() as graph:
441 + if meta_filename:
442 + saver = self.recover_model(meta_filename)
443 +
444 + with tf.device(device_fn):
445 + if not meta_filename:
446 + saver = self.build_model(self.model, self.reader)
447 +
448 + global_step = tf.get_collection("global_step")[0]
449 + loss = tf.get_collection("loss")[0]
450 + predictions = tf.get_collection("predictions")[0]
451 + labels = tf.get_collection("labels")[0]
452 + train_op = tf.get_collection("train_op")[0]
453 + init_op = tf.global_variables_initializer()
454 +
455 + sv = tf.train.Supervisor(graph,
456 + logdir=self.train_dir,
457 + init_op=init_op,
458 + is_chief=self.is_master,
459 + global_step=global_step,
460 + save_model_secs=15 * 60,
461 + save_summaries_secs=120,
462 + saver=saver)
463 +
464 + logging.info("%s: Starting managed session.", task_as_string(self.task))
465 + with sv.managed_session(target, config=self.config) as sess:
466 + try:
467 + logging.info("%s: Entering training loop.", task_as_string(self.task))
468 + while (not sv.should_stop()) and (not self.max_steps_reached):
469 + batch_start_time = time.time()
470 + _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
471 + [train_op, global_step, loss, predictions, labels])
472 + seconds_per_batch = time.time() - batch_start_time
473 + examples_per_second = labels_val.shape[0] / seconds_per_batch
474 +
475 + if self.max_steps and self.max_steps <= global_step_val:
476 + self.max_steps_reached = True
477 +
478 + if self.is_master and global_step_val % 10 == 0 and self.train_dir:
479 + eval_start_time = time.time()
480 + hit_at_one = eval_util.calculate_hit_at_one(predictions_val,
481 + labels_val)
482 + perr = eval_util.calculate_precision_at_equal_recall_rate(
483 + predictions_val, labels_val)
484 + gap = eval_util.calculate_gap(predictions_val, labels_val)
485 + eval_end_time = time.time()
486 + eval_time = eval_end_time - eval_start_time
487 +
488 + logging.info("training step " + str(global_step_val) + " | Loss: " +
489 + ("%.2f" % loss_val) + " Examples/sec: " +
490 + ("%.2f" % examples_per_second) + " | Hit@1: " +
491 + ("%.2f" % hit_at_one) + " PERR: " + ("%.2f" % perr) +
492 + " GAP: " + ("%.2f" % gap))
493 +
494 + sv.summary_writer.add_summary(
495 + utils.MakeSummary("model/Training_Hit@1", hit_at_one),
496 + global_step_val)
497 + sv.summary_writer.add_summary(
498 + utils.MakeSummary("model/Training_Perr", perr), global_step_val)
499 + sv.summary_writer.add_summary(
500 + utils.MakeSummary("model/Training_GAP", gap), global_step_val)
501 + sv.summary_writer.add_summary(
502 + utils.MakeSummary("global_step/Examples/Second",
503 + examples_per_second), global_step_val)
504 + sv.summary_writer.flush()
505 +
506 + # Exporting the model every x steps
507 + time_to_export = ((self.last_model_export_step == 0) or
508 + (global_step_val - self.last_model_export_step >=
509 + self.export_model_steps))
510 +
511 + if self.is_master and time_to_export:
512 + self.export_model(global_step_val, sv.saver, sv.save_path, sess)
513 + self.last_model_export_step = global_step_val
514 + else:
515 + logging.info("training step " + str(global_step_val) + " | Loss: " +
516 + ("%.2f" % loss_val) + " Examples/sec: " +
517 + ("%.2f" % examples_per_second))
518 + except tf.errors.OutOfRangeError:
519 + logging.info("%s: Done training -- epoch limit reached.",
520 + task_as_string(self.task))
521 +
522 + logging.info("%s: Exited training loop.", task_as_string(self.task))
523 + sv.Stop()
524 +
525 + def export_model(self, global_step_val, saver, save_path, session):
526 +
527 + # If the model has already been exported at this step, return.
528 + if global_step_val == self.last_model_export_step:
529 + return
530 +
531 + saver.save(session, save_path, global_step_val)
532 +
533 + def start_server_if_distributed(self):
534 + """Starts a server if the execution is distributed."""
535 +
536 + if self.cluster:
537 + logging.info("%s: Starting trainer within cluster %s.",
538 + task_as_string(self.task), self.cluster.as_dict())
539 + server = start_server(self.cluster, self.task)
540 + target = server.target
541 + device_fn = tf.train.replica_device_setter(
542 + ps_device="/job:ps",
543 + worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
544 + cluster=self.cluster)
545 + else:
546 + target = ""
547 + device_fn = ""
548 + return (target, device_fn)
549 +
550 + def remove_training_directory(self, train_dir):
551 + """Removes the training directory."""
552 + try:
553 + logging.info("%s: Removing existing train directory.",
554 + task_as_string(self.task))
555 + gfile.DeleteRecursively(train_dir)
556 + except:
557 + logging.error(
558 + "%s: Failed to delete directory " + train_dir +
559 + " when starting a new model. Please delete it manually and" +
560 + " try again.", task_as_string(self.task))
561 +
562 + def get_meta_filename(self, start_new_model, train_dir):
563 + if start_new_model:
564 + logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
565 + task_as_string(self.task))
566 + return None
567 +
568 + latest_checkpoint = tf.train.latest_checkpoint(train_dir)
569 + if not latest_checkpoint:
570 + logging.info("%s: No checkpoint file found. Building a new model.",
571 + task_as_string(self.task))
572 + return None
573 +
574 + meta_filename = latest_checkpoint + ".meta"
575 + if not gfile.Exists(meta_filename):
576 + logging.info("%s: No meta graph file found. Building a new model.",
577 + task_as_string(self.task))
578 + return None
579 + else:
580 + return meta_filename
581 +
582 + def recover_model(self, meta_filename):
583 + logging.info("%s: Restoring from meta graph file %s",
584 + task_as_string(self.task), meta_filename)
585 + return tf.train.import_meta_graph(meta_filename)
586 +
587 + def build_model(self, model, reader):
588 + """Find the model and build the graph."""
589 +
590 + label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
591 + optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
592 +
593 + build_graph(reader=reader,
594 + model=model,
595 + optimizer_class=optimizer_class,
596 + clip_gradient_norm=FLAGS.clip_gradient_norm,
597 + train_data_pattern=FLAGS.train_data_pattern,
598 + label_loss_fn=label_loss_fn,
599 + base_learning_rate=FLAGS.base_learning_rate,
600 + learning_rate_decay=FLAGS.learning_rate_decay,
601 + learning_rate_decay_examples=FLAGS.learning_rate_decay_examples,
602 + regularization_penalty=FLAGS.regularization_penalty,
603 + num_readers=FLAGS.num_readers,
604 + batch_size=FLAGS.batch_size,
605 + num_epochs=FLAGS.num_epochs)
606 +
607 + return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=0.25)
608 +
609 +
610 +def get_reader():
611 + # Convert feature_names and feature_sizes to lists of values.
612 + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
613 + FLAGS.feature_names, FLAGS.feature_sizes)
614 +
615 + if FLAGS.frame_features:
616 + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
617 + feature_sizes=feature_sizes,
618 + segment_labels=FLAGS.segment_labels)
619 + else:
620 + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
621 + feature_sizes=feature_sizes)
622 +
623 + return reader
624 +
625 +
626 +class ParameterServer(object):
627 + """A parameter server to serve variables in a distributed execution."""
628 +
629 + def __init__(self, cluster, task):
630 + """Creates a ParameterServer.
631 +
632 + Args:
633 + cluster: A tf.train.ClusterSpec if the execution is distributed. None
634 + otherwise.
635 + task: A TaskSpec describing the job type and the task index.
636 + """
637 +
638 + self.cluster = cluster
639 + self.task = task
640 +
641 + def run(self):
642 + """Starts the parameter server."""
643 +
644 + logging.info("%s: Starting parameter server within cluster %s.",
645 + task_as_string(self.task), self.cluster.as_dict())
646 + server = start_server(self.cluster, self.task)
647 + server.join()
648 +
649 +
650 +def start_server(cluster, task):
651 + """Creates a Server.
652 +
653 + Args:
654 + cluster: A tf.train.ClusterSpec if the execution is distributed. None
655 + otherwise.
656 + task: A TaskSpec describing the job type and the task index.
657 + """
658 +
659 + if not task.type:
660 + raise ValueError("%s: The task type must be specified." %
661 + task_as_string(task))
662 + if task.index is None:
663 + raise ValueError("%s: The task index must be specified." %
664 + task_as_string(task))
665 +
666 + # Create and start a server.
667 + return tf.train.Server(tf.train.ClusterSpec(cluster),
668 + protocol="grpc",
669 + job_name=task.type,
670 + task_index=task.index)
671 +
672 +
673 +def task_as_string(task):
674 + return "/job:%s/task:%s" % (task.type, task.index)
675 +
676 +
677 +def main(unused_argv):
678 + # Load the environment.
679 + env = json.loads(os.environ.get("TF_CONFIG", "{}"))
680 +
681 + # Load the cluster data from the environment.
682 + cluster_data = env.get("cluster", None)
683 + cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
684 +
685 + # Load the task data from the environment.
686 + task_data = env.get("task", None) or {"type": "master", "index": 0}
687 + task = type("TaskSpec", (object,), task_data)
688 +
689 + # Logging the version.
690 + logging.set_verbosity(tf.logging.INFO)
691 + logging.info("%s: Tensorflow version: %s.", task_as_string(task),
692 + tf.__version__)
693 +
694 + # Dispatch to a master, a worker, or a parameter server.
695 + if not cluster or task.type == "master" or task.type == "worker":
696 + model = find_class_by_name(FLAGS.model,
697 + [frame_level_models, video_level_models])()
698 +
699 + reader = get_reader()
700 +
701 + model_exporter = export_model.ModelExporter(
702 + frame_features=FLAGS.frame_features, model=model, reader=reader)
703 +
704 + Trainer(cluster, task, FLAGS.train_dir, model, reader, model_exporter,
705 + FLAGS.log_device_placement, FLAGS.max_steps,
706 + FLAGS.export_model_steps).run(start_new_model=FLAGS.start_new_model)
707 +
708 + elif task.type == "ps":
709 + ParameterServer(cluster, task).run()
710 + else:
711 + raise ValueError("%s: Invalid task_type: %s." %
712 + (task_as_string(task), task.type))
713 +
714 +
715 +if __name__ == "__main__":
716 + app.run()
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains a collection of util functions for training and evaluating."""
15 +
16 +import numpy
17 +import tensorflow as tf
18 +from tensorflow import logging
19 +
20 +try:
21 + xrange # Python 2
22 +except NameError:
23 + xrange = range # Python 3
24 +
25 +
26 +def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2):
27 + """Dequantize the feature from the byte format to the float format.
28 +
29 + Args:
30 + feat_vector: the input 1-d vector.
31 + max_quantized_value: the maximum of the quantized value.
32 + min_quantized_value: the minimum of the quantized value.
33 +
34 + Returns:
35 + A float vector which has the same shape as feat_vector.
36 + """
37 + assert max_quantized_value > min_quantized_value
38 + quantized_range = max_quantized_value - min_quantized_value
39 + scalar = quantized_range / 255.0
40 + bias = (quantized_range / 512.0) + min_quantized_value
41 + return feat_vector * scalar + bias
42 +
43 +
44 +def MakeSummary(name, value):
45 + """Creates a tf.Summary proto with the given name and value."""
46 + summary = tf.Summary()
47 + val = summary.value.add()
48 + val.tag = str(name)
49 + val.simple_value = float(value)
50 + return summary
51 +
52 +
53 +def AddGlobalStepSummary(summary_writer,
54 + global_step_val,
55 + global_step_info_dict,
56 + summary_scope="Eval"):
57 + """Add the global_step summary to the Tensorboard.
58 +
59 + Args:
60 + summary_writer: Tensorflow summary_writer.
61 + global_step_val: a int value of the global step.
62 + global_step_info_dict: a dictionary of the evaluation metrics calculated for
63 + a mini-batch.
64 + summary_scope: Train or Eval.
65 +
66 + Returns:
67 + A string of this global_step summary
68 + """
69 + this_hit_at_one = global_step_info_dict["hit_at_one"]
70 + this_perr = global_step_info_dict["perr"]
71 + this_loss = global_step_info_dict["loss"]
72 + examples_per_second = global_step_info_dict.get("examples_per_second", -1)
73 +
74 + summary_writer.add_summary(
75 + MakeSummary("GlobalStep/" + summary_scope + "_Hit@1", this_hit_at_one),
76 + global_step_val)
77 + summary_writer.add_summary(
78 + MakeSummary("GlobalStep/" + summary_scope + "_Perr", this_perr),
79 + global_step_val)
80 + summary_writer.add_summary(
81 + MakeSummary("GlobalStep/" + summary_scope + "_Loss", this_loss),
82 + global_step_val)
83 +
84 + if examples_per_second != -1:
85 + summary_writer.add_summary(
86 + MakeSummary("GlobalStep/" + summary_scope + "_Example_Second",
87 + examples_per_second), global_step_val)
88 +
89 + summary_writer.flush()
90 + info = (
91 + "global_step {0} | Batch Hit@1: {1:.3f} | Batch PERR: {2:.3f} | Batch "
92 + "Loss: {3:.3f} | Examples_per_sec: {4:.3f}").format(
93 + global_step_val, this_hit_at_one, this_perr, this_loss,
94 + examples_per_second)
95 + return info
96 +
97 +
98 +def AddEpochSummary(summary_writer,
99 + global_step_val,
100 + epoch_info_dict,
101 + summary_scope="Eval"):
102 + """Add the epoch summary to the Tensorboard.
103 +
104 + Args:
105 + summary_writer: Tensorflow summary_writer.
106 + global_step_val: a int value of the global step.
107 + epoch_info_dict: a dictionary of the evaluation metrics calculated for the
108 + whole epoch.
109 + summary_scope: Train or Eval.
110 +
111 + Returns:
112 + A string of this global_step summary
113 + """
114 + epoch_id = epoch_info_dict["epoch_id"]
115 + avg_hit_at_one = epoch_info_dict["avg_hit_at_one"]
116 + avg_perr = epoch_info_dict["avg_perr"]
117 + avg_loss = epoch_info_dict["avg_loss"]
118 + aps = epoch_info_dict["aps"]
119 + gap = epoch_info_dict["gap"]
120 + mean_ap = numpy.mean(aps)
121 +
122 + summary_writer.add_summary(
123 + MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one),
124 + global_step_val)
125 + summary_writer.add_summary(
126 + MakeSummary("Epoch/" + summary_scope + "_Avg_Perr", avg_perr),
127 + global_step_val)
128 + summary_writer.add_summary(
129 + MakeSummary("Epoch/" + summary_scope + "_Avg_Loss", avg_loss),
130 + global_step_val)
131 + summary_writer.add_summary(
132 + MakeSummary("Epoch/" + summary_scope + "_MAP", mean_ap), global_step_val)
133 + summary_writer.add_summary(
134 + MakeSummary("Epoch/" + summary_scope + "_GAP", gap), global_step_val)
135 + summary_writer.flush()
136 +
137 + info = ("epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} "
138 + "| MAP: {3:.3f} | GAP: {4:.3f} | Avg_Loss: {5:3f} | num_classes: {6}"
139 + ).format(epoch_id, avg_hit_at_one, avg_perr, mean_ap, gap, avg_loss,
140 + len(aps))
141 + return info
142 +
143 +
144 +def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes):
145 + """Extract the list of feature names and the dimensionality of each feature
146 +
147 + from string of comma separated values.
148 +
149 + Args:
150 + feature_names: string containing comma separated list of feature names
151 + feature_sizes: string containing comma separated list of feature sizes
152 +
153 + Returns:
154 + List of the feature names and list of the dimensionality of each feature.
155 + Elements in the first/second list are strings/integers.
156 + """
157 + list_of_feature_names = [
158 + feature_names.strip() for feature_names in feature_names.split(",")
159 + ]
160 + list_of_feature_sizes = [
161 + int(feature_sizes) for feature_sizes in feature_sizes.split(",")
162 + ]
163 + if len(list_of_feature_names) != len(list_of_feature_sizes):
164 + logging.error("length of the feature names (=" +
165 + str(len(list_of_feature_names)) + ") != length of feature "
166 + "sizes (=" + str(len(list_of_feature_sizes)) + ")")
167 +
168 + return list_of_feature_names, list_of_feature_sizes
169 +
170 +
171 +def clip_gradient_norms(gradients_to_variables, max_norm):
172 + """Clips the gradients by the given value.
173 +
174 + Args:
175 + gradients_to_variables: A list of gradient to variable pairs (tuples).
176 + max_norm: the maximum norm value.
177 +
178 + Returns:
179 + A list of clipped gradient to variable pairs.
180 + """
181 + clipped_grads_and_vars = []
182 + for grad, var in gradients_to_variables:
183 + if grad is not None:
184 + if isinstance(grad, tf.IndexedSlices):
185 + tmp = tf.clip_by_norm(grad.values, max_norm)
186 + grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape)
187 + else:
188 + grad = tf.clip_by_norm(grad, max_norm)
189 + clipped_grads_and_vars.append((grad, var))
190 + return clipped_grads_and_vars
191 +
192 +
193 +def combine_gradients(tower_grads):
194 + """Calculate the combined gradient for each shared variable across all towers.
195 +
196 + Note that this function provides a synchronization point across all towers.
197 +
198 + Args:
199 + tower_grads: List of lists of (gradient, variable) tuples. The outer list is
200 + over individual gradients. The inner list is over the gradient calculation
201 + for each tower.
202 +
203 + Returns:
204 + List of pairs of (gradient, variable) where the gradient has been summed
205 + across all towers.
206 + """
207 + filtered_grads = [
208 + [x for x in grad_list if x[0] is not None] for grad_list in tower_grads
209 + ]
210 + final_grads = []
211 + for i in xrange(len(filtered_grads[0])):
212 + grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))]
213 + grad = tf.stack([x[0] for x in grads], 0)
214 + grad = tf.reduce_sum(grad, 0)
215 + final_grads.append((
216 + grad,
217 + filtered_grads[0][i][1],
218 + ))
219 +
220 + return final_grads
1 +# Copyright 2016 Google Inc. All Rights Reserved.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS-IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +"""Contains model definitions."""
15 +import math
16 +
17 +import models
18 +import tensorflow as tf
19 +import utils
20 +
21 +from tensorflow import flags
22 +import tensorflow.contrib.slim as slim
23 +
24 +FLAGS = flags.FLAGS
25 +flags.DEFINE_integer(
26 + "moe_num_mixtures", 2,
27 + "The number of mixtures (excluding the dummy 'expert') used for MoeModel.")
28 +
29 +
30 +class LogisticModel(models.BaseModel):
31 + """Logistic model with L2 regularization."""
32 +
33 + def create_model(self,
34 + model_input,
35 + vocab_size,
36 + l2_penalty=1e-8,
37 + **unused_params):
38 + """Creates a logistic model.
39 +
40 + Args:
41 + model_input: 'batch' x 'num_features' matrix of input features.
42 + vocab_size: The number of classes in the dataset.
43 +
44 + Returns:
45 + A dictionary with a tensor containing the probability predictions of the
46 + model in the 'predictions' key. The dimensions of the tensor are
47 + batch_size x num_classes.
48 + """
49 + output = slim.fully_connected(
50 + model_input,
51 + vocab_size,
52 + activation_fn=tf.nn.sigmoid,
53 + weights_regularizer=slim.l2_regularizer(l2_penalty))
54 + return {"predictions": output}
55 +
56 +
57 +class MoeModel(models.BaseModel):
58 + """A softmax over a mixture of logistic models (with L2 regularization)."""
59 +
60 + def create_model(self,
61 + model_input,
62 + vocab_size,
63 + num_mixtures=None,
64 + l2_penalty=1e-8,
65 + **unused_params):
66 + """Creates a Mixture of (Logistic) Experts model.
67 +
68 + The model consists of a per-class softmax distribution over a
69 + configurable number of logistic classifiers. One of the classifiers in the
70 + mixture is not trained, and always predicts 0.
71 +
72 + Args:
73 + model_input: 'batch_size' x 'num_features' matrix of input features.
74 + vocab_size: The number of classes in the dataset.
75 + num_mixtures: The number of mixtures (excluding a dummy 'expert' that
76 + always predicts the non-existence of an entity).
77 + l2_penalty: How much to penalize the squared magnitudes of parameter
78 + values.
79 +
80 + Returns:
81 + A dictionary with a tensor containing the probability predictions of the
82 + model in the 'predictions' key. The dimensions of the tensor are
83 + batch_size x num_classes.
84 + """
85 + num_mixtures = num_mixtures or FLAGS.moe_num_mixtures
86 +
87 + gate_activations = slim.fully_connected(
88 + model_input,
89 + vocab_size * (num_mixtures + 1),
90 + activation_fn=None,
91 + biases_initializer=None,
92 + weights_regularizer=slim.l2_regularizer(l2_penalty),
93 + scope="gates")
94 + expert_activations = slim.fully_connected(
95 + model_input,
96 + vocab_size * num_mixtures,
97 + activation_fn=None,
98 + weights_regularizer=slim.l2_regularizer(l2_penalty),
99 + scope="experts")
100 +
101 + gating_distribution = tf.nn.softmax(
102 + tf.reshape(
103 + gate_activations,
104 + [-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1)
105 + expert_distribution = tf.nn.sigmoid(
106 + tf.reshape(expert_activations,
107 + [-1, num_mixtures])) # (Batch * #Labels) x num_mixtures
108 +
109 + final_probabilities_by_class_and_batch = tf.reduce_sum(
110 + gating_distribution[:, :num_mixtures] * expert_distribution, 1)
111 + final_probabilities = tf.reshape(final_probabilities_by_class_and_batch,
112 + [-1, vocab_size])
113 + return {"predictions": final_probabilities}