Showing
1 changed file
with
91 additions
and
0 deletions
for_dataset/csv_to_tfrecord.py
0 → 100644
| 1 | +from __future__ import division | ||
| 2 | +from __future__ import print_function | ||
| 3 | +from __future__ import absolute_import | ||
| 4 | + | ||
| 5 | +import os | ||
| 6 | +import io | ||
| 7 | +import pandas as pd | ||
| 8 | +import tensorflow as tf | ||
| 9 | + | ||
| 10 | +from PIL import Image | ||
| 11 | +from object_detection.utils import dataset_util | ||
| 12 | +from collections import namedtuple, OrderedDict | ||
| 13 | + | ||
| 14 | +flags = tf.app.flags | ||
| 15 | +flags.DEFINE_string('csv_input', '', 'Path to the CSV input') | ||
| 16 | +flags.DEFINE_string('output_path', '', 'Path to output TFRecord') | ||
| 17 | +flags.DEFINE_string('image_dir', '', 'Path to images') | ||
| 18 | +FLAGS = flags.FLAGS | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +# TO-DO replace this with label map | ||
| 22 | +def class_text_to_int(row_label): | ||
| 23 | + if row_label == 'fire': | ||
| 24 | + return 1 | ||
| 25 | + else: | ||
| 26 | + None | ||
| 27 | + | ||
| 28 | + | ||
| 29 | +def split(df, group): | ||
| 30 | + data = namedtuple('data', ['filename', 'object']) | ||
| 31 | + gb = df.groupby(group) | ||
| 32 | + return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)] | ||
| 33 | + | ||
| 34 | + | ||
| 35 | +def create_tf_example(group, path): | ||
| 36 | + with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: | ||
| 37 | + encoded_jpg = fid.read() | ||
| 38 | + encoded_jpg_io = io.BytesIO(encoded_jpg) | ||
| 39 | + image = Image.open(encoded_jpg_io) | ||
| 40 | + width, height = image.size | ||
| 41 | + | ||
| 42 | + filename = group.filename.encode('utf8') | ||
| 43 | + image_format = b'jpg' | ||
| 44 | + xmins = [] | ||
| 45 | + xmaxs = [] | ||
| 46 | + ymins = [] | ||
| 47 | + ymaxs = [] | ||
| 48 | + classes_text = [] | ||
| 49 | + classes = [] | ||
| 50 | + | ||
| 51 | + for index, row in group.object.iterrows(): | ||
| 52 | + xmins.append(row['xmin'] / width) | ||
| 53 | + xmaxs.append(row['xmax'] / width) | ||
| 54 | + ymins.append(row['ymin'] / height) | ||
| 55 | + ymaxs.append(row['ymax'] / height) | ||
| 56 | + classes_text.append(row['class'].encode('utf8')) | ||
| 57 | + classes.append(class_text_to_int(row['class'])) | ||
| 58 | + | ||
| 59 | + tf_example = tf.train.Example(features=tf.train.Features(feature={ | ||
| 60 | + 'image/height': dataset_util.int64_feature(height), | ||
| 61 | + 'image/width': dataset_util.int64_feature(width), | ||
| 62 | + 'image/filename': dataset_util.bytes_feature(filename), | ||
| 63 | + 'image/source_id': dataset_util.bytes_feature(filename), | ||
| 64 | + 'image/encoded': dataset_util.bytes_feature(encoded_jpg), | ||
| 65 | + 'image/format': dataset_util.bytes_feature(image_format), | ||
| 66 | + 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), | ||
| 67 | + 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), | ||
| 68 | + 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), | ||
| 69 | + 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), | ||
| 70 | + 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), | ||
| 71 | + 'image/object/class/label': dataset_util.int64_list_feature(classes), | ||
| 72 | + })) | ||
| 73 | + return tf_example | ||
| 74 | + | ||
| 75 | + | ||
| 76 | +def main(_): | ||
| 77 | + writer = tf.python_io.TFRecordWriter(FLAGS.output_path) | ||
| 78 | + path = os.path.join(FLAGS.image_dir) | ||
| 79 | + examples = pd.read_csv(FLAGS.csv_input) | ||
| 80 | + grouped = split(examples, 'filename') | ||
| 81 | + for group in grouped: | ||
| 82 | + tf_example = create_tf_example(group, path) | ||
| 83 | + writer.write(tf_example.SerializeToString()) | ||
| 84 | + | ||
| 85 | + writer.close() | ||
| 86 | + output_path = os.path.join(os.getcwd(), FLAGS.output_path) | ||
| 87 | + print('Successfully created the TFRecords: {}'.format(output_path)) | ||
| 88 | + | ||
| 89 | + | ||
| 90 | +if __name__ == '__main__': | ||
| 91 | + tf.app.run() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment