forked from SidHard/tfAlexNet
-
Notifications
You must be signed in to change notification settings - Fork 21
/
importData.py
129 lines (109 loc) · 4.11 KB
/
importData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import os
import cv2
class Dataset:
def __init__(self, imagePath, extensions):
self.data = createImageList(imagePath, extensions)
np.random.shuffle(self.data)
self.num_records = len(self.data)
self.next_record = 0
self.labels, self.inputs = zip(*self.data)
category = np.unique(self.labels)
self.num_labels = len(category)
self.category2label = dict(zip(category, range(len(category))))
self.label2category = {l: k for k, l in self.category2label.items()}
# Convert the labels to numbers
self.labels = [self.category2label[l] for l in self.labels]
def __len__(self):
return self.num_records
def onehot(self, label):
v = np.zeros(self.num_labels)
v[label] = 1
return v
def recordsRemaining(self):
return len(self) - self.next_record
def hasNextRecord(self):
return self.next_record < self.num_records
def preprocess(self, img):
pp = cv2.resize(img, (227, 227))
pp = np.asarray(pp, dtype=np.float32)
pp /= 255
pp = pp.reshape((pp.shape[0], pp.shape[1], 3))
return pp
def nextRecord(self):
if not self.hasNextRecord():
np.random.shuffle(self.data)
self.next_record = 0
self.labels, self.inputs = zip(*self.data)
category = np.unique(self.labels)
self.num_labels = len(category)
self.category2label = dict(zip(category, range(len(category))))
self.label2category = {l: k for k, l in self.category2label.items()}
# Convert the labels to numbers
self.labels = [self.category2label[l] for l in self.labels]
# return None
label = self.onehot(self.labels[self.next_record])
input = self.preprocess(cv2.imread(self.inputs[self.next_record]))
self.next_record += 1
return label, input
def nextBatch(self, batch_size):
records = []
for i in range(batch_size):
record = self.nextRecord()
if record is None:
break
records.append(record)
labels, input = zip(*records)
return labels, input
def createImageList(imagePath, extensions):
imageFilenames = []
labels = []
categoryList = [None]
categoryList = [c for c in sorted(os.listdir(imagePath))
if c[0] != '.' and
os.path.isdir(os.path.join(imagePath, c))]
for category in categoryList:
if category:
walkPath = os.path.join(imagePath, category)
else:
walkPath = imagePath
category = os.path.split(imagePath)[1]
w = _walk(walkPath)
while True:
try:
dirpath, dirnames, filenames = w.next()
except StopIteration:
break
# Don't enter directories that begin with '.'
for d in dirnames[:]:
if d.startswith('.'):
dirnames.remove(d)
dirnames.sort()
# Ignore files that begin with '.'
filenames = [f for f in filenames if not f.startswith('.')]
# Only load images with the right extension
filenames = [f for f in filenames if os.path.splitext(f)[1].lower() in extensions]
filenames.sort()
# imageFilenames = [os.path.join(dirpath, f) for f in filenames]
for f in filenames:
imageFilenames.append([category, os.path.join(dirpath, f)])
return imageFilenames
def _walk(top):
"""
Directory tree generator lifted from python 2.6 and then
stripped down. It improves on the 2.5 os.walk() by adding
the 'followlinks' capability.
GLU: copied from image sensor
"""
names = os.listdir(top)
dirs, nondirs = [], []
for name in names:
if os.path.isdir(os.path.join(top, name)):
dirs.append(name)
else:
nondirs.append(name)
yield top, dirs, nondirs
for name in dirs:
path = os.path.join(top, name)
for x in _walk(path):
yield x