-
Notifications
You must be signed in to change notification settings - Fork 44
/
ngransac_train_init.py
119 lines (82 loc) · 3.7 KB
/
ngransac_train_init.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import math
import torch
import torch.optim as optim
import ngransac
from network import CNNet
from dataset import SparseDataset
import util
# parse command line arguments
parser = util.create_parser(
description = "Train a neural guidance network using correspondence distance to a ground truth model to calculate target probabilities.")
parser.add_argument('--datasets', '-ds',
default='brown_bm_3---brown_bm_3-maxpairs-10000-random---skip-10-dilate-25,st_peters_square',
help='which datasets to use, separate multiple datasets by comma')
parser.add_argument('--variant', '-v', default='train',
help='subfolder of the dataset to use')
parser.add_argument('--learningrate', '-lr', type=float, default=0.001,
help='learning rate')
parser.add_argument('--epochs', '-e', type=int, default=1000,
help='number of epochs')
parser.add_argument('--model', '-m', default='',
help='load a model to contuinue training or leave empty to create a new model')
opt = parser.parse_args()
# construct folder that should contain pre-calculated correspondences
data_folder = opt.variant + '_data'
if opt.orb:
data_folder += '_orb'
if opt.rootsift:
data_folder += '_rs'
train_data = opt.datasets.split(',') #support multiple training datasets used jointly
train_data = ['traindata/' + ds + '/' + data_folder + '/' for ds in train_data]
print('Using datasets:')
for d in train_data:
print(d)
trainset = SparseDataset(train_data, opt.ratio, opt.nfeatures, opt.fmat, opt.nosideinfo)
trainset_loader = torch.utils.data.DataLoader(trainset, shuffle=True, num_workers=6, batch_size=opt.batchsize)
print("\nImage pairs: ", len(trainset), "\n")
# create or load model
model = CNNet(opt.resblocks)
if len(opt.model) > 0:
model.load_state_dict(torch.load(opt.model))
model = model.cuda()
model.train()
optimizer = optim.Adam(model.parameters(), lr=opt.learningrate)
iteration = 0
# keep track of the training progress
session_string = util.create_session_string('init', opt.fmat, opt.orb, opt.rootsift, opt.ratio, opt.session)
train_log = open('log_%s.txt' % (session_string), 'w', 1)
# in the initalization we optimize the KLDiv of the predicted distribution and the target distgribution (see paper supplement A, Eq. 12)
distLoss = torch.nn.KLDivLoss(reduction='sum')
# main training loop
for epoch in range(0, opt.epochs):
print("=== Starting Epoch", epoch, "==================================")
# store the network every epoch
torch.save(model.state_dict(), './weights_%s.net' % (session_string))
# main training loop in the current epoch
for correspondences, gt_F, gt_E, gt_R, gt_t, K1, K2, im_size1, im_size2 in trainset_loader:
log_probs = model(correspondences.cuda())
probs = torch.exp(log_probs).cpu()
target_probs = torch.zeros(probs.size())
for b in range(0, correspondences.size(0)):
# calculate the target distribution (see paper supplement A, Eq. 12)
if opt.fmat:
# === CASE FUNDAMENTAL MATRIX =========================================
util.denormalize_pts(correspondences[b, 0:2], im_size1[b])
util.denormalize_pts(correspondences[b, 2:4], im_size2[b])
ngransac.gtdist(correspondences[b], target_probs[b], gt_F[b], opt.threshold, True)
else:
# === CASE ESSENTIAL MATRIX =========================================
ngransac.gtdist(correspondences[b], target_probs[b], gt_E[b], opt.threshold, False)
log_probs.squeeze_()
target_probs.squeeze_()
# KL divergence
loss = distLoss(log_probs, target_probs.cuda()) / correspondences.size(0)
# update model
loss.backward()
optimizer.step()
optimizer.zero_grad()
print("Iteration: ", iteration, "Loss: ", float(loss))
train_log.write('%d %f\n' % (iteration, loss))
iteration += 1
del log_probs, probs, target_probs, loss