-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
145 lines (114 loc) · 5.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from utils import *
from dataset import PascalVOCDataset
from model import *
from tqdm import tqdm
torch.cuda.empty_cache()
# Data parameters
data_folder: str = r'D:\ObjectDetection\PascalVOC' # folder with data files
keep_difficult: bool = True # difficult objects to detect
n_classes: int = len(label_map)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Learning parameters
checkpoint: str = 'checkpoints/checkpoint_ssd300.pt' # path to model checkpoint, None if none
batch_size: int = 25 # batch size
iterations: int = 120000 # number of iterations to train
workers: int = 0 # number of workers for loading data in the DataLoader
print_freq: int = 100 # print training status every __ batches
lr: float = 1e-4 # learning rate
decay_lr_at: [int] = [80000, 100000] # decay learning rate after these many iterations
decay_lr_to: float = 0.1 # decay learning rate to this fraction of the existing learning rate
momentum: float = 0.9 # momentum, when using SGD
weight_decay: float = 5e-4 # weight decay
grad_clip: float = 0.0 # clip if gradients are exploding
betas: (float, float) = (0.9, 0.999) # betas, when using Adam
cudnn.benchmark = True
def main():
global start_epoch, label_map, epoch, checkpoint, decay_lr_at
if checkpoint is None:
start_epoch = 0
# Initialize model or load checkpoint
print("Initializing model...")
model = SSD300(n_classes=n_classes)
biases, not_biases = [], []
for param_name, param in model.named_parameters():
if param.requires_grad:
if param_name.endswith('.bias'):
biases.append(param)
else:
not_biases.append(param)
optimizer = torch.optim.Adam(params=[{'params': biases, 'lr': 2 * lr}, {'params': not_biases}],
lr=lr, betas=betas, weight_decay=weight_decay, amsgrad=True)
else:
checkpoint = torch.load(checkpoint)
start_epoch = checkpoint['epoch'] + 1
print('\nLoaded checkpoint from epoch %d.\n' % start_epoch)
model = checkpoint['model']
optimizer = checkpoint['optimizer']
# Move to default device
model = model.to(device)
criterion = MultiBoxLoss(priors_cxcy=model.priors_cxcy).to(device)
# Custom dataloaders
train_dataset = PascalVOCDataset(data_folder, split='train', keep_difficult=keep_difficult)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
num_workers=workers, collate_fn=train_dataset.collate_fn)
# Epochs
epochs = iterations // (len(train_dataset) // 32) # the paper trains for 120k iteration with a batch size of 32
print('Training for %d epochs...' % epochs)
decay_lr_at = [it // (len(train_dataset) // 32) for it in decay_lr_at]
for epoch in range(start_epoch, epochs):
# One epoch's training
train(train_loader=train_loader,
model=model,
criterion=criterion,
optimizer=optimizer,
epoch=epoch)
# Decay learning rate at particular epochs
if epoch in decay_lr_at:
adjust_learning_rate(optimizer, decay_lr_to)
# Save checkpoint
save_checkpoint(epoch, model, optimizer)
def train(train_loader, model, criterion, optimizer, epoch):
# switch to train mode
model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
end = time.time()
for i, (images, boxes, labels, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
# measure data loading time
data_time.update(time.time() - end)
images = images.to(device) # (batch_size (N), 3, 300, 300)
boxes = [b.to(device) for b in boxes]
labels = [l.to(device) for l in labels]
# Forward prop.
predicted_locs, predicted_scores = model(images)
# Loss
loss = criterion(predicted_locs, predicted_scores, boxes, labels)
# Backward prop.
optimizer.zero_grad()
loss.backward()
# Clip gradients, if necessary
if grad_clip != 0.0:
torch.nn.utils.clip_grad_value_(model.parameters(), grad_clip)
# Update model
optimizer.step()
# Keep track of metrics
losses.update(loss.item(), images.size(0))
batch_time.update(time.time() - end)
end = time.time()
# Print status
if i % print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Batch Time: Value = {batch_time.val:.3f} (Average = {batch_time.avg:.3f})\t'
'Data Time: Value = {data_time.val:.3f} (Average = {data_time.avg:.3f})\t'
'Loss = {loss.val:.4f}'.format(epoch, i, len(train_loader),
batch_time=batch_time,
data_time=data_time,
loss=losses))
del predicted_locs, predicted_scores, images, boxes, labels # free some memory since their histories may be stored
if __name__ == '__main__':
main()