forked from vanhuyz/CycleGAN-TensorFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_data.py
116 lines (89 loc) · 3.43 KB
/
build_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import tensorflow as tf
import random
import os
try:
from os import scandir
except ImportError:
# Python 2 polyfill module
from scandir import scandir
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('X_input_dir', 'data/apple2orange/trainA',
'X input directory, default: data/apple2orange/trainA')
tf.flags.DEFINE_string('Y_input_dir', 'data/apple2orange/trainB',
'Y input directory, default: data/apple2orange/trainB')
tf.flags.DEFINE_string('X_output_file', 'data/tfrecords/apple.tfrecords',
'X output tfrecords file, default: data/tfrecords/apple.tfrecords')
tf.flags.DEFINE_string('Y_output_file', 'data/tfrecords/orange.tfrecords',
'Y output tfrecords file, default: data/tfrecords/orange.tfrecords')
def data_reader(input_dir, shuffle=True):
"""Read images from input_dir then shuffle them
Args:
input_dir: string, path of input dir, e.g., /path/to/dir
Returns:
file_paths: list of strings
"""
file_paths = []
for img_file in scandir(input_dir):
if img_file.name.endswith('.jpg') and img_file.is_file():
file_paths.append(img_file.path)
if shuffle:
# Shuffle the ordering of all image files in order to guarantee
# random ordering of the images with respect to label in the
# saved TFRecord files. Make the randomization repeatable.
shuffled_index = list(range(len(file_paths)))
random.seed(12345)
random.shuffle(shuffled_index)
file_paths = [file_paths[i] for i in shuffled_index]
return file_paths
def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _convert_to_example(file_path, image_buffer):
"""Build an Example proto for an example.
Args:
file_path: string, path to an image file, e.g., '/path/to/example.JPG'
image_buffer: string, JPEG encoding of RGB image
Returns:
Example proto
"""
file_name = file_path.split('/')[-1]
example = tf.train.Example(features=tf.train.Features(feature={
'image/file_name': _bytes_feature(tf.compat.as_bytes(os.path.basename(file_name))),
'image/encoded_image': _bytes_feature((image_buffer))
}))
return example
def data_writer(input_dir, output_file):
"""Write data to tfrecords
"""
file_paths = data_reader(input_dir)
# create tfrecords dir if not exists
output_dir = os.path.dirname(output_file)
try:
os.makedirs(output_dir)
except os.error as e:
pass
images_num = len(file_paths)
# dump to tfrecords file
writer = tf.python_io.TFRecordWriter(output_file)
for i in range(len(file_paths)):
file_path = file_paths[i]
with tf.gfile.FastGFile(file_path, 'rb') as f:
image_data = f.read()
example = _convert_to_example(file_path, image_data)
writer.write(example.SerializeToString())
if i % 500 == 0:
print("Processed {}/{}.".format(i, images_num))
print("Done.")
writer.close()
def main(unused_argv):
print("Convert X data to tfrecords...")
data_writer(FLAGS.X_input_dir, FLAGS.X_output_file)
print("Convert Y data to tfrecords...")
data_writer(FLAGS.Y_input_dir, FLAGS.Y_output_file)
if __name__ == '__main__':
tf.app.run()