forked from zhangks98/eeg-adapt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_base.py
138 lines (119 loc) · 5.07 KB
/
train_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python
# coding: utf-8
'''Subject-independent classification with KU Data,
using Deep ConvNet model from [1].
References
----------
.. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J.,
Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T. (2017).
Deep learning with convolutional neural networks for EEG decoding and
visualization.
Human Brain Mapping , Aug. 2017. Online: http://dx.doi.org/10.1002/hbm.23730
'''
import argparse
import json
import logging
import sys
from os import makedirs
from os.path import join as pjoin
from shutil import copy2, move
import h5py
import numpy as np
import torch
import torch.nn.functional as F
from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.models.deep4 import Deep4Net
from braindecode.torch_ext.optimizers import AdamW
from braindecode.torch_ext.util import set_random_seeds
from sklearn.model_selection import KFold
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
level=logging.INFO, stream=sys.stdout)
parser = argparse.ArgumentParser(
description='Subject-independent classification with KU Data')
parser.add_argument('datapath', type=str, help='Path to the h5 data file')
parser.add_argument('outpath', type=str, help='Path to the result folder')
parser.add_argument('-fold', type=int,
help='k-fold index, starts with 0', required=True)
parser.add_argument('-gpu', type=int, help='The gpu device to use', default=0)
args = parser.parse_args()
datapath = args.datapath
outpath = args.outpath
fold = args.fold
assert(fold >= 0 and fold < 54)
# Randomly shuffled subject.
subjs = [35, 47, 46, 37, 13, 27, 12, 32, 53, 54, 4, 40, 19, 41, 18, 42, 34, 7,
49, 9, 5, 48, 29, 15, 21, 17, 31, 45, 1, 38, 51, 8, 11, 16, 28, 44, 24,
52, 3, 26, 39, 50, 6, 23, 2, 14, 25, 20, 10, 33, 22, 43, 36, 30]
test_subj = subjs[fold]
cv_set = np.array(subjs[fold+1:] + subjs[:fold])
kf = KFold(n_splits=6)
dfile = h5py.File(datapath, 'r')
torch.cuda.set_device(args.gpu)
set_random_seeds(seed=20200205, cuda=True)
BATCH_SIZE = 16
TRAIN_EPOCH = 200 # consider 200 for early stopping
# Get data from single subject.
def get_data(subj):
dpath = '/s' + str(subj)
X = dfile[pjoin(dpath, 'X')]
Y = dfile[pjoin(dpath, 'Y')]
return X, Y
def get_multi_data(subjs):
Xs = []
Ys = []
for s in subjs:
x, y = get_data(s)
Xs.append(x[:])
Ys.append(y[:])
X = np.concatenate(Xs, axis=0)
Y = np.concatenate(Ys, axis=0)
return X, Y
cv_loss = []
for cv_index, (train_index, test_index) in enumerate(kf.split(cv_set)):
train_subjs = cv_set[train_index]
valid_subjs = cv_set[test_index]
X_train, Y_train = get_multi_data(train_subjs)
X_val, Y_val = get_multi_data(valid_subjs)
X_test, Y_test = get_data(test_subj)
train_set = SignalAndTarget(X_train, y=Y_train)
valid_set = SignalAndTarget(X_val, y=Y_val)
test_set = SignalAndTarget(X_test[200:], y=Y_test[200:])
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length = auto ensures we only get a single output in the time dimension
model = Deep4Net(in_chans=in_chans, n_classes=n_classes,
input_time_length=train_set.X.shape[2],
final_conv_length='auto').cuda()
# these are good values for the deep model
optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001)
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )
# Fit the base model for transfer learning, use early stopping as a hack to remember the model
exp = model.fit(train_set.X, train_set.y, epochs=TRAIN_EPOCH, batch_size=BATCH_SIZE, scheduler='cosine',
validation_data=(valid_set.X, valid_set.y), remember_best_column='valid_loss')
rememberer = exp.rememberer
base_model_param = {
'epoch': rememberer.best_epoch,
'model_state_dict': rememberer.model_state_dict,
'optimizer_state_dict': rememberer.optimizer_state_dict,
'loss': rememberer.lowest_val
}
torch.save(base_model_param, pjoin(
outpath, 'model_f{}_cv{}.pt'.format(fold, cv_index)))
model.epochs_df.to_csv(
pjoin(outpath, 'epochs_f{}_cv{}.csv'.format(fold, cv_index)))
cv_loss.append(rememberer.lowest_val)
test_loss = model.evaluate(test_set.X, test_set.y)
with open(pjoin(outpath, 'test_base_s{}_f{}_cv{}.json'.format(test_subj, fold, cv_index)), 'w') as f:
json.dump(test_loss, f)
best_cv = np.argmin(cv_loss)
best_dir = pjoin(outpath, "best")
makedirs(best_dir, exist_ok=True)
with open(pjoin(best_dir, "fold_bestcv.txt"), 'a') as f:
f.write("{}, {}\n".format(fold, best_cv))
copy2(pjoin(outpath, 'model_f{}_cv{}.pt'.format(fold, best_cv)),
pjoin(best_dir, 'model_f{}.pt'.format(fold)))
copy2(pjoin(outpath, 'epochs_f{}_cv{}.csv'.format(fold, best_cv)),
pjoin(best_dir, 'epochs_f{}.csv'.format(fold)))
copy2(pjoin(outpath, 'test_base_s{}_f{}_cv{}.json'.format(test_subj, fold, best_cv)),
pjoin(best_dir, 'test_base_s{}_f{}.json'.format(test_subj, fold)))
dfile.close()