-
Notifications
You must be signed in to change notification settings - Fork 97
/
dataset_split.py
95 lines (71 loc) · 3.67 KB
/
dataset_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding: utf-8 -*-
"""
Created on Thu Jul 12 16:34:16 2018
@author: yy
"""
import random
from sklearn.cross_validation import train_test_split
from keras import backend as K
from dataset_load import load_dataset_cnn
image_size = (128, 32)
class Dataset:
def __init__(self, path_name):
#训练集
self.train_images = None
self.train_labels = None
#验证集
self.valid_images = None
self.valid_labels = None
#测试集
self.test_images = None
self.test_labels = None
#数据集加载路径
self.path_name = path_name
#当前库采用的维度顺序
self.input_shape = None
#加载数据集并按照交叉验证的原则划分数据集并进行相关预处理工作
def load(self, img_rows = image_size[1], img_cols = image_size[0],
img_channels = 1):
#加载数据集到内存
images, labels = load_dataset_cnn(self.path_name)
train_images, valid_images, train_labels, valid_labels = train_test_split(images, labels, test_size = 0.3, random_state = random.randint(0, 100))
_, test_images, _, test_labels = train_test_split(images, labels, test_size = 0.5, random_state = random.randint(0, 100))
#当前的维度顺序如果为'th',则输入图片数据时的顺序为:channels,rows,cols,否则:rows,cols,channels
#这部分代码就是根据keras库要求的维度顺序重组训练数据集
if K.image_dim_ordering() == 'th':
train_images = train_images.reshape(train_images.shape[0], img_channels, img_rows, img_cols)
valid_images = valid_images.reshape(valid_images.shape[0], img_channels, img_rows, img_cols)
test_images = test_images.reshape(test_images.shape[0], img_channels, img_rows, img_cols)
self.input_shape = (img_channels, img_rows, img_cols)
else:
train_images = train_images.reshape(train_images.shape[0], img_rows, img_cols, img_channels)
valid_images = valid_images.reshape(valid_images.shape[0], img_rows, img_cols, img_channels)
test_images = test_images.reshape(test_images.shape[0], img_rows, img_cols, img_channels)
self.input_shape = (img_rows, img_cols, img_channels)
#输出训练集、验证集、测试集的数量
print(train_images.shape[0], 'train samples')
print(valid_images.shape[0], 'valid samples')
print(test_images.shape[0], 'test samples')
#像素数据浮点化以便归一化
train_images = train_images.astype('float32')
valid_images = valid_images.astype('float32')
test_images = test_images.astype('float32')
#将其归一化,图像的各像素值归一化到0~1区间
train_images /= 255
valid_images /= 255
test_images /= 255
self.train_images = train_images
self.valid_images = valid_images
self.test_images = test_images
self.train_labels = train_labels
self.valid_labels = valid_labels
self.test_labels = test_labels
if __name__ == '__main__':
dataset = Dataset('../img_data/all/')
dataset.load()
'''
#评估模型
model = Model()
model.load_model(file_path = './model/face_model')
model.evaluate(dataset)
'''