-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_generator.py
81 lines (60 loc) · 2.55 KB
/
data_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import tensorflow as tf
import os
# Avoid OOM errors by setting GPU Memory Consumption Growth
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
print(f"Setting memory growth for gpu: {gpu}")
tf.config.experimental.set_memory_growth(gpu, True)
vggface2_train = './datasets/VGG-Face2/data/train' # Total nº files: 3141890
vggface2_test = './datasets/VGG-Face2/data/test' # Total nº files: 169396
dataset_path = vggface2_train
labels_list = [tf.constant(x) for x in sorted(os.listdir(dataset_path))]
AUTO = tf.data.AUTOTUNE
def get_label(file_path):
return tf.strings.split(file_path, os.path.sep)[-2]
def process_image(file_path, target_size, labels=True):
label = get_label(file_path)
label = tf.argmax(label == labels_list) # Sparse labels
img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.convert_image_dtype(img, dtype=tf.float32)
img = tf.image.resize(img, [target_size, target_size])
if labels:
return img, label
else:
return img
def create_data_generators(target_size, batch_size, ret_labels=True, seed=0):
dataset = tf.data.Dataset.list_files(dataset_path + '/*/*', shuffle=True, seed=seed)
total_samples = len(dataset)
train_split = 0.9
val_split = 0.05
test_split = 0.05
train_len = int(total_samples * train_split)
val_len = int(total_samples * val_split)
test_len = int(total_samples * test_split)
train_ds = dataset.take(train_len)
val_ds = dataset.skip(train_len).take(val_len)
test_ds = dataset.skip(train_len).skip(val_len)
print(f"[INFO] Num images for train: {train_len} -> train_ds: {len(train_ds)}")
print(f"[INFO] Num images for validation: {val_len} -> val_ds: {len(val_ds)}")
print(f"[INFO] Num images for test: {test_len} -> test_ds: {len(test_ds)}")
train_ds = (
train_ds
.shuffle(buffer_size=2*batch_size, seed=seed)
.map(lambda x: process_image(x, target_size, labels=ret_labels), num_parallel_calls=AUTO)
.batch(batch_size=batch_size)
.prefetch(buffer_size=AUTO)
)
val_ds = (
val_ds
.map(lambda x: process_image(x, target_size, labels=ret_labels), num_parallel_calls=AUTO)
.batch(batch_size=batch_size)
.prefetch(buffer_size=AUTO)
)
test_ds = (
test_ds
.map(lambda x: process_image(x, target_size, labels=ret_labels), num_parallel_calls=AUTO)
.batch(batch_size=batch_size)
.prefetch(buffer_size=AUTO)
)
return train_ds, val_ds, test_ds