Showing
26 changed files
with
5435 additions
and
1 deletions
youtube-8m @ e6f6bf68
1 | -Subproject commit e6f6bf682d20bb21904ea9c081c15e070809d914 |
yt8m/__init__.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. |
yt8m/average_precision_calculator.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Calculate or keep track of the interpolated average precision. | ||
15 | + | ||
16 | +It provides an interface for calculating interpolated average precision for an | ||
17 | +entire list or the top-n ranked items. For the definition of the | ||
18 | +(non-)interpolated average precision: | ||
19 | +http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf | ||
20 | + | ||
21 | +Example usages: | ||
22 | +1) Use it as a static function call to directly calculate average precision for | ||
23 | +a short ranked list in the memory. | ||
24 | + | ||
25 | +``` | ||
26 | +import random | ||
27 | + | ||
28 | +p = np.array([random.random() for _ in xrange(10)]) | ||
29 | +a = np.array([random.choice([0, 1]) for _ in xrange(10)]) | ||
30 | + | ||
31 | +ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a) | ||
32 | +``` | ||
33 | + | ||
34 | +2) Use it as an object for long ranked list that cannot be stored in memory or | ||
35 | +the case where partial predictions can be observed at a time (Tensorflow | ||
36 | +predictions). In this case, we first call the function accumulate many times | ||
37 | +to process parts of the ranked list. After processing all the parts, we call | ||
38 | +peek_interpolated_ap_at_n. | ||
39 | +``` | ||
40 | +p1 = np.array([random.random() for _ in xrange(5)]) | ||
41 | +a1 = np.array([random.choice([0, 1]) for _ in xrange(5)]) | ||
42 | +p2 = np.array([random.random() for _ in xrange(5)]) | ||
43 | +a2 = np.array([random.choice([0, 1]) for _ in xrange(5)]) | ||
44 | + | ||
45 | +# interpolated average precision at 10 using 1000 break points | ||
46 | +calculator = average_precision_calculator.AveragePrecisionCalculator(10) | ||
47 | +calculator.accumulate(p1, a1) | ||
48 | +calculator.accumulate(p2, a2) | ||
49 | +ap3 = calculator.peek_ap_at_n() | ||
50 | +``` | ||
51 | +""" | ||
52 | + | ||
53 | +import heapq | ||
54 | +import random | ||
55 | +import numbers | ||
56 | + | ||
57 | +import numpy | ||
58 | + | ||
59 | + | ||
60 | +class AveragePrecisionCalculator(object): | ||
61 | + """Calculate the average precision and average precision at n.""" | ||
62 | + | ||
63 | + def __init__(self, top_n=None): | ||
64 | + """Construct an AveragePrecisionCalculator to calculate average precision. | ||
65 | + | ||
66 | + This class is used to calculate the average precision for a single label. | ||
67 | + | ||
68 | + Args: | ||
69 | + top_n: A positive Integer specifying the average precision at n, or None | ||
70 | + to use all provided data points. | ||
71 | + | ||
72 | + Raises: | ||
73 | + ValueError: An error occurred when the top_n is not a positive integer. | ||
74 | + """ | ||
75 | + if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None): | ||
76 | + raise ValueError("top_n must be a positive integer or None.") | ||
77 | + | ||
78 | + self._top_n = top_n # average precision at n | ||
79 | + self._total_positives = 0 # total number of positives have seen | ||
80 | + self._heap = [] # max heap of (prediction, actual) | ||
81 | + | ||
82 | + @property | ||
83 | + def heap_size(self): | ||
84 | + """Gets the heap size maintained in the class.""" | ||
85 | + return len(self._heap) | ||
86 | + | ||
87 | + @property | ||
88 | + def num_accumulated_positives(self): | ||
89 | + """Gets the number of positive samples that have been accumulated.""" | ||
90 | + return self._total_positives | ||
91 | + | ||
92 | + def accumulate(self, predictions, actuals, num_positives=None): | ||
93 | + """Accumulate the predictions and their ground truth labels. | ||
94 | + | ||
95 | + After the function call, we may call peek_ap_at_n to actually calculate | ||
96 | + the average precision. | ||
97 | + Note predictions and actuals must have the same shape. | ||
98 | + | ||
99 | + Args: | ||
100 | + predictions: a list storing the prediction scores. | ||
101 | + actuals: a list storing the ground truth labels. Any value larger than 0 | ||
102 | + will be treated as positives, otherwise as negatives. num_positives = If | ||
103 | + the 'predictions' and 'actuals' inputs aren't complete, then it's | ||
104 | + possible some true positives were missed in them. In that case, you can | ||
105 | + provide 'num_positives' in order to accurately track recall. | ||
106 | + | ||
107 | + Raises: | ||
108 | + ValueError: An error occurred when the format of the input is not the | ||
109 | + numpy 1-D array or the shape of predictions and actuals does not match. | ||
110 | + """ | ||
111 | + if len(predictions) != len(actuals): | ||
112 | + raise ValueError("the shape of predictions and actuals does not match.") | ||
113 | + | ||
114 | + if num_positives is not None: | ||
115 | + if not isinstance(num_positives, numbers.Number) or num_positives < 0: | ||
116 | + raise ValueError( | ||
117 | + "'num_positives' was provided but it was a negative number.") | ||
118 | + | ||
119 | + if num_positives is not None: | ||
120 | + self._total_positives += num_positives | ||
121 | + else: | ||
122 | + self._total_positives += numpy.size( | ||
123 | + numpy.where(numpy.array(actuals) > 1e-5)) | ||
124 | + topk = self._top_n | ||
125 | + heap = self._heap | ||
126 | + | ||
127 | + for i in range(numpy.size(predictions)): | ||
128 | + if topk is None or len(heap) < topk: | ||
129 | + heapq.heappush(heap, (predictions[i], actuals[i])) | ||
130 | + else: | ||
131 | + if predictions[i] > heap[0][0]: # heap[0] is the smallest | ||
132 | + heapq.heappop(heap) | ||
133 | + heapq.heappush(heap, (predictions[i], actuals[i])) | ||
134 | + | ||
135 | + def clear(self): | ||
136 | + """Clear the accumulated predictions.""" | ||
137 | + self._heap = [] | ||
138 | + self._total_positives = 0 | ||
139 | + | ||
140 | + def peek_ap_at_n(self): | ||
141 | + """Peek the non-interpolated average precision at n. | ||
142 | + | ||
143 | + Returns: | ||
144 | + The non-interpolated average precision at n (default 0). | ||
145 | + If n is larger than the length of the ranked list, | ||
146 | + the average precision will be returned. | ||
147 | + """ | ||
148 | + if self.heap_size <= 0: | ||
149 | + return 0 | ||
150 | + predlists = numpy.array(list(zip(*self._heap))) | ||
151 | + | ||
152 | + ap = self.ap_at_n(predlists[0], | ||
153 | + predlists[1], | ||
154 | + n=self._top_n, | ||
155 | + total_num_positives=self._total_positives) | ||
156 | + return ap | ||
157 | + | ||
158 | + @staticmethod | ||
159 | + def ap(predictions, actuals): | ||
160 | + """Calculate the non-interpolated average precision. | ||
161 | + | ||
162 | + Args: | ||
163 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
164 | + actuals: a numpy 1-D array storing the ground truth labels. Any value | ||
165 | + larger than 0 will be treated as positives, otherwise as negatives. | ||
166 | + | ||
167 | + Returns: | ||
168 | + The non-interpolated average precision at n. | ||
169 | + If n is larger than the length of the ranked list, | ||
170 | + the average precision will be returned. | ||
171 | + | ||
172 | + Raises: | ||
173 | + ValueError: An error occurred when the format of the input is not the | ||
174 | + numpy 1-D array or the shape of predictions and actuals does not match. | ||
175 | + """ | ||
176 | + return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None) | ||
177 | + | ||
178 | + @staticmethod | ||
179 | + def ap_at_n(predictions, actuals, n=20, total_num_positives=None): | ||
180 | + """Calculate the non-interpolated average precision. | ||
181 | + | ||
182 | + Args: | ||
183 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
184 | + actuals: a numpy 1-D array storing the ground truth labels. Any value | ||
185 | + larger than 0 will be treated as positives, otherwise as negatives. | ||
186 | + n: the top n items to be considered in ap@n. | ||
187 | + total_num_positives : (optionally) you can specify the number of total | ||
188 | + positive in the list. If specified, it will be used in calculation. | ||
189 | + | ||
190 | + Returns: | ||
191 | + The non-interpolated average precision at n. | ||
192 | + If n is larger than the length of the ranked list, | ||
193 | + the average precision will be returned. | ||
194 | + | ||
195 | + Raises: | ||
196 | + ValueError: An error occurred when | ||
197 | + 1) the format of the input is not the numpy 1-D array; | ||
198 | + 2) the shape of predictions and actuals does not match; | ||
199 | + 3) the input n is not a positive integer. | ||
200 | + """ | ||
201 | + if len(predictions) != len(actuals): | ||
202 | + raise ValueError("the shape of predictions and actuals does not match.") | ||
203 | + | ||
204 | + if n is not None: | ||
205 | + if not isinstance(n, int) or n <= 0: | ||
206 | + raise ValueError("n must be 'None' or a positive integer." | ||
207 | + " It was '%s'." % n) | ||
208 | + | ||
209 | + ap = 0.0 | ||
210 | + | ||
211 | + predictions = numpy.array(predictions) | ||
212 | + actuals = numpy.array(actuals) | ||
213 | + | ||
214 | + # add a shuffler to avoid overestimating the ap | ||
215 | + predictions, actuals = AveragePrecisionCalculator._shuffle( | ||
216 | + predictions, actuals) | ||
217 | + sortidx = sorted(range(len(predictions)), | ||
218 | + key=lambda k: predictions[k], | ||
219 | + reverse=True) | ||
220 | + | ||
221 | + if total_num_positives is None: | ||
222 | + numpos = numpy.size(numpy.where(actuals > 0)) | ||
223 | + else: | ||
224 | + numpos = total_num_positives | ||
225 | + | ||
226 | + if numpos == 0: | ||
227 | + return 0 | ||
228 | + | ||
229 | + if n is not None: | ||
230 | + numpos = min(numpos, n) | ||
231 | + delta_recall = 1.0 / numpos | ||
232 | + poscount = 0.0 | ||
233 | + | ||
234 | + # calculate the ap | ||
235 | + r = len(sortidx) | ||
236 | + if n is not None: | ||
237 | + r = min(r, n) | ||
238 | + for i in range(r): | ||
239 | + if actuals[sortidx[i]] > 0: | ||
240 | + poscount += 1 | ||
241 | + ap += poscount / (i + 1) * delta_recall | ||
242 | + return ap | ||
243 | + | ||
244 | + @staticmethod | ||
245 | + def _shuffle(predictions, actuals): | ||
246 | + random.seed(0) | ||
247 | + suffidx = random.sample(range(len(predictions)), len(predictions)) | ||
248 | + predictions = predictions[suffidx] | ||
249 | + actuals = actuals[suffidx] | ||
250 | + return predictions, actuals | ||
251 | + | ||
252 | + @staticmethod | ||
253 | + def _zero_one_normalize(predictions, epsilon=1e-7): | ||
254 | + """Normalize the predictions to the range between 0.0 and 1.0. | ||
255 | + | ||
256 | + For some predictions like SVM predictions, we need to normalize them before | ||
257 | + calculate the interpolated average precision. The normalization will not | ||
258 | + change the rank in the original list and thus won't change the average | ||
259 | + precision. | ||
260 | + | ||
261 | + Args: | ||
262 | + predictions: a numpy 1-D array storing the sparse prediction scores. | ||
263 | + epsilon: a small constant to avoid denominator being zero. | ||
264 | + | ||
265 | + Returns: | ||
266 | + The normalized prediction. | ||
267 | + """ | ||
268 | + denominator = numpy.max(predictions) - numpy.min(predictions) | ||
269 | + ret = (predictions - numpy.min(predictions)) / numpy.max( | ||
270 | + denominator, epsilon) | ||
271 | + return ret |
yt8m/convert_prediction_from_json_to_csv.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Utility to convert the output of batch prediction into a CSV submission. | ||
15 | + | ||
16 | +It converts the JSON files created by the command | ||
17 | +'gcloud beta ml jobs submit prediction' into a CSV file ready for submission. | ||
18 | +""" | ||
19 | + | ||
20 | +import json | ||
21 | +import tensorflow as tf | ||
22 | + | ||
23 | +from builtins import range | ||
24 | +from tensorflow import app | ||
25 | +from tensorflow import flags | ||
26 | +from tensorflow import gfile | ||
27 | +from tensorflow import logging | ||
28 | + | ||
29 | +FLAGS = flags.FLAGS | ||
30 | + | ||
31 | +if __name__ == "__main__": | ||
32 | + | ||
33 | + flags.DEFINE_string( | ||
34 | + "json_prediction_files_pattern", None, | ||
35 | + "Pattern specifying the list of JSON files that the command " | ||
36 | + "'gcloud beta ml jobs submit prediction' outputs. These files are " | ||
37 | + "located in the output path of the prediction command and are prefixed " | ||
38 | + "with 'prediction.results'.") | ||
39 | + flags.DEFINE_string( | ||
40 | + "csv_output_file", None, | ||
41 | + "The file to save the predictions converted to the CSV format.") | ||
42 | + | ||
43 | + | ||
44 | +def get_csv_header(): | ||
45 | + return "VideoId,LabelConfidencePairs\n" | ||
46 | + | ||
47 | + | ||
48 | +def to_csv_row(json_data): | ||
49 | + | ||
50 | + video_id = json_data["video_id"] | ||
51 | + | ||
52 | + class_indexes = json_data["class_indexes"] | ||
53 | + predictions = json_data["predictions"] | ||
54 | + | ||
55 | + if isinstance(video_id, list): | ||
56 | + video_id = video_id[0] | ||
57 | + class_indexes = class_indexes[0] | ||
58 | + predictions = predictions[0] | ||
59 | + | ||
60 | + if len(class_indexes) != len(predictions): | ||
61 | + raise ValueError( | ||
62 | + "The number of indexes (%s) and predictions (%s) must be equal." % | ||
63 | + (len(class_indexes), len(predictions))) | ||
64 | + | ||
65 | + return (video_id.decode("utf-8") + "," + | ||
66 | + " ".join("%i %f" % (class_indexes[i], predictions[i]) | ||
67 | + for i in range(len(class_indexes))) + "\n") | ||
68 | + | ||
69 | + | ||
70 | +def main(unused_argv): | ||
71 | + logging.set_verbosity(tf.logging.INFO) | ||
72 | + | ||
73 | + if not FLAGS.json_prediction_files_pattern: | ||
74 | + raise ValueError( | ||
75 | + "The flag --json_prediction_files_pattern must be specified.") | ||
76 | + | ||
77 | + if not FLAGS.csv_output_file: | ||
78 | + raise ValueError("The flag --csv_output_file must be specified.") | ||
79 | + | ||
80 | + logging.info("Looking for prediction files with pattern: %s", | ||
81 | + FLAGS.json_prediction_files_pattern) | ||
82 | + | ||
83 | + file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern) | ||
84 | + logging.info("Found files: %s", file_paths) | ||
85 | + | ||
86 | + logging.info("Writing submission file to: %s", FLAGS.csv_output_file) | ||
87 | + with gfile.Open(FLAGS.csv_output_file, "w+") as output_file: | ||
88 | + output_file.write(get_csv_header()) | ||
89 | + | ||
90 | + for file_path in file_paths: | ||
91 | + logging.info("processing file: %s", file_path) | ||
92 | + | ||
93 | + with gfile.Open(file_path) as input_file: | ||
94 | + | ||
95 | + for line in input_file: | ||
96 | + json_data = json.loads(line) | ||
97 | + output_file.write(to_csv_row(json_data)) | ||
98 | + | ||
99 | + output_file.flush() | ||
100 | + logging.info("done") | ||
101 | + | ||
102 | + | ||
103 | +if __name__ == "__main__": | ||
104 | + app.run() |
yt8m/esot3ria/features.pb
0 → 100644
No preview for this file type
yt8m/esot3ria/inference_pb.py
0 → 100644
1 | +import numpy as np | ||
2 | +import tensorflow as tf | ||
3 | +from tensorflow import logging | ||
4 | +from tensorflow import gfile | ||
5 | +import esot3ria.pbutil as pbutil | ||
6 | + | ||
7 | + | ||
8 | +def get_segments(batch_video_mtx, batch_num_frames, segment_size): | ||
9 | + """Get segment-level inputs from frame-level features.""" | ||
10 | + video_batch_size = batch_video_mtx.shape[0] | ||
11 | + max_frame = batch_video_mtx.shape[1] | ||
12 | + feature_dim = batch_video_mtx.shape[-1] | ||
13 | + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size | ||
14 | + padded_segment_sizes *= segment_size | ||
15 | + segment_mask = ( | ||
16 | + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) | ||
17 | + | ||
18 | + # Segment bags. | ||
19 | + frame_bags = batch_video_mtx.reshape((-1, feature_dim)) | ||
20 | + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( | ||
21 | + (-1, segment_size, feature_dim)) | ||
22 | + | ||
23 | + # Segment num frames. | ||
24 | + segment_start_times = np.arange(0, max_frame, segment_size) | ||
25 | + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times | ||
26 | + num_segment_bags = num_segments.reshape((-1)) | ||
27 | + valid_segment_mask = num_segment_bags > 0 | ||
28 | + segment_num_frames = num_segment_bags[valid_segment_mask] | ||
29 | + segment_num_frames[segment_num_frames > segment_size] = segment_size | ||
30 | + | ||
31 | + max_segment_num = (max_frame + segment_size - 1) // segment_size | ||
32 | + video_idxs = np.tile( | ||
33 | + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) | ||
34 | + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) | ||
35 | + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) | ||
36 | + video_segment_ids = idx_bags[valid_segment_mask] | ||
37 | + | ||
38 | + return { | ||
39 | + "video_batch": segment_frames, | ||
40 | + "num_frames_batch": segment_num_frames, | ||
41 | + "video_segment_ids": video_segment_ids | ||
42 | + } | ||
43 | + | ||
44 | + | ||
45 | +def format_prediction(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ||
46 | + batch_size = len(video_ids) | ||
47 | + for video_index in range(batch_size): | ||
48 | + video_prediction = predictions[video_index] | ||
49 | + if whitelisted_cls_mask is not None: | ||
50 | + # Whitelist classes. | ||
51 | + video_prediction *= whitelisted_cls_mask | ||
52 | + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] | ||
53 | + line = [(class_index, predictions[video_index][class_index]) | ||
54 | + for class_index in top_indices] | ||
55 | + line = sorted(line, key=lambda p: -p[1]) | ||
56 | + return (video_ids[video_index] + "," + | ||
57 | + " ".join("%i %g" % (label, score) for (label, score) in line) + | ||
58 | + "\n").encode("utf8") | ||
59 | + | ||
60 | + | ||
61 | +def inference_pb(filename): | ||
62 | + with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: | ||
63 | + | ||
64 | + # 200527 Esot3riA | ||
65 | + # 0. Import SequenceExample type target from pb. | ||
66 | + target_video = pbutil.convert_pb(filename) | ||
67 | + | ||
68 | + # 1. Load video features from pb. | ||
69 | + video_id_batch_val = np.array([b'video']) | ||
70 | + n_frames = len(target_video.feature_lists.feature_list['rgb'].feature) | ||
71 | + # Restrict frame size to 300 | ||
72 | + if n_frames > 300: | ||
73 | + n_frames = 300 | ||
74 | + video_batch_val = np.zeros((300, 1152)) | ||
75 | + for i in range(n_frames): | ||
76 | + video_batch_rgb_raw = target_video.feature_lists.feature_list['rgb'].feature[i].bytes_list.value[0] | ||
77 | + video_batch_rgb = np.array(tf.cast(tf.decode_raw(video_batch_rgb_raw, tf.float32), tf.float32).eval()) | ||
78 | + video_batch_audio_raw = target_video.feature_lists.feature_list['audio'].feature[i].bytes_list.value[0] | ||
79 | + video_batch_audio = np.array(tf.cast(tf.decode_raw(video_batch_audio_raw, tf.float32), tf.float32).eval()) | ||
80 | + video_batch_val[i] = np.concatenate([video_batch_rgb, video_batch_audio], axis=0) | ||
81 | + video_batch_val = np.array([video_batch_val]) | ||
82 | + num_frames_batch_val = np.array([n_frames]) | ||
83 | + # 200527 Esot3riA End | ||
84 | + | ||
85 | + # Restore checkpoint and meta-graph file | ||
86 | + checkpoint_file = '/Users/esot3ria/PycharmProjects/yt8m/models/frame' \ | ||
87 | + '/sample_model/inference_model/segment_inference_model' | ||
88 | + if not gfile.Exists(checkpoint_file + ".meta"): | ||
89 | + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) | ||
90 | + meta_graph_location = checkpoint_file + ".meta" | ||
91 | + logging.info("loading meta-graph: " + meta_graph_location) | ||
92 | + | ||
93 | + with tf.device("/cpu:0"): | ||
94 | + saver = tf.train.import_meta_graph(meta_graph_location, | ||
95 | + clear_devices=True) | ||
96 | + logging.info("restoring variables from " + checkpoint_file) | ||
97 | + saver.restore(sess, checkpoint_file) | ||
98 | + input_tensor = tf.get_collection("input_batch_raw")[0] | ||
99 | + num_frames_tensor = tf.get_collection("num_frames")[0] | ||
100 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
101 | + | ||
102 | + # Workaround for num_epochs issue. | ||
103 | + def set_up_init_ops(variables): | ||
104 | + init_op_list = [] | ||
105 | + for variable in list(variables): | ||
106 | + if "train_input" in variable.name: | ||
107 | + init_op_list.append(tf.assign(variable, 1)) | ||
108 | + variables.remove(variable) | ||
109 | + init_op_list.append(tf.variables_initializer(variables)) | ||
110 | + return init_op_list | ||
111 | + | ||
112 | + sess.run( | ||
113 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | ||
114 | + | ||
115 | + coord = tf.train.Coordinator() | ||
116 | + threads = tf.train.start_queue_runners(sess=sess, coord=coord) | ||
117 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
118 | + dtype=np.float32) | ||
119 | + segment_label_ids_file = '../segment_label_ids.csv' | ||
120 | + with tf.io.gfile.GFile(segment_label_ids_file) as fobj: | ||
121 | + for line in fobj: | ||
122 | + try: | ||
123 | + cls_id = int(line) | ||
124 | + whitelisted_cls_mask[cls_id] = 1. | ||
125 | + except ValueError: | ||
126 | + # Simply skip the non-integer line. | ||
127 | + continue | ||
128 | + | ||
129 | + # 200527 Esot3riA | ||
130 | + # 2. Make segment features. | ||
131 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) | ||
132 | + video_segment_ids = results["video_segment_ids"] | ||
133 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | ||
134 | + video_id_batch_val = np.array([ | ||
135 | + "%s:%d" % (x.decode("utf8"), y) | ||
136 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | ||
137 | + ]) | ||
138 | + video_batch_val = results["video_batch"] | ||
139 | + num_frames_batch_val = results["num_frames_batch"] | ||
140 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | ||
141 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " | ||
142 | + "with correct segment_labels settings.") | ||
143 | + | ||
144 | + predictions_val, = sess.run([predictions_tensor], | ||
145 | + feed_dict={ | ||
146 | + input_tensor: video_batch_val, | ||
147 | + num_frames_tensor: num_frames_batch_val | ||
148 | + }) | ||
149 | + logging.info(predictions_val) | ||
150 | + logging.info("profit :D") | ||
151 | + | ||
152 | + # result = format_prediction(video_id_batch_val, predictions_val, 10, whitelisted_cls_mask) | ||
153 | + | ||
154 | + | ||
155 | +if __name__ == '__main__': | ||
156 | + logging.set_verbosity(tf.logging.INFO) | ||
157 | + | ||
158 | + filename = 'features.pb' | ||
159 | + inference_pb(filename) |
yt8m/esot3ria/pbutil.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import numpy | ||
3 | + | ||
4 | + | ||
5 | +def _make_bytes(int_array): | ||
6 | + if bytes == str: # Python2 | ||
7 | + return ''.join(map(chr, int_array)) | ||
8 | + else: | ||
9 | + return bytes(int_array) | ||
10 | + | ||
11 | + | ||
12 | +def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0): | ||
13 | + """Quantizes float32 `features` into string.""" | ||
14 | + assert features.dtype == 'float32' | ||
15 | + assert len(features.shape) == 1 # 1-D array | ||
16 | + features = numpy.clip(features, min_quantized_value, max_quantized_value) | ||
17 | + quantize_range = max_quantized_value - min_quantized_value | ||
18 | + features = (features - min_quantized_value) * (255.0 / quantize_range) | ||
19 | + features = [int(round(f)) for f in features] | ||
20 | + | ||
21 | + return _make_bytes(features) | ||
22 | + | ||
23 | + | ||
24 | +# for parse feature.pb | ||
25 | + | ||
26 | +contexts = { | ||
27 | + 'AUDIO/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
28 | + 'AUDIO/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
29 | + 'RGB/feature/dimensions': tf.io.FixedLenFeature([], tf.int64), | ||
30 | + 'RGB/feature/rate': tf.io.FixedLenFeature([], tf.float32), | ||
31 | + 'clip/data_path': tf.io.FixedLenFeature([], tf.string), | ||
32 | + 'clip/end/timestamp': tf.io.FixedLenFeature([], tf.int64), | ||
33 | + 'clip/start/timestamp': tf.io.FixedLenFeature([], tf.int64) | ||
34 | +} | ||
35 | + | ||
36 | +features = { | ||
37 | + 'AUDIO/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
38 | + 'AUDIO/feature/timestamp': tf.io.VarLenFeature(tf.int64), | ||
39 | + 'RGB/feature/floats': tf.io.VarLenFeature(dtype=tf.float32), | ||
40 | + 'RGB/feature/timestamp': tf.io.VarLenFeature(tf.int64) | ||
41 | + | ||
42 | +} | ||
43 | + | ||
44 | + | ||
45 | +def parse_exmp(serial_exmp): | ||
46 | + _, sequence_parsed = tf.io.parse_single_sequence_example( | ||
47 | + serialized=serial_exmp, | ||
48 | + context_features=contexts, | ||
49 | + sequence_features=features) | ||
50 | + | ||
51 | + sequence_parsed = tf.contrib.learn.run_n(sequence_parsed)[0] | ||
52 | + | ||
53 | + audio = sequence_parsed['AUDIO/feature/floats'].values | ||
54 | + rgb = sequence_parsed['RGB/feature/floats'].values | ||
55 | + | ||
56 | + # print(audio.values) | ||
57 | + # print(type(audio.values)) | ||
58 | + | ||
59 | + # audio is 128 8bit, rgb is 1024 8bit for every second | ||
60 | + audio_slices = [audio[128 * i: 128 * (i + 1)] for i in range(len(audio) // 128)] | ||
61 | + rgb_slices = [rgb[1024 * i: 1024 * (i + 1)] for i in range(len(rgb) // 1024)] | ||
62 | + | ||
63 | + byte_audio = [] | ||
64 | + byte_rgb = [] | ||
65 | + | ||
66 | + for seg in audio_slices: | ||
67 | + # audio_seg = quantize(seg) | ||
68 | + audio_seg = _make_bytes(seg) | ||
69 | + byte_audio.append(audio_seg) | ||
70 | + | ||
71 | + for seg in rgb_slices: | ||
72 | + # rgb_seg = quantize(seg) | ||
73 | + rgb_seg = _make_bytes(seg) | ||
74 | + byte_rgb.append(rgb_seg) | ||
75 | + | ||
76 | + return byte_audio, byte_rgb | ||
77 | + | ||
78 | + | ||
79 | +def make_exmp(id, audio, rgb): | ||
80 | + audio_features = [] | ||
81 | + rgb_features = [] | ||
82 | + | ||
83 | + for embedding in audio: | ||
84 | + embedding_feature = tf.train.Feature( | ||
85 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
86 | + audio_features.append(embedding_feature) | ||
87 | + | ||
88 | + for embedding in rgb: | ||
89 | + embedding_feature = tf.train.Feature( | ||
90 | + bytes_list=tf.train.BytesList(value=[embedding])) | ||
91 | + rgb_features.append(embedding_feature) | ||
92 | + | ||
93 | + # for construct yt8m data | ||
94 | + seq_exmp = tf.train.SequenceExample( | ||
95 | + context=tf.train.Features( | ||
96 | + feature={ | ||
97 | + 'id': tf.train.Feature(bytes_list=tf.train.BytesList( | ||
98 | + value=[id.encode('utf-8')])) | ||
99 | + }), | ||
100 | + feature_lists=tf.train.FeatureLists( | ||
101 | + feature_list={ | ||
102 | + 'audio': tf.train.FeatureList( | ||
103 | + feature=audio_features | ||
104 | + ), | ||
105 | + 'rgb': tf.train.FeatureList( | ||
106 | + feature=rgb_features | ||
107 | + ) | ||
108 | + }) | ||
109 | + ) | ||
110 | + serialized = seq_exmp.SerializeToString() | ||
111 | + return serialized | ||
112 | + | ||
113 | + | ||
114 | +def convert_pb(filename): | ||
115 | + sequence_example = open(filename, 'rb').read() | ||
116 | + | ||
117 | + audio, rgb = parse_exmp(sequence_example) | ||
118 | + tmp_example = make_exmp('video', audio, rgb) | ||
119 | + | ||
120 | + decoded = tf.train.SequenceExample.FromString(tmp_example) | ||
121 | + return decoded |
yt8m/esot3ria/readpb.py
0 → 100644
1 | +import tensorflow as tf | ||
2 | +import numpy as np | ||
3 | + | ||
4 | +frame_lvl_record = "test0000.tfrecord" | ||
5 | + | ||
6 | +feat_rgb = [] | ||
7 | +feat_audio = [] | ||
8 | + | ||
9 | +for example in tf.python_io.tf_record_iterator(frame_lvl_record): | ||
10 | + tf_seq_example = tf.train.SequenceExample.FromString(example) | ||
11 | + test = tf_seq_example.SerializeToString() | ||
12 | + n_frames = len(tf_seq_example.feature_lists.feature_list['audio'].feature) | ||
13 | + sess = tf.InteractiveSession() | ||
14 | + rgb_frame = [] | ||
15 | + audio_frame = [] | ||
16 | + # iterate through frames | ||
17 | + for i in range(n_frames): | ||
18 | + rgb_frame.append(tf.cast(tf.decode_raw( | ||
19 | + tf_seq_example.feature_lists.feature_list['rgb'] | ||
20 | + .feature[i].bytes_list.value[0], tf.uint8) | ||
21 | + , tf.float32).eval()) | ||
22 | + audio_frame.append(tf.cast(tf.decode_raw( | ||
23 | + tf_seq_example.feature_lists.feature_list['audio'] | ||
24 | + .feature[i].bytes_list.value[0], tf.uint8) | ||
25 | + , tf.float32).eval()) | ||
26 | + | ||
27 | + sess.close() | ||
28 | + | ||
29 | + feat_audio.append(audio_frame) | ||
30 | + feat_rgb.append(rgb_frame) | ||
31 | + break | ||
32 | + | ||
33 | +print('The first video has %d frames' %len(feat_rgb[0])) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
yt8m/esot3ria/test0000.tfrecord
0 → 100644
No preview for this file type
yt8m/eval.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Binary for evaluating Tensorflow models on the YouTube-8M dataset.""" | ||
15 | + | ||
16 | +import json | ||
17 | +import os | ||
18 | +import time | ||
19 | + | ||
20 | +from absl import logging | ||
21 | +import eval_util | ||
22 | +import frame_level_models | ||
23 | +import losses | ||
24 | +import readers | ||
25 | +import tensorflow as tf | ||
26 | +from tensorflow import flags | ||
27 | +from tensorflow.python.lib.io import file_io | ||
28 | +import utils | ||
29 | +import video_level_models | ||
30 | + | ||
31 | +FLAGS = flags.FLAGS | ||
32 | + | ||
33 | +if __name__ == "__main__": | ||
34 | + # Dataset flags. | ||
35 | + flags.DEFINE_string( | ||
36 | + "train_dir", "/tmp/yt8m_model/", | ||
37 | + "The directory to load the model files from. " | ||
38 | + "The tensorboard metrics files are also saved to this " | ||
39 | + "directory.") | ||
40 | + flags.DEFINE_string( | ||
41 | + "eval_data_pattern", "", | ||
42 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
43 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
44 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
45 | + flags.DEFINE_bool( | ||
46 | + "segment_labels", False, | ||
47 | + "If set, then --eval_data_pattern must be frame-level features (but with" | ||
48 | + " segment_labels). Otherwise, --eval_data_pattern must be aggregated " | ||
49 | + "video-level features. The model must also be set appropriately (i.e. to " | ||
50 | + "read 3D batches VS 4D batches.") | ||
51 | + | ||
52 | + # Other flags. | ||
53 | + flags.DEFINE_integer("batch_size", 1024, | ||
54 | + "How many examples to process per batch.") | ||
55 | + flags.DEFINE_integer("num_readers", 8, | ||
56 | + "How many threads to use for reading input files.") | ||
57 | + flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.") | ||
58 | + flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.") | ||
59 | + | ||
60 | + | ||
61 | +def find_class_by_name(name, modules): | ||
62 | + """Searches the provided modules for the named class and returns it.""" | ||
63 | + modules = [getattr(module, name, None) for module in modules] | ||
64 | + return next(a for a in modules if a) | ||
65 | + | ||
66 | + | ||
67 | +def get_input_evaluation_tensors(reader, | ||
68 | + data_pattern, | ||
69 | + batch_size=1024, | ||
70 | + num_readers=1): | ||
71 | + """Creates the section of the graph which reads the evaluation data. | ||
72 | + | ||
73 | + Args: | ||
74 | + reader: A class which parses the training data. | ||
75 | + data_pattern: A 'glob' style path to the data files. | ||
76 | + batch_size: How many examples to process at a time. | ||
77 | + num_readers: How many I/O threads to use. | ||
78 | + | ||
79 | + Returns: | ||
80 | + A tuple containing the features tensor, labels tensor, and optionally a | ||
81 | + tensor containing the number of frames per video. The exact dimensions | ||
82 | + depend on the reader being used. | ||
83 | + | ||
84 | + Raises: | ||
85 | + IOError: If no files matching the given pattern were found. | ||
86 | + """ | ||
87 | + logging.info("Using batch size of %d for evaluation.", batch_size) | ||
88 | + with tf.name_scope("eval_input"): | ||
89 | + files = tf.io.gfile.glob(data_pattern) | ||
90 | + if not files: | ||
91 | + raise IOError("Unable to find the evaluation files.") | ||
92 | + logging.info("number of evaluation files: %d", len(files)) | ||
93 | + filename_queue = tf.train.string_input_producer(files, | ||
94 | + shuffle=False, | ||
95 | + num_epochs=1) | ||
96 | + eval_data = [ | ||
97 | + reader.prepare_reader(filename_queue) for _ in range(num_readers) | ||
98 | + ] | ||
99 | + return tf.train.batch_join(eval_data, | ||
100 | + batch_size=batch_size, | ||
101 | + capacity=3 * batch_size, | ||
102 | + allow_smaller_final_batch=True, | ||
103 | + enqueue_many=True) | ||
104 | + | ||
105 | + | ||
106 | +def build_graph(reader, | ||
107 | + model, | ||
108 | + eval_data_pattern, | ||
109 | + label_loss_fn, | ||
110 | + batch_size=1024, | ||
111 | + num_readers=1): | ||
112 | + """Creates the Tensorflow graph for evaluation. | ||
113 | + | ||
114 | + Args: | ||
115 | + reader: The data file reader. It should inherit from BaseReader. | ||
116 | + model: The core model (e.g. logistic or neural net). It should inherit from | ||
117 | + BaseModel. | ||
118 | + eval_data_pattern: glob path to the evaluation data files. | ||
119 | + label_loss_fn: What kind of loss to apply to the model. It should inherit | ||
120 | + from BaseLoss. | ||
121 | + batch_size: How many examples to process at a time. | ||
122 | + num_readers: How many threads to use for I/O operations. | ||
123 | + """ | ||
124 | + | ||
125 | + global_step = tf.Variable(0, trainable=False, name="global_step") | ||
126 | + input_data_dict = get_input_evaluation_tensors(reader, | ||
127 | + eval_data_pattern, | ||
128 | + batch_size=batch_size, | ||
129 | + num_readers=num_readers) | ||
130 | + video_id_batch = input_data_dict["video_ids"] | ||
131 | + model_input_raw = input_data_dict["video_matrix"] | ||
132 | + labels_batch = input_data_dict["labels"] | ||
133 | + num_frames = input_data_dict["num_frames"] | ||
134 | + tf.compat.v1.summary.histogram("model_input_raw", model_input_raw) | ||
135 | + | ||
136 | + feature_dim = len(model_input_raw.get_shape()) - 1 | ||
137 | + | ||
138 | + # Normalize input features. | ||
139 | + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) | ||
140 | + | ||
141 | + with tf.compat.v1.variable_scope("tower"): | ||
142 | + result = model.create_model(model_input, | ||
143 | + num_frames=num_frames, | ||
144 | + vocab_size=reader.num_classes, | ||
145 | + labels=labels_batch, | ||
146 | + is_training=False) | ||
147 | + | ||
148 | + predictions = result["predictions"] | ||
149 | + tf.compat.v1.summary.histogram("model_activations", predictions) | ||
150 | + if "loss" in result.keys(): | ||
151 | + label_loss = result["loss"] | ||
152 | + else: | ||
153 | + label_loss = label_loss_fn.calculate_loss(predictions, labels_batch) | ||
154 | + | ||
155 | + tf.compat.v1.add_to_collection("global_step", global_step) | ||
156 | + tf.compat.v1.add_to_collection("loss", label_loss) | ||
157 | + tf.compat.v1.add_to_collection("predictions", predictions) | ||
158 | + tf.compat.v1.add_to_collection("input_batch", model_input) | ||
159 | + tf.compat.v1.add_to_collection("input_batch_raw", model_input_raw) | ||
160 | + tf.compat.v1.add_to_collection("video_id_batch", video_id_batch) | ||
161 | + tf.compat.v1.add_to_collection("num_frames", num_frames) | ||
162 | + tf.compat.v1.add_to_collection("labels", tf.cast(labels_batch, tf.float32)) | ||
163 | + if FLAGS.segment_labels: | ||
164 | + tf.compat.v1.add_to_collection("label_weights", | ||
165 | + input_data_dict["label_weights"]) | ||
166 | + tf.compat.v1.add_to_collection("summary_op", tf.compat.v1.summary.merge_all()) | ||
167 | + | ||
168 | + | ||
169 | +def evaluation_loop(fetches, saver, summary_writer, evl_metrics, | ||
170 | + last_global_step_val): | ||
171 | + """Run the evaluation loop once. | ||
172 | + | ||
173 | + Args: | ||
174 | + fetches: a dict of tensors to be run within Session. | ||
175 | + saver: a tensorflow saver to restore the model. | ||
176 | + summary_writer: a tensorflow summary_writer | ||
177 | + evl_metrics: an EvaluationMetrics object. | ||
178 | + last_global_step_val: the global step used in the previous evaluation. | ||
179 | + | ||
180 | + Returns: | ||
181 | + The global_step used in the latest model. | ||
182 | + """ | ||
183 | + | ||
184 | + global_step_val = -1 | ||
185 | + with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( | ||
186 | + allow_growth=True))) as sess: | ||
187 | + latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir) | ||
188 | + if latest_checkpoint: | ||
189 | + logging.info("Loading checkpoint for eval: %s", latest_checkpoint) | ||
190 | + # Restores from checkpoint | ||
191 | + saver.restore(sess, latest_checkpoint) | ||
192 | + # Assuming model_checkpoint_path looks something like: | ||
193 | + # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it. | ||
194 | + global_step_val = os.path.basename(latest_checkpoint).split("-")[-1] | ||
195 | + | ||
196 | + # Save model | ||
197 | + if FLAGS.segment_labels: | ||
198 | + inference_model_name = "segment_inference_model" | ||
199 | + else: | ||
200 | + inference_model_name = "inference_model" | ||
201 | + saver.save( | ||
202 | + sess, | ||
203 | + os.path.join(FLAGS.train_dir, "inference_model", | ||
204 | + inference_model_name)) | ||
205 | + else: | ||
206 | + logging.info("No checkpoint file found.") | ||
207 | + return global_step_val | ||
208 | + | ||
209 | + if global_step_val == last_global_step_val: | ||
210 | + logging.info( | ||
211 | + "skip this checkpoint global_step_val=%s " | ||
212 | + "(same as the previous one).", global_step_val) | ||
213 | + return global_step_val | ||
214 | + | ||
215 | + sess.run([tf.local_variables_initializer()]) | ||
216 | + | ||
217 | + # Start the queue runners. | ||
218 | + coord = tf.train.Coordinator() | ||
219 | + try: | ||
220 | + threads = [] | ||
221 | + for qr in tf.compat.v1.get_collection(tf.GraphKeys.QUEUE_RUNNERS): | ||
222 | + threads.extend( | ||
223 | + qr.create_threads(sess, coord=coord, daemon=True, start=True)) | ||
224 | + logging.info("enter eval_once loop global_step_val = %s. ", | ||
225 | + global_step_val) | ||
226 | + | ||
227 | + evl_metrics.clear() | ||
228 | + | ||
229 | + examples_processed = 0 | ||
230 | + while not coord.should_stop(): | ||
231 | + batch_start_time = time.time() | ||
232 | + output_data_dict = sess.run(fetches) | ||
233 | + seconds_per_batch = time.time() - batch_start_time | ||
234 | + labels_val = output_data_dict["labels"] | ||
235 | + summary_val = output_data_dict["summary"] | ||
236 | + example_per_second = labels_val.shape[0] / seconds_per_batch | ||
237 | + examples_processed += labels_val.shape[0] | ||
238 | + | ||
239 | + predictions = output_data_dict["predictions"] | ||
240 | + if FLAGS.segment_labels: | ||
241 | + # This is a workaround to ignore the unrated labels. | ||
242 | + predictions *= output_data_dict["label_weights"] | ||
243 | + iteration_info_dict = evl_metrics.accumulate(predictions, labels_val, | ||
244 | + output_data_dict["loss"]) | ||
245 | + iteration_info_dict["examples_per_second"] = example_per_second | ||
246 | + | ||
247 | + iterinfo = utils.AddGlobalStepSummary( | ||
248 | + summary_writer, | ||
249 | + global_step_val, | ||
250 | + iteration_info_dict, | ||
251 | + summary_scope="SegEval" if FLAGS.segment_labels else "Eval") | ||
252 | + logging.info("examples_processed: %d | %s", examples_processed, | ||
253 | + iterinfo) | ||
254 | + | ||
255 | + except tf.errors.OutOfRangeError as e: | ||
256 | + logging.info( | ||
257 | + "Done with batched inference. Now calculating global performance " | ||
258 | + "metrics.") | ||
259 | + # calculate the metrics for the entire epoch | ||
260 | + epoch_info_dict = evl_metrics.get() | ||
261 | + epoch_info_dict["epoch_id"] = global_step_val | ||
262 | + | ||
263 | + summary_writer.add_summary(summary_val, global_step_val) | ||
264 | + epochinfo = utils.AddEpochSummary( | ||
265 | + summary_writer, | ||
266 | + global_step_val, | ||
267 | + epoch_info_dict, | ||
268 | + summary_scope="SegEval" if FLAGS.segment_labels else "Eval") | ||
269 | + logging.info(epochinfo) | ||
270 | + evl_metrics.clear() | ||
271 | + except Exception as e: # pylint: disable=broad-except | ||
272 | + logging.info("Unexpected exception: %s", str(e)) | ||
273 | + coord.request_stop(e) | ||
274 | + | ||
275 | + coord.request_stop() | ||
276 | + coord.join(threads, stop_grace_period_secs=10) | ||
277 | + logging.info("Total: examples_processed: %d", examples_processed) | ||
278 | + | ||
279 | + return global_step_val | ||
280 | + | ||
281 | + | ||
282 | +def evaluate(): | ||
283 | + """Starts main evaluation loop.""" | ||
284 | + tf.compat.v1.set_random_seed(0) # for reproducibility | ||
285 | + | ||
286 | + # Write json of flags | ||
287 | + model_flags_path = os.path.join(FLAGS.train_dir, "model_flags.json") | ||
288 | + if not file_io.file_exists(model_flags_path): | ||
289 | + raise IOError(("Cannot find file %s. Did you run train.py on the same " | ||
290 | + "--train_dir?") % model_flags_path) | ||
291 | + flags_dict = json.loads(file_io.FileIO(model_flags_path, mode="r").read()) | ||
292 | + | ||
293 | + with tf.Graph().as_default(): | ||
294 | + # convert feature_names and feature_sizes to lists of values | ||
295 | + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( | ||
296 | + flags_dict["feature_names"], flags_dict["feature_sizes"]) | ||
297 | + | ||
298 | + if flags_dict["frame_features"]: | ||
299 | + reader = readers.YT8MFrameFeatureReader( | ||
300 | + feature_names=feature_names, | ||
301 | + feature_sizes=feature_sizes, | ||
302 | + segment_labels=FLAGS.segment_labels) | ||
303 | + else: | ||
304 | + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, | ||
305 | + feature_sizes=feature_sizes) | ||
306 | + | ||
307 | + model = find_class_by_name(flags_dict["model"], | ||
308 | + [frame_level_models, video_level_models])() | ||
309 | + label_loss_fn = find_class_by_name(flags_dict["label_loss"], [losses])() | ||
310 | + | ||
311 | + if not FLAGS.eval_data_pattern: | ||
312 | + raise IOError("'eval_data_pattern' was not specified. Nothing to " | ||
313 | + "evaluate.") | ||
314 | + | ||
315 | + build_graph(reader=reader, | ||
316 | + model=model, | ||
317 | + eval_data_pattern=FLAGS.eval_data_pattern, | ||
318 | + label_loss_fn=label_loss_fn, | ||
319 | + num_readers=FLAGS.num_readers, | ||
320 | + batch_size=FLAGS.batch_size) | ||
321 | + logging.info("built evaluation graph") | ||
322 | + | ||
323 | + # A dict of tensors to be run in Session. | ||
324 | + fetches = { | ||
325 | + "video_id": tf.compat.v1.get_collection("video_id_batch")[0], | ||
326 | + "predictions": tf.compat.v1.get_collection("predictions")[0], | ||
327 | + "labels": tf.compat.v1.get_collection("labels")[0], | ||
328 | + "loss": tf.compat.v1.get_collection("loss")[0], | ||
329 | + "summary": tf.compat.v1.get_collection("summary_op")[0] | ||
330 | + } | ||
331 | + if FLAGS.segment_labels: | ||
332 | + fetches["label_weights"] = tf.compat.v1.get_collection("label_weights")[0] | ||
333 | + | ||
334 | + saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) | ||
335 | + summary_writer = tf.compat.v1.summary.FileWriter( | ||
336 | + os.path.join(FLAGS.train_dir, "eval"), | ||
337 | + graph=tf.compat.v1.get_default_graph()) | ||
338 | + | ||
339 | + evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k, | ||
340 | + None) | ||
341 | + | ||
342 | + last_global_step_val = -1 | ||
343 | + while True: | ||
344 | + last_global_step_val = evaluation_loop(fetches, saver, summary_writer, | ||
345 | + evl_metrics, last_global_step_val) | ||
346 | + if FLAGS.run_once: | ||
347 | + break | ||
348 | + | ||
349 | + | ||
350 | +def main(unused_argv): | ||
351 | + logging.set_verbosity(logging.INFO) | ||
352 | + logging.info("tensorflow version: %s", tf.__version__) | ||
353 | + evaluate() | ||
354 | + | ||
355 | + | ||
356 | +if __name__ == "__main__": | ||
357 | + tf.compat.v1.app.run() |
yt8m/eval_util.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Provides functions to help with evaluating models.""" | ||
15 | +import average_precision_calculator as ap_calculator | ||
16 | +import mean_average_precision_calculator as map_calculator | ||
17 | +import numpy | ||
18 | +from tensorflow.python.platform import gfile | ||
19 | + | ||
20 | + | ||
21 | +def flatten(l): | ||
22 | + """Merges a list of lists into a single list. """ | ||
23 | + return [item for sublist in l for item in sublist] | ||
24 | + | ||
25 | + | ||
26 | +def calculate_hit_at_one(predictions, actuals): | ||
27 | + """Performs a local (numpy) calculation of the hit at one. | ||
28 | + | ||
29 | + Args: | ||
30 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
31 | + 'batch' x 'num_classes'. | ||
32 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
33 | + 'num_classes'. | ||
34 | + | ||
35 | + Returns: | ||
36 | + float: The average hit at one across the entire batch. | ||
37 | + """ | ||
38 | + top_prediction = numpy.argmax(predictions, 1) | ||
39 | + hits = actuals[numpy.arange(actuals.shape[0]), top_prediction] | ||
40 | + return numpy.average(hits) | ||
41 | + | ||
42 | + | ||
43 | +def calculate_precision_at_equal_recall_rate(predictions, actuals): | ||
44 | + """Performs a local (numpy) calculation of the PERR. | ||
45 | + | ||
46 | + Args: | ||
47 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
48 | + 'batch' x 'num_classes'. | ||
49 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
50 | + 'num_classes'. | ||
51 | + | ||
52 | + Returns: | ||
53 | + float: The average precision at equal recall rate across the entire batch. | ||
54 | + """ | ||
55 | + aggregated_precision = 0.0 | ||
56 | + num_videos = actuals.shape[0] | ||
57 | + for row in numpy.arange(num_videos): | ||
58 | + num_labels = int(numpy.sum(actuals[row])) | ||
59 | + top_indices = numpy.argpartition(predictions[row], | ||
60 | + -num_labels)[-num_labels:] | ||
61 | + item_precision = 0.0 | ||
62 | + for label_index in top_indices: | ||
63 | + if predictions[row][label_index] > 0: | ||
64 | + item_precision += actuals[row][label_index] | ||
65 | + item_precision /= top_indices.size | ||
66 | + aggregated_precision += item_precision | ||
67 | + aggregated_precision /= num_videos | ||
68 | + return aggregated_precision | ||
69 | + | ||
70 | + | ||
71 | +def calculate_gap(predictions, actuals, top_k=20): | ||
72 | + """Performs a local (numpy) calculation of the global average precision. | ||
73 | + | ||
74 | + Only the top_k predictions are taken for each of the videos. | ||
75 | + | ||
76 | + Args: | ||
77 | + predictions: Matrix containing the outputs of the model. Dimensions are | ||
78 | + 'batch' x 'num_classes'. | ||
79 | + actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x | ||
80 | + 'num_classes'. | ||
81 | + top_k: How many predictions to use per video. | ||
82 | + | ||
83 | + Returns: | ||
84 | + float: The global average precision. | ||
85 | + """ | ||
86 | + gap_calculator = ap_calculator.AveragePrecisionCalculator() | ||
87 | + sparse_predictions, sparse_labels, num_positives = top_k_by_class( | ||
88 | + predictions, actuals, top_k) | ||
89 | + gap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels), | ||
90 | + sum(num_positives)) | ||
91 | + return gap_calculator.peek_ap_at_n() | ||
92 | + | ||
93 | + | ||
94 | +def top_k_by_class(predictions, labels, k=20): | ||
95 | + """Extracts the top k predictions for each video, sorted by class. | ||
96 | + | ||
97 | + Args: | ||
98 | + predictions: A numpy matrix containing the outputs of the model. Dimensions | ||
99 | + are 'batch' x 'num_classes'. | ||
100 | + k: the top k non-zero entries to preserve in each prediction. | ||
101 | + | ||
102 | + Returns: | ||
103 | + A tuple (predictions,labels, true_positives). 'predictions' and 'labels' | ||
104 | + are lists of lists of floats. 'true_positives' is a list of scalars. The | ||
105 | + length of the lists are equal to the number of classes. The entries in the | ||
106 | + predictions variable are probability predictions, and | ||
107 | + the corresponding entries in the labels variable are the ground truth for | ||
108 | + those predictions. The entries in 'true_positives' are the number of true | ||
109 | + positives for each class in the ground truth. | ||
110 | + | ||
111 | + Raises: | ||
112 | + ValueError: An error occurred when the k is not a positive integer. | ||
113 | + """ | ||
114 | + if k <= 0: | ||
115 | + raise ValueError("k must be a positive integer.") | ||
116 | + k = min(k, predictions.shape[1]) | ||
117 | + num_classes = predictions.shape[1] | ||
118 | + prediction_triplets = [] | ||
119 | + for video_index in range(predictions.shape[0]): | ||
120 | + prediction_triplets.extend( | ||
121 | + top_k_triplets(predictions[video_index], labels[video_index], k)) | ||
122 | + out_predictions = [[] for _ in range(num_classes)] | ||
123 | + out_labels = [[] for _ in range(num_classes)] | ||
124 | + for triplet in prediction_triplets: | ||
125 | + out_predictions[triplet[0]].append(triplet[1]) | ||
126 | + out_labels[triplet[0]].append(triplet[2]) | ||
127 | + out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)] | ||
128 | + | ||
129 | + return out_predictions, out_labels, out_true_positives | ||
130 | + | ||
131 | + | ||
132 | +def top_k_triplets(predictions, labels, k=20): | ||
133 | + """Get the top_k for a 1-d numpy array. | ||
134 | + | ||
135 | + Returns a sparse list of tuples in | ||
136 | + (prediction, class) format | ||
137 | + """ | ||
138 | + m = len(predictions) | ||
139 | + k = min(k, m) | ||
140 | + indices = numpy.argpartition(predictions, -k)[-k:] | ||
141 | + return [(index, predictions[index], labels[index]) for index in indices] | ||
142 | + | ||
143 | + | ||
144 | +class EvaluationMetrics(object): | ||
145 | + """A class to store the evaluation metrics.""" | ||
146 | + | ||
147 | + def __init__(self, num_class, top_k, top_n): | ||
148 | + """Construct an EvaluationMetrics object to store the evaluation metrics. | ||
149 | + | ||
150 | + Args: | ||
151 | + num_class: A positive integer specifying the number of classes. | ||
152 | + top_k: A positive integer specifying how many predictions are considered | ||
153 | + per video. | ||
154 | + top_n: A positive Integer specifying the average precision at n, or None | ||
155 | + to use all provided data points. | ||
156 | + | ||
157 | + Raises: | ||
158 | + ValueError: An error occurred when MeanAveragePrecisionCalculator cannot | ||
159 | + not be constructed. | ||
160 | + """ | ||
161 | + self.sum_hit_at_one = 0.0 | ||
162 | + self.sum_perr = 0.0 | ||
163 | + self.sum_loss = 0.0 | ||
164 | + self.map_calculator = map_calculator.MeanAveragePrecisionCalculator( | ||
165 | + num_class, top_n=top_n) | ||
166 | + self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator() | ||
167 | + self.top_k = top_k | ||
168 | + self.num_examples = 0 | ||
169 | + | ||
170 | + def accumulate(self, predictions, labels, loss): | ||
171 | + """Accumulate the metrics calculated locally for this mini-batch. | ||
172 | + | ||
173 | + Args: | ||
174 | + predictions: A numpy matrix containing the outputs of the model. | ||
175 | + Dimensions are 'batch' x 'num_classes'. | ||
176 | + labels: A numpy matrix containing the ground truth labels. Dimensions are | ||
177 | + 'batch' x 'num_classes'. | ||
178 | + loss: A numpy array containing the loss for each sample. | ||
179 | + | ||
180 | + Returns: | ||
181 | + dictionary: A dictionary storing the metrics for the mini-batch. | ||
182 | + | ||
183 | + Raises: | ||
184 | + ValueError: An error occurred when the shape of predictions and actuals | ||
185 | + does not match. | ||
186 | + """ | ||
187 | + batch_size = labels.shape[0] | ||
188 | + mean_hit_at_one = calculate_hit_at_one(predictions, labels) | ||
189 | + mean_perr = calculate_precision_at_equal_recall_rate(predictions, labels) | ||
190 | + mean_loss = numpy.mean(loss) | ||
191 | + | ||
192 | + # Take the top 20 predictions. | ||
193 | + sparse_predictions, sparse_labels, num_positives = top_k_by_class( | ||
194 | + predictions, labels, self.top_k) | ||
195 | + self.map_calculator.accumulate(sparse_predictions, sparse_labels, | ||
196 | + num_positives) | ||
197 | + self.global_ap_calculator.accumulate(flatten(sparse_predictions), | ||
198 | + flatten(sparse_labels), | ||
199 | + sum(num_positives)) | ||
200 | + | ||
201 | + self.num_examples += batch_size | ||
202 | + self.sum_hit_at_one += mean_hit_at_one * batch_size | ||
203 | + self.sum_perr += mean_perr * batch_size | ||
204 | + self.sum_loss += mean_loss * batch_size | ||
205 | + | ||
206 | + return {"hit_at_one": mean_hit_at_one, "perr": mean_perr, "loss": mean_loss} | ||
207 | + | ||
208 | + def get(self): | ||
209 | + """Calculate the evaluation metrics for the whole epoch. | ||
210 | + | ||
211 | + Raises: | ||
212 | + ValueError: If no examples were accumulated. | ||
213 | + | ||
214 | + Returns: | ||
215 | + dictionary: a dictionary storing the evaluation metrics for the epoch. The | ||
216 | + dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and | ||
217 | + aps (default nan). | ||
218 | + """ | ||
219 | + if self.num_examples <= 0: | ||
220 | + raise ValueError("total_sample must be positive.") | ||
221 | + avg_hit_at_one = self.sum_hit_at_one / self.num_examples | ||
222 | + avg_perr = self.sum_perr / self.num_examples | ||
223 | + avg_loss = self.sum_loss / self.num_examples | ||
224 | + | ||
225 | + aps = self.map_calculator.peek_map_at_n() | ||
226 | + gap = self.global_ap_calculator.peek_ap_at_n() | ||
227 | + | ||
228 | + epoch_info_dict = { | ||
229 | + "avg_hit_at_one": avg_hit_at_one, | ||
230 | + "avg_perr": avg_perr, | ||
231 | + "avg_loss": avg_loss, | ||
232 | + "aps": aps, | ||
233 | + "gap": gap | ||
234 | + } | ||
235 | + return epoch_info_dict | ||
236 | + | ||
237 | + def clear(self): | ||
238 | + """Clear the evaluation metrics and reset the EvaluationMetrics object.""" | ||
239 | + self.sum_hit_at_one = 0.0 | ||
240 | + self.sum_perr = 0.0 | ||
241 | + self.sum_loss = 0.0 | ||
242 | + self.map_calculator.clear() | ||
243 | + self.global_ap_calculator.clear() | ||
244 | + self.num_examples = 0 |
yt8m/export_model.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Utilities to export a model for batch prediction.""" | ||
15 | + | ||
16 | +import tensorflow as tf | ||
17 | +import tensorflow.contrib.slim as slim | ||
18 | + | ||
19 | +from tensorflow.python.saved_model import builder as saved_model_builder | ||
20 | +from tensorflow.python.saved_model import signature_constants | ||
21 | +from tensorflow.python.saved_model import signature_def_utils | ||
22 | +from tensorflow.python.saved_model import tag_constants | ||
23 | +from tensorflow.python.saved_model import utils as saved_model_utils | ||
24 | + | ||
25 | +_TOP_PREDICTIONS_IN_OUTPUT = 20 | ||
26 | + | ||
27 | + | ||
28 | +class ModelExporter(object): | ||
29 | + | ||
30 | + def __init__(self, frame_features, model, reader): | ||
31 | + self.frame_features = frame_features | ||
32 | + self.model = model | ||
33 | + self.reader = reader | ||
34 | + | ||
35 | + with tf.Graph().as_default() as graph: | ||
36 | + self.inputs, self.outputs = self.build_inputs_and_outputs() | ||
37 | + self.graph = graph | ||
38 | + self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True) | ||
39 | + | ||
40 | + def export_model(self, model_dir, global_step_val, last_checkpoint): | ||
41 | + """Exports the model so that it can used for batch predictions.""" | ||
42 | + | ||
43 | + with self.graph.as_default(): | ||
44 | + with tf.Session() as session: | ||
45 | + session.run(tf.global_variables_initializer()) | ||
46 | + self.saver.restore(session, last_checkpoint) | ||
47 | + | ||
48 | + signature = signature_def_utils.build_signature_def( | ||
49 | + inputs=self.inputs, | ||
50 | + outputs=self.outputs, | ||
51 | + method_name=signature_constants.PREDICT_METHOD_NAME) | ||
52 | + | ||
53 | + signature_map = { | ||
54 | + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature | ||
55 | + } | ||
56 | + | ||
57 | + model_builder = saved_model_builder.SavedModelBuilder(model_dir) | ||
58 | + model_builder.add_meta_graph_and_variables( | ||
59 | + session, | ||
60 | + tags=[tag_constants.SERVING], | ||
61 | + signature_def_map=signature_map, | ||
62 | + clear_devices=True) | ||
63 | + model_builder.save() | ||
64 | + | ||
65 | + def build_inputs_and_outputs(self): | ||
66 | + if self.frame_features: | ||
67 | + serialized_examples = tf.placeholder(tf.string, shape=(None,)) | ||
68 | + | ||
69 | + fn = lambda x: self.build_prediction_graph(x) | ||
70 | + video_id_output, top_indices_output, top_predictions_output = (tf.map_fn( | ||
71 | + fn, serialized_examples, dtype=(tf.string, tf.int32, tf.float32))) | ||
72 | + | ||
73 | + else: | ||
74 | + serialized_examples = tf.placeholder(tf.string, shape=(None,)) | ||
75 | + | ||
76 | + video_id_output, top_indices_output, top_predictions_output = ( | ||
77 | + self.build_prediction_graph(serialized_examples)) | ||
78 | + | ||
79 | + inputs = { | ||
80 | + "example_bytes": | ||
81 | + saved_model_utils.build_tensor_info(serialized_examples) | ||
82 | + } | ||
83 | + | ||
84 | + outputs = { | ||
85 | + "video_id": | ||
86 | + saved_model_utils.build_tensor_info(video_id_output), | ||
87 | + "class_indexes": | ||
88 | + saved_model_utils.build_tensor_info(top_indices_output), | ||
89 | + "predictions": | ||
90 | + saved_model_utils.build_tensor_info(top_predictions_output) | ||
91 | + } | ||
92 | + | ||
93 | + return inputs, outputs | ||
94 | + | ||
95 | + def build_prediction_graph(self, serialized_examples): | ||
96 | + input_data_dict = ( | ||
97 | + self.reader.prepare_serialized_examples(serialized_examples)) | ||
98 | + video_id = input_data_dict["video_ids"] | ||
99 | + model_input_raw = input_data_dict["video_matrix"] | ||
100 | + labels_batch = input_data_dict["labels"] | ||
101 | + num_frames = input_data_dict["num_frames"] | ||
102 | + | ||
103 | + feature_dim = len(model_input_raw.get_shape()) - 1 | ||
104 | + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) | ||
105 | + | ||
106 | + with tf.variable_scope("tower"): | ||
107 | + result = self.model.create_model(model_input, | ||
108 | + num_frames=num_frames, | ||
109 | + vocab_size=self.reader.num_classes, | ||
110 | + labels=labels_batch, | ||
111 | + is_training=False) | ||
112 | + | ||
113 | + for variable in slim.get_model_variables(): | ||
114 | + tf.summary.histogram(variable.op.name, variable) | ||
115 | + | ||
116 | + predictions = result["predictions"] | ||
117 | + | ||
118 | + top_predictions, top_indices = tf.nn.top_k(predictions, | ||
119 | + _TOP_PREDICTIONS_IN_OUTPUT) | ||
120 | + return video_id, top_indices, top_predictions |
yt8m/export_model_mediapipe.py
0 → 100644
1 | +# Lint as: python3 | ||
2 | +import numpy as np | ||
3 | +import tensorflow as tf | ||
4 | +from tensorflow import app | ||
5 | +from tensorflow import flags | ||
6 | + | ||
7 | +FLAGS = flags.FLAGS | ||
8 | + | ||
9 | + | ||
10 | +def main(unused_argv): | ||
11 | + # Get the input tensor names to be replaced. | ||
12 | + tf.reset_default_graph() | ||
13 | + meta_graph_location = FLAGS.checkpoint_file + ".meta" | ||
14 | + tf.train.import_meta_graph(meta_graph_location, clear_devices=True) | ||
15 | + | ||
16 | + input_tensor_name = tf.get_collection("input_batch_raw")[0].name | ||
17 | + num_frames_tensor_name = tf.get_collection("num_frames")[0].name | ||
18 | + | ||
19 | + # Create output graph. | ||
20 | + saver = tf.train.Saver() | ||
21 | + tf.reset_default_graph() | ||
22 | + | ||
23 | + input_feature_placeholder = tf.placeholder( | ||
24 | + tf.float32, shape=(None, None, 1152)) | ||
25 | + num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1)) | ||
26 | + | ||
27 | + saver = tf.train.import_meta_graph( | ||
28 | + meta_graph_location, | ||
29 | + input_map={ | ||
30 | + input_tensor_name: input_feature_placeholder, | ||
31 | + num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1) | ||
32 | + }, | ||
33 | + clear_devices=True) | ||
34 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
35 | + | ||
36 | + with tf.Session() as sess: | ||
37 | + print("restoring variables from " + FLAGS.checkpoint_file) | ||
38 | + saver.restore(sess, FLAGS.checkpoint_file) | ||
39 | + tf.saved_model.simple_save( | ||
40 | + sess, | ||
41 | + FLAGS.output_dir, | ||
42 | + inputs={'rgb_and_audio': input_feature_placeholder, | ||
43 | + 'num_frames': num_frames_placeholder}, | ||
44 | + outputs={'predictions': predictions_tensor}) | ||
45 | + | ||
46 | + # Try running inference. | ||
47 | + predictions = sess.run( | ||
48 | + [predictions_tensor], | ||
49 | + feed_dict={ | ||
50 | + input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32), | ||
51 | + num_frames_placeholder: np.array([[7]], dtype=np.int32)}) | ||
52 | + print('Test inference:', predictions) | ||
53 | + | ||
54 | + print('Model saved to ', FLAGS.output_dir) | ||
55 | + | ||
56 | + | ||
57 | +if __name__ == '__main__': | ||
58 | + flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.') | ||
59 | + flags.DEFINE_string('output_dir', None, 'SavedModel output directory.') | ||
60 | + app.run(main) |
yt8m/frame_level_models.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains a collection of models which operate on variable-length sequences.""" | ||
15 | +import math | ||
16 | + | ||
17 | +import model_utils as utils | ||
18 | +import models | ||
19 | +import tensorflow as tf | ||
20 | +from tensorflow import flags | ||
21 | +import tensorflow.contrib.slim as slim | ||
22 | +import video_level_models | ||
23 | + | ||
24 | +FLAGS = flags.FLAGS | ||
25 | +flags.DEFINE_integer("iterations", 30, "Number of frames per batch for DBoF.") | ||
26 | +flags.DEFINE_bool("dbof_add_batch_norm", True, | ||
27 | + "Adds batch normalization to the DBoF model.") | ||
28 | +flags.DEFINE_bool( | ||
29 | + "sample_random_frames", True, | ||
30 | + "If true samples random frames (for frame level models). If false, a random" | ||
31 | + "sequence of frames is sampled instead.") | ||
32 | +flags.DEFINE_integer("dbof_cluster_size", 8192, | ||
33 | + "Number of units in the DBoF cluster layer.") | ||
34 | +flags.DEFINE_integer("dbof_hidden_size", 1024, | ||
35 | + "Number of units in the DBoF hidden layer.") | ||
36 | +flags.DEFINE_string( | ||
37 | + "dbof_pooling_method", "max", | ||
38 | + "The pooling method used in the DBoF cluster layer. " | ||
39 | + "Choices are 'average' and 'max'.") | ||
40 | +flags.DEFINE_string( | ||
41 | + "dbof_activation", "sigmoid", | ||
42 | + "The nonlinear activation method for cluster and hidden dense layer, e.g., " | ||
43 | + "sigmoid, relu6, etc.") | ||
44 | +flags.DEFINE_string( | ||
45 | + "video_level_classifier_model", "MoeModel", | ||
46 | + "Some Frame-Level models can be decomposed into a " | ||
47 | + "generalized pooling operation followed by a " | ||
48 | + "classifier layer") | ||
49 | +flags.DEFINE_integer("lstm_cells", 1024, "Number of LSTM cells.") | ||
50 | +flags.DEFINE_integer("lstm_layers", 2, "Number of LSTM layers.") | ||
51 | + | ||
52 | + | ||
53 | +class FrameLevelLogisticModel(models.BaseModel): | ||
54 | + """Creates a logistic classifier over the aggregated frame-level features.""" | ||
55 | + | ||
56 | + def create_model(self, model_input, vocab_size, num_frames, **unused_params): | ||
57 | + """See base class. | ||
58 | + | ||
59 | + This class is intended to be an example for implementors of frame level | ||
60 | + models. If you want to train a model over averaged features it is more | ||
61 | + efficient to average them beforehand rather than on the fly. | ||
62 | + | ||
63 | + Args: | ||
64 | + model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of | ||
65 | + input features. | ||
66 | + vocab_size: The number of classes in the dataset. | ||
67 | + num_frames: A vector of length 'batch' which indicates the number of | ||
68 | + frames for each video (before padding). | ||
69 | + | ||
70 | + Returns: | ||
71 | + A dictionary with a tensor containing the probability predictions of the | ||
72 | + model in the 'predictions' key. The dimensions of the tensor are | ||
73 | + 'batch_size' x 'num_classes'. | ||
74 | + """ | ||
75 | + num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32) | ||
76 | + feature_size = model_input.get_shape().as_list()[2] | ||
77 | + | ||
78 | + denominators = tf.reshape(tf.tile(num_frames, [1, feature_size]), | ||
79 | + [-1, feature_size]) | ||
80 | + avg_pooled = tf.reduce_sum(model_input, axis=[1]) / denominators | ||
81 | + | ||
82 | + output = slim.fully_connected(avg_pooled, | ||
83 | + vocab_size, | ||
84 | + activation_fn=tf.nn.sigmoid, | ||
85 | + weights_regularizer=slim.l2_regularizer(1e-8)) | ||
86 | + return {"predictions": output} | ||
87 | + | ||
88 | + | ||
89 | +class DbofModel(models.BaseModel): | ||
90 | + """Creates a Deep Bag of Frames model. | ||
91 | + | ||
92 | + The model projects the features for each frame into a higher dimensional | ||
93 | + 'clustering' space, pools across frames in that space, and then | ||
94 | + uses a configurable video-level model to classify the now aggregated features. | ||
95 | + | ||
96 | + The model will randomly sample either frames or sequences of frames during | ||
97 | + training to speed up convergence. | ||
98 | + """ | ||
99 | + | ||
100 | + ACT_FN_MAP = { | ||
101 | + "sigmoid": tf.nn.sigmoid, | ||
102 | + "relu6": tf.nn.relu6, | ||
103 | + } | ||
104 | + | ||
105 | + def create_model(self, | ||
106 | + model_input, | ||
107 | + vocab_size, | ||
108 | + num_frames, | ||
109 | + iterations=None, | ||
110 | + add_batch_norm=None, | ||
111 | + sample_random_frames=None, | ||
112 | + cluster_size=None, | ||
113 | + hidden_size=None, | ||
114 | + is_training=True, | ||
115 | + **unused_params): | ||
116 | + """See base class. | ||
117 | + | ||
118 | + Args: | ||
119 | + model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of | ||
120 | + input features. | ||
121 | + vocab_size: The number of classes in the dataset. | ||
122 | + num_frames: A vector of length 'batch' which indicates the number of | ||
123 | + frames for each video (before padding). | ||
124 | + iterations: the number of frames to be sampled. | ||
125 | + add_batch_norm: whether to add batch norm during training. | ||
126 | + sample_random_frames: whether to sample random frames or random sequences. | ||
127 | + cluster_size: the output neuron number of the cluster layer. | ||
128 | + hidden_size: the output neuron number of the hidden layer. | ||
129 | + is_training: whether to build the graph in training mode. | ||
130 | + | ||
131 | + Returns: | ||
132 | + A dictionary with a tensor containing the probability predictions of the | ||
133 | + model in the 'predictions' key. The dimensions of the tensor are | ||
134 | + 'batch_size' x 'num_classes'. | ||
135 | + """ | ||
136 | + iterations = iterations or FLAGS.iterations | ||
137 | + add_batch_norm = add_batch_norm or FLAGS.dbof_add_batch_norm | ||
138 | + random_frames = sample_random_frames or FLAGS.sample_random_frames | ||
139 | + cluster_size = cluster_size or FLAGS.dbof_cluster_size | ||
140 | + hidden1_size = hidden_size or FLAGS.dbof_hidden_size | ||
141 | + act_fn = self.ACT_FN_MAP.get(FLAGS.dbof_activation) | ||
142 | + assert act_fn is not None, ("dbof_activation is not valid: %s." % | ||
143 | + FLAGS.dbof_activation) | ||
144 | + | ||
145 | + num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32) | ||
146 | + if random_frames: | ||
147 | + model_input = utils.SampleRandomFrames(model_input, num_frames, | ||
148 | + iterations) | ||
149 | + else: | ||
150 | + model_input = utils.SampleRandomSequence(model_input, num_frames, | ||
151 | + iterations) | ||
152 | + max_frames = model_input.get_shape().as_list()[1] | ||
153 | + feature_size = model_input.get_shape().as_list()[2] | ||
154 | + reshaped_input = tf.reshape(model_input, [-1, feature_size]) | ||
155 | + tf.compat.v1.summary.histogram("input_hist", reshaped_input) | ||
156 | + | ||
157 | + if add_batch_norm: | ||
158 | + reshaped_input = slim.batch_norm(reshaped_input, | ||
159 | + center=True, | ||
160 | + scale=True, | ||
161 | + is_training=is_training, | ||
162 | + scope="input_bn") | ||
163 | + | ||
164 | + cluster_weights = tf.compat.v1.get_variable( | ||
165 | + "cluster_weights", [feature_size, cluster_size], | ||
166 | + initializer=tf.random_normal_initializer(stddev=1 / | ||
167 | + math.sqrt(feature_size))) | ||
168 | + tf.compat.v1.summary.histogram("cluster_weights", cluster_weights) | ||
169 | + activation = tf.matmul(reshaped_input, cluster_weights) | ||
170 | + if add_batch_norm: | ||
171 | + activation = slim.batch_norm(activation, | ||
172 | + center=True, | ||
173 | + scale=True, | ||
174 | + is_training=is_training, | ||
175 | + scope="cluster_bn") | ||
176 | + else: | ||
177 | + cluster_biases = tf.compat.v1.get_variable( | ||
178 | + "cluster_biases", [cluster_size], | ||
179 | + initializer=tf.random_normal_initializer(stddev=1 / | ||
180 | + math.sqrt(feature_size))) | ||
181 | + tf.compat.v1.summary.histogram("cluster_biases", cluster_biases) | ||
182 | + activation += cluster_biases | ||
183 | + activation = act_fn(activation) | ||
184 | + tf.compat.v1.summary.histogram("cluster_output", activation) | ||
185 | + | ||
186 | + activation = tf.reshape(activation, [-1, max_frames, cluster_size]) | ||
187 | + activation = utils.FramePooling(activation, FLAGS.dbof_pooling_method) | ||
188 | + | ||
189 | + hidden1_weights = tf.compat.v1.get_variable( | ||
190 | + "hidden1_weights", [cluster_size, hidden1_size], | ||
191 | + initializer=tf.random_normal_initializer(stddev=1 / | ||
192 | + math.sqrt(cluster_size))) | ||
193 | + tf.compat.v1.summary.histogram("hidden1_weights", hidden1_weights) | ||
194 | + activation = tf.matmul(activation, hidden1_weights) | ||
195 | + if add_batch_norm: | ||
196 | + activation = slim.batch_norm(activation, | ||
197 | + center=True, | ||
198 | + scale=True, | ||
199 | + is_training=is_training, | ||
200 | + scope="hidden1_bn") | ||
201 | + else: | ||
202 | + hidden1_biases = tf.compat.v1.get_variable( | ||
203 | + "hidden1_biases", [hidden1_size], | ||
204 | + initializer=tf.random_normal_initializer(stddev=0.01)) | ||
205 | + tf.compat.v1.summary.histogram("hidden1_biases", hidden1_biases) | ||
206 | + activation += hidden1_biases | ||
207 | + activation = act_fn(activation) | ||
208 | + tf.compat.v1.summary.histogram("hidden1_output", activation) | ||
209 | + | ||
210 | + aggregated_model = getattr(video_level_models, | ||
211 | + FLAGS.video_level_classifier_model) | ||
212 | + return aggregated_model().create_model(model_input=activation, | ||
213 | + vocab_size=vocab_size, | ||
214 | + **unused_params) | ||
215 | + | ||
216 | + | ||
217 | +class LstmModel(models.BaseModel): | ||
218 | + """Creates a model which uses a stack of LSTMs to represent the video.""" | ||
219 | + | ||
220 | + def create_model(self, model_input, vocab_size, num_frames, **unused_params): | ||
221 | + """See base class. | ||
222 | + | ||
223 | + Args: | ||
224 | + model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of | ||
225 | + input features. | ||
226 | + vocab_size: The number of classes in the dataset. | ||
227 | + num_frames: A vector of length 'batch' which indicates the number of | ||
228 | + frames for each video (before padding). | ||
229 | + | ||
230 | + Returns: | ||
231 | + A dictionary with a tensor containing the probability predictions of the | ||
232 | + model in the 'predictions' key. The dimensions of the tensor are | ||
233 | + 'batch_size' x 'num_classes'. | ||
234 | + """ | ||
235 | + lstm_size = FLAGS.lstm_cells | ||
236 | + number_of_layers = FLAGS.lstm_layers | ||
237 | + | ||
238 | + stacked_lstm = tf.contrib.rnn.MultiRNNCell([ | ||
239 | + tf.contrib.rnn.BasicLSTMCell(lstm_size, forget_bias=1.0) | ||
240 | + for _ in range(number_of_layers) | ||
241 | + ]) | ||
242 | + | ||
243 | + _, state = tf.nn.dynamic_rnn(stacked_lstm, | ||
244 | + model_input, | ||
245 | + sequence_length=num_frames, | ||
246 | + dtype=tf.float32) | ||
247 | + | ||
248 | + aggregated_model = getattr(video_level_models, | ||
249 | + FLAGS.video_level_classifier_model) | ||
250 | + | ||
251 | + return aggregated_model().create_model(model_input=state[-1].h, | ||
252 | + vocab_size=vocab_size, | ||
253 | + **unused_params) |
yt8m/inference.py
0 → 100644
1 | +# Copyright 2017 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Binary for generating predictions over a set of videos.""" | ||
15 | + | ||
16 | +from __future__ import print_function | ||
17 | + | ||
18 | +import glob | ||
19 | +import heapq | ||
20 | +import json | ||
21 | +import os | ||
22 | +import tarfile | ||
23 | +import tempfile | ||
24 | +import time | ||
25 | +import numpy as np | ||
26 | + | ||
27 | +import readers | ||
28 | +from six.moves import urllib | ||
29 | +import tensorflow as tf | ||
30 | +from tensorflow import app | ||
31 | +from tensorflow import flags | ||
32 | +from tensorflow import gfile | ||
33 | +from tensorflow import logging | ||
34 | +from tensorflow.python.lib.io import file_io | ||
35 | +import utils | ||
36 | + | ||
37 | +FLAGS = flags.FLAGS | ||
38 | + | ||
39 | +if __name__ == "__main__": | ||
40 | + # Input | ||
41 | + flags.DEFINE_string( | ||
42 | + "train_dir", "", "The directory to load the model files from. We assume " | ||
43 | + "that you have already run eval.py onto this, such that " | ||
44 | + "inference_model.* files already exist.") | ||
45 | + flags.DEFINE_string( | ||
46 | + "input_data_pattern", "", | ||
47 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
48 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
49 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
50 | + flags.DEFINE_string( | ||
51 | + "input_model_tgz", "", | ||
52 | + "If given, must be path to a .tgz file that was written " | ||
53 | + "by this binary using flag --output_model_tgz. In this " | ||
54 | + "case, the .tgz file will be untarred to " | ||
55 | + "--untar_model_dir and the model will be used for " | ||
56 | + "inference.") | ||
57 | + flags.DEFINE_string( | ||
58 | + "untar_model_dir", "/tmp/yt8m-model", | ||
59 | + "If --input_model_tgz is given, then this directory will " | ||
60 | + "be created and the contents of the .tgz file will be " | ||
61 | + "untarred here.") | ||
62 | + flags.DEFINE_bool( | ||
63 | + "segment_labels", False, | ||
64 | + "If set, then --input_data_pattern must be frame-level features (but with" | ||
65 | + " segment_labels). Otherwise, --input_data_pattern must be aggregated " | ||
66 | + "video-level features. The model must also be set appropriately (i.e. to " | ||
67 | + "read 3D batches VS 4D batches.") | ||
68 | + flags.DEFINE_integer("segment_max_pred", 100000, | ||
69 | + "Limit total number of segment outputs per entity.") | ||
70 | + flags.DEFINE_string( | ||
71 | + "segment_label_ids_file", | ||
72 | + "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv", | ||
73 | + "The file that contains the segment label ids.") | ||
74 | + | ||
75 | + # Output | ||
76 | + flags.DEFINE_string("output_file", "", "The file to save the predictions to.") | ||
77 | + flags.DEFINE_string( | ||
78 | + "output_model_tgz", "", | ||
79 | + "If given, should be a filename with a .tgz extension, " | ||
80 | + "the model graph and checkpoint will be bundled in this " | ||
81 | + "gzip tar. This file can be uploaded to Kaggle for the " | ||
82 | + "top 10 participants.") | ||
83 | + flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.") | ||
84 | + | ||
85 | + # Other flags. | ||
86 | + flags.DEFINE_integer("batch_size", 512, | ||
87 | + "How many examples to process per batch.") | ||
88 | + flags.DEFINE_integer("num_readers", 1, | ||
89 | + "How many threads to use for reading input files.") | ||
90 | + | ||
91 | + | ||
92 | +def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ||
93 | + """Create an information line the submission file.""" | ||
94 | + batch_size = len(video_ids) | ||
95 | + for video_index in range(batch_size): | ||
96 | + video_prediction = predictions[video_index] | ||
97 | + if whitelisted_cls_mask is not None: | ||
98 | + # Whitelist classes. | ||
99 | + video_prediction *= whitelisted_cls_mask | ||
100 | + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] | ||
101 | + line = [(class_index, predictions[video_index][class_index]) | ||
102 | + for class_index in top_indices] | ||
103 | + line = sorted(line, key=lambda p: -p[1]) | ||
104 | + yield (video_ids[video_index] + "," + | ||
105 | + " ".join("%i %g" % (label, score) for (label, score) in line) + | ||
106 | + "\n").encode("utf8") | ||
107 | + | ||
108 | + | ||
109 | +def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): | ||
110 | + """Creates the section of the graph which reads the input data. | ||
111 | + | ||
112 | + Args: | ||
113 | + reader: A class which parses the input data. | ||
114 | + data_pattern: A 'glob' style path to the data files. | ||
115 | + batch_size: How many examples to process at a time. | ||
116 | + num_readers: How many I/O threads to use. | ||
117 | + | ||
118 | + Returns: | ||
119 | + A tuple containing the features tensor, labels tensor, and optionally a | ||
120 | + tensor containing the number of frames per video. The exact dimensions | ||
121 | + depend on the reader being used. | ||
122 | + | ||
123 | + Raises: | ||
124 | + IOError: If no files matching the given pattern were found. | ||
125 | + """ | ||
126 | + with tf.name_scope("input"): | ||
127 | + files = gfile.Glob(data_pattern) | ||
128 | + if not files: | ||
129 | + raise IOError("Unable to find input files. data_pattern='" + | ||
130 | + data_pattern + "'") | ||
131 | + logging.info("number of input files: " + str(len(files))) | ||
132 | + filename_queue = tf.train.string_input_producer(files, | ||
133 | + num_epochs=1, | ||
134 | + shuffle=False) | ||
135 | + examples_and_labels = [ | ||
136 | + reader.prepare_reader(filename_queue) for _ in range(num_readers) | ||
137 | + ] | ||
138 | + | ||
139 | + input_data_dict = (tf.train.batch_join(examples_and_labels, | ||
140 | + batch_size=batch_size, | ||
141 | + allow_smaller_final_batch=True, | ||
142 | + enqueue_many=True)) | ||
143 | + video_id_batch = input_data_dict["video_ids"] | ||
144 | + video_batch = input_data_dict["video_matrix"] | ||
145 | + num_frames_batch = input_data_dict["num_frames"] | ||
146 | + return video_id_batch, video_batch, num_frames_batch | ||
147 | + | ||
148 | + | ||
149 | +def get_segments(batch_video_mtx, batch_num_frames, segment_size): | ||
150 | + """Get segment-level inputs from frame-level features.""" | ||
151 | + video_batch_size = batch_video_mtx.shape[0] | ||
152 | + max_frame = batch_video_mtx.shape[1] | ||
153 | + feature_dim = batch_video_mtx.shape[-1] | ||
154 | + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size | ||
155 | + padded_segment_sizes *= segment_size | ||
156 | + segment_mask = ( | ||
157 | + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) | ||
158 | + | ||
159 | + # Segment bags. | ||
160 | + frame_bags = batch_video_mtx.reshape((-1, feature_dim)) | ||
161 | + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( | ||
162 | + (-1, segment_size, feature_dim)) | ||
163 | + | ||
164 | + # Segment num frames. | ||
165 | + segment_start_times = np.arange(0, max_frame, segment_size) | ||
166 | + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times | ||
167 | + num_segment_bags = num_segments.reshape((-1)) | ||
168 | + valid_segment_mask = num_segment_bags > 0 | ||
169 | + segment_num_frames = num_segment_bags[valid_segment_mask] | ||
170 | + segment_num_frames[segment_num_frames > segment_size] = segment_size | ||
171 | + | ||
172 | + max_segment_num = (max_frame + segment_size - 1) // segment_size | ||
173 | + video_idxs = np.tile( | ||
174 | + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) | ||
175 | + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) | ||
176 | + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) | ||
177 | + video_segment_ids = idx_bags[valid_segment_mask] | ||
178 | + | ||
179 | + return { | ||
180 | + "video_batch": segment_frames, | ||
181 | + "num_frames_batch": segment_num_frames, | ||
182 | + "video_segment_ids": video_segment_ids | ||
183 | + } | ||
184 | + | ||
185 | + | ||
186 | +def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ||
187 | + top_k): | ||
188 | + """Inference function.""" | ||
189 | + with tf.Session(config=tf.ConfigProto( | ||
190 | + allow_soft_placement=True)) as sess, gfile.Open(out_file_location, | ||
191 | + "w+") as out_file: | ||
192 | + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors( | ||
193 | + reader, data_pattern, batch_size) | ||
194 | + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model" | ||
195 | + checkpoint_file = os.path.join(train_dir, "inference_model", | ||
196 | + inference_model_name) | ||
197 | + if not gfile.Exists(checkpoint_file + ".meta"): | ||
198 | + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) | ||
199 | + meta_graph_location = checkpoint_file + ".meta" | ||
200 | + logging.info("loading meta-graph: " + meta_graph_location) | ||
201 | + | ||
202 | + if FLAGS.output_model_tgz: | ||
203 | + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar: | ||
204 | + for model_file in glob.glob(checkpoint_file + ".*"): | ||
205 | + tar.add(model_file, arcname=os.path.basename(model_file)) | ||
206 | + tar.add(os.path.join(train_dir, "model_flags.json"), | ||
207 | + arcname="model_flags.json") | ||
208 | + print("Tarred model onto " + FLAGS.output_model_tgz) | ||
209 | + with tf.device("/cpu:0"): | ||
210 | + saver = tf.train.import_meta_graph(meta_graph_location, | ||
211 | + clear_devices=True) | ||
212 | + logging.info("restoring variables from " + checkpoint_file) | ||
213 | + saver.restore(sess, checkpoint_file) | ||
214 | + input_tensor = tf.get_collection("input_batch_raw")[0] | ||
215 | + num_frames_tensor = tf.get_collection("num_frames")[0] | ||
216 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
217 | + | ||
218 | + # Workaround for num_epochs issue. | ||
219 | + def set_up_init_ops(variables): | ||
220 | + init_op_list = [] | ||
221 | + for variable in list(variables): | ||
222 | + if "train_input" in variable.name: | ||
223 | + init_op_list.append(tf.assign(variable, 1)) | ||
224 | + variables.remove(variable) | ||
225 | + init_op_list.append(tf.variables_initializer(variables)) | ||
226 | + return init_op_list | ||
227 | + | ||
228 | + sess.run( | ||
229 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | ||
230 | + | ||
231 | + coord = tf.train.Coordinator() | ||
232 | + threads = tf.train.start_queue_runners(sess=sess, coord=coord) | ||
233 | + num_examples_processed = 0 | ||
234 | + start_time = time.time() | ||
235 | + whitelisted_cls_mask = None | ||
236 | + if FLAGS.segment_labels: | ||
237 | + final_out_file = out_file | ||
238 | + out_file = tempfile.NamedTemporaryFile() | ||
239 | + logging.info( | ||
240 | + "Segment temp prediction output will be written to temp file: %s", | ||
241 | + out_file.name) | ||
242 | + if FLAGS.segment_label_ids_file: | ||
243 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
244 | + dtype=np.float32) | ||
245 | + segment_label_ids_file = FLAGS.segment_label_ids_file | ||
246 | + if segment_label_ids_file.startswith("http"): | ||
247 | + logging.info("Retrieving segment ID whitelist files from %s...", | ||
248 | + segment_label_ids_file) | ||
249 | + segment_label_ids_file, _ = urllib.request.urlretrieve( | ||
250 | + segment_label_ids_file) | ||
251 | + with tf.io.gfile.GFile(segment_label_ids_file) as fobj: | ||
252 | + for line in fobj: | ||
253 | + try: | ||
254 | + cls_id = int(line) | ||
255 | + whitelisted_cls_mask[cls_id] = 1. | ||
256 | + except ValueError: | ||
257 | + # Simply skip the non-integer line. | ||
258 | + continue | ||
259 | + | ||
260 | + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) | ||
261 | + | ||
262 | + try: | ||
263 | + while not coord.should_stop(): | ||
264 | + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( | ||
265 | + [video_id_batch, video_batch, num_frames_batch]) | ||
266 | + if FLAGS.segment_labels: | ||
267 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) | ||
268 | + video_segment_ids = results["video_segment_ids"] | ||
269 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | ||
270 | + video_id_batch_val = np.array([ | ||
271 | + "%s:%d" % (x.decode("utf8"), y) | ||
272 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | ||
273 | + ]) | ||
274 | + video_batch_val = results["video_batch"] | ||
275 | + num_frames_batch_val = results["num_frames_batch"] | ||
276 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | ||
277 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " | ||
278 | + "with correct segment_labels settings.") | ||
279 | + | ||
280 | + predictions_val, = sess.run([predictions_tensor], | ||
281 | + feed_dict={ | ||
282 | + input_tensor: video_batch_val, | ||
283 | + num_frames_tensor: num_frames_batch_val | ||
284 | + }) | ||
285 | + now = time.time() | ||
286 | + num_examples_processed += len(video_batch_val) | ||
287 | + elapsed_time = now - start_time | ||
288 | + logging.info("num examples processed: " + str(num_examples_processed) + | ||
289 | + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) + | ||
290 | + " examples/sec: %.2f" % | ||
291 | + (num_examples_processed / elapsed_time)) | ||
292 | + for line in format_lines(video_id_batch_val, predictions_val, top_k, | ||
293 | + whitelisted_cls_mask): | ||
294 | + out_file.write(line) | ||
295 | + out_file.flush() | ||
296 | + | ||
297 | + except tf.errors.OutOfRangeError: | ||
298 | + logging.info("Done with inference. The output file was written to " + | ||
299 | + out_file.name) | ||
300 | + finally: | ||
301 | + coord.request_stop() | ||
302 | + | ||
303 | + if FLAGS.segment_labels: | ||
304 | + # Re-read the file and do heap sort. | ||
305 | + # Create multiple heaps. | ||
306 | + logging.info("Post-processing segment predictions...") | ||
307 | + heaps = {} | ||
308 | + out_file.seek(0, 0) | ||
309 | + for line in out_file: | ||
310 | + segment_id, preds = line.decode("utf8").split(",") | ||
311 | + if segment_id == "VideoId": | ||
312 | + # Skip the headline. | ||
313 | + continue | ||
314 | + preds = preds.split(" ") | ||
315 | + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | ||
316 | + pred_cls_scores = [ | ||
317 | + float(preds[idx]) for idx in range(1, len(preds), 2) | ||
318 | + ] | ||
319 | + for cls, score in zip(pred_cls_ids, pred_cls_scores): | ||
320 | + if not whitelisted_cls_mask[cls]: | ||
321 | + # Skip non-whitelisted classes. | ||
322 | + continue | ||
323 | + if cls not in heaps: | ||
324 | + heaps[cls] = [] | ||
325 | + if len(heaps[cls]) >= FLAGS.segment_max_pred: | ||
326 | + heapq.heappushpop(heaps[cls], (score, segment_id)) | ||
327 | + else: | ||
328 | + heapq.heappush(heaps[cls], (score, segment_id)) | ||
329 | + logging.info("Writing sorted segment predictions to: %s", | ||
330 | + final_out_file.name) | ||
331 | + final_out_file.write("Class,Segments\n") | ||
332 | + for cls, cls_heap in heaps.items(): | ||
333 | + cls_heap.sort(key=lambda x: x[0], reverse=True) | ||
334 | + final_out_file.write("%d,%s\n" % | ||
335 | + (cls, " ".join([x[1] for x in cls_heap]))) | ||
336 | + final_out_file.close() | ||
337 | + | ||
338 | + out_file.close() | ||
339 | + | ||
340 | + coord.join(threads) | ||
341 | + sess.close() | ||
342 | + | ||
343 | + | ||
344 | +def main(unused_argv): | ||
345 | + logging.set_verbosity(tf.logging.INFO) | ||
346 | + if FLAGS.input_model_tgz: | ||
347 | + if FLAGS.train_dir: | ||
348 | + raise ValueError("You cannot supply --train_dir if supplying " | ||
349 | + "--input_model_tgz") | ||
350 | + # Untar. | ||
351 | + if not os.path.exists(FLAGS.untar_model_dir): | ||
352 | + os.makedirs(FLAGS.untar_model_dir) | ||
353 | + tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) | ||
354 | + FLAGS.train_dir = FLAGS.untar_model_dir | ||
355 | + | ||
356 | + flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") | ||
357 | + if not file_io.file_exists(flags_dict_file): | ||
358 | + raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) | ||
359 | + flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read()) | ||
360 | + | ||
361 | + # convert feature_names and feature_sizes to lists of values | ||
362 | + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( | ||
363 | + flags_dict["feature_names"], flags_dict["feature_sizes"]) | ||
364 | + | ||
365 | + if flags_dict["frame_features"]: | ||
366 | + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, | ||
367 | + feature_sizes=feature_sizes) | ||
368 | + else: | ||
369 | + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, | ||
370 | + feature_sizes=feature_sizes) | ||
371 | + | ||
372 | + if not FLAGS.output_file: | ||
373 | + raise ValueError("'output_file' was not specified. " | ||
374 | + "Unable to continue with inference.") | ||
375 | + | ||
376 | + if not FLAGS.input_data_pattern: | ||
377 | + raise ValueError("'input_data_pattern' was not specified. " | ||
378 | + "Unable to continue with inference.") | ||
379 | + | ||
380 | + inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern, | ||
381 | + FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) | ||
382 | + | ||
383 | + | ||
384 | +if __name__ == "__main__": | ||
385 | + app.run() |
yt8m/inference_per_segment.py
0 → 100644
1 | +# Copyright 2017 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Binary for generating predictions over a set of videos.""" | ||
15 | + | ||
16 | +from __future__ import print_function | ||
17 | + | ||
18 | +import glob | ||
19 | +import heapq | ||
20 | +import json | ||
21 | +import os | ||
22 | +import tarfile | ||
23 | +import tempfile | ||
24 | +import time | ||
25 | +import numpy as np | ||
26 | + | ||
27 | +import readers | ||
28 | +from six.moves import urllib | ||
29 | +import tensorflow as tf | ||
30 | +from tensorflow import app | ||
31 | +from tensorflow import flags | ||
32 | +from tensorflow import gfile | ||
33 | +from tensorflow import logging | ||
34 | +from tensorflow.python.lib.io import file_io | ||
35 | +import utils | ||
36 | +from collections import Counter | ||
37 | + | ||
38 | +FLAGS = flags.FLAGS | ||
39 | + | ||
40 | +if __name__ == "__main__": | ||
41 | + # Input | ||
42 | + flags.DEFINE_string( | ||
43 | + "train_dir", "", "The directory to load the model files from. We assume " | ||
44 | + "that you have already run eval.py onto this, such that " | ||
45 | + "inference_model.* files already exist.") | ||
46 | + flags.DEFINE_string( | ||
47 | + "input_data_pattern", "", | ||
48 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
49 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
50 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
51 | + flags.DEFINE_string( | ||
52 | + "input_model_tgz", "", | ||
53 | + "If given, must be path to a .tgz file that was written " | ||
54 | + "by this binary using flag --output_model_tgz. In this " | ||
55 | + "case, the .tgz file will be untarred to " | ||
56 | + "--untar_model_dir and the model will be used for " | ||
57 | + "inference.") | ||
58 | + flags.DEFINE_string( | ||
59 | + "untar_model_dir", "/tmp/yt8m-model", | ||
60 | + "If --input_model_tgz is given, then this directory will " | ||
61 | + "be created and the contents of the .tgz file will be " | ||
62 | + "untarred here.") | ||
63 | + flags.DEFINE_bool( | ||
64 | + "segment_labels", False, | ||
65 | + "If set, then --input_data_pattern must be frame-level features (but with" | ||
66 | + " segment_labels). Otherwise, --input_data_pattern must be aggregated " | ||
67 | + "video-level features. The model must also be set appropriately (i.e. to " | ||
68 | + "read 3D batches VS 4D batches.") | ||
69 | + flags.DEFINE_integer("segment_max_pred", 100000, | ||
70 | + "Limit total number of segment outputs per entity.") | ||
71 | + flags.DEFINE_string( | ||
72 | + "segment_label_ids_file", | ||
73 | + "https://raw.githubusercontent.com/google/youtube-8m/master/segment_label_ids.csv", | ||
74 | + "The file that contains the segment label ids.") | ||
75 | + | ||
76 | + # Output | ||
77 | + flags.DEFINE_string("output_file", "", "The file to save the predictions to.") | ||
78 | + flags.DEFINE_string( | ||
79 | + "output_model_tgz", "", | ||
80 | + "If given, should be a filename with a .tgz extension, " | ||
81 | + "the model graph and checkpoint will be bundled in this " | ||
82 | + "gzip tar. This file can be uploaded to Kaggle for the " | ||
83 | + "top 10 participants.") | ||
84 | + flags.DEFINE_integer("top_k", 1, "How many predictions to output per video.") | ||
85 | + | ||
86 | + # Other flags. | ||
87 | + flags.DEFINE_integer("batch_size", 512, | ||
88 | + "How many examples to process per batch.") | ||
89 | + flags.DEFINE_integer("num_readers", 1, | ||
90 | + "How many threads to use for reading input files.") | ||
91 | + | ||
92 | + | ||
93 | +def format_lines(video_ids, predictions, top_k, whitelisted_cls_mask=None): | ||
94 | + """Create an information line the submission file.""" | ||
95 | + batch_size = len(video_ids) | ||
96 | + for video_index in range(batch_size): | ||
97 | + video_prediction = predictions[video_index] | ||
98 | + if whitelisted_cls_mask is not None: | ||
99 | + # Whitelist classes. | ||
100 | + video_prediction *= whitelisted_cls_mask | ||
101 | + top_indices = np.argpartition(video_prediction, -top_k)[-top_k:] | ||
102 | + line = [(class_index, predictions[video_index][class_index]) | ||
103 | + for class_index in top_indices] | ||
104 | + line = sorted(line, key=lambda p: -p[1]) | ||
105 | + yield (video_ids[video_index] + "," + | ||
106 | + " ".join("%i %g" % (label, score) for (label, score) in line) + | ||
107 | + "\n").encode("utf8") | ||
108 | + | ||
109 | + | ||
110 | +def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): | ||
111 | + """Creates the section of the graph which reads the input data. | ||
112 | + | ||
113 | + Args: | ||
114 | + reader: A class which parses the input data. | ||
115 | + data_pattern: A 'glob' style path to the data files. | ||
116 | + batch_size: How many examples to process at a time. | ||
117 | + num_readers: How many I/O threads to use. | ||
118 | + | ||
119 | + Returns: | ||
120 | + A tuple containing the features tensor, labels tensor, and optionally a | ||
121 | + tensor containing the number of frames per video. The exact dimensions | ||
122 | + depend on the reader being used. | ||
123 | + | ||
124 | + Raises: | ||
125 | + IOError: If no files matching the given pattern were found. | ||
126 | + """ | ||
127 | + with tf.name_scope("input"): | ||
128 | + files = gfile.Glob(data_pattern) | ||
129 | + if not files: | ||
130 | + raise IOError("Unable to find input files. data_pattern='" + | ||
131 | + data_pattern + "'") | ||
132 | + logging.info("number of input files: " + str(len(files))) | ||
133 | + filename_queue = tf.train.string_input_producer(files, | ||
134 | + num_epochs=1, | ||
135 | + shuffle=False) | ||
136 | + examples_and_labels = [ | ||
137 | + reader.prepare_reader(filename_queue) for _ in range(num_readers) | ||
138 | + ] | ||
139 | + | ||
140 | + input_data_dict = (tf.train.batch_join(examples_and_labels, | ||
141 | + batch_size=batch_size, | ||
142 | + allow_smaller_final_batch=True, | ||
143 | + enqueue_many=True)) | ||
144 | + video_id_batch = input_data_dict["video_ids"] | ||
145 | + video_batch = input_data_dict["video_matrix"] | ||
146 | + num_frames_batch = input_data_dict["num_frames"] | ||
147 | + return video_id_batch, video_batch, num_frames_batch | ||
148 | + | ||
149 | + | ||
150 | +def get_segments(batch_video_mtx, batch_num_frames, segment_size): | ||
151 | + """Get segment-level inputs from frame-level features.""" | ||
152 | + video_batch_size = batch_video_mtx.shape[0] | ||
153 | + max_frame = batch_video_mtx.shape[1] | ||
154 | + feature_dim = batch_video_mtx.shape[-1] | ||
155 | + padded_segment_sizes = (batch_num_frames + segment_size - 1) // segment_size | ||
156 | + padded_segment_sizes *= segment_size | ||
157 | + segment_mask = ( | ||
158 | + 0 < (padded_segment_sizes[:, np.newaxis] - np.arange(0, max_frame))) | ||
159 | + | ||
160 | + # Segment bags. | ||
161 | + frame_bags = batch_video_mtx.reshape((-1, feature_dim)) | ||
162 | + segment_frames = frame_bags[segment_mask.reshape(-1)].reshape( | ||
163 | + (-1, segment_size, feature_dim)) | ||
164 | + | ||
165 | + # Segment num frames. | ||
166 | + segment_start_times = np.arange(0, max_frame, segment_size) | ||
167 | + num_segments = batch_num_frames[:, np.newaxis] - segment_start_times | ||
168 | + num_segment_bags = num_segments.reshape((-1)) | ||
169 | + valid_segment_mask = num_segment_bags > 0 | ||
170 | + segment_num_frames = num_segment_bags[valid_segment_mask] | ||
171 | + segment_num_frames[segment_num_frames > segment_size] = segment_size | ||
172 | + | ||
173 | + max_segment_num = (max_frame + segment_size - 1) // segment_size | ||
174 | + video_idxs = np.tile( | ||
175 | + np.arange(0, video_batch_size)[:, np.newaxis], [1, max_segment_num]) | ||
176 | + segment_idxs = np.tile(segment_start_times, [video_batch_size, 1]) | ||
177 | + idx_bags = np.stack([video_idxs, segment_idxs], axis=-1).reshape((-1, 2)) | ||
178 | + video_segment_ids = idx_bags[valid_segment_mask] | ||
179 | + | ||
180 | + return { | ||
181 | + "video_batch": segment_frames, | ||
182 | + "num_frames_batch": segment_num_frames, | ||
183 | + "video_segment_ids": video_segment_ids | ||
184 | + } | ||
185 | + | ||
186 | + | ||
187 | +def inference(reader, train_dir, data_pattern, out_file_location, batch_size, | ||
188 | + top_k): | ||
189 | + """Inference function.""" | ||
190 | + with tf.Session(config=tf.ConfigProto( | ||
191 | + allow_soft_placement=True)) as sess, gfile.Open(out_file_location, | ||
192 | + "w+") as out_file: | ||
193 | + video_id_batch, video_batch, num_frames_batch = get_input_data_tensors( | ||
194 | + reader, data_pattern, batch_size) | ||
195 | + inference_model_name = "segment_inference_model" if FLAGS.segment_labels else "inference_model" | ||
196 | + checkpoint_file = os.path.join(train_dir, "inference_model", | ||
197 | + inference_model_name) | ||
198 | + if not gfile.Exists(checkpoint_file + ".meta"): | ||
199 | + raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) | ||
200 | + meta_graph_location = checkpoint_file + ".meta" | ||
201 | + logging.info("loading meta-graph: " + meta_graph_location) | ||
202 | + | ||
203 | + if FLAGS.output_model_tgz: | ||
204 | + with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar: | ||
205 | + for model_file in glob.glob(checkpoint_file + ".*"): | ||
206 | + tar.add(model_file, arcname=os.path.basename(model_file)) | ||
207 | + tar.add(os.path.join(train_dir, "model_flags.json"), | ||
208 | + arcname="model_flags.json") | ||
209 | + print("Tarred model onto " + FLAGS.output_model_tgz) | ||
210 | + with tf.device("/cpu:0"): | ||
211 | + saver = tf.train.import_meta_graph(meta_graph_location, | ||
212 | + clear_devices=True) | ||
213 | + logging.info("restoring variables from " + checkpoint_file) | ||
214 | + saver.restore(sess, checkpoint_file) | ||
215 | + input_tensor = tf.get_collection("input_batch_raw")[0] | ||
216 | + num_frames_tensor = tf.get_collection("num_frames")[0] | ||
217 | + predictions_tensor = tf.get_collection("predictions")[0] | ||
218 | + | ||
219 | + # Workaround for num_epochs issue. | ||
220 | + def set_up_init_ops(variables): | ||
221 | + init_op_list = [] | ||
222 | + for variable in list(variables): | ||
223 | + if "train_input" in variable.name: | ||
224 | + init_op_list.append(tf.assign(variable, 1)) | ||
225 | + variables.remove(variable) | ||
226 | + init_op_list.append(tf.variables_initializer(variables)) | ||
227 | + return init_op_list | ||
228 | + | ||
229 | + sess.run( | ||
230 | + set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))) | ||
231 | + | ||
232 | + coord = tf.train.Coordinator() | ||
233 | + threads = tf.train.start_queue_runners(sess=sess, coord=coord) | ||
234 | + num_examples_processed = 0 | ||
235 | + start_time = time.time() | ||
236 | + whitelisted_cls_mask = None | ||
237 | + if FLAGS.segment_labels: | ||
238 | + final_out_file = out_file | ||
239 | + out_file = tempfile.NamedTemporaryFile() | ||
240 | + logging.info( | ||
241 | + "Segment temp prediction output will be written to temp file: %s", | ||
242 | + out_file.name) | ||
243 | + if FLAGS.segment_label_ids_file: | ||
244 | + whitelisted_cls_mask = np.zeros((predictions_tensor.get_shape()[-1],), | ||
245 | + dtype=np.float32) | ||
246 | + segment_label_ids_file = FLAGS.segment_label_ids_file | ||
247 | + if segment_label_ids_file.startswith("http"): | ||
248 | + logging.info("Retrieving segment ID whitelist files from %s...", | ||
249 | + segment_label_ids_file) | ||
250 | + segment_label_ids_file, _ = urllib.request.urlretrieve( | ||
251 | + segment_label_ids_file) | ||
252 | + with tf.io.gfile.GFile(segment_label_ids_file) as fobj: | ||
253 | + for line in fobj: | ||
254 | + try: | ||
255 | + cls_id = int(line) | ||
256 | + whitelisted_cls_mask[cls_id] = 1. | ||
257 | + except ValueError: | ||
258 | + # Simply skip the non-integer line. | ||
259 | + continue | ||
260 | + | ||
261 | + out_file.write(u"VideoId,LabelConfidencePairs\n".encode("utf8")) | ||
262 | + | ||
263 | + try: | ||
264 | + while not coord.should_stop(): | ||
265 | + video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run( | ||
266 | + [video_id_batch, video_batch, num_frames_batch]) | ||
267 | + if FLAGS.segment_labels: | ||
268 | + results = get_segments(video_batch_val, num_frames_batch_val, 5) | ||
269 | + video_segment_ids = results["video_segment_ids"] | ||
270 | + video_id_batch_val = video_id_batch_val[video_segment_ids[:, 0]] | ||
271 | + video_id_batch_val = np.array([ | ||
272 | + "%s:%d" % (x.decode("utf8"), y) | ||
273 | + for x, y in zip(video_id_batch_val, video_segment_ids[:, 1]) | ||
274 | + ]) | ||
275 | + video_batch_val = results["video_batch"] | ||
276 | + num_frames_batch_val = results["num_frames_batch"] | ||
277 | + if input_tensor.get_shape()[1] != video_batch_val.shape[1]: | ||
278 | + raise ValueError("max_frames mismatch. Please re-run the eval.py " | ||
279 | + "with correct segment_labels settings.") | ||
280 | + | ||
281 | + predictions_val, = sess.run([predictions_tensor], | ||
282 | + feed_dict={ | ||
283 | + input_tensor: video_batch_val, | ||
284 | + num_frames_tensor: num_frames_batch_val | ||
285 | + }) | ||
286 | + now = time.time() | ||
287 | + num_examples_processed += len(video_batch_val) | ||
288 | + elapsed_time = now - start_time | ||
289 | + logging.info("num examples processed: " + str(num_examples_processed) + | ||
290 | + " elapsed seconds: " + "{0:.2f}".format(elapsed_time) + | ||
291 | + " examples/sec: %.2f" % | ||
292 | + (num_examples_processed / elapsed_time)) | ||
293 | + for line in format_lines(video_id_batch_val, predictions_val, top_k, | ||
294 | + whitelisted_cls_mask): | ||
295 | + out_file.write(line) | ||
296 | + out_file.flush() | ||
297 | + | ||
298 | + except tf.errors.OutOfRangeError: | ||
299 | + logging.info("Done with inference. The output file was written to " + | ||
300 | + out_file.name) | ||
301 | + finally: | ||
302 | + coord.request_stop() | ||
303 | + | ||
304 | + if FLAGS.segment_labels: | ||
305 | + # Re-read the file and do heap sort. | ||
306 | + # Create multiple heaps. | ||
307 | + logging.info("Post-processing segment predictions...") | ||
308 | + segment_id_list = [] | ||
309 | + segment_classes = [] | ||
310 | + cls_result_arr = [] | ||
311 | + out_file.seek(0, 0) | ||
312 | + for line in out_file: | ||
313 | + segment_id, preds = line.decode("utf8").split(",") | ||
314 | + if segment_id == "VideoId": | ||
315 | + # Skip the headline. | ||
316 | + continue | ||
317 | + | ||
318 | + preds = preds.split(" ") | ||
319 | + pred_cls_ids = [int(preds[idx]) for idx in range(0, len(preds), 2)] | ||
320 | + # ======================================= | ||
321 | + segment_id = str(segment_id.split(":")[0]) | ||
322 | + if segment_id not in segment_id_list: | ||
323 | + segment_id_list.append(str(segment_id)) | ||
324 | + segment_classes.append("") | ||
325 | + | ||
326 | + index = segment_id_list.index(segment_id) | ||
327 | + for classes in pred_cls_ids: | ||
328 | + segment_classes[index] = str(segment_classes[index]) + str( | ||
329 | + classes) + " " # append classes from new segment | ||
330 | + | ||
331 | + for segs, item in zip(segment_id_list, segment_classes): | ||
332 | + print('====== R E C O R D ======') | ||
333 | + cls_arr = item.split(" ")[:-1] | ||
334 | + | ||
335 | + cls_arr = list(map(int, cls_arr)) | ||
336 | + cls_arr = sorted(cls_arr) | ||
337 | + | ||
338 | + result_string = "" | ||
339 | + | ||
340 | + temp = Counter(cls_arr) | ||
341 | + for item in temp: | ||
342 | + result_string = result_string + str(item) + ":" + str(temp[item]) + "," | ||
343 | + | ||
344 | + cls_result_arr.append(result_string[:-1]) | ||
345 | + logging.info(segs + " : " + result_string[:-1]) | ||
346 | + # ======================================= | ||
347 | + final_out_file.write("vid_id,seg_classes\n") | ||
348 | + for seg_id, class_indcies in zip(segment_id_list, cls_result_arr): | ||
349 | + final_out_file.write("%s,%s\n" % (seg_id, str(class_indcies))) | ||
350 | + final_out_file.close() | ||
351 | + | ||
352 | + out_file.close() | ||
353 | + | ||
354 | + coord.join(threads) | ||
355 | + sess.close() | ||
356 | + | ||
357 | + | ||
358 | +def main(unused_argv): | ||
359 | + logging.set_verbosity(tf.logging.INFO) | ||
360 | + if FLAGS.input_model_tgz: | ||
361 | + if FLAGS.train_dir: | ||
362 | + raise ValueError("You cannot supply --train_dir if supplying " | ||
363 | + "--input_model_tgz") | ||
364 | + # Untar. | ||
365 | + if not os.path.exists(FLAGS.untar_model_dir): | ||
366 | + os.makedirs(FLAGS.untar_model_dir) | ||
367 | + tarfile.open(FLAGS.input_model_tgz).extractall(FLAGS.untar_model_dir) | ||
368 | + FLAGS.train_dir = FLAGS.untar_model_dir | ||
369 | + | ||
370 | + flags_dict_file = os.path.join(FLAGS.train_dir, "model_flags.json") | ||
371 | + if not file_io.file_exists(flags_dict_file): | ||
372 | + raise IOError("Cannot find %s. Did you run eval.py?" % flags_dict_file) | ||
373 | + flags_dict = json.loads(file_io.FileIO(flags_dict_file, "r").read()) | ||
374 | + | ||
375 | + # convert feature_names and feature_sizes to lists of values | ||
376 | + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( | ||
377 | + flags_dict["feature_names"], flags_dict["feature_sizes"]) | ||
378 | + | ||
379 | + if flags_dict["frame_features"]: | ||
380 | + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, | ||
381 | + feature_sizes=feature_sizes) | ||
382 | + else: | ||
383 | + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, | ||
384 | + feature_sizes=feature_sizes) | ||
385 | + | ||
386 | + if not FLAGS.output_file: | ||
387 | + raise ValueError("'output_file' was not specified. " | ||
388 | + "Unable to continue with inference.") | ||
389 | + | ||
390 | + if not FLAGS.input_data_pattern: | ||
391 | + raise ValueError("'input_data_pattern' was not specified. " | ||
392 | + "Unable to continue with inference.") | ||
393 | + | ||
394 | + inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern, | ||
395 | + FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) | ||
396 | + | ||
397 | + | ||
398 | +if __name__ == "__main__": | ||
399 | + app.run() |
yt8m/losses.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Provides definitions for non-regularized training or test losses.""" | ||
15 | + | ||
16 | +import tensorflow as tf | ||
17 | + | ||
18 | + | ||
19 | +class BaseLoss(object): | ||
20 | + """Inherit from this class when implementing new losses.""" | ||
21 | + | ||
22 | + def calculate_loss(self, unused_predictions, unused_labels, **unused_params): | ||
23 | + """Calculates the average loss of the examples in a mini-batch. | ||
24 | + | ||
25 | + Args: | ||
26 | + unused_predictions: a 2-d tensor storing the prediction scores, in which | ||
27 | + each row represents a sample in the mini-batch and each column | ||
28 | + represents a class. | ||
29 | + unused_labels: a 2-d tensor storing the labels, which has the same shape | ||
30 | + as the unused_predictions. The labels must be in the range of 0 and 1. | ||
31 | + unused_params: loss specific parameters. | ||
32 | + | ||
33 | + Returns: | ||
34 | + A scalar loss tensor. | ||
35 | + """ | ||
36 | + raise NotImplementedError() | ||
37 | + | ||
38 | + | ||
39 | +class CrossEntropyLoss(BaseLoss): | ||
40 | + """Calculate the cross entropy loss between the predictions and labels.""" | ||
41 | + | ||
42 | + def calculate_loss(self, | ||
43 | + predictions, | ||
44 | + labels, | ||
45 | + label_weights=None, | ||
46 | + **unused_params): | ||
47 | + with tf.name_scope("loss_xent"): | ||
48 | + epsilon = 1e-5 | ||
49 | + float_labels = tf.cast(labels, tf.float32) | ||
50 | + cross_entropy_loss = float_labels * tf.math.log(predictions + epsilon) + ( | ||
51 | + 1 - float_labels) * tf.math.log(1 - predictions + epsilon) | ||
52 | + cross_entropy_loss = tf.negative(cross_entropy_loss) | ||
53 | + if label_weights is not None: | ||
54 | + cross_entropy_loss *= label_weights | ||
55 | + return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1)) | ||
56 | + | ||
57 | + | ||
58 | +class HingeLoss(BaseLoss): | ||
59 | + """Calculate the hinge loss between the predictions and labels. | ||
60 | + | ||
61 | + Note the subgradient is used in the backpropagation, and thus the optimization | ||
62 | + may converge slower. The predictions trained by the hinge loss are between -1 | ||
63 | + and +1. | ||
64 | + """ | ||
65 | + | ||
66 | + def calculate_loss(self, predictions, labels, b=1.0, **unused_params): | ||
67 | + with tf.name_scope("loss_hinge"): | ||
68 | + float_labels = tf.cast(labels, tf.float32) | ||
69 | + all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32) | ||
70 | + all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32) | ||
71 | + sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones) | ||
72 | + hinge_loss = tf.maximum( | ||
73 | + all_zeros, | ||
74 | + tf.scalar_mul(b, all_ones) - sign_labels * predictions) | ||
75 | + return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1)) | ||
76 | + | ||
77 | + | ||
78 | +class SoftmaxLoss(BaseLoss): | ||
79 | + """Calculate the softmax loss between the predictions and labels. | ||
80 | + | ||
81 | + The function calculates the loss in the following way: first we feed the | ||
82 | + predictions to the softmax activation function and then we calculate | ||
83 | + the minus linear dot product between the logged softmax activations and the | ||
84 | + normalized ground truth label. | ||
85 | + | ||
86 | + It is an extension to the one-hot label. It allows for more than one positive | ||
87 | + labels for each sample. | ||
88 | + """ | ||
89 | + | ||
90 | + def calculate_loss(self, predictions, labels, **unused_params): | ||
91 | + with tf.name_scope("loss_softmax"): | ||
92 | + epsilon = 10e-8 | ||
93 | + float_labels = tf.cast(labels, tf.float32) | ||
94 | + # l1 normalization (labels are no less than 0) | ||
95 | + label_rowsum = tf.maximum(tf.reduce_sum(float_labels, 1, keep_dims=True), | ||
96 | + epsilon) | ||
97 | + norm_float_labels = tf.div(float_labels, label_rowsum) | ||
98 | + softmax_outputs = tf.nn.softmax(predictions) | ||
99 | + softmax_loss = tf.negative( | ||
100 | + tf.reduce_sum(tf.multiply(norm_float_labels, tf.log(softmax_outputs)), | ||
101 | + 1)) | ||
102 | + return tf.reduce_mean(softmax_loss) |
yt8m/mean_average_precision_calculator.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Calculate the mean average precision. | ||
15 | + | ||
16 | +It provides an interface for calculating mean average precision | ||
17 | +for an entire list or the top-n ranked items. | ||
18 | + | ||
19 | +Example usages: | ||
20 | +We first call the function accumulate many times to process parts of the ranked | ||
21 | +list. After processing all the parts, we call peek_map_at_n | ||
22 | +to calculate the mean average precision. | ||
23 | + | ||
24 | +``` | ||
25 | +import random | ||
26 | + | ||
27 | +p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)]) | ||
28 | +a = np.array([[random.choice([0, 1]) for _ in xrange(50)] | ||
29 | + for _ in xrange(1000)]) | ||
30 | + | ||
31 | +# mean average precision for 50 classes. | ||
32 | +calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator( | ||
33 | + num_class=50) | ||
34 | +calculator.accumulate(p, a) | ||
35 | +aps = calculator.peek_map_at_n() | ||
36 | +``` | ||
37 | +""" | ||
38 | + | ||
39 | +import average_precision_calculator | ||
40 | + | ||
41 | + | ||
42 | +class MeanAveragePrecisionCalculator(object): | ||
43 | + """This class is to calculate mean average precision.""" | ||
44 | + | ||
45 | + def __init__(self, num_class, filter_empty_classes=True, top_n=None): | ||
46 | + """Construct a calculator to calculate the (macro) average precision. | ||
47 | + | ||
48 | + Args: | ||
49 | + num_class: A positive Integer specifying the number of classes. | ||
50 | + filter_empty_classes: whether to filter classes without any positives. | ||
51 | + top_n: A positive Integer specifying the average precision at n, or None | ||
52 | + to use all provided data points. | ||
53 | + | ||
54 | + Raises: | ||
55 | + ValueError: An error occurred when num_class is not a positive integer; | ||
56 | + or the top_n_array is not a list of positive integers. | ||
57 | + """ | ||
58 | + if not isinstance(num_class, int) or num_class <= 1: | ||
59 | + raise ValueError("num_class must be a positive integer.") | ||
60 | + | ||
61 | + self._ap_calculators = [] # member of AveragePrecisionCalculator | ||
62 | + self._num_class = num_class # total number of classes | ||
63 | + self._filter_empty_classes = filter_empty_classes | ||
64 | + for _ in range(num_class): | ||
65 | + self._ap_calculators.append( | ||
66 | + average_precision_calculator.AveragePrecisionCalculator(top_n=top_n)) | ||
67 | + | ||
68 | + def accumulate(self, predictions, actuals, num_positives=None): | ||
69 | + """Accumulate the predictions and their ground truth labels. | ||
70 | + | ||
71 | + Args: | ||
72 | + predictions: A list of lists storing the prediction scores. The outer | ||
73 | + dimension corresponds to classes. | ||
74 | + actuals: A list of lists storing the ground truth labels. The dimensions | ||
75 | + should correspond to the predictions input. Any value larger than 0 will | ||
76 | + be treated as positives, otherwise as negatives. | ||
77 | + num_positives: If provided, it is a list of numbers representing the | ||
78 | + number of true positives for each class. If not provided, the number of | ||
79 | + true positives will be inferred from the 'actuals' array. | ||
80 | + | ||
81 | + Raises: | ||
82 | + ValueError: An error occurred when the shape of predictions and actuals | ||
83 | + does not match. | ||
84 | + """ | ||
85 | + if not num_positives: | ||
86 | + num_positives = [None for i in range(self._num_class)] | ||
87 | + | ||
88 | + calculators = self._ap_calculators | ||
89 | + for i in range(self._num_class): | ||
90 | + calculators[i].accumulate(predictions[i], actuals[i], num_positives[i]) | ||
91 | + | ||
92 | + def clear(self): | ||
93 | + for calculator in self._ap_calculators: | ||
94 | + calculator.clear() | ||
95 | + | ||
96 | + def is_empty(self): | ||
97 | + return ([calculator.heap_size for calculator in self._ap_calculators | ||
98 | + ] == [0 for _ in range(self._num_class)]) | ||
99 | + | ||
100 | + def peek_map_at_n(self): | ||
101 | + """Peek the non-interpolated mean average precision at n. | ||
102 | + | ||
103 | + Returns: | ||
104 | + An array of non-interpolated average precision at n (default 0) for each | ||
105 | + class. | ||
106 | + """ | ||
107 | + aps = [] | ||
108 | + for i in range(self._num_class): | ||
109 | + if (not self._filter_empty_classes or | ||
110 | + self._ap_calculators[i].num_accumulated_positives > 0): | ||
111 | + ap = self._ap_calculators[i].peek_ap_at_n() | ||
112 | + aps.append(ap) | ||
113 | + return aps |
yt8m/model_utils.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains a collection of util functions for model construction.""" | ||
15 | +import numpy | ||
16 | +import tensorflow as tf | ||
17 | +from tensorflow import logging | ||
18 | +from tensorflow import flags | ||
19 | +import tensorflow.contrib.slim as slim | ||
20 | + | ||
21 | + | ||
22 | +def SampleRandomSequence(model_input, num_frames, num_samples): | ||
23 | + """Samples a random sequence of frames of size num_samples. | ||
24 | + | ||
25 | + Args: | ||
26 | + model_input: A tensor of size batch_size x max_frames x feature_size | ||
27 | + num_frames: A tensor of size batch_size x 1 | ||
28 | + num_samples: A scalar | ||
29 | + | ||
30 | + Returns: | ||
31 | + `model_input`: A tensor of size batch_size x num_samples x feature_size | ||
32 | + """ | ||
33 | + | ||
34 | + batch_size = tf.shape(model_input)[0] | ||
35 | + frame_index_offset = tf.tile(tf.expand_dims(tf.range(num_samples), 0), | ||
36 | + [batch_size, 1]) | ||
37 | + max_start_frame_index = tf.maximum(num_frames - num_samples, 0) | ||
38 | + start_frame_index = tf.cast( | ||
39 | + tf.multiply(tf.random_uniform([batch_size, 1]), | ||
40 | + tf.cast(max_start_frame_index + 1, tf.float32)), tf.int32) | ||
41 | + frame_index = tf.minimum(start_frame_index + frame_index_offset, | ||
42 | + tf.cast(num_frames - 1, tf.int32)) | ||
43 | + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1), | ||
44 | + [1, num_samples]) | ||
45 | + index = tf.stack([batch_index, frame_index], 2) | ||
46 | + return tf.gather_nd(model_input, index) | ||
47 | + | ||
48 | + | ||
49 | +def SampleRandomFrames(model_input, num_frames, num_samples): | ||
50 | + """Samples a random set of frames of size num_samples. | ||
51 | + | ||
52 | + Args: | ||
53 | + model_input: A tensor of size batch_size x max_frames x feature_size | ||
54 | + num_frames: A tensor of size batch_size x 1 | ||
55 | + num_samples: A scalar | ||
56 | + | ||
57 | + Returns: | ||
58 | + `model_input`: A tensor of size batch_size x num_samples x feature_size | ||
59 | + """ | ||
60 | + batch_size = tf.shape(model_input)[0] | ||
61 | + frame_index = tf.cast( | ||
62 | + tf.multiply(tf.random_uniform([batch_size, num_samples]), | ||
63 | + tf.tile(tf.cast(num_frames, tf.float32), [1, num_samples])), | ||
64 | + tf.int32) | ||
65 | + batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1), | ||
66 | + [1, num_samples]) | ||
67 | + index = tf.stack([batch_index, frame_index], 2) | ||
68 | + return tf.gather_nd(model_input, index) | ||
69 | + | ||
70 | + | ||
71 | +def FramePooling(frames, method, **unused_params): | ||
72 | + """Pools over the frames of a video. | ||
73 | + | ||
74 | + Args: | ||
75 | + frames: A tensor with shape [batch_size, num_frames, feature_size]. | ||
76 | + method: "average", "max", "attention", or "none". | ||
77 | + | ||
78 | + Returns: | ||
79 | + A tensor with shape [batch_size, feature_size] for average, max, or | ||
80 | + attention pooling. A tensor with shape [batch_size*num_frames, feature_size] | ||
81 | + for none pooling. | ||
82 | + | ||
83 | + Raises: | ||
84 | + ValueError: if method is other than "average", "max", "attention", or | ||
85 | + "none". | ||
86 | + """ | ||
87 | + if method == "average": | ||
88 | + return tf.reduce_mean(frames, 1) | ||
89 | + elif method == "max": | ||
90 | + return tf.reduce_max(frames, 1) | ||
91 | + elif method == "none": | ||
92 | + feature_size = frames.shape_as_list()[2] | ||
93 | + return tf.reshape(frames, [-1, feature_size]) | ||
94 | + else: | ||
95 | + raise ValueError("Unrecognized pooling method: %s" % method) |
yt8m/models.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains the base class for models.""" | ||
15 | + | ||
16 | + | ||
17 | +class BaseModel(object): | ||
18 | + """Inherit from this class when implementing new models.""" | ||
19 | + | ||
20 | + def create_model(self, unused_model_input, **unused_params): | ||
21 | + raise NotImplementedError() |
yt8m/readers.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Provides readers configured for different datasets.""" | ||
15 | + | ||
16 | +import tensorflow as tf | ||
17 | +import utils | ||
18 | + | ||
19 | + | ||
20 | +def resize_axis(tensor, axis, new_size, fill_value=0): | ||
21 | + """Truncates or pads a tensor to new_size on on a given axis. | ||
22 | + | ||
23 | + Truncate or extend tensor such that tensor.shape[axis] == new_size. If the | ||
24 | + size increases, the padding will be performed at the end, using fill_value. | ||
25 | + | ||
26 | + Args: | ||
27 | + tensor: The tensor to be resized. | ||
28 | + axis: An integer representing the dimension to be sliced. | ||
29 | + new_size: An integer or 0d tensor representing the new value for | ||
30 | + tensor.shape[axis]. | ||
31 | + fill_value: Value to use to fill any new entries in the tensor. Will be cast | ||
32 | + to the type of tensor. | ||
33 | + | ||
34 | + Returns: | ||
35 | + The resized tensor. | ||
36 | + """ | ||
37 | + tensor = tf.convert_to_tensor(tensor) | ||
38 | + shape = tf.unstack(tf.shape(tensor)) | ||
39 | + | ||
40 | + pad_shape = shape[:] | ||
41 | + pad_shape[axis] = tf.maximum(0, new_size - shape[axis]) | ||
42 | + | ||
43 | + shape[axis] = tf.minimum(shape[axis], new_size) | ||
44 | + shape = tf.stack(shape) | ||
45 | + | ||
46 | + resized = tf.concat([ | ||
47 | + tf.slice(tensor, tf.zeros_like(shape), shape), | ||
48 | + tf.fill(tf.stack(pad_shape), tf.cast(fill_value, tensor.dtype)) | ||
49 | + ], axis) | ||
50 | + | ||
51 | + # Update shape. | ||
52 | + new_shape = tensor.get_shape().as_list() # A copy is being made. | ||
53 | + new_shape[axis] = new_size | ||
54 | + resized.set_shape(new_shape) | ||
55 | + return resized | ||
56 | + | ||
57 | + | ||
58 | +class BaseReader(object): | ||
59 | + """Inherit from this class when implementing new readers.""" | ||
60 | + | ||
61 | + def prepare_reader(self, unused_filename_queue): | ||
62 | + """Create a thread for generating prediction and label tensors.""" | ||
63 | + raise NotImplementedError() | ||
64 | + | ||
65 | + | ||
66 | +class YT8MAggregatedFeatureReader(BaseReader): | ||
67 | + """Reads TFRecords of pre-aggregated Examples. | ||
68 | + | ||
69 | + The TFRecords must contain Examples with a sparse int64 'labels' feature and | ||
70 | + a fixed length float32 feature, obtained from the features in 'feature_name'. | ||
71 | + The float features are assumed to be an average of dequantized values. | ||
72 | + """ | ||
73 | + | ||
74 | + def __init__( # pylint: disable=dangerous-default-value | ||
75 | + self, | ||
76 | + num_classes=3862, | ||
77 | + feature_sizes=[1024, 128], | ||
78 | + feature_names=["mean_rgb", "mean_audio"]): | ||
79 | + """Construct a YT8MAggregatedFeatureReader. | ||
80 | + | ||
81 | + Args: | ||
82 | + num_classes: a positive integer for the number of classes. | ||
83 | + feature_sizes: positive integer(s) for the feature dimensions as a list. | ||
84 | + feature_names: the feature name(s) in the tensorflow record as a list. | ||
85 | + """ | ||
86 | + | ||
87 | + assert len(feature_names) == len(feature_sizes), ( | ||
88 | + "length of feature_names (={}) != length of feature_sizes (={})".format( | ||
89 | + len(feature_names), len(feature_sizes))) | ||
90 | + | ||
91 | + self.num_classes = num_classes | ||
92 | + self.feature_sizes = feature_sizes | ||
93 | + self.feature_names = feature_names | ||
94 | + | ||
95 | + def prepare_reader(self, filename_queue, batch_size=1024): | ||
96 | + """Creates a single reader thread for pre-aggregated YouTube 8M Examples. | ||
97 | + | ||
98 | + Args: | ||
99 | + filename_queue: A tensorflow queue of filename locations. | ||
100 | + batch_size: batch size used for feature output. | ||
101 | + | ||
102 | + Returns: | ||
103 | + A dict of video indexes, features, labels, and frame counts. | ||
104 | + """ | ||
105 | + reader = tf.TFRecordReader() | ||
106 | + _, serialized_examples = reader.read_up_to(filename_queue, batch_size) | ||
107 | + | ||
108 | + tf.add_to_collection("serialized_examples", serialized_examples) | ||
109 | + return self.prepare_serialized_examples(serialized_examples) | ||
110 | + | ||
111 | + def prepare_serialized_examples(self, serialized_examples): | ||
112 | + """Parse a single video-level TF Example.""" | ||
113 | + # set the mapping from the fields to data types in the proto | ||
114 | + num_features = len(self.feature_names) | ||
115 | + assert num_features > 0, "self.feature_names is empty!" | ||
116 | + assert len(self.feature_names) == len(self.feature_sizes), \ | ||
117 | + "length of feature_names (={}) != length of feature_sizes (={})".format( | ||
118 | + len(self.feature_names), len(self.feature_sizes)) | ||
119 | + | ||
120 | + feature_map = { | ||
121 | + "id": tf.io.FixedLenFeature([], tf.string), | ||
122 | + "labels": tf.io.VarLenFeature(tf.int64) | ||
123 | + } | ||
124 | + for feature_index in range(num_features): | ||
125 | + feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature( | ||
126 | + [self.feature_sizes[feature_index]], tf.float32) | ||
127 | + | ||
128 | + features = tf.parse_example(serialized_examples, features=feature_map) | ||
129 | + labels = tf.sparse_to_indicator(features["labels"], self.num_classes) | ||
130 | + labels.set_shape([None, self.num_classes]) | ||
131 | + concatenated_features = tf.concat( | ||
132 | + [features[feature_name] for feature_name in self.feature_names], 1) | ||
133 | + | ||
134 | + output_dict = { | ||
135 | + "video_ids": features["id"], | ||
136 | + "video_matrix": concatenated_features, | ||
137 | + "labels": labels, | ||
138 | + "num_frames": tf.ones([tf.shape(serialized_examples)[0]]) | ||
139 | + } | ||
140 | + | ||
141 | + return output_dict | ||
142 | + | ||
143 | + | ||
144 | +class YT8MFrameFeatureReader(BaseReader): | ||
145 | + """Reads TFRecords of SequenceExamples. | ||
146 | + | ||
147 | + The TFRecords must contain SequenceExamples with the sparse in64 'labels' | ||
148 | + context feature and a fixed length byte-quantized feature vector, obtained | ||
149 | + from the features in 'feature_names'. The quantized features will be mapped | ||
150 | + back into a range between min_quantized_value and max_quantized_value. | ||
151 | + """ | ||
152 | + | ||
153 | + def __init__( # pylint: disable=dangerous-default-value | ||
154 | + self, | ||
155 | + num_classes=3862, | ||
156 | + feature_sizes=[1024, 128], | ||
157 | + feature_names=["rgb", "audio"], | ||
158 | + max_frames=300, | ||
159 | + segment_labels=False, | ||
160 | + segment_size=5): | ||
161 | + """Construct a YT8MFrameFeatureReader. | ||
162 | + | ||
163 | + Args: | ||
164 | + num_classes: a positive integer for the number of classes. | ||
165 | + feature_sizes: positive integer(s) for the feature dimensions as a list. | ||
166 | + feature_names: the feature name(s) in the tensorflow record as a list. | ||
167 | + max_frames: the maximum number of frames to process. | ||
168 | + segment_labels: if we read segment labels instead. | ||
169 | + segment_size: the segment_size used for reading segments. | ||
170 | + """ | ||
171 | + | ||
172 | + assert len(feature_names) == len(feature_sizes), ( | ||
173 | + "length of feature_names (={}) != length of feature_sizes (={})".format( | ||
174 | + len(feature_names), len(feature_sizes))) | ||
175 | + | ||
176 | + self.num_classes = num_classes | ||
177 | + self.feature_sizes = feature_sizes | ||
178 | + self.feature_names = feature_names | ||
179 | + self.max_frames = max_frames | ||
180 | + self.segment_labels = segment_labels | ||
181 | + self.segment_size = segment_size | ||
182 | + | ||
183 | + def get_video_matrix(self, features, feature_size, max_frames, | ||
184 | + max_quantized_value, min_quantized_value): | ||
185 | + """Decodes features from an input string and quantizes it. | ||
186 | + | ||
187 | + Args: | ||
188 | + features: raw feature values | ||
189 | + feature_size: length of each frame feature vector | ||
190 | + max_frames: number of frames (rows) in the output feature_matrix | ||
191 | + max_quantized_value: the maximum of the quantized value. | ||
192 | + min_quantized_value: the minimum of the quantized value. | ||
193 | + | ||
194 | + Returns: | ||
195 | + feature_matrix: matrix of all frame-features | ||
196 | + num_frames: number of frames in the sequence | ||
197 | + """ | ||
198 | + decoded_features = tf.reshape( | ||
199 | + tf.cast(tf.decode_raw(features, tf.uint8), tf.float32), | ||
200 | + [-1, feature_size]) | ||
201 | + | ||
202 | + num_frames = tf.minimum(tf.shape(decoded_features)[0], max_frames) | ||
203 | + feature_matrix = utils.Dequantize(decoded_features, max_quantized_value, | ||
204 | + min_quantized_value) | ||
205 | + feature_matrix = resize_axis(feature_matrix, 0, max_frames) | ||
206 | + return feature_matrix, num_frames | ||
207 | + | ||
208 | + def prepare_reader(self, | ||
209 | + filename_queue, | ||
210 | + max_quantized_value=2, | ||
211 | + min_quantized_value=-2): | ||
212 | + """Creates a single reader thread for YouTube8M SequenceExamples. | ||
213 | + | ||
214 | + Args: | ||
215 | + filename_queue: A tensorflow queue of filename locations. | ||
216 | + max_quantized_value: the maximum of the quantized value. | ||
217 | + min_quantized_value: the minimum of the quantized value. | ||
218 | + | ||
219 | + Returns: | ||
220 | + A dict of video indexes, video features, labels, and frame counts. | ||
221 | + """ | ||
222 | + reader = tf.TFRecordReader() | ||
223 | + _, serialized_example = reader.read(filename_queue) | ||
224 | + | ||
225 | + return self.prepare_serialized_examples(serialized_example, | ||
226 | + max_quantized_value, | ||
227 | + min_quantized_value) | ||
228 | + | ||
229 | + def prepare_serialized_examples(self, | ||
230 | + serialized_example, | ||
231 | + max_quantized_value=2, | ||
232 | + min_quantized_value=-2): | ||
233 | + """Parse single serialized SequenceExample from the TFRecords.""" | ||
234 | + | ||
235 | + # Read/parse frame/segment-level labels. | ||
236 | + context_features = { | ||
237 | + "id": tf.io.FixedLenFeature([], tf.string), | ||
238 | + } | ||
239 | + if self.segment_labels: | ||
240 | + context_features.update({ | ||
241 | + # There is no need to read end-time given we always assume the segment | ||
242 | + # has the same size. | ||
243 | + "segment_labels": tf.io.VarLenFeature(tf.int64), | ||
244 | + "segment_start_times": tf.io.VarLenFeature(tf.int64), | ||
245 | + "segment_scores": tf.io.VarLenFeature(tf.float32) | ||
246 | + }) | ||
247 | + else: | ||
248 | + context_features.update({"labels": tf.io.VarLenFeature(tf.int64)}) | ||
249 | + sequence_features = { | ||
250 | + feature_name: tf.io.FixedLenSequenceFeature([], dtype=tf.string) | ||
251 | + for feature_name in self.feature_names | ||
252 | + } | ||
253 | + contexts, features = tf.io.parse_single_sequence_example( | ||
254 | + serialized_example, | ||
255 | + context_features=context_features, | ||
256 | + sequence_features=sequence_features) | ||
257 | + | ||
258 | + # loads (potentially) different types of features and concatenates them | ||
259 | + num_features = len(self.feature_names) | ||
260 | + assert num_features > 0, "No feature selected: feature_names is empty!" | ||
261 | + | ||
262 | + assert len(self.feature_names) == len(self.feature_sizes), ( | ||
263 | + "length of feature_names (={}) != length of feature_sizes (={})".format( | ||
264 | + len(self.feature_names), len(self.feature_sizes))) | ||
265 | + | ||
266 | + num_frames = -1 # the number of frames in the video | ||
267 | + feature_matrices = [None] * num_features # an array of different features | ||
268 | + for feature_index in range(num_features): | ||
269 | + feature_matrix, num_frames_in_this_feature = self.get_video_matrix( | ||
270 | + features[self.feature_names[feature_index]], | ||
271 | + self.feature_sizes[feature_index], self.max_frames, | ||
272 | + max_quantized_value, min_quantized_value) | ||
273 | + if num_frames == -1: | ||
274 | + num_frames = num_frames_in_this_feature | ||
275 | + | ||
276 | + feature_matrices[feature_index] = feature_matrix | ||
277 | + | ||
278 | + # cap the number of frames at self.max_frames | ||
279 | + num_frames = tf.minimum(num_frames, self.max_frames) | ||
280 | + | ||
281 | + # concatenate different features | ||
282 | + video_matrix = tf.concat(feature_matrices, 1) | ||
283 | + | ||
284 | + # Partition frame-level feature matrix to segment-level feature matrix. | ||
285 | + if self.segment_labels: | ||
286 | + start_times = contexts["segment_start_times"].values | ||
287 | + # Here we assume all the segments that started at the same start time has | ||
288 | + # the same segment_size. | ||
289 | + uniq_start_times, seg_idxs = tf.unique(start_times, | ||
290 | + out_idx=tf.dtypes.int64) | ||
291 | + # TODO(zhengxu): Ensure the segment_sizes are all same. | ||
292 | + segment_size = self.segment_size | ||
293 | + # Range gather matrix, e.g., [[0,1,2],[1,2,3]] for segment_size == 3. | ||
294 | + range_mtx = tf.expand_dims(uniq_start_times, axis=-1) + tf.expand_dims( | ||
295 | + tf.range(0, segment_size, dtype=tf.int64), axis=0) | ||
296 | + # Shape: [num_segment, segment_size, feature_dim]. | ||
297 | + batch_video_matrix = tf.gather_nd(video_matrix, | ||
298 | + tf.expand_dims(range_mtx, axis=-1)) | ||
299 | + num_segment = tf.shape(batch_video_matrix)[0] | ||
300 | + batch_video_ids = tf.reshape(tf.tile([contexts["id"]], [num_segment]), | ||
301 | + (num_segment,)) | ||
302 | + batch_frames = tf.reshape(tf.tile([segment_size], [num_segment]), | ||
303 | + (num_segment,)) | ||
304 | + | ||
305 | + # For segment labels, all labels are not exhausively rated. So we only | ||
306 | + # evaluate the rated labels. | ||
307 | + | ||
308 | + # Label indices for each segment, shape: [num_segment, 2]. | ||
309 | + label_indices = tf.stack([seg_idxs, contexts["segment_labels"].values], | ||
310 | + axis=-1) | ||
311 | + label_values = contexts["segment_scores"].values | ||
312 | + sparse_labels = tf.sparse.SparseTensor(label_indices, label_values, | ||
313 | + (num_segment, self.num_classes)) | ||
314 | + batch_labels = tf.sparse.to_dense(sparse_labels, validate_indices=False) | ||
315 | + | ||
316 | + sparse_label_weights = tf.sparse.SparseTensor( | ||
317 | + label_indices, tf.ones_like(label_values, dtype=tf.float32), | ||
318 | + (num_segment, self.num_classes)) | ||
319 | + batch_label_weights = tf.sparse.to_dense(sparse_label_weights, | ||
320 | + validate_indices=False) | ||
321 | + else: | ||
322 | + # Process video-level labels. | ||
323 | + label_indices = contexts["labels"].values | ||
324 | + sparse_labels = tf.sparse.SparseTensor( | ||
325 | + tf.expand_dims(label_indices, axis=-1), | ||
326 | + tf.ones_like(contexts["labels"].values, dtype=tf.bool), | ||
327 | + (self.num_classes,)) | ||
328 | + labels = tf.sparse.to_dense(sparse_labels, | ||
329 | + default_value=False, | ||
330 | + validate_indices=False) | ||
331 | + # convert to batch format. | ||
332 | + batch_video_ids = tf.expand_dims(contexts["id"], 0) | ||
333 | + batch_video_matrix = tf.expand_dims(video_matrix, 0) | ||
334 | + batch_labels = tf.expand_dims(labels, 0) | ||
335 | + batch_frames = tf.expand_dims(num_frames, 0) | ||
336 | + batch_label_weights = None | ||
337 | + | ||
338 | + output_dict = { | ||
339 | + "video_ids": batch_video_ids, | ||
340 | + "video_matrix": batch_video_matrix, | ||
341 | + "labels": batch_labels, | ||
342 | + "num_frames": batch_frames, | ||
343 | + } | ||
344 | + if batch_label_weights is not None: | ||
345 | + output_dict["label_weights"] = batch_label_weights | ||
346 | + | ||
347 | + return output_dict |
yt8m/segment_eval_inference.py
0 → 100644
1 | +"""Eval mAP@N metric from inference file.""" | ||
2 | + | ||
3 | +from __future__ import absolute_import | ||
4 | +from __future__ import division | ||
5 | +from __future__ import print_function | ||
6 | + | ||
7 | +from absl import app | ||
8 | +from absl import flags | ||
9 | + | ||
10 | +import mean_average_precision_calculator as map_calculator | ||
11 | +import numpy as np | ||
12 | +import tensorflow as tf | ||
13 | + | ||
14 | +flags.DEFINE_string( | ||
15 | + "eval_data_pattern", "", | ||
16 | + "File glob defining the evaluation dataset in tensorflow.SequenceExample " | ||
17 | + "format. The SequenceExamples are expected to have an 'rgb' byte array " | ||
18 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
19 | +flags.DEFINE_string( | ||
20 | + "label_cache", "", | ||
21 | + "The path for the label cache file. Leave blank for not to cache.") | ||
22 | +flags.DEFINE_string("submission_file", "", | ||
23 | + "The segment submission file generated by inference.py.") | ||
24 | +flags.DEFINE_integer( | ||
25 | + "top_n", 0, | ||
26 | + "The cap per-class predictions by a maximum of N. Use 0 for not capping.") | ||
27 | + | ||
28 | +FLAGS = flags.FLAGS | ||
29 | + | ||
30 | + | ||
31 | +class Labels(object): | ||
32 | + """Contains the class to hold label objects. | ||
33 | + | ||
34 | + This class can serialize and de-serialize the groundtruths. | ||
35 | + The ground truth is in a mapping from (segment_id, class_id) -> label_score. | ||
36 | + """ | ||
37 | + | ||
38 | + def __init__(self, labels): | ||
39 | + """__init__ method.""" | ||
40 | + self._labels = labels | ||
41 | + | ||
42 | + @property | ||
43 | + def labels(self): | ||
44 | + """Return the ground truth mapping. See class docstring for details.""" | ||
45 | + return self._labels | ||
46 | + | ||
47 | + def to_file(self, file_name): | ||
48 | + """Materialize the GT mapping to file.""" | ||
49 | + with tf.gfile.Open(file_name, "w") as fobj: | ||
50 | + for k, v in self._labels.items(): | ||
51 | + seg_id, label = k | ||
52 | + line = "%s,%s,%s\n" % (seg_id, label, v) | ||
53 | + fobj.write(line) | ||
54 | + | ||
55 | + @classmethod | ||
56 | + def from_file(cls, file_name): | ||
57 | + """Read the GT mapping from cached file.""" | ||
58 | + labels = {} | ||
59 | + with tf.gfile.Open(file_name) as fobj: | ||
60 | + for line in fobj: | ||
61 | + line = line.strip().strip("\n") | ||
62 | + seg_id, label, score = line.split(",") | ||
63 | + labels[(seg_id, int(label))] = float(score) | ||
64 | + return cls(labels) | ||
65 | + | ||
66 | + | ||
67 | +def read_labels(data_pattern, cache_path=""): | ||
68 | + """Read labels from TFRecords. | ||
69 | + | ||
70 | + Args: | ||
71 | + data_pattern: the data pattern to the TFRecords. | ||
72 | + cache_path: the cache path for the label file. | ||
73 | + | ||
74 | + Returns: | ||
75 | + a Labels object. | ||
76 | + """ | ||
77 | + if cache_path: | ||
78 | + if tf.gfile.Exists(cache_path): | ||
79 | + tf.logging.info("Reading cached labels from %s..." % cache_path) | ||
80 | + return Labels.from_file(cache_path) | ||
81 | + tf.enable_eager_execution() | ||
82 | + data_paths = tf.gfile.Glob(data_pattern) | ||
83 | + ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50) | ||
84 | + context_features = { | ||
85 | + "id": tf.FixedLenFeature([], tf.string), | ||
86 | + "segment_labels": tf.VarLenFeature(tf.int64), | ||
87 | + "segment_start_times": tf.VarLenFeature(tf.int64), | ||
88 | + "segment_scores": tf.VarLenFeature(tf.float32) | ||
89 | + } | ||
90 | + | ||
91 | + def _parse_se_func(sequence_example): | ||
92 | + return tf.parse_single_sequence_example(sequence_example, | ||
93 | + context_features=context_features) | ||
94 | + | ||
95 | + ds = ds.map(_parse_se_func) | ||
96 | + rated_labels = {} | ||
97 | + tf.logging.info("Reading labels from TFRecords...") | ||
98 | + last_batch = 0 | ||
99 | + batch_size = 5000 | ||
100 | + for cxt_feature_val, _ in ds: | ||
101 | + video_id = cxt_feature_val["id"].numpy() | ||
102 | + segment_labels = cxt_feature_val["segment_labels"].values.numpy() | ||
103 | + segment_start_times = cxt_feature_val["segment_start_times"].values.numpy() | ||
104 | + segment_scores = cxt_feature_val["segment_scores"].values.numpy() | ||
105 | + for label, start_time, score in zip(segment_labels, segment_start_times, | ||
106 | + segment_scores): | ||
107 | + rated_labels[("%s:%d" % (video_id, start_time), label)] = score | ||
108 | + batch_id = len(rated_labels) // batch_size | ||
109 | + if batch_id != last_batch: | ||
110 | + tf.logging.info("%d examples processed.", len(rated_labels)) | ||
111 | + last_batch = batch_id | ||
112 | + tf.logging.info("Finish reading labels from TFRecords...") | ||
113 | + labels_obj = Labels(rated_labels) | ||
114 | + if cache_path: | ||
115 | + tf.logging.info("Caching labels to %s..." % cache_path) | ||
116 | + labels_obj.to_file(cache_path) | ||
117 | + return labels_obj | ||
118 | + | ||
119 | + | ||
120 | +def read_segment_predictions(file_path, labels, top_n=None): | ||
121 | + """Read segement predictions. | ||
122 | + | ||
123 | + Args: | ||
124 | + file_path: the submission file path. | ||
125 | + labels: a Labels object containing the eval labels. | ||
126 | + top_n: the per-class class capping. | ||
127 | + | ||
128 | + Returns: | ||
129 | + a segment prediction list for each classes. | ||
130 | + """ | ||
131 | + cls_preds = {} # A label_id to pred list mapping. | ||
132 | + with tf.gfile.Open(file_path) as fobj: | ||
133 | + tf.logging.info("Reading predictions from %s..." % file_path) | ||
134 | + for line in fobj: | ||
135 | + label_id, pred_ids_val = line.split(",") | ||
136 | + pred_ids = pred_ids_val.split(" ") | ||
137 | + if top_n: | ||
138 | + pred_ids = pred_ids[:top_n] | ||
139 | + pred_ids = [ | ||
140 | + pred_id for pred_id in pred_ids | ||
141 | + if (pred_id, int(label_id)) in labels.labels | ||
142 | + ] | ||
143 | + cls_preds[int(label_id)] = pred_ids | ||
144 | + if len(cls_preds) % 50 == 0: | ||
145 | + tf.logging.info("Processed %d classes..." % len(cls_preds)) | ||
146 | + tf.logging.info("Finish reading predictions.") | ||
147 | + return cls_preds | ||
148 | + | ||
149 | + | ||
150 | +def main(unused_argv): | ||
151 | + """Entry function of the script.""" | ||
152 | + if not FLAGS.submission_file: | ||
153 | + raise ValueError("You must input submission file.") | ||
154 | + eval_labels = read_labels(FLAGS.eval_data_pattern, | ||
155 | + cache_path=FLAGS.label_cache) | ||
156 | + tf.logging.info("Total rated segments: %d." % len(eval_labels.labels)) | ||
157 | + positive_counter = {} | ||
158 | + for k, v in eval_labels.labels.items(): | ||
159 | + _, label_id = k | ||
160 | + if v > 0: | ||
161 | + positive_counter[label_id] = positive_counter.get(label_id, 0) + 1 | ||
162 | + | ||
163 | + seg_preds = read_segment_predictions(FLAGS.submission_file, | ||
164 | + eval_labels, | ||
165 | + top_n=FLAGS.top_n) | ||
166 | + map_cal = map_calculator.MeanAveragePrecisionCalculator(len(seg_preds)) | ||
167 | + seg_labels = [] | ||
168 | + seg_scored_preds = [] | ||
169 | + num_positives = [] | ||
170 | + for label_id in sorted(seg_preds): | ||
171 | + class_preds = seg_preds[label_id] | ||
172 | + seg_label = [eval_labels.labels[(pred, label_id)] for pred in class_preds] | ||
173 | + seg_labels.append(seg_label) | ||
174 | + seg_scored_pred = [] | ||
175 | + if class_preds: | ||
176 | + seg_scored_pred = [ | ||
177 | + float(x) / len(class_preds) for x in range(len(class_preds), 0, -1) | ||
178 | + ] | ||
179 | + seg_scored_preds.append(seg_scored_pred) | ||
180 | + num_positives.append(positive_counter[label_id]) | ||
181 | + map_cal.accumulate(seg_scored_preds, seg_labels, num_positives) | ||
182 | + map_at_n = np.mean(map_cal.peek_map_at_n()) | ||
183 | + tf.logging.info("Num classes: %d | mAP@%d: %.6f" % | ||
184 | + (len(seg_preds), FLAGS.top_n, map_at_n)) | ||
185 | + | ||
186 | + | ||
187 | +if __name__ == "__main__": | ||
188 | + app.run(main) |
yt8m/segment_label_ids.csv
0 → 100644
1 | +Index | ||
2 | +3 | ||
3 | +7 | ||
4 | +8 | ||
5 | +11 | ||
6 | +12 | ||
7 | +17 | ||
8 | +18 | ||
9 | +19 | ||
10 | +21 | ||
11 | +22 | ||
12 | +23 | ||
13 | +28 | ||
14 | +31 | ||
15 | +30 | ||
16 | +32 | ||
17 | +33 | ||
18 | +34 | ||
19 | +41 | ||
20 | +43 | ||
21 | +45 | ||
22 | +46 | ||
23 | +48 | ||
24 | +53 | ||
25 | +54 | ||
26 | +52 | ||
27 | +55 | ||
28 | +58 | ||
29 | +59 | ||
30 | +60 | ||
31 | +61 | ||
32 | +65 | ||
33 | +68 | ||
34 | +73 | ||
35 | +71 | ||
36 | +74 | ||
37 | +75 | ||
38 | +76 | ||
39 | +77 | ||
40 | +80 | ||
41 | +83 | ||
42 | +90 | ||
43 | +88 | ||
44 | +89 | ||
45 | +92 | ||
46 | +95 | ||
47 | +100 | ||
48 | +101 | ||
49 | +99 | ||
50 | +104 | ||
51 | +105 | ||
52 | +109 | ||
53 | +113 | ||
54 | +112 | ||
55 | +115 | ||
56 | +116 | ||
57 | +118 | ||
58 | +120 | ||
59 | +121 | ||
60 | +123 | ||
61 | +125 | ||
62 | +127 | ||
63 | +131 | ||
64 | +128 | ||
65 | +129 | ||
66 | +130 | ||
67 | +137 | ||
68 | +141 | ||
69 | +143 | ||
70 | +145 | ||
71 | +148 | ||
72 | +152 | ||
73 | +151 | ||
74 | +156 | ||
75 | +155 | ||
76 | +158 | ||
77 | +160 | ||
78 | +164 | ||
79 | +163 | ||
80 | +169 | ||
81 | +170 | ||
82 | +172 | ||
83 | +171 | ||
84 | +173 | ||
85 | +174 | ||
86 | +175 | ||
87 | +176 | ||
88 | +178 | ||
89 | +182 | ||
90 | +184 | ||
91 | +186 | ||
92 | +188 | ||
93 | +187 | ||
94 | +192 | ||
95 | +191 | ||
96 | +190 | ||
97 | +194 | ||
98 | +197 | ||
99 | +196 | ||
100 | +198 | ||
101 | +201 | ||
102 | +202 | ||
103 | +200 | ||
104 | +199 | ||
105 | +205 | ||
106 | +204 | ||
107 | +209 | ||
108 | +207 | ||
109 | +206 | ||
110 | +210 | ||
111 | +213 | ||
112 | +214 | ||
113 | +220 | ||
114 | +218 | ||
115 | +217 | ||
116 | +226 | ||
117 | +227 | ||
118 | +231 | ||
119 | +232 | ||
120 | +229 | ||
121 | +233 | ||
122 | +235 | ||
123 | +237 | ||
124 | +244 | ||
125 | +240 | ||
126 | +249 | ||
127 | +246 | ||
128 | +248 | ||
129 | +239 | ||
130 | +250 | ||
131 | +245 | ||
132 | +255 | ||
133 | +253 | ||
134 | +256 | ||
135 | +261 | ||
136 | +259 | ||
137 | +263 | ||
138 | +262 | ||
139 | +266 | ||
140 | +267 | ||
141 | +268 | ||
142 | +269 | ||
143 | +271 | ||
144 | +276 | ||
145 | +273 | ||
146 | +277 | ||
147 | +274 | ||
148 | +278 | ||
149 | +279 | ||
150 | +280 | ||
151 | +288 | ||
152 | +291 | ||
153 | +295 | ||
154 | +294 | ||
155 | +293 | ||
156 | +297 | ||
157 | +296 | ||
158 | +300 | ||
159 | +299 | ||
160 | +303 | ||
161 | +302 | ||
162 | +304 | ||
163 | +305 | ||
164 | +313 | ||
165 | +307 | ||
166 | +311 | ||
167 | +310 | ||
168 | +312 | ||
169 | +316 | ||
170 | +318 | ||
171 | +321 | ||
172 | +322 | ||
173 | +331 | ||
174 | +333 | ||
175 | +329 | ||
176 | +330 | ||
177 | +334 | ||
178 | +343 | ||
179 | +349 | ||
180 | +340 | ||
181 | +344 | ||
182 | +348 | ||
183 | +358 | ||
184 | +347 | ||
185 | +359 | ||
186 | +355 | ||
187 | +361 | ||
188 | +360 | ||
189 | +364 | ||
190 | +365 | ||
191 | +368 | ||
192 | +369 | ||
193 | +366 | ||
194 | +370 | ||
195 | +374 | ||
196 | +380 | ||
197 | +373 | ||
198 | +385 | ||
199 | +384 | ||
200 | +388 | ||
201 | +389 | ||
202 | +382 | ||
203 | +393 | ||
204 | +381 | ||
205 | +390 | ||
206 | +394 | ||
207 | +399 | ||
208 | +397 | ||
209 | +396 | ||
210 | +402 | ||
211 | +400 | ||
212 | +398 | ||
213 | +401 | ||
214 | +405 | ||
215 | +406 | ||
216 | +410 | ||
217 | +408 | ||
218 | +416 | ||
219 | +415 | ||
220 | +419 | ||
221 | +422 | ||
222 | +414 | ||
223 | +421 | ||
224 | +424 | ||
225 | +429 | ||
226 | +418 | ||
227 | +427 | ||
228 | +434 | ||
229 | +428 | ||
230 | +435 | ||
231 | +430 | ||
232 | +441 | ||
233 | +439 | ||
234 | +437 | ||
235 | +443 | ||
236 | +440 | ||
237 | +442 | ||
238 | +445 | ||
239 | +446 | ||
240 | +448 | ||
241 | +454 | ||
242 | +444 | ||
243 | +453 | ||
244 | +455 | ||
245 | +451 | ||
246 | +452 | ||
247 | +458 | ||
248 | +460 | ||
249 | +465 | ||
250 | +457 | ||
251 | +463 | ||
252 | +462 | ||
253 | +461 | ||
254 | +464 | ||
255 | +469 | ||
256 | +468 | ||
257 | +472 | ||
258 | +473 | ||
259 | +471 | ||
260 | +475 | ||
261 | +474 | ||
262 | +477 | ||
263 | +485 | ||
264 | +491 | ||
265 | +488 | ||
266 | +482 | ||
267 | +490 | ||
268 | +496 | ||
269 | +494 | ||
270 | +483 | ||
271 | +495 | ||
272 | +493 | ||
273 | +507 | ||
274 | +501 | ||
275 | +499 | ||
276 | +503 | ||
277 | +498 | ||
278 | +514 | ||
279 | +504 | ||
280 | +502 | ||
281 | +506 | ||
282 | +508 | ||
283 | +511 | ||
284 | +527 | ||
285 | +526 | ||
286 | +532 | ||
287 | +513 | ||
288 | +519 | ||
289 | +525 | ||
290 | +518 | ||
291 | +528 | ||
292 | +522 | ||
293 | +523 | ||
294 | +535 | ||
295 | +539 | ||
296 | +540 | ||
297 | +533 | ||
298 | +521 | ||
299 | +541 | ||
300 | +547 | ||
301 | +550 | ||
302 | +544 | ||
303 | +549 | ||
304 | +551 | ||
305 | +554 | ||
306 | +543 | ||
307 | +548 | ||
308 | +557 | ||
309 | +560 | ||
310 | +552 | ||
311 | +559 | ||
312 | +563 | ||
313 | +565 | ||
314 | +567 | ||
315 | +555 | ||
316 | +576 | ||
317 | +568 | ||
318 | +564 | ||
319 | +573 | ||
320 | +581 | ||
321 | +580 | ||
322 | +572 | ||
323 | +571 | ||
324 | +584 | ||
325 | +590 | ||
326 | +585 | ||
327 | +587 | ||
328 | +588 | ||
329 | +592 | ||
330 | +598 | ||
331 | +597 | ||
332 | +599 | ||
333 | +603 | ||
334 | +600 | ||
335 | +604 | ||
336 | +605 | ||
337 | +614 | ||
338 | +602 | ||
339 | +610 | ||
340 | +608 | ||
341 | +611 | ||
342 | +612 | ||
343 | +613 | ||
344 | +617 | ||
345 | +620 | ||
346 | +607 | ||
347 | +624 | ||
348 | +627 | ||
349 | +625 | ||
350 | +631 | ||
351 | +629 | ||
352 | +638 | ||
353 | +632 | ||
354 | +634 | ||
355 | +644 | ||
356 | +641 | ||
357 | +642 | ||
358 | +646 | ||
359 | +652 | ||
360 | +647 | ||
361 | +637 | ||
362 | +661 | ||
363 | +635 | ||
364 | +658 | ||
365 | +648 | ||
366 | +663 | ||
367 | +668 | ||
368 | +664 | ||
369 | +656 | ||
370 | +666 | ||
371 | +671 | ||
372 | +683 | ||
373 | +675 | ||
374 | +669 | ||
375 | +676 | ||
376 | +667 | ||
377 | +691 | ||
378 | +685 | ||
379 | +673 | ||
380 | +688 | ||
381 | +702 | ||
382 | +684 | ||
383 | +679 | ||
384 | +694 | ||
385 | +686 | ||
386 | +689 | ||
387 | +680 | ||
388 | +693 | ||
389 | +703 | ||
390 | +697 | ||
391 | +698 | ||
392 | +692 | ||
393 | +705 | ||
394 | +706 | ||
395 | +712 | ||
396 | +711 | ||
397 | +709 | ||
398 | +710 | ||
399 | +726 | ||
400 | +713 | ||
401 | +721 | ||
402 | +720 | ||
403 | +715 | ||
404 | +717 | ||
405 | +730 | ||
406 | +728 | ||
407 | +723 | ||
408 | +716 | ||
409 | +722 | ||
410 | +718 | ||
411 | +732 | ||
412 | +724 | ||
413 | +736 | ||
414 | +725 | ||
415 | +742 | ||
416 | +727 | ||
417 | +735 | ||
418 | +740 | ||
419 | +748 | ||
420 | +738 | ||
421 | +746 | ||
422 | +751 | ||
423 | +749 | ||
424 | +752 | ||
425 | +754 | ||
426 | +760 | ||
427 | +763 | ||
428 | +756 | ||
429 | +758 | ||
430 | +766 | ||
431 | +764 | ||
432 | +757 | ||
433 | +780 | ||
434 | +767 | ||
435 | +769 | ||
436 | +771 | ||
437 | +786 | ||
438 | +785 | ||
439 | +781 | ||
440 | +787 | ||
441 | +778 | ||
442 | +783 | ||
443 | +792 | ||
444 | +791 | ||
445 | +795 | ||
446 | +788 | ||
447 | +805 | ||
448 | +802 | ||
449 | +801 | ||
450 | +793 | ||
451 | +796 | ||
452 | +804 | ||
453 | +803 | ||
454 | +797 | ||
455 | +814 | ||
456 | +813 | ||
457 | +789 | ||
458 | +808 | ||
459 | +818 | ||
460 | +816 | ||
461 | +817 | ||
462 | +811 | ||
463 | +820 | ||
464 | +826 | ||
465 | +829 | ||
466 | +824 | ||
467 | +821 | ||
468 | +825 | ||
469 | +822 | ||
470 | +835 | ||
471 | +833 | ||
472 | +843 | ||
473 | +823 | ||
474 | +827 | ||
475 | +830 | ||
476 | +832 | ||
477 | +837 | ||
478 | +852 | ||
479 | +844 | ||
480 | +841 | ||
481 | +812 | ||
482 | +847 | ||
483 | +862 | ||
484 | +869 | ||
485 | +860 | ||
486 | +838 | ||
487 | +870 | ||
488 | +846 | ||
489 | +858 | ||
490 | +854 | ||
491 | +880 | ||
492 | +876 | ||
493 | +857 | ||
494 | +859 | ||
495 | +877 | ||
496 | +871 | ||
497 | +855 | ||
498 | +875 | ||
499 | +861 | ||
500 | +867 | ||
501 | +892 | ||
502 | +898 | ||
503 | +888 | ||
504 | +884 | ||
505 | +887 | ||
506 | +891 | ||
507 | +906 | ||
508 | +900 | ||
509 | +878 | ||
510 | +885 | ||
511 | +883 | ||
512 | +901 | ||
513 | +903 | ||
514 | +907 | ||
515 | +930 | ||
516 | +897 | ||
517 | +914 | ||
518 | +917 | ||
519 | +910 | ||
520 | +905 | ||
521 | +909 | ||
522 | +933 | ||
523 | +932 | ||
524 | +922 | ||
525 | +913 | ||
526 | +923 | ||
527 | +931 | ||
528 | +911 | ||
529 | +937 | ||
530 | +918 | ||
531 | +955 | ||
532 | +915 | ||
533 | +944 | ||
534 | +952 | ||
535 | +945 | ||
536 | +948 | ||
537 | +946 | ||
538 | +970 | ||
539 | +974 | ||
540 | +958 | ||
541 | +925 | ||
542 | +979 | ||
543 | +942 | ||
544 | +965 | ||
545 | +975 | ||
546 | +950 | ||
547 | +982 | ||
548 | +940 | ||
549 | +973 | ||
550 | +962 | ||
551 | +972 | ||
552 | +957 | ||
553 | +984 | ||
554 | +983 | ||
555 | +964 | ||
556 | +1007 | ||
557 | +971 | ||
558 | +981 | ||
559 | +954 | ||
560 | +993 | ||
561 | +991 | ||
562 | +996 | ||
563 | +1005 | ||
564 | +1015 | ||
565 | +1009 | ||
566 | +995 | ||
567 | +986 | ||
568 | +1000 | ||
569 | +985 | ||
570 | +980 | ||
571 | +1016 | ||
572 | +1011 | ||
573 | +999 | ||
574 | +1002 | ||
575 | +994 | ||
576 | +1013 | ||
577 | +1010 | ||
578 | +992 | ||
579 | +1008 | ||
580 | +1036 | ||
581 | +1025 | ||
582 | +1012 | ||
583 | +990 | ||
584 | +1037 | ||
585 | +1040 | ||
586 | +1031 | ||
587 | +1019 | ||
588 | +1052 | ||
589 | +1001 | ||
590 | +1055 | ||
591 | +1032 | ||
592 | +1069 | ||
593 | +1058 | ||
594 | +1014 | ||
595 | +1023 | ||
596 | +1030 | ||
597 | +1061 | ||
598 | +1035 | ||
599 | +1034 | ||
600 | +1053 | ||
601 | +1045 | ||
602 | +1046 | ||
603 | +1067 | ||
604 | +1060 | ||
605 | +1049 | ||
606 | +1056 | ||
607 | +1074 | ||
608 | +1066 | ||
609 | +1044 | ||
610 | +1038 | ||
611 | +1073 | ||
612 | +1077 | ||
613 | +1068 | ||
614 | +1057 | ||
615 | +1072 | ||
616 | +1104 | ||
617 | +1083 | ||
618 | +1089 | ||
619 | +1087 | ||
620 | +1099 | ||
621 | +1076 | ||
622 | +1086 | ||
623 | +1098 | ||
624 | +1094 | ||
625 | +1095 | ||
626 | +1096 | ||
627 | +1101 | ||
628 | +1107 | ||
629 | +1105 | ||
630 | +1117 | ||
631 | +1093 | ||
632 | +1106 | ||
633 | +1122 | ||
634 | +1119 | ||
635 | +1103 | ||
636 | +1128 | ||
637 | +1120 | ||
638 | +1126 | ||
639 | +1102 | ||
640 | +1115 | ||
641 | +1124 | ||
642 | +1123 | ||
643 | +1131 | ||
644 | +1136 | ||
645 | +1144 | ||
646 | +1121 | ||
647 | +1137 | ||
648 | +1132 | ||
649 | +1133 | ||
650 | +1157 | ||
651 | +1134 | ||
652 | +1143 | ||
653 | +1159 | ||
654 | +1164 | ||
655 | +1155 | ||
656 | +1142 | ||
657 | +1150 | ||
658 | +1148 | ||
659 | +1161 | ||
660 | +1165 | ||
661 | +1147 | ||
662 | +1162 | ||
663 | +1152 | ||
664 | +1174 | ||
665 | +1160 | ||
666 | +1166 | ||
667 | +1190 | ||
668 | +1175 | ||
669 | +1167 | ||
670 | +1156 | ||
671 | +1180 | ||
672 | +1171 | ||
673 | +1179 | ||
674 | +1172 | ||
675 | +1186 | ||
676 | +1188 | ||
677 | +1201 | ||
678 | +1177 | ||
679 | +1208 | ||
680 | +1183 | ||
681 | +1189 | ||
682 | +1192 | ||
683 | +1209 | ||
684 | +1214 | ||
685 | +1197 | ||
686 | +1168 | ||
687 | +1202 | ||
688 | +1205 | ||
689 | +1203 | ||
690 | +1199 | ||
691 | +1219 | ||
692 | +1217 | ||
693 | +1187 | ||
694 | +1206 | ||
695 | +1210 | ||
696 | +1241 | ||
697 | +1221 | ||
698 | +1218 | ||
699 | +1223 | ||
700 | +1236 | ||
701 | +1212 | ||
702 | +1237 | ||
703 | +1195 | ||
704 | +1216 | ||
705 | +1247 | ||
706 | +1234 | ||
707 | +1240 | ||
708 | +1257 | ||
709 | +1224 | ||
710 | +1243 | ||
711 | +1259 | ||
712 | +1242 | ||
713 | +1282 | ||
714 | +1222 | ||
715 | +1254 | ||
716 | +1227 | ||
717 | +1235 | ||
718 | +1269 | ||
719 | +1258 | ||
720 | +1290 | ||
721 | +1275 | ||
722 | +1262 | ||
723 | +1252 | ||
724 | +1248 | ||
725 | +1272 | ||
726 | +1246 | ||
727 | +1225 | ||
728 | +1245 | ||
729 | +1277 | ||
730 | +1298 | ||
731 | +1288 | ||
732 | +1271 | ||
733 | +1265 | ||
734 | +1286 | ||
735 | +1260 | ||
736 | +1266 | ||
737 | +1296 | ||
738 | +1280 | ||
739 | +1285 | ||
740 | +1293 | ||
741 | +1276 | ||
742 | +1287 | ||
743 | +1289 | ||
744 | +1261 | ||
745 | +1264 | ||
746 | +1295 | ||
747 | +1291 | ||
748 | +1283 | ||
749 | +1311 | ||
750 | +1303 | ||
751 | +1330 | ||
752 | +1315 | ||
753 | +1300 | ||
754 | +1333 | ||
755 | +1307 | ||
756 | +1325 | ||
757 | +1334 | ||
758 | +1316 | ||
759 | +1314 | ||
760 | +1317 | ||
761 | +1310 | ||
762 | +1329 | ||
763 | +1324 | ||
764 | +1339 | ||
765 | +1346 | ||
766 | +1342 | ||
767 | +1352 | ||
768 | +1321 | ||
769 | +1376 | ||
770 | +1366 | ||
771 | +1308 | ||
772 | +1345 | ||
773 | +1348 | ||
774 | +1386 | ||
775 | +1383 | ||
776 | +1372 | ||
777 | +1367 | ||
778 | +1400 | ||
779 | +1382 | ||
780 | +1375 | ||
781 | +1392 | ||
782 | +1380 | ||
783 | +1371 | ||
784 | +1393 | ||
785 | +1389 | ||
786 | +1353 | ||
787 | +1387 | ||
788 | +1374 | ||
789 | +1379 | ||
790 | +1381 | ||
791 | +1359 | ||
792 | +1360 | ||
793 | +1396 | ||
794 | +1399 | ||
795 | +1365 | ||
796 | +1424 | ||
797 | +1373 | ||
798 | +1411 | ||
799 | +1401 | ||
800 | +1397 | ||
801 | +1395 | ||
802 | +1412 | ||
803 | +1394 | ||
804 | +1368 | ||
805 | +1423 | ||
806 | +1391 | ||
807 | +1435 | ||
808 | +1409 | ||
809 | +1443 | ||
810 | +1402 | ||
811 | +1425 | ||
812 | +1415 | ||
813 | +1421 | ||
814 | +1426 | ||
815 | +1433 | ||
816 | +1420 | ||
817 | +1452 | ||
818 | +1436 | ||
819 | +1430 | ||
820 | +1408 | ||
821 | +1458 | ||
822 | +1429 | ||
823 | +1453 | ||
824 | +1454 | ||
825 | +1447 | ||
826 | +1472 | ||
827 | +1486 | ||
828 | +1468 | ||
829 | +1461 | ||
830 | +1467 | ||
831 | +1484 | ||
832 | +1457 | ||
833 | +1444 | ||
834 | +1450 | ||
835 | +1451 | ||
836 | +1459 | ||
837 | +1462 | ||
838 | +1449 | ||
839 | +1476 | ||
840 | +1470 | ||
841 | +1471 | ||
842 | +1498 | ||
843 | +1488 | ||
844 | +1442 | ||
845 | +1480 | ||
846 | +1456 | ||
847 | +1466 | ||
848 | +1505 | ||
849 | +1517 | ||
850 | +1464 | ||
851 | +1503 | ||
852 | +1490 | ||
853 | +1519 | ||
854 | +1481 | ||
855 | +1493 | ||
856 | +1463 | ||
857 | +1532 | ||
858 | +1487 | ||
859 | +1501 | ||
860 | +1500 | ||
861 | +1495 | ||
862 | +1509 | ||
863 | +1535 | ||
864 | +1506 | ||
865 | +1521 | ||
866 | +1580 | ||
867 | +1540 | ||
868 | +1502 | ||
869 | +1520 | ||
870 | +1496 | ||
871 | +1569 | ||
872 | +1515 | ||
873 | +1489 | ||
874 | +1507 | ||
875 | +1527 | ||
876 | +1545 | ||
877 | +1560 | ||
878 | +1510 | ||
879 | +1514 | ||
880 | +1526 | ||
881 | +1594 | ||
882 | +1511 | ||
883 | +1572 | ||
884 | +1548 | ||
885 | +1584 | ||
886 | +1556 | ||
887 | +1588 | ||
888 | +1628 | ||
889 | +1555 | ||
890 | +1568 | ||
891 | +1550 | ||
892 | +1622 | ||
893 | +1563 | ||
894 | +1603 | ||
895 | +1616 | ||
896 | +1576 | ||
897 | +1549 | ||
898 | +1537 | ||
899 | +1593 | ||
900 | +1618 | ||
901 | +1645 | ||
902 | +1624 | ||
903 | +1617 | ||
904 | +1634 | ||
905 | +1595 | ||
906 | +1597 | ||
907 | +1590 | ||
908 | +1632 | ||
909 | +1575 | ||
910 | +1559 | ||
911 | +1625 | ||
912 | +1615 | ||
913 | +1591 | ||
914 | +1630 | ||
915 | +1608 | ||
916 | +1621 | ||
917 | +1589 | ||
918 | +1646 | ||
919 | +1643 | ||
920 | +1652 | ||
921 | +1627 | ||
922 | +1611 | ||
923 | +1626 | ||
924 | +1613 | ||
925 | +1639 | ||
926 | +1655 | ||
927 | +1620 | ||
928 | +1602 | ||
929 | +1651 | ||
930 | +1653 | ||
931 | +1669 | ||
932 | +1638 | ||
933 | +1696 | ||
934 | +1649 | ||
935 | +1675 | ||
936 | +1660 | ||
937 | +1683 | ||
938 | +1666 | ||
939 | +1671 | ||
940 | +1703 | ||
941 | +1716 | ||
942 | +1637 | ||
943 | +1672 | ||
944 | +1676 | ||
945 | +1692 | ||
946 | +1711 | ||
947 | +1680 | ||
948 | +1641 | ||
949 | +1688 | ||
950 | +1708 | ||
951 | +1704 | ||
952 | +1690 | ||
953 | +1674 | ||
954 | +1718 | ||
955 | +1699 | ||
956 | +1723 | ||
957 | +1756 | ||
958 | +1700 | ||
959 | +1662 | ||
960 | +1715 | ||
961 | +1657 | ||
962 | +1733 | ||
963 | +1728 | ||
964 | +1670 | ||
965 | +1712 | ||
966 | +1685 | ||
967 | +1724 | ||
968 | +1735 | ||
969 | +1714 | ||
970 | +1730 | ||
971 | +1747 | ||
972 | +1656 | ||
973 | +1737 | ||
974 | +1705 | ||
975 | +1693 | ||
976 | +1713 | ||
977 | +1689 | ||
978 | +1753 | ||
979 | +1739 | ||
980 | +1721 | ||
981 | +1725 | ||
982 | +1749 | ||
983 | +1732 | ||
984 | +1743 | ||
985 | +1731 | ||
986 | +1767 | ||
987 | +1738 | ||
988 | +1831 | ||
989 | +1771 | ||
990 | +1726 | ||
991 | +1746 | ||
992 | +1776 | ||
993 | +1775 | ||
994 | +1799 | ||
995 | +1774 | ||
996 | +1780 | ||
997 | +1781 | ||
998 | +1769 | ||
999 | +1805 | ||
1000 | +1788 | ||
1001 | +1801 |
yt8m/train.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Binary for training Tensorflow models on the YouTube-8M dataset.""" | ||
15 | + | ||
16 | +import json | ||
17 | +import os | ||
18 | +import time | ||
19 | + | ||
20 | +import eval_util | ||
21 | +import export_model | ||
22 | +import losses | ||
23 | +import frame_level_models | ||
24 | +import video_level_models | ||
25 | +import readers | ||
26 | +import tensorflow as tf | ||
27 | +import tensorflow.contrib.slim as slim | ||
28 | +from tensorflow.python.lib.io import file_io | ||
29 | +from tensorflow import app | ||
30 | +from tensorflow import flags | ||
31 | +from tensorflow import gfile | ||
32 | +from tensorflow import logging | ||
33 | +from tensorflow.python.client import device_lib | ||
34 | +import utils | ||
35 | + | ||
36 | +FLAGS = flags.FLAGS | ||
37 | + | ||
38 | +if __name__ == "__main__": | ||
39 | + # Dataset flags. | ||
40 | + flags.DEFINE_string("train_dir", "/tmp/yt8m_model/", | ||
41 | + "The directory to save the model files in.") | ||
42 | + flags.DEFINE_string( | ||
43 | + "train_data_pattern", "", | ||
44 | + "File glob for the training dataset. If the files refer to Frame Level " | ||
45 | + "features (i.e. tensorflow.SequenceExample), then set --reader_type " | ||
46 | + "format. The (Sequence)Examples are expected to have 'rgb' byte array " | ||
47 | + "sequence feature as well as a 'labels' int64 context feature.") | ||
48 | + flags.DEFINE_string("feature_names", "mean_rgb", "Name of the feature " | ||
49 | + "to use for training.") | ||
50 | + flags.DEFINE_string("feature_sizes", "1024", "Length of the feature vectors.") | ||
51 | + | ||
52 | + # Model flags. | ||
53 | + flags.DEFINE_bool( | ||
54 | + "frame_features", False, | ||
55 | + "If set, then --train_data_pattern must be frame-level features. " | ||
56 | + "Otherwise, --train_data_pattern must be aggregated video-level " | ||
57 | + "features. The model must also be set appropriately (i.e. to read 3D " | ||
58 | + "batches VS 4D batches.") | ||
59 | + flags.DEFINE_bool( | ||
60 | + "segment_labels", False, | ||
61 | + "If set, then --train_data_pattern must be frame-level features (but with" | ||
62 | + " segment_labels). Otherwise, --train_data_pattern must be aggregated " | ||
63 | + "video-level features. The model must also be set appropriately (i.e. to " | ||
64 | + "read 3D batches VS 4D batches.") | ||
65 | + flags.DEFINE_string( | ||
66 | + "model", "LogisticModel", | ||
67 | + "Which architecture to use for the model. Models are defined " | ||
68 | + "in models.py.") | ||
69 | + flags.DEFINE_bool( | ||
70 | + "start_new_model", False, | ||
71 | + "If set, this will not resume from a checkpoint and will instead create a" | ||
72 | + " new model instance.") | ||
73 | + | ||
74 | + # Training flags. | ||
75 | + flags.DEFINE_integer( | ||
76 | + "num_gpu", 1, "The maximum number of GPU devices to use for training. " | ||
77 | + "Flag only applies if GPUs are installed") | ||
78 | + flags.DEFINE_integer("batch_size", 1024, | ||
79 | + "How many examples to process per batch for training.") | ||
80 | + flags.DEFINE_string("label_loss", "CrossEntropyLoss", | ||
81 | + "Which loss function to use for training the model.") | ||
82 | + flags.DEFINE_float( | ||
83 | + "regularization_penalty", 1.0, | ||
84 | + "How much weight to give to the regularization loss (the label loss has " | ||
85 | + "a weight of 1).") | ||
86 | + flags.DEFINE_float("base_learning_rate", 0.01, | ||
87 | + "Which learning rate to start with.") | ||
88 | + flags.DEFINE_float( | ||
89 | + "learning_rate_decay", 0.95, | ||
90 | + "Learning rate decay factor to be applied every " | ||
91 | + "learning_rate_decay_examples.") | ||
92 | + flags.DEFINE_float( | ||
93 | + "learning_rate_decay_examples", 4000000, | ||
94 | + "Multiply current learning rate by learning_rate_decay " | ||
95 | + "every learning_rate_decay_examples.") | ||
96 | + flags.DEFINE_integer( | ||
97 | + "num_epochs", 5, "How many passes to make over the dataset before " | ||
98 | + "halting training.") | ||
99 | + flags.DEFINE_integer( | ||
100 | + "max_steps", None, | ||
101 | + "The maximum number of iterations of the training loop.") | ||
102 | + flags.DEFINE_integer( | ||
103 | + "export_model_steps", 1000, | ||
104 | + "The period, in number of steps, with which the model " | ||
105 | + "is exported for batch prediction.") | ||
106 | + | ||
107 | + # Other flags. | ||
108 | + flags.DEFINE_integer("num_readers", 8, | ||
109 | + "How many threads to use for reading input files.") | ||
110 | + flags.DEFINE_string("optimizer", "AdamOptimizer", | ||
111 | + "What optimizer class to use.") | ||
112 | + flags.DEFINE_float("clip_gradient_norm", 1.0, "Norm to clip gradients to.") | ||
113 | + flags.DEFINE_bool( | ||
114 | + "log_device_placement", False, | ||
115 | + "Whether to write the device on which every op will run into the " | ||
116 | + "logs on startup.") | ||
117 | + | ||
118 | + | ||
119 | +def validate_class_name(flag_value, category, modules, expected_superclass): | ||
120 | + """Checks that the given string matches a class of the expected type. | ||
121 | + | ||
122 | + Args: | ||
123 | + flag_value: A string naming the class to instantiate. | ||
124 | + category: A string used further describe the class in error messages (e.g. | ||
125 | + 'model', 'reader', 'loss'). | ||
126 | + modules: A list of modules to search for the given class. | ||
127 | + expected_superclass: A class that the given class should inherit from. | ||
128 | + | ||
129 | + Raises: | ||
130 | + FlagsError: If the given class could not be found or if the first class | ||
131 | + found with that name doesn't inherit from the expected superclass. | ||
132 | + | ||
133 | + Returns: | ||
134 | + True if a class was found that matches the given constraints. | ||
135 | + """ | ||
136 | + candidates = [getattr(module, flag_value, None) for module in modules] | ||
137 | + for candidate in candidates: | ||
138 | + if not candidate: | ||
139 | + continue | ||
140 | + if not issubclass(candidate, expected_superclass): | ||
141 | + raise flags.FlagsError( | ||
142 | + "%s '%s' doesn't inherit from %s." % | ||
143 | + (category, flag_value, expected_superclass.__name__)) | ||
144 | + return True | ||
145 | + raise flags.FlagsError("Unable to find %s '%s'." % (category, flag_value)) | ||
146 | + | ||
147 | + | ||
148 | +def get_input_data_tensors(reader, | ||
149 | + data_pattern, | ||
150 | + batch_size=1000, | ||
151 | + num_epochs=None, | ||
152 | + num_readers=1): | ||
153 | + """Creates the section of the graph which reads the training data. | ||
154 | + | ||
155 | + Args: | ||
156 | + reader: A class which parses the training data. | ||
157 | + data_pattern: A 'glob' style path to the data files. | ||
158 | + batch_size: How many examples to process at a time. | ||
159 | + num_epochs: How many passes to make over the training data. Set to 'None' to | ||
160 | + run indefinitely. | ||
161 | + num_readers: How many I/O threads to use. | ||
162 | + | ||
163 | + Returns: | ||
164 | + A tuple containing the features tensor, labels tensor, and optionally a | ||
165 | + tensor containing the number of frames per video. The exact dimensions | ||
166 | + depend on the reader being used. | ||
167 | + | ||
168 | + Raises: | ||
169 | + IOError: If no files matching the given pattern were found. | ||
170 | + """ | ||
171 | + logging.info("Using batch size of " + str(batch_size) + " for training.") | ||
172 | + with tf.name_scope("train_input"): | ||
173 | + files = gfile.Glob(data_pattern) | ||
174 | + if not files: | ||
175 | + raise IOError("Unable to find training files. data_pattern='" + | ||
176 | + data_pattern + "'.") | ||
177 | + logging.info("Number of training files: %s.", str(len(files))) | ||
178 | + filename_queue = tf.train.string_input_producer(files, | ||
179 | + num_epochs=num_epochs, | ||
180 | + shuffle=True) | ||
181 | + training_data = [ | ||
182 | + reader.prepare_reader(filename_queue) for _ in range(num_readers) | ||
183 | + ] | ||
184 | + | ||
185 | + return tf.train.shuffle_batch_join(training_data, | ||
186 | + batch_size=batch_size, | ||
187 | + capacity=batch_size * 5, | ||
188 | + min_after_dequeue=batch_size, | ||
189 | + allow_smaller_final_batch=True, | ||
190 | + enqueue_many=True) | ||
191 | + | ||
192 | + | ||
193 | +def find_class_by_name(name, modules): | ||
194 | + """Searches the provided modules for the named class and returns it.""" | ||
195 | + modules = [getattr(module, name, None) for module in modules] | ||
196 | + return next(a for a in modules if a) | ||
197 | + | ||
198 | + | ||
199 | +def build_graph(reader, | ||
200 | + model, | ||
201 | + train_data_pattern, | ||
202 | + label_loss_fn=losses.CrossEntropyLoss(), | ||
203 | + batch_size=1000, | ||
204 | + base_learning_rate=0.01, | ||
205 | + learning_rate_decay_examples=1000000, | ||
206 | + learning_rate_decay=0.95, | ||
207 | + optimizer_class=tf.train.AdamOptimizer, | ||
208 | + clip_gradient_norm=1.0, | ||
209 | + regularization_penalty=1, | ||
210 | + num_readers=1, | ||
211 | + num_epochs=None): | ||
212 | + """Creates the Tensorflow graph. | ||
213 | + | ||
214 | + This will only be called once in the life of | ||
215 | + a training model, because after the graph is created the model will be | ||
216 | + restored from a meta graph file rather than being recreated. | ||
217 | + | ||
218 | + Args: | ||
219 | + reader: The data file reader. It should inherit from BaseReader. | ||
220 | + model: The core model (e.g. logistic or neural net). It should inherit from | ||
221 | + BaseModel. | ||
222 | + train_data_pattern: glob path to the training data files. | ||
223 | + label_loss_fn: What kind of loss to apply to the model. It should inherit | ||
224 | + from BaseLoss. | ||
225 | + batch_size: How many examples to process at a time. | ||
226 | + base_learning_rate: What learning rate to initialize the optimizer with. | ||
227 | + optimizer_class: Which optimization algorithm to use. | ||
228 | + clip_gradient_norm: Magnitude of the gradient to clip to. | ||
229 | + regularization_penalty: How much weight to give the regularization loss | ||
230 | + compared to the label loss. | ||
231 | + num_readers: How many threads to use for I/O operations. | ||
232 | + num_epochs: How many passes to make over the data. 'None' means an unlimited | ||
233 | + number of passes. | ||
234 | + """ | ||
235 | + | ||
236 | + global_step = tf.Variable(0, trainable=False, name="global_step") | ||
237 | + | ||
238 | + local_device_protos = device_lib.list_local_devices() | ||
239 | + gpus = [x.name for x in local_device_protos if x.device_type == "GPU"] | ||
240 | + gpus = gpus[:FLAGS.num_gpu] | ||
241 | + num_gpus = len(gpus) | ||
242 | + | ||
243 | + if num_gpus > 0: | ||
244 | + logging.info("Using the following GPUs to train: " + str(gpus)) | ||
245 | + num_towers = num_gpus | ||
246 | + device_string = "/gpu:%d" | ||
247 | + else: | ||
248 | + logging.info("No GPUs found. Training on CPU.") | ||
249 | + num_towers = 1 | ||
250 | + device_string = "/cpu:%d" | ||
251 | + | ||
252 | + learning_rate = tf.train.exponential_decay(base_learning_rate, | ||
253 | + global_step * batch_size * | ||
254 | + num_towers, | ||
255 | + learning_rate_decay_examples, | ||
256 | + learning_rate_decay, | ||
257 | + staircase=True) | ||
258 | + tf.summary.scalar("learning_rate", learning_rate) | ||
259 | + | ||
260 | + optimizer = optimizer_class(learning_rate) | ||
261 | + input_data_dict = (get_input_data_tensors(reader, | ||
262 | + train_data_pattern, | ||
263 | + batch_size=batch_size * num_towers, | ||
264 | + num_readers=num_readers, | ||
265 | + num_epochs=num_epochs)) | ||
266 | + model_input_raw = input_data_dict["video_matrix"] | ||
267 | + labels_batch = input_data_dict["labels"] | ||
268 | + num_frames = input_data_dict["num_frames"] | ||
269 | + print("model_input_shape, ", model_input_raw.shape) | ||
270 | + tf.summary.histogram("model/input_raw", model_input_raw) | ||
271 | + | ||
272 | + feature_dim = len(model_input_raw.get_shape()) - 1 | ||
273 | + | ||
274 | + model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) | ||
275 | + | ||
276 | + tower_inputs = tf.split(model_input, num_towers) | ||
277 | + tower_labels = tf.split(labels_batch, num_towers) | ||
278 | + tower_num_frames = tf.split(num_frames, num_towers) | ||
279 | + tower_gradients = [] | ||
280 | + tower_predictions = [] | ||
281 | + tower_label_losses = [] | ||
282 | + tower_reg_losses = [] | ||
283 | + for i in range(num_towers): | ||
284 | + # For some reason these 'with' statements can't be combined onto the same | ||
285 | + # line. They have to be nested. | ||
286 | + with tf.device(device_string % i): | ||
287 | + with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)): | ||
288 | + with (slim.arg_scope([slim.model_variable, slim.variable], | ||
289 | + device="/cpu:0" if num_gpus != 1 else "/gpu:0")): | ||
290 | + result = model.create_model(tower_inputs[i], | ||
291 | + num_frames=tower_num_frames[i], | ||
292 | + vocab_size=reader.num_classes, | ||
293 | + labels=tower_labels[i]) | ||
294 | + for variable in slim.get_model_variables(): | ||
295 | + tf.summary.histogram(variable.op.name, variable) | ||
296 | + | ||
297 | + predictions = result["predictions"] | ||
298 | + tower_predictions.append(predictions) | ||
299 | + | ||
300 | + if "loss" in result.keys(): | ||
301 | + label_loss = result["loss"] | ||
302 | + else: | ||
303 | + label_loss = label_loss_fn.calculate_loss(predictions, | ||
304 | + tower_labels[i]) | ||
305 | + | ||
306 | + if "regularization_loss" in result.keys(): | ||
307 | + reg_loss = result["regularization_loss"] | ||
308 | + else: | ||
309 | + reg_loss = tf.constant(0.0) | ||
310 | + | ||
311 | + reg_losses = tf.losses.get_regularization_losses() | ||
312 | + if reg_losses: | ||
313 | + reg_loss += tf.add_n(reg_losses) | ||
314 | + | ||
315 | + tower_reg_losses.append(reg_loss) | ||
316 | + | ||
317 | + # Adds update_ops (e.g., moving average updates in batch normalization) as | ||
318 | + # a dependency to the train_op. | ||
319 | + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | ||
320 | + if "update_ops" in result.keys(): | ||
321 | + update_ops += result["update_ops"] | ||
322 | + if update_ops: | ||
323 | + with tf.control_dependencies(update_ops): | ||
324 | + barrier = tf.no_op(name="gradient_barrier") | ||
325 | + with tf.control_dependencies([barrier]): | ||
326 | + label_loss = tf.identity(label_loss) | ||
327 | + | ||
328 | + tower_label_losses.append(label_loss) | ||
329 | + | ||
330 | + # Incorporate the L2 weight penalties etc. | ||
331 | + final_loss = regularization_penalty * reg_loss + label_loss | ||
332 | + gradients = optimizer.compute_gradients( | ||
333 | + final_loss, colocate_gradients_with_ops=False) | ||
334 | + tower_gradients.append(gradients) | ||
335 | + label_loss = tf.reduce_mean(tf.stack(tower_label_losses)) | ||
336 | + tf.summary.scalar("label_loss", label_loss) | ||
337 | + if regularization_penalty != 0: | ||
338 | + reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses)) | ||
339 | + tf.summary.scalar("reg_loss", reg_loss) | ||
340 | + merged_gradients = utils.combine_gradients(tower_gradients) | ||
341 | + | ||
342 | + if clip_gradient_norm > 0: | ||
343 | + with tf.name_scope("clip_grads"): | ||
344 | + merged_gradients = utils.clip_gradient_norms(merged_gradients, | ||
345 | + clip_gradient_norm) | ||
346 | + | ||
347 | + train_op = optimizer.apply_gradients(merged_gradients, | ||
348 | + global_step=global_step) | ||
349 | + | ||
350 | + tf.add_to_collection("global_step", global_step) | ||
351 | + tf.add_to_collection("loss", label_loss) | ||
352 | + tf.add_to_collection("predictions", tf.concat(tower_predictions, 0)) | ||
353 | + tf.add_to_collection("input_batch_raw", model_input_raw) | ||
354 | + tf.add_to_collection("input_batch", model_input) | ||
355 | + tf.add_to_collection("num_frames", num_frames) | ||
356 | + tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32)) | ||
357 | + tf.add_to_collection("train_op", train_op) | ||
358 | + | ||
359 | + | ||
360 | +class Trainer(object): | ||
361 | + """A Trainer to train a Tensorflow graph.""" | ||
362 | + | ||
363 | + def __init__(self, | ||
364 | + cluster, | ||
365 | + task, | ||
366 | + train_dir, | ||
367 | + model, | ||
368 | + reader, | ||
369 | + model_exporter, | ||
370 | + log_device_placement=True, | ||
371 | + max_steps=None, | ||
372 | + export_model_steps=1000): | ||
373 | + """"Creates a Trainer. | ||
374 | + | ||
375 | + Args: | ||
376 | + cluster: A tf.train.ClusterSpec if the execution is distributed. None | ||
377 | + otherwise. | ||
378 | + task: A TaskSpec describing the job type and the task index. | ||
379 | + """ | ||
380 | + | ||
381 | + self.cluster = cluster | ||
382 | + self.task = task | ||
383 | + self.is_master = (task.type == "master" and task.index == 0) | ||
384 | + self.train_dir = train_dir | ||
385 | + self.config = tf.ConfigProto(allow_soft_placement=True, | ||
386 | + log_device_placement=log_device_placement) | ||
387 | + self.config.gpu_options.allow_growth = True | ||
388 | + self.model = model | ||
389 | + self.reader = reader | ||
390 | + self.model_exporter = model_exporter | ||
391 | + self.max_steps = max_steps | ||
392 | + self.max_steps_reached = False | ||
393 | + self.export_model_steps = export_model_steps | ||
394 | + self.last_model_export_step = 0 | ||
395 | + | ||
396 | + | ||
397 | +# if self.is_master and self.task.index > 0: | ||
398 | +# raise StandardError("%s: Only one replica of master expected", | ||
399 | +# task_as_string(self.task)) | ||
400 | + | ||
401 | + def run(self, start_new_model=False): | ||
402 | + """Performs training on the currently defined Tensorflow graph. | ||
403 | + | ||
404 | + Returns: | ||
405 | + A tuple of the training Hit@1 and the training PERR. | ||
406 | + """ | ||
407 | + if self.is_master and start_new_model: | ||
408 | + self.remove_training_directory(self.train_dir) | ||
409 | + | ||
410 | + if not os.path.exists(self.train_dir): | ||
411 | + os.makedirs(self.train_dir) | ||
412 | + | ||
413 | + model_flags_dict = { | ||
414 | + "model": FLAGS.model, | ||
415 | + "feature_sizes": FLAGS.feature_sizes, | ||
416 | + "feature_names": FLAGS.feature_names, | ||
417 | + "frame_features": FLAGS.frame_features, | ||
418 | + "label_loss": FLAGS.label_loss, | ||
419 | + } | ||
420 | + flags_json_path = os.path.join(FLAGS.train_dir, "model_flags.json") | ||
421 | + if file_io.file_exists(flags_json_path): | ||
422 | + existing_flags = json.load(file_io.FileIO(flags_json_path, mode="r")) | ||
423 | + if existing_flags != model_flags_dict: | ||
424 | + logging.error( | ||
425 | + "Model flags do not match existing file %s. Please " | ||
426 | + "delete the file, change --train_dir, or pass flag " | ||
427 | + "--start_new_model", flags_json_path) | ||
428 | + logging.error("Ran model with flags: %s", str(model_flags_dict)) | ||
429 | + logging.error("Previously ran with flags: %s", str(existing_flags)) | ||
430 | + exit(1) | ||
431 | + else: | ||
432 | + # Write the file. | ||
433 | + with file_io.FileIO(flags_json_path, mode="w") as fout: | ||
434 | + fout.write(json.dumps(model_flags_dict)) | ||
435 | + | ||
436 | + target, device_fn = self.start_server_if_distributed() | ||
437 | + | ||
438 | + meta_filename = self.get_meta_filename(start_new_model, self.train_dir) | ||
439 | + | ||
440 | + with tf.Graph().as_default() as graph: | ||
441 | + if meta_filename: | ||
442 | + saver = self.recover_model(meta_filename) | ||
443 | + | ||
444 | + with tf.device(device_fn): | ||
445 | + if not meta_filename: | ||
446 | + saver = self.build_model(self.model, self.reader) | ||
447 | + | ||
448 | + global_step = tf.get_collection("global_step")[0] | ||
449 | + loss = tf.get_collection("loss")[0] | ||
450 | + predictions = tf.get_collection("predictions")[0] | ||
451 | + labels = tf.get_collection("labels")[0] | ||
452 | + train_op = tf.get_collection("train_op")[0] | ||
453 | + init_op = tf.global_variables_initializer() | ||
454 | + | ||
455 | + sv = tf.train.Supervisor(graph, | ||
456 | + logdir=self.train_dir, | ||
457 | + init_op=init_op, | ||
458 | + is_chief=self.is_master, | ||
459 | + global_step=global_step, | ||
460 | + save_model_secs=15 * 60, | ||
461 | + save_summaries_secs=120, | ||
462 | + saver=saver) | ||
463 | + | ||
464 | + logging.info("%s: Starting managed session.", task_as_string(self.task)) | ||
465 | + with sv.managed_session(target, config=self.config) as sess: | ||
466 | + try: | ||
467 | + logging.info("%s: Entering training loop.", task_as_string(self.task)) | ||
468 | + while (not sv.should_stop()) and (not self.max_steps_reached): | ||
469 | + batch_start_time = time.time() | ||
470 | + _, global_step_val, loss_val, predictions_val, labels_val = sess.run( | ||
471 | + [train_op, global_step, loss, predictions, labels]) | ||
472 | + seconds_per_batch = time.time() - batch_start_time | ||
473 | + examples_per_second = labels_val.shape[0] / seconds_per_batch | ||
474 | + | ||
475 | + if self.max_steps and self.max_steps <= global_step_val: | ||
476 | + self.max_steps_reached = True | ||
477 | + | ||
478 | + if self.is_master and global_step_val % 10 == 0 and self.train_dir: | ||
479 | + eval_start_time = time.time() | ||
480 | + hit_at_one = eval_util.calculate_hit_at_one(predictions_val, | ||
481 | + labels_val) | ||
482 | + perr = eval_util.calculate_precision_at_equal_recall_rate( | ||
483 | + predictions_val, labels_val) | ||
484 | + gap = eval_util.calculate_gap(predictions_val, labels_val) | ||
485 | + eval_end_time = time.time() | ||
486 | + eval_time = eval_end_time - eval_start_time | ||
487 | + | ||
488 | + logging.info("training step " + str(global_step_val) + " | Loss: " + | ||
489 | + ("%.2f" % loss_val) + " Examples/sec: " + | ||
490 | + ("%.2f" % examples_per_second) + " | Hit@1: " + | ||
491 | + ("%.2f" % hit_at_one) + " PERR: " + ("%.2f" % perr) + | ||
492 | + " GAP: " + ("%.2f" % gap)) | ||
493 | + | ||
494 | + sv.summary_writer.add_summary( | ||
495 | + utils.MakeSummary("model/Training_Hit@1", hit_at_one), | ||
496 | + global_step_val) | ||
497 | + sv.summary_writer.add_summary( | ||
498 | + utils.MakeSummary("model/Training_Perr", perr), global_step_val) | ||
499 | + sv.summary_writer.add_summary( | ||
500 | + utils.MakeSummary("model/Training_GAP", gap), global_step_val) | ||
501 | + sv.summary_writer.add_summary( | ||
502 | + utils.MakeSummary("global_step/Examples/Second", | ||
503 | + examples_per_second), global_step_val) | ||
504 | + sv.summary_writer.flush() | ||
505 | + | ||
506 | + # Exporting the model every x steps | ||
507 | + time_to_export = ((self.last_model_export_step == 0) or | ||
508 | + (global_step_val - self.last_model_export_step >= | ||
509 | + self.export_model_steps)) | ||
510 | + | ||
511 | + if self.is_master and time_to_export: | ||
512 | + self.export_model(global_step_val, sv.saver, sv.save_path, sess) | ||
513 | + self.last_model_export_step = global_step_val | ||
514 | + else: | ||
515 | + logging.info("training step " + str(global_step_val) + " | Loss: " + | ||
516 | + ("%.2f" % loss_val) + " Examples/sec: " + | ||
517 | + ("%.2f" % examples_per_second)) | ||
518 | + except tf.errors.OutOfRangeError: | ||
519 | + logging.info("%s: Done training -- epoch limit reached.", | ||
520 | + task_as_string(self.task)) | ||
521 | + | ||
522 | + logging.info("%s: Exited training loop.", task_as_string(self.task)) | ||
523 | + sv.Stop() | ||
524 | + | ||
525 | + def export_model(self, global_step_val, saver, save_path, session): | ||
526 | + | ||
527 | + # If the model has already been exported at this step, return. | ||
528 | + if global_step_val == self.last_model_export_step: | ||
529 | + return | ||
530 | + | ||
531 | + saver.save(session, save_path, global_step_val) | ||
532 | + | ||
533 | + def start_server_if_distributed(self): | ||
534 | + """Starts a server if the execution is distributed.""" | ||
535 | + | ||
536 | + if self.cluster: | ||
537 | + logging.info("%s: Starting trainer within cluster %s.", | ||
538 | + task_as_string(self.task), self.cluster.as_dict()) | ||
539 | + server = start_server(self.cluster, self.task) | ||
540 | + target = server.target | ||
541 | + device_fn = tf.train.replica_device_setter( | ||
542 | + ps_device="/job:ps", | ||
543 | + worker_device="/job:%s/task:%d" % (self.task.type, self.task.index), | ||
544 | + cluster=self.cluster) | ||
545 | + else: | ||
546 | + target = "" | ||
547 | + device_fn = "" | ||
548 | + return (target, device_fn) | ||
549 | + | ||
550 | + def remove_training_directory(self, train_dir): | ||
551 | + """Removes the training directory.""" | ||
552 | + try: | ||
553 | + logging.info("%s: Removing existing train directory.", | ||
554 | + task_as_string(self.task)) | ||
555 | + gfile.DeleteRecursively(train_dir) | ||
556 | + except: | ||
557 | + logging.error( | ||
558 | + "%s: Failed to delete directory " + train_dir + | ||
559 | + " when starting a new model. Please delete it manually and" + | ||
560 | + " try again.", task_as_string(self.task)) | ||
561 | + | ||
562 | + def get_meta_filename(self, start_new_model, train_dir): | ||
563 | + if start_new_model: | ||
564 | + logging.info("%s: Flag 'start_new_model' is set. Building a new model.", | ||
565 | + task_as_string(self.task)) | ||
566 | + return None | ||
567 | + | ||
568 | + latest_checkpoint = tf.train.latest_checkpoint(train_dir) | ||
569 | + if not latest_checkpoint: | ||
570 | + logging.info("%s: No checkpoint file found. Building a new model.", | ||
571 | + task_as_string(self.task)) | ||
572 | + return None | ||
573 | + | ||
574 | + meta_filename = latest_checkpoint + ".meta" | ||
575 | + if not gfile.Exists(meta_filename): | ||
576 | + logging.info("%s: No meta graph file found. Building a new model.", | ||
577 | + task_as_string(self.task)) | ||
578 | + return None | ||
579 | + else: | ||
580 | + return meta_filename | ||
581 | + | ||
582 | + def recover_model(self, meta_filename): | ||
583 | + logging.info("%s: Restoring from meta graph file %s", | ||
584 | + task_as_string(self.task), meta_filename) | ||
585 | + return tf.train.import_meta_graph(meta_filename) | ||
586 | + | ||
587 | + def build_model(self, model, reader): | ||
588 | + """Find the model and build the graph.""" | ||
589 | + | ||
590 | + label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])() | ||
591 | + optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train]) | ||
592 | + | ||
593 | + build_graph(reader=reader, | ||
594 | + model=model, | ||
595 | + optimizer_class=optimizer_class, | ||
596 | + clip_gradient_norm=FLAGS.clip_gradient_norm, | ||
597 | + train_data_pattern=FLAGS.train_data_pattern, | ||
598 | + label_loss_fn=label_loss_fn, | ||
599 | + base_learning_rate=FLAGS.base_learning_rate, | ||
600 | + learning_rate_decay=FLAGS.learning_rate_decay, | ||
601 | + learning_rate_decay_examples=FLAGS.learning_rate_decay_examples, | ||
602 | + regularization_penalty=FLAGS.regularization_penalty, | ||
603 | + num_readers=FLAGS.num_readers, | ||
604 | + batch_size=FLAGS.batch_size, | ||
605 | + num_epochs=FLAGS.num_epochs) | ||
606 | + | ||
607 | + return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=0.25) | ||
608 | + | ||
609 | + | ||
610 | +def get_reader(): | ||
611 | + # Convert feature_names and feature_sizes to lists of values. | ||
612 | + feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( | ||
613 | + FLAGS.feature_names, FLAGS.feature_sizes) | ||
614 | + | ||
615 | + if FLAGS.frame_features: | ||
616 | + reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, | ||
617 | + feature_sizes=feature_sizes, | ||
618 | + segment_labels=FLAGS.segment_labels) | ||
619 | + else: | ||
620 | + reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, | ||
621 | + feature_sizes=feature_sizes) | ||
622 | + | ||
623 | + return reader | ||
624 | + | ||
625 | + | ||
626 | +class ParameterServer(object): | ||
627 | + """A parameter server to serve variables in a distributed execution.""" | ||
628 | + | ||
629 | + def __init__(self, cluster, task): | ||
630 | + """Creates a ParameterServer. | ||
631 | + | ||
632 | + Args: | ||
633 | + cluster: A tf.train.ClusterSpec if the execution is distributed. None | ||
634 | + otherwise. | ||
635 | + task: A TaskSpec describing the job type and the task index. | ||
636 | + """ | ||
637 | + | ||
638 | + self.cluster = cluster | ||
639 | + self.task = task | ||
640 | + | ||
641 | + def run(self): | ||
642 | + """Starts the parameter server.""" | ||
643 | + | ||
644 | + logging.info("%s: Starting parameter server within cluster %s.", | ||
645 | + task_as_string(self.task), self.cluster.as_dict()) | ||
646 | + server = start_server(self.cluster, self.task) | ||
647 | + server.join() | ||
648 | + | ||
649 | + | ||
650 | +def start_server(cluster, task): | ||
651 | + """Creates a Server. | ||
652 | + | ||
653 | + Args: | ||
654 | + cluster: A tf.train.ClusterSpec if the execution is distributed. None | ||
655 | + otherwise. | ||
656 | + task: A TaskSpec describing the job type and the task index. | ||
657 | + """ | ||
658 | + | ||
659 | + if not task.type: | ||
660 | + raise ValueError("%s: The task type must be specified." % | ||
661 | + task_as_string(task)) | ||
662 | + if task.index is None: | ||
663 | + raise ValueError("%s: The task index must be specified." % | ||
664 | + task_as_string(task)) | ||
665 | + | ||
666 | + # Create and start a server. | ||
667 | + return tf.train.Server(tf.train.ClusterSpec(cluster), | ||
668 | + protocol="grpc", | ||
669 | + job_name=task.type, | ||
670 | + task_index=task.index) | ||
671 | + | ||
672 | + | ||
673 | +def task_as_string(task): | ||
674 | + return "/job:%s/task:%s" % (task.type, task.index) | ||
675 | + | ||
676 | + | ||
677 | +def main(unused_argv): | ||
678 | + # Load the environment. | ||
679 | + env = json.loads(os.environ.get("TF_CONFIG", "{}")) | ||
680 | + | ||
681 | + # Load the cluster data from the environment. | ||
682 | + cluster_data = env.get("cluster", None) | ||
683 | + cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None | ||
684 | + | ||
685 | + # Load the task data from the environment. | ||
686 | + task_data = env.get("task", None) or {"type": "master", "index": 0} | ||
687 | + task = type("TaskSpec", (object,), task_data) | ||
688 | + | ||
689 | + # Logging the version. | ||
690 | + logging.set_verbosity(tf.logging.INFO) | ||
691 | + logging.info("%s: Tensorflow version: %s.", task_as_string(task), | ||
692 | + tf.__version__) | ||
693 | + | ||
694 | + # Dispatch to a master, a worker, or a parameter server. | ||
695 | + if not cluster or task.type == "master" or task.type == "worker": | ||
696 | + model = find_class_by_name(FLAGS.model, | ||
697 | + [frame_level_models, video_level_models])() | ||
698 | + | ||
699 | + reader = get_reader() | ||
700 | + | ||
701 | + model_exporter = export_model.ModelExporter( | ||
702 | + frame_features=FLAGS.frame_features, model=model, reader=reader) | ||
703 | + | ||
704 | + Trainer(cluster, task, FLAGS.train_dir, model, reader, model_exporter, | ||
705 | + FLAGS.log_device_placement, FLAGS.max_steps, | ||
706 | + FLAGS.export_model_steps).run(start_new_model=FLAGS.start_new_model) | ||
707 | + | ||
708 | + elif task.type == "ps": | ||
709 | + ParameterServer(cluster, task).run() | ||
710 | + else: | ||
711 | + raise ValueError("%s: Invalid task_type: %s." % | ||
712 | + (task_as_string(task), task.type)) | ||
713 | + | ||
714 | + | ||
715 | +if __name__ == "__main__": | ||
716 | + app.run() |
yt8m/utils.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains a collection of util functions for training and evaluating.""" | ||
15 | + | ||
16 | +import numpy | ||
17 | +import tensorflow as tf | ||
18 | +from tensorflow import logging | ||
19 | + | ||
20 | +try: | ||
21 | + xrange # Python 2 | ||
22 | +except NameError: | ||
23 | + xrange = range # Python 3 | ||
24 | + | ||
25 | + | ||
26 | +def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2): | ||
27 | + """Dequantize the feature from the byte format to the float format. | ||
28 | + | ||
29 | + Args: | ||
30 | + feat_vector: the input 1-d vector. | ||
31 | + max_quantized_value: the maximum of the quantized value. | ||
32 | + min_quantized_value: the minimum of the quantized value. | ||
33 | + | ||
34 | + Returns: | ||
35 | + A float vector which has the same shape as feat_vector. | ||
36 | + """ | ||
37 | + assert max_quantized_value > min_quantized_value | ||
38 | + quantized_range = max_quantized_value - min_quantized_value | ||
39 | + scalar = quantized_range / 255.0 | ||
40 | + bias = (quantized_range / 512.0) + min_quantized_value | ||
41 | + return feat_vector * scalar + bias | ||
42 | + | ||
43 | + | ||
44 | +def MakeSummary(name, value): | ||
45 | + """Creates a tf.Summary proto with the given name and value.""" | ||
46 | + summary = tf.Summary() | ||
47 | + val = summary.value.add() | ||
48 | + val.tag = str(name) | ||
49 | + val.simple_value = float(value) | ||
50 | + return summary | ||
51 | + | ||
52 | + | ||
53 | +def AddGlobalStepSummary(summary_writer, | ||
54 | + global_step_val, | ||
55 | + global_step_info_dict, | ||
56 | + summary_scope="Eval"): | ||
57 | + """Add the global_step summary to the Tensorboard. | ||
58 | + | ||
59 | + Args: | ||
60 | + summary_writer: Tensorflow summary_writer. | ||
61 | + global_step_val: a int value of the global step. | ||
62 | + global_step_info_dict: a dictionary of the evaluation metrics calculated for | ||
63 | + a mini-batch. | ||
64 | + summary_scope: Train or Eval. | ||
65 | + | ||
66 | + Returns: | ||
67 | + A string of this global_step summary | ||
68 | + """ | ||
69 | + this_hit_at_one = global_step_info_dict["hit_at_one"] | ||
70 | + this_perr = global_step_info_dict["perr"] | ||
71 | + this_loss = global_step_info_dict["loss"] | ||
72 | + examples_per_second = global_step_info_dict.get("examples_per_second", -1) | ||
73 | + | ||
74 | + summary_writer.add_summary( | ||
75 | + MakeSummary("GlobalStep/" + summary_scope + "_Hit@1", this_hit_at_one), | ||
76 | + global_step_val) | ||
77 | + summary_writer.add_summary( | ||
78 | + MakeSummary("GlobalStep/" + summary_scope + "_Perr", this_perr), | ||
79 | + global_step_val) | ||
80 | + summary_writer.add_summary( | ||
81 | + MakeSummary("GlobalStep/" + summary_scope + "_Loss", this_loss), | ||
82 | + global_step_val) | ||
83 | + | ||
84 | + if examples_per_second != -1: | ||
85 | + summary_writer.add_summary( | ||
86 | + MakeSummary("GlobalStep/" + summary_scope + "_Example_Second", | ||
87 | + examples_per_second), global_step_val) | ||
88 | + | ||
89 | + summary_writer.flush() | ||
90 | + info = ( | ||
91 | + "global_step {0} | Batch Hit@1: {1:.3f} | Batch PERR: {2:.3f} | Batch " | ||
92 | + "Loss: {3:.3f} | Examples_per_sec: {4:.3f}").format( | ||
93 | + global_step_val, this_hit_at_one, this_perr, this_loss, | ||
94 | + examples_per_second) | ||
95 | + return info | ||
96 | + | ||
97 | + | ||
98 | +def AddEpochSummary(summary_writer, | ||
99 | + global_step_val, | ||
100 | + epoch_info_dict, | ||
101 | + summary_scope="Eval"): | ||
102 | + """Add the epoch summary to the Tensorboard. | ||
103 | + | ||
104 | + Args: | ||
105 | + summary_writer: Tensorflow summary_writer. | ||
106 | + global_step_val: a int value of the global step. | ||
107 | + epoch_info_dict: a dictionary of the evaluation metrics calculated for the | ||
108 | + whole epoch. | ||
109 | + summary_scope: Train or Eval. | ||
110 | + | ||
111 | + Returns: | ||
112 | + A string of this global_step summary | ||
113 | + """ | ||
114 | + epoch_id = epoch_info_dict["epoch_id"] | ||
115 | + avg_hit_at_one = epoch_info_dict["avg_hit_at_one"] | ||
116 | + avg_perr = epoch_info_dict["avg_perr"] | ||
117 | + avg_loss = epoch_info_dict["avg_loss"] | ||
118 | + aps = epoch_info_dict["aps"] | ||
119 | + gap = epoch_info_dict["gap"] | ||
120 | + mean_ap = numpy.mean(aps) | ||
121 | + | ||
122 | + summary_writer.add_summary( | ||
123 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one), | ||
124 | + global_step_val) | ||
125 | + summary_writer.add_summary( | ||
126 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Perr", avg_perr), | ||
127 | + global_step_val) | ||
128 | + summary_writer.add_summary( | ||
129 | + MakeSummary("Epoch/" + summary_scope + "_Avg_Loss", avg_loss), | ||
130 | + global_step_val) | ||
131 | + summary_writer.add_summary( | ||
132 | + MakeSummary("Epoch/" + summary_scope + "_MAP", mean_ap), global_step_val) | ||
133 | + summary_writer.add_summary( | ||
134 | + MakeSummary("Epoch/" + summary_scope + "_GAP", gap), global_step_val) | ||
135 | + summary_writer.flush() | ||
136 | + | ||
137 | + info = ("epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} " | ||
138 | + "| MAP: {3:.3f} | GAP: {4:.3f} | Avg_Loss: {5:3f} | num_classes: {6}" | ||
139 | + ).format(epoch_id, avg_hit_at_one, avg_perr, mean_ap, gap, avg_loss, | ||
140 | + len(aps)) | ||
141 | + return info | ||
142 | + | ||
143 | + | ||
144 | +def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes): | ||
145 | + """Extract the list of feature names and the dimensionality of each feature | ||
146 | + | ||
147 | + from string of comma separated values. | ||
148 | + | ||
149 | + Args: | ||
150 | + feature_names: string containing comma separated list of feature names | ||
151 | + feature_sizes: string containing comma separated list of feature sizes | ||
152 | + | ||
153 | + Returns: | ||
154 | + List of the feature names and list of the dimensionality of each feature. | ||
155 | + Elements in the first/second list are strings/integers. | ||
156 | + """ | ||
157 | + list_of_feature_names = [ | ||
158 | + feature_names.strip() for feature_names in feature_names.split(",") | ||
159 | + ] | ||
160 | + list_of_feature_sizes = [ | ||
161 | + int(feature_sizes) for feature_sizes in feature_sizes.split(",") | ||
162 | + ] | ||
163 | + if len(list_of_feature_names) != len(list_of_feature_sizes): | ||
164 | + logging.error("length of the feature names (=" + | ||
165 | + str(len(list_of_feature_names)) + ") != length of feature " | ||
166 | + "sizes (=" + str(len(list_of_feature_sizes)) + ")") | ||
167 | + | ||
168 | + return list_of_feature_names, list_of_feature_sizes | ||
169 | + | ||
170 | + | ||
171 | +def clip_gradient_norms(gradients_to_variables, max_norm): | ||
172 | + """Clips the gradients by the given value. | ||
173 | + | ||
174 | + Args: | ||
175 | + gradients_to_variables: A list of gradient to variable pairs (tuples). | ||
176 | + max_norm: the maximum norm value. | ||
177 | + | ||
178 | + Returns: | ||
179 | + A list of clipped gradient to variable pairs. | ||
180 | + """ | ||
181 | + clipped_grads_and_vars = [] | ||
182 | + for grad, var in gradients_to_variables: | ||
183 | + if grad is not None: | ||
184 | + if isinstance(grad, tf.IndexedSlices): | ||
185 | + tmp = tf.clip_by_norm(grad.values, max_norm) | ||
186 | + grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape) | ||
187 | + else: | ||
188 | + grad = tf.clip_by_norm(grad, max_norm) | ||
189 | + clipped_grads_and_vars.append((grad, var)) | ||
190 | + return clipped_grads_and_vars | ||
191 | + | ||
192 | + | ||
193 | +def combine_gradients(tower_grads): | ||
194 | + """Calculate the combined gradient for each shared variable across all towers. | ||
195 | + | ||
196 | + Note that this function provides a synchronization point across all towers. | ||
197 | + | ||
198 | + Args: | ||
199 | + tower_grads: List of lists of (gradient, variable) tuples. The outer list is | ||
200 | + over individual gradients. The inner list is over the gradient calculation | ||
201 | + for each tower. | ||
202 | + | ||
203 | + Returns: | ||
204 | + List of pairs of (gradient, variable) where the gradient has been summed | ||
205 | + across all towers. | ||
206 | + """ | ||
207 | + filtered_grads = [ | ||
208 | + [x for x in grad_list if x[0] is not None] for grad_list in tower_grads | ||
209 | + ] | ||
210 | + final_grads = [] | ||
211 | + for i in xrange(len(filtered_grads[0])): | ||
212 | + grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))] | ||
213 | + grad = tf.stack([x[0] for x in grads], 0) | ||
214 | + grad = tf.reduce_sum(grad, 0) | ||
215 | + final_grads.append(( | ||
216 | + grad, | ||
217 | + filtered_grads[0][i][1], | ||
218 | + )) | ||
219 | + | ||
220 | + return final_grads |
yt8m/video_level_models.py
0 → 100644
1 | +# Copyright 2016 Google Inc. All Rights Reserved. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS-IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | +"""Contains model definitions.""" | ||
15 | +import math | ||
16 | + | ||
17 | +import models | ||
18 | +import tensorflow as tf | ||
19 | +import utils | ||
20 | + | ||
21 | +from tensorflow import flags | ||
22 | +import tensorflow.contrib.slim as slim | ||
23 | + | ||
24 | +FLAGS = flags.FLAGS | ||
25 | +flags.DEFINE_integer( | ||
26 | + "moe_num_mixtures", 2, | ||
27 | + "The number of mixtures (excluding the dummy 'expert') used for MoeModel.") | ||
28 | + | ||
29 | + | ||
30 | +class LogisticModel(models.BaseModel): | ||
31 | + """Logistic model with L2 regularization.""" | ||
32 | + | ||
33 | + def create_model(self, | ||
34 | + model_input, | ||
35 | + vocab_size, | ||
36 | + l2_penalty=1e-8, | ||
37 | + **unused_params): | ||
38 | + """Creates a logistic model. | ||
39 | + | ||
40 | + Args: | ||
41 | + model_input: 'batch' x 'num_features' matrix of input features. | ||
42 | + vocab_size: The number of classes in the dataset. | ||
43 | + | ||
44 | + Returns: | ||
45 | + A dictionary with a tensor containing the probability predictions of the | ||
46 | + model in the 'predictions' key. The dimensions of the tensor are | ||
47 | + batch_size x num_classes. | ||
48 | + """ | ||
49 | + output = slim.fully_connected( | ||
50 | + model_input, | ||
51 | + vocab_size, | ||
52 | + activation_fn=tf.nn.sigmoid, | ||
53 | + weights_regularizer=slim.l2_regularizer(l2_penalty)) | ||
54 | + return {"predictions": output} | ||
55 | + | ||
56 | + | ||
57 | +class MoeModel(models.BaseModel): | ||
58 | + """A softmax over a mixture of logistic models (with L2 regularization).""" | ||
59 | + | ||
60 | + def create_model(self, | ||
61 | + model_input, | ||
62 | + vocab_size, | ||
63 | + num_mixtures=None, | ||
64 | + l2_penalty=1e-8, | ||
65 | + **unused_params): | ||
66 | + """Creates a Mixture of (Logistic) Experts model. | ||
67 | + | ||
68 | + The model consists of a per-class softmax distribution over a | ||
69 | + configurable number of logistic classifiers. One of the classifiers in the | ||
70 | + mixture is not trained, and always predicts 0. | ||
71 | + | ||
72 | + Args: | ||
73 | + model_input: 'batch_size' x 'num_features' matrix of input features. | ||
74 | + vocab_size: The number of classes in the dataset. | ||
75 | + num_mixtures: The number of mixtures (excluding a dummy 'expert' that | ||
76 | + always predicts the non-existence of an entity). | ||
77 | + l2_penalty: How much to penalize the squared magnitudes of parameter | ||
78 | + values. | ||
79 | + | ||
80 | + Returns: | ||
81 | + A dictionary with a tensor containing the probability predictions of the | ||
82 | + model in the 'predictions' key. The dimensions of the tensor are | ||
83 | + batch_size x num_classes. | ||
84 | + """ | ||
85 | + num_mixtures = num_mixtures or FLAGS.moe_num_mixtures | ||
86 | + | ||
87 | + gate_activations = slim.fully_connected( | ||
88 | + model_input, | ||
89 | + vocab_size * (num_mixtures + 1), | ||
90 | + activation_fn=None, | ||
91 | + biases_initializer=None, | ||
92 | + weights_regularizer=slim.l2_regularizer(l2_penalty), | ||
93 | + scope="gates") | ||
94 | + expert_activations = slim.fully_connected( | ||
95 | + model_input, | ||
96 | + vocab_size * num_mixtures, | ||
97 | + activation_fn=None, | ||
98 | + weights_regularizer=slim.l2_regularizer(l2_penalty), | ||
99 | + scope="experts") | ||
100 | + | ||
101 | + gating_distribution = tf.nn.softmax( | ||
102 | + tf.reshape( | ||
103 | + gate_activations, | ||
104 | + [-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1) | ||
105 | + expert_distribution = tf.nn.sigmoid( | ||
106 | + tf.reshape(expert_activations, | ||
107 | + [-1, num_mixtures])) # (Batch * #Labels) x num_mixtures | ||
108 | + | ||
109 | + final_probabilities_by_class_and_batch = tf.reduce_sum( | ||
110 | + gating_distribution[:, :num_mixtures] * expert_distribution, 1) | ||
111 | + final_probabilities = tf.reshape(final_probabilities_by_class_and_batch, | ||
112 | + [-1, vocab_size]) | ||
113 | + return {"predictions": final_probabilities} |
-
Please register or login to post a comment