-
Notifications
You must be signed in to change notification settings - Fork 5
/
train.py
201 lines (176 loc) · 8.26 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import sys
import torch
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable
from torch.utils import data
# get_model is defined in the __init__.py file
from ptsemseg.models import get_model
from ptsemseg.loader import get_loader, get_data_path
from ptsemseg.loss import cross_entropy2d, cross_entropy3d, FocalCrossEntropyLoss3d, FocalCrossEntropyLoss2d, DiceLoss
from ptsemseg.metrics import scores
from lr_scheduling import *
from tensorboardX import SummaryWriter
from guotai_brats17.parse_config import parse_config
import random
from guotai_brats17.data_loader import DataLoader
DEBUG = False
def log(s):
if DEBUG:
print(s)
def train(args, guotai_config):
# Setup Dataloader
print('###### Step One: Setup Dataloader')
data_loader = get_loader(args.dataset)
data_path = get_data_path(args.dataset)
# For 2D dataset keep is_transform True
# loader = data_loader(data_path, is_transform=True, img_size=(args.img_rows, args.img_cols))
# For 3D dataset keep is_transform False
# loader = data_loader(data_path, is_transform=False, img_size=(args.img_rows, args.img_cols))
if args.dataset == 'brats17_loader_guotai':
config_data = config['data']
# print(config_data)
# config_net = config['network']
config_train = config['training']
random.seed(config_train.get('random_seed', 1))
assert (config_data['with_ground_truth'])
# net_type = config_net['net_type']
# net_name = config_net['net_name']
# class_num = config_net['class_num']
# batch_size = config_data.get('batch_size', 5)
dataloader_guotai = DataLoader(config_data)
dataloader_guotai.load_data()
loader = data_loader(dataloader_guotai)
elif args.dataset == 'brats17_loader':
loader = data_loader(data_path, is_transform=False, img_size=(args.img_rows, args.img_cols))
else:
loader = data_loader(data_path, is_transform=True, img_size=(args.img_rows, args.img_cols))
n_classes = args.n_classes
trainloader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=4, shuffle=True)
# Setup Model
print('###### Step Two: Setup Model')
model = get_model(args.arch, n_classes)
if args.pretrained_path != 'empty':
model = torch.load(args.pretrained_path)
#model = torch.load('/home/donghao/Desktop/donghao/isbi2019/code/fast_segmentation_code/runs/bisenet3Dbrain_brats17_loader_1_251_3020_min.pkl')
#model = torch.load('/home/donghao/Desktop/donghao/isbi2019/code/fast_segmentation_code/runs/2177/bisenet3Dbrain_brats17_loader_1_293_min.pkl')
#model = torch.load('/home/donghao/Desktop/donghao/isbi2019/code/fast_segmentation_code/runs/9863/FCDenseNet57_brats17_loader_1_599.pkl')
# model =
if torch.cuda.is_available():
model.cuda(0)
test_image, test_segmap = loader[0]
test_image = Variable(test_image.unsqueeze(0).cuda(0))
else:
test_image, test_segmap = loader[0]
test_image = Variable(test_image.unsqueeze(0))
log('The optimizer is Adam')
log('The learning rate is {}'.format(args.l_rate))
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.99)
optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Train Model
print('###### Step Three: Training Model')
epoch_loss_array_total = np.zeros([1, 2])
for epoch in range(args.n_epoch):
img_counter = 1
loss_sum = 0
for i, (images, labels) in enumerate(trainloader):
img_counter = img_counter + 1
if torch.cuda.is_available():
images = Variable(images.cuda(0))
labels = Variable(labels.cuda(0))
else:
images = Variable(images)
labels = Variable(labels)
optimizer.zero_grad()
# log('The maximum value of input image is {}'.format(images.max()))
# print(images)
outputs = model(images)
# print(outputs)
if args.arch == 'bisenet3Dbrain' or args.arch == 'unet3d_cls' or args.arch == 'FCDenseNet57' or args.arch == 'FCDenseNet103':
loss = cross_entropy3d(outputs, labels)
elif args.arch == 'unet3d_res':
labels = labels * 40
labels = labels + 1
log('The unique value of labels are {}'.format(np.unique(labels)))
log('The maximum of outputs are {}'.format(outputs.max()))
log('The size of output is {}'.format(outputs.size()))
log('The size of labels is {}'.format(labels.size()))
loss = nn.L1Loss()
labels = labels.type(torch.cuda.FloatTensor)
outputs = torch.squeeze(outputs, dim=1)
loss = loss(outputs, labels)
else:
loss = cross_entropy2d(outputs, labels)
loss.backward()
optimizer.step()
loss_sum = loss_sum + torch.Tensor([loss.data]).unsqueeze(0).cpu()
avg_loss = loss_sum / img_counter
avg_loss_array = np.array(avg_loss)
epoch_loss_array_total = np.concatenate((epoch_loss_array_total, [[avg_loss_array[0][0], epoch]]), axis=0)
print('The current loss of epoch', epoch, 'is', avg_loss_array[0][0])
# training model will be saved
log('The variable avg_loss_array is {}'.format(avg_loss_array))
writer.add_scalar('train_main_loss', avg_loss_array[0][0], epoch)
if epoch % 10 == 0:
torch.save(model, "runs/{}/{}_{}_{}_{}.pkl".format(rand_int, args.arch, args.dataset, args.feature_scale, epoch))
# I guess the shape is (epoch, 2)
log('epoch_loss_array_total is {}'.format(epoch_loss_array_total))
# The shape of epoch_loss_array_total is (epoch, 2)
log('the shape of epoch_loss_array_total is {}'.format(epoch_loss_array_total.shape))
epoch_loss_array_total = np.delete(arr=epoch_loss_array_total, obj=0, axis=0)
log('the shape of epoch_loss_array_total after removal is {}'.format(epoch_loss_array_total.shape))
loss_min_indice = np.argmin(epoch_loss_array_total, axis=0)
log('The loss_min_indice is {}'.format(loss_min_indice))
torch.save(model, "runs/{}/{}_{}_{}_{}_min.pkl".format(rand_int, args.arch, args.dataset, args.feature_scale,
loss_min_indice[0]))
sys.stdout = orig_stdout
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--arch', nargs='?', type=str, default='FCDenseNet103',
help='Architecture to use [\'fcn8s, unet, segnet etc\']')
parser.add_argument('--dataset', nargs='?', type=str, default='brats17_loader',
help='Dataset to use [\'pascal, camvid, ade20k etc\']')
parser.add_argument('--img_rows', nargs='?', type=int, default=256,
help='Height of the input image')
parser.add_argument('--img_cols', nargs='?', type=int, default=256,
help='Height of the input image')
parser.add_argument('--n_epoch', nargs='?', type=int, default=2000,
help='# of the epochs')
parser.add_argument('--batch_size', nargs='?', type=int, default=1,
help='Batch Size')
parser.add_argument('--l_rate', nargs='?', type=float, default=1e-4,
help='Learning Rate')
parser.add_argument('--feature_scale', nargs='?', type=int, default=1,
help='Divider for # of features to use')
parser.add_argument('--patch_size', nargs='?', type=int, default=[64, 64, 64],
help='patch_size for training')
parser.add_argument('--pretrained_path', nargs='?', type=str, default='empty',
help='path for pretrained model')
parser.add_argument('--n_classes', nargs='?', type=int, default=4,
help='the number of class for classification')
args = parser.parse_args()
rand_int = np.random.randint(10000)
# 1, load configuration parameters
config_file_path = '/home/donghao/Desktop/donghao/isbi2019/code/fast_segmentation_code/runs/FCDenseNet57_train_87_wt_ax.txt'
log('Load Configuration Parameters')
config = parse_config(config_file_path)
orig_stdout = sys.stdout
writer = SummaryWriter('runs/' + str(rand_int))
f = open("runs/{}/log.txt".format(rand_int), 'w')
sys.stdout = f
print('###### Step Zero: Log Number is ', rand_int)
print('The dataset is {}'.format(args.dataset))
print('The nettype is {}'.format(args.arch))
print('The patch size is {}'.format(args.patch_size))
print('The learning rate is {}'.format(args.l_rate))
print('The total training epoch is {}'.format(args.n_epoch))
print('The batch_size is {}'.format(args.batch_size))
print('The pretrained path is {}'.format(args.pretrained_path))
# print('pretrained: /home/donghao/Desktop/donghao/isbi2019/code/fast_segmentation_code/runs/2177/bisenet3Dbrain_brats17_loader_1_293_min.pkl')
train(args=args, guotai_config=config)