Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
F committed Oct 25, 2018
1 parent 52eeeec commit 2bfec5a
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 22 deletions.
25 changes: 9 additions & 16 deletions Learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, conf, inference=False):
print('{}_{} model generated'.format(conf.net_mode, conf.net_depth))

if not inference:
# self.milestones = conf.milestones
self.milestones = conf.milestones
self.loader, self.class_num = get_train_loader(conf)

self.writer = SummaryWriter(conf.log_path)
Expand All @@ -48,7 +48,7 @@ def __init__(self, conf, inference=False):
{'params': paras_only_bn}
], lr = conf.lr, momentum = conf.momentum)
print(self.optimizer)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)
# self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

print('optimizers generated')
self.board_loss_every = len(self.loader)//100
Expand Down Expand Up @@ -183,9 +183,15 @@ def find_lr(self,

def train(self, conf, epochs):
self.model.train()
running_loss = 0.
running_loss = 0.
for e in range(epochs):
print('epoch {} started'.format(e))
if e == self.milestones[0]:
self.schedule_lr()
if e == self.milestones[1]:
self.schedule_lr()
if e == self.milestones[2]:
self.schedule_lr()
for imgs, labels in tqdm(iter(self.loader)):
imgs = imgs.to(conf.device)
labels = labels.to(conf.device)
Expand All @@ -200,7 +206,6 @@ def train(self, conf, epochs):
if self.step % self.board_loss_every == 0 and self.step != 0:
loss_board = running_loss / self.board_loss_every
self.writer.add_scalar('train_loss', loss_board, self.step)
self.scheduler.step(loss_board)
running_loss = 0.

if self.step % self.evaluate_every == 0 and self.step != 0:
Expand All @@ -216,18 +221,6 @@ def train(self, conf, epochs):

self.step += 1

# if e == self.milestones[0]:
# self.schedule_lr()
# self.save_state(conf, accuracy, to_save_folder=True, extra='{} epochs'.format(e))

# if e == self.milestones[1]:
# self.schedule_lr()
# self.save_state(conf, accuracy, to_save_folder=True, extra='{} epochs'.format(e))

# if e == self.milestones[2]:
# self.schedule_lr()
# self.save_state(conf, accuracy, to_save_folder=True, extra='{} epochs'.format(e))

self.save_state(conf, accuracy, to_save_folder=True, extra='final')

def schedule_lr(self):
Expand Down
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def get_config(training = True):
conf.save_path = conf.work_path/'save'
# conf.weight_decay = 5e-4
conf.lr = 1e-3
# conf.milestones = [3,4,5] # mobildefacenet
conf.milestones = [4,6,7] # arcface
conf.milestones = [10,13,15] # mobildefacenet
# conf.milestones = [12,15,18] # arcface
conf.momentum = 0.9
conf.pin_memory = True
# conf.num_workers = 4 # when batchsize is 200
Expand Down
5 changes: 3 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,13 @@ def forward(self, embbedings, label):
# 0<=theta+m<=pi
# -m<=theta<=pi-m
cond_v = cos_theta - self.threshold
cond_mask = cond_v > 0
cond_mask = cond_v <= 0
keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
cos_theta_m[cond_mask] = keep_val[cond_mask]
label = label.view(-1,1) #size=(B,1)
output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
output[torch.range(0,nB-1).to(torch.long),label] = cos_theta_m[torch.range(0,nB-1).to(torch.long),label]
idx_ = torch.arange(0, nB, dtype=torch.long)
output[idx_, label] = cos_theta_m[idx_, label]
output *= self.s # scale up in order to make softmax work, first introduced in normface
return output

Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from Learner import face_learner
import argparse

# python train.py -net mobilefacenet -b 200 -w 4

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='for face verification')
parser.add_argument("-e", "--epochs", help="training epochs", default=8, type=int)
parser.add_argument("-e", "--epochs", help="training epochs", default=18, type=int)
parser.add_argument("-net", "--net_mode", help="which network, [ir, ir_se, mobilefacenet]",default='ir_se', type=str)
parser.add_argument("-depth", "--net_depth", help="how many layers [50,100,152]", default=50, type=int)
parser.add_argument('-lr','--lr',help='learning rate',default=1e-3, type=float)
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def separate_bn_paras(modules):
if not isinstance(modules, list):
modules = [*module.modules()]
modules = [*modules.modules()]
paras_only_bn = []
paras_wo_bn = []
for layer in modules:
Expand Down

0 comments on commit 2bfec5a

Please sign in to comment.