Skip to content

Commit

Permalink
20180809
Browse files Browse the repository at this point in the history
  • Loading branch information
F committed Aug 9, 2018
1 parent 188fb72 commit d7b70d4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
11 changes: 3 additions & 8 deletions Learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tensorboardX import SummaryWriter
from matplotlib import pyplot as plt
plt.switch_backend('agg')
from utils import get_time, gen_plot, hflip_batch
from utils import get_time, gen_plot, hflip_batch, separate_bn_paras
from PIL import Image
from torchvision import transforms as trans
import math
Expand All @@ -34,13 +34,8 @@ def __init__(self, conf, inference=False):

print('two model heads generated')

paras_only_bn = []
paras_wo_bn = []
for para in self.model.named_parameters():
if 'bn' in para[0]:
paras_only_bn.append(para[1])
else:
paras_wo_bn.append(para[1])
paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

if conf.use_mobilfacenet:
self.optimizer = optim.SGD([
{'params': paras_wo_bn[:-1], 'weight_decay': 4e-5},
Expand Down
17 changes: 17 additions & 0 deletions prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pathlib import Path
from config import get_config
from data.data_pipe import load_bin, load_mx_rec
import argparse

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='for face verification')
parser.add_argument("-r", "--rec_path", help="mxnet record file path",default='faces_emore', type=str)
args = parser.parse_args()
conf = get_config()
rec_path = conf.data_path/args.rec_path
load_mx_rec(rec_path)

bin_files = ['agedb_30', 'cfp_fp', 'lfw', 'calfw', 'cfp_ff', 'cplfw', 'vgg2_fp']

for i in range(len(bin_files)):
load_bin(rec_path/(bin_files[i]+'.bin'), rec_path/bin_files[i], conf.test_transform)
17 changes: 17 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@
import pdb
import cv2

def separate_bn_paras(modules):
if not isinstance(modules, list):
modules = [*module.modules()]
paras_only_bn = []
paras_wo_bn = []
for layer in modules:
if 'model' in str(layer.__class__):
continue
if 'container' in str(layer.__class__):
continue
else:
if 'batchnorm' in str(layer.__class__):
paras_only_bn.extend([*layer.parameters()])
else:
paras_wo_bn.extend([*layer.parameters()])
return paras_only_bn, paras_wo_bn

def prepare_facebank(conf, model, mtcnn, tta = True):
model.eval()
embeddings = []
Expand Down

0 comments on commit d7b70d4

Please sign in to comment.