-
Notifications
You must be signed in to change notification settings - Fork 10
/
train_pfcands_adversarial_from_logpreds.py
80 lines (71 loc) · 2.79 KB
/
train_pfcands_adversarial_from_logpreds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import print_function
import os
import shutil
import argparse
from common import data
from adversarial import fit_from_logpreds as fit
from importlib import import_module
if __name__ == '__main__':
# location of data
train_val_fname = '/data/hqu/ntuples/20170717/pfcands_minor_labels/train_file_*.h5'
test_fname = '/data/hqu/ntuples/20170717/pfcands_minor_labels/testing/train_file_*.h5'
example_fname = '/data/hqu/ntuples/20170717/pfcands_minor_labels/train_file_0.h5'
# parse args
parser = argparse.ArgumentParser(description="train pfcands",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
fit.add_fit_args(parser)
data.add_data_args(parser)
parser.set_defaults(
# network
network='resnet_adversarial',
# config
model_prefix='/data/hqu/training/mxnet/models/adversarial/pfcands_minor_labels-20170717/resnet-simple-adv/resnet',
disp_batches=500,
# data
data_config='data_pfcands_adversarial',
data_train=train_val_fname,
train_val_split=0.8,
data_test=test_fname,
data_example=example_fname,
data_names=None,
label_names='softmax_label',
weight_names='weight,class_weight',
num_examples=-1,
# train
batch_size=1024,
num_epochs=200,
optimizer='sgd', # NOT converging with ADAM !!!
lr=0.01,
top_k=2,
lr_step_epochs='20,30,40,60',
)
args = parser.parse_args()
# load data config
dd = import_module('data.' + args.data_config)
args.data_names = ','.join(dd.train_groups)
if args.dryrun:
print('--DRY RUN--')
# args.weight_names = ''
args.data_train = example_fname
args.train_val_split = 0.5
# args.num_examples = dd.nb_wgt_samples([example_fname], args.weight_names)[0]
if args.load_epoch:
print('-' * 50)
# n_train, n_val, n_test = dd.nb_wgt_samples([args.data_train, args.data_val, args.data_test], args.weight_names)
n_train_val, n_test = dd.nb_samples([args.data_train, args.data_test])
n_train = int(n_train_val * args.train_val_split)
n_val = int(n_train_val * (1 - args.train_val_split))
print(' --- Training sample size = %d, Validation sample size = %d, Test sample size = %d ---' % (n_train, n_val, n_test))
if args.num_examples < 0:
args.num_examples = n_train
# load network
sym = import_module('symbols.' + args.network)
if args.predict:
fit.predict(args, sym, dd.load_data)
else:
save_dir = os.path.dirname(args.model_prefix)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
shutil.copy('symbols/%s.py' % args.network, save_dir)
# train
fit.fit(args, sym, dd.load_data)