-
Notifications
You must be signed in to change notification settings - Fork 17
/
train_mnist.py
150 lines (130 loc) · 5.36 KB
/
train_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from keras.datasets import mnist
from keras.utils import to_categorical
from os_elm import OS_ELM
import numpy as np
import tensorflow as tf
import tqdm
def softmax(a):
c = np.max(a, axis=-1).reshape(-1, 1)
exp_a = np.exp(a - c)
sum_exp_a = np.sum(exp_a, axis=-1).reshape(-1, 1)
return exp_a / sum_exp_a
def main():
# ===========================================
# Instantiate os-elm
# ===========================================
n_input_nodes = 784
n_hidden_nodes = 1024
n_output_nodes = 10
os_elm = OS_ELM(
# the number of input nodes.
n_input_nodes=n_input_nodes,
# the number of hidden nodes.
n_hidden_nodes=n_hidden_nodes,
# the number of output nodes.
n_output_nodes=n_output_nodes,
# loss function.
# the default value is 'mean_squared_error'.
# for the other functions, we support
# 'mean_absolute_error', 'categorical_crossentropy', and 'binary_crossentropy'.
loss='mean_squared_error',
# activation function applied to the hidden nodes.
# the default value is 'sigmoid'.
# for the other functions, we support 'linear' and 'tanh'.
# NOTE: OS-ELM can apply an activation function only to the hidden nodes.
activation='sigmoid',
)
# ===========================================
# Prepare dataset
# ===========================================
n_classes = n_output_nodes
# load MNIST
(x_train, t_train), (x_test, t_test) = mnist.load_data()
# normalize images' values within [0, 1]
x_train = x_train.reshape(-1, n_input_nodes) / 255.
x_test = x_test.reshape(-1, n_input_nodes) / 255.
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
# convert label data into one-hot-vector format data.
t_train = to_categorical(t_train, num_classes=n_classes)
t_test = to_categorical(t_test, num_classes=n_classes)
t_train = t_train.astype(np.float32)
t_test = t_test.astype(np.float32)
# divide the training dataset into two datasets:
# (1) for the initial training phase
# (2) for the sequential training phase
# NOTE: the number of training samples for the initial training phase
# must be much greater than the number of the model's hidden nodes.
# here, we assign int(1.5 * n_hidden_nodes) training samples
# for the initial training phase.
border = int(1.5 * n_hidden_nodes)
x_train_init = x_train[:border]
x_train_seq = x_train[border:]
t_train_init = t_train[:border]
t_train_seq = t_train[border:]
# ===========================================
# Training
# ===========================================
# the initial training phase
pbar = tqdm.tqdm(total=len(x_train), desc='initial training phase')
os_elm.init_train(x_train_init, t_train_init)
pbar.update(n=len(x_train_init))
# the sequential training phase
pbar.set_description('sequential training phase')
batch_size = 64
for i in range(0, len(x_train_seq), batch_size):
x_batch = x_train_seq[i:i+batch_size]
t_batch = t_train_seq[i:i+batch_size]
os_elm.seq_train(x_batch, t_batch)
pbar.update(n=len(x_batch))
pbar.close()
# ===========================================
# Prediction
# ===========================================
# sample 10 validation samples from x_test
n = 10
x = x_test[:n]
t = t_test[:n]
# 'predict' method returns raw values of output nodes.
y = os_elm.predict(x)
# apply softmax function to the output values.
y = softmax(y)
# check the answers.
for i in range(n):
max_ind = np.argmax(y[i])
print('========== sample index %d ==========' % i)
print('estimated answer: class %d' % max_ind)
print('estimated probability: %.3f' % y[i,max_ind])
print('true answer: class %d' % np.argmax(t[i]))
# ===========================================
# Evaluation
# ===========================================
# we currently support 'loss' and 'accuracy' for 'metrics'.
# NOTE: 'accuracy' is valid only if the model assumes
# to deal with a classification problem, while 'loss' is always valid.
# loss = os_elm.evaluate(x_test, t_test, metrics=['loss']
[loss, accuracy] = os_elm.evaluate(x_test, t_test, metrics=['loss', 'accuracy'])
print('val_loss: %f, val_accuracy: %f' % (loss, accuracy))
# ===========================================
# Save model
# ===========================================
print('saving model parameters...')
os_elm.save('./checkpoint/model.ckpt')
# initialize weights of os_elm
os_elm.initialize_variables()
# ===========================================
# Load model
# ===========================================
# If you want to load weights to a model,
# the architecture of the model must be exactly the same
# as the one when the weights were saved.
print('restoring model parameters...')
os_elm.restore('./checkpoint/model.ckpt')
# ===========================================
# ReEvaluation
# ===========================================
# loss = os_elm.evaluate(x_test, t_test, metrics=['loss']
[loss, accuracy] = os_elm.evaluate(x_test, t_test, metrics=['loss', 'accuracy'])
print('val_loss: %f, val_accuracy: %f' % (loss, accuracy))
if __name__ == '__main__':
main()