-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
executable file
·384 lines (332 loc) · 19.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# coding=utf-8
# Copyright (c) Microsoft. All rights reserved.
import argparse
import json
import os
import random
from datetime import datetime
from pprint import pprint
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, BatchSampler
from pretrained_models import *
from tensorboardX import SummaryWriter
#from torch.utils.tensorboard import SummaryWriter
from experiments.exp_def import TaskDefs
from mt_dnn.inference import eval_model, extract_encoding
from data_utils.log_wrapper import create_logger
from data_utils.task_def import EncoderModelType
from data_utils.utils import set_environment
from mt_dnn.batcher import SingleTaskDataset, MultiTaskDataset, Collater, MultiTaskBatchSampler
from mt_dnn.model import MTDNNModel
def model_config(parser):
parser.add_argument('--update_bert_opt', default=0, type=int)
parser.add_argument('--multi_gpu_on', action='store_true')
parser.add_argument('--mem_cum_type', type=str, default='simple',
help='bilinear/simple/defualt')
parser.add_argument('--answer_num_turn', type=int, default=5)
parser.add_argument('--answer_mem_drop_p', type=float, default=0.1)
parser.add_argument('--answer_att_hidden_size', type=int, default=128)
parser.add_argument('--answer_att_type', type=str, default='bilinear',
help='bilinear/simple/defualt')
parser.add_argument('--answer_rnn_type', type=str, default='gru',
help='rnn/gru/lstm')
parser.add_argument('--answer_sum_att_type', type=str, default='bilinear',
help='bilinear/simple/defualt')
parser.add_argument('--answer_merge_opt', type=int, default=1)
parser.add_argument('--answer_mem_type', type=int, default=1)
parser.add_argument('--max_answer_len', type=int, default=10)
parser.add_argument('--answer_dropout_p', type=float, default=0.1)
parser.add_argument('--answer_weight_norm_on', action='store_true')
parser.add_argument('--dump_state_on', action='store_true')
parser.add_argument('--answer_opt', type=int, default=0, help='0,1')
parser.add_argument('--mtl_opt', type=int, default=0)
parser.add_argument('--ratio', type=float, default=0)
parser.add_argument('--mix_opt', type=int, default=0)
parser.add_argument('--max_seq_len', type=int, default=512)
parser.add_argument('--init_ratio', type=float, default=1)
parser.add_argument('--encoder_type', type=int, default=EncoderModelType.BERT)
parser.add_argument('--num_hidden_layers', type=int, default=-1)
# BERT pre-training
parser.add_argument('--bert_model_type', type=str, default='bert-base-uncased')
parser.add_argument('--do_lower_case', action='store_true')
parser.add_argument('--masked_lm_prob', type=float, default=0.15)
parser.add_argument('--short_seq_prob', type=float, default=0.2)
parser.add_argument('--max_predictions_per_seq', type=int, default=128)
return parser
def data_config(parser):
parser.add_argument('--log_file', default='mt-dnn-train.log', help='path for log file.')
parser.add_argument('--tensorboard', action='store_true')
parser.add_argument('--tensorboard_logdir', default='tensorboard_logdir')
parser.add_argument("--init_checkpoint", default='mt_dnn_models/bert_model_base_uncased.pt', type=str)
parser.add_argument('--data_dir', default='data/canonical_data/bert_uncased_lower')
parser.add_argument('--data_sort_on', action='store_true')
parser.add_argument('--name', default='farmer')
parser.add_argument('--task_def', type=str, default="experiments/glue/glue_task_def.yml")
parser.add_argument('--train_datasets', default='mnli')
parser.add_argument('--test_datasets', default='mnli_matched,mnli_mismatched')
parser.add_argument('--glue_format_on', action='store_true')
parser.add_argument('--mkd-opt', type=int, default=0,
help=">0 to turn on knowledge distillation, requires 'softlabel' column in input data")
parser.add_argument('--do_padding', action='store_true')
return parser
def train_config(parser):
parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available(),
help='whether to use GPU acceleration.')
parser.add_argument('--log_per_updates', type=int, default=500)
parser.add_argument('--save_per_updates', type=int, default=10000)
parser.add_argument('--save_per_updates_on', action='store_true')
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--batch_size_eval', type=int, default=8)
parser.add_argument('--optimizer', default='adamax',
help='supported optimizer: adamax, sgd, adadelta, adam')
parser.add_argument('--grad_clipping', type=float, default=0)
parser.add_argument('--global_grad_clipping', type=float, default=1.0)
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--learning_rate', type=float, default=5e-5)
parser.add_argument('--momentum', type=float, default=0)
parser.add_argument('--warmup', type=float, default=0.1)
parser.add_argument('--warmup_schedule', type=str, default='warmup_linear')
parser.add_argument('--adam_eps', type=float, default=1e-6)
parser.add_argument('--vb_dropout', action='store_false')
parser.add_argument('--dropout_p', type=float, default=0.1)
parser.add_argument('--dropout_w', type=float, default=0.000)
parser.add_argument('--bert_dropout_p', type=float, default=0.1)
# loading
parser.add_argument("--model_ckpt", default='checkpoints/model_0.pt', type=str)
parser.add_argument("--resume", action='store_true')
# scheduler
parser.add_argument('--have_lr_scheduler', dest='have_lr_scheduler', action='store_false')
parser.add_argument('--multi_step_lr', type=str, default='10,20,30')
#parser.add_argument('--feature_based_on', action='store_true')
parser.add_argument('--lr_gamma', type=float, default=0.5)
parser.add_argument('--scheduler_type', type=str, default='ms', help='ms/rop/exp')
parser.add_argument('--output_dir', default='checkpoint')
parser.add_argument('--seed', type=int, default=2018,
help='random seed for data shuffling, embedding init, etc.')
parser.add_argument('--grad_accumulation_step', type=int, default=1)
#fp 16
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
# adv training
parser.add_argument('--adv_train', action='store_true')
# the current release only includes smart perturbation
parser.add_argument('--adv_opt', default=0, type=int)
parser.add_argument('--adv_p_norm', default='inf', type=str)
parser.add_argument('--adv_alpha', default=1, type=float)
parser.add_argument('--adv_k', default=1, type=int)
parser.add_argument('--adv_step_size', default=1e-3, type=float)
parser.add_argument('--adv_noise_var', default=1e-5, type=float)
parser.add_argument('--adv_epsilon', default=1e-6, type=float)
return parser
parser = argparse.ArgumentParser()
parser = data_config(parser)
parser = model_config(parser)
parser = train_config(parser)
parser.add_argument('--encode_mode', action='store_true', help="only encode test data")
args = parser.parse_args()
output_dir = args.output_dir
data_dir = args.data_dir
args.train_datasets = args.train_datasets.split(',')
args.test_datasets = args.test_datasets.split(',')
pprint(args)
os.makedirs(output_dir, exist_ok=True)
output_dir = os.path.abspath(output_dir)
set_environment(args.seed, args.cuda)
log_path = args.log_file
logger = create_logger(__name__, to_disk=True, log_file=log_path)
logger.info(args.answer_opt)
task_defs = TaskDefs(args.task_def)
encoder_type = args.encoder_type
def dump(path, data):
with open(path, 'w') as f:
json.dump(data, f)
def main():
logger.info('Launching the MT-DNN training')
opt = vars(args)
# update data dir
opt['data_dir'] = data_dir
batch_size = args.batch_size
tasks = {}
task_def_list = []
dropout_list = []
train_datasets = []
for dataset in args.train_datasets:
prefix = dataset.split('_')[0]
if prefix in tasks:
continue
task_id = len(tasks)
tasks[prefix] = task_id
task_def = task_defs.get_task_def(prefix)
task_def_list.append(task_def)
train_path = os.path.join(data_dir, '{}_train.json'.format(dataset))
logger.info('Loading {} as task {}'.format(train_path, task_id))
train_data_set = SingleTaskDataset(train_path, True, maxlen=args.max_seq_len, task_id=task_id, task_def=task_def)
train_datasets.append(train_data_set)
train_collater = Collater(dropout_w=args.dropout_w, encoder_type=encoder_type, soft_label=args.mkd_opt > 0, max_seq_len=args.max_seq_len, do_padding=args.do_padding)
multi_task_train_dataset = MultiTaskDataset(train_datasets)
multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, args.batch_size, args.mix_opt, args.ratio)
multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, collate_fn=train_collater.collate_fn, pin_memory=args.cuda)
opt['task_def_list'] = task_def_list
dev_data_list = []
test_data_list = []
test_collater = Collater(is_train=False, encoder_type=encoder_type, max_seq_len=args.max_seq_len, do_padding=args.do_padding)
for dataset in args.test_datasets:
prefix = dataset.split('_')[0]
task_def = task_defs.get_task_def(prefix)
task_id = tasks[prefix]
task_type = task_def.task_type
data_type = task_def.data_type
dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset))
dev_data = None
if os.path.exists(dev_path):
dev_data_set = SingleTaskDataset(dev_path, False, maxlen=args.max_seq_len, task_id=task_id, task_def=task_def)
dev_data = DataLoader(dev_data_set, batch_size=args.batch_size_eval, collate_fn=test_collater.collate_fn, pin_memory=args.cuda)
dev_data_list.append(dev_data)
test_path = os.path.join(data_dir, '{}_test.json'.format(dataset))
test_data = None
if os.path.exists(test_path):
test_data_set = SingleTaskDataset(test_path, False, maxlen=args.max_seq_len, task_id=task_id, task_def=task_def)
test_data = DataLoader(test_data_set, batch_size=args.batch_size_eval, collate_fn=test_collater.collate_fn, pin_memory=args.cuda)
test_data_list.append(test_data)
logger.info('#' * 20)
logger.info(opt)
logger.info('#' * 20)
# div number of grad accumulation.
num_all_batches = args.epochs * len(multi_task_train_data) // args.grad_accumulation_step
logger.info('############# Gradient Accumulation Info #############')
logger.info('number of step: {}'.format(args.epochs * len(multi_task_train_data)))
logger.info('number of grad grad_accumulation step: {}'.format(args.grad_accumulation_step))
logger.info('adjusted number of step: {}'.format(num_all_batches))
logger.info('############# Gradient Accumulation Info #############')
init_model = args.init_checkpoint
state_dict = None
if os.path.exists(init_model):
if encoder_type == EncoderModelType.BERT:
state_dict = torch.load(init_model)
config = state_dict['config']
elif encoder_type == EncoderModelType.ROBERTA:
model_path = '{}/model.pt'.format(init_model)
state_dict = torch.load(model_path)
arch = state_dict['args'].arch
arch = arch.replace('_', '-')
# convert model arch
from data_utils.roberta_utils import update_roberta_keys
from data_utils.roberta_utils import patch_name_dict
state = update_roberta_keys(state_dict['model'], nlayer=state_dict['args'].encoder_layers)
state = patch_name_dict(state)
literal_encoder_type = EncoderModelType(opt['encoder_type']).name.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[literal_encoder_type]
config = config_class.from_pretrained(arch).to_dict()
state_dict = {'state': state}
else:
if opt['encoder_type'] not in EncoderModelType._value2member_map_:
raise ValueError("encoder_type is out of pre-defined types")
literal_encoder_type = EncoderModelType(opt['encoder_type']).name.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[literal_encoder_type]
config = config_class.from_pretrained(init_model).to_dict()
config['attention_probs_dropout_prob'] = args.bert_dropout_p
config['hidden_dropout_prob'] = args.bert_dropout_p
config['multi_gpu_on'] = opt["multi_gpu_on"]
if args.num_hidden_layers != -1:
config['num_hidden_layers'] = args.num_hidden_layers
opt.update(config)
model = MTDNNModel(opt, state_dict=state_dict, num_train_step=num_all_batches)
if args.resume and args.model_ckpt:
logger.info('loading model from {}'.format(args.model_ckpt))
model.load(args.model_ckpt)
#### model meta str
headline = '############# Model Arch of MT-DNN #############'
### print network
logger.info('\n{}\n{}\n'.format(headline, model.network))
# dump config
config_file = os.path.join(output_dir, 'config.json')
with open(config_file, 'w', encoding='utf-8') as writer:
writer.write('{}\n'.format(json.dumps(opt)))
writer.write('\n{}\n{}\n'.format(headline, model.network))
logger.info("Total number of params: {}".format(model.total_param))
# tensorboard
if args.tensorboard:
args.tensorboard_logdir = os.path.join(args.output_dir, args.tensorboard_logdir)
tensorboard = SummaryWriter(log_dir=args.tensorboard_logdir)
if args.encode_mode:
for idx, dataset in enumerate(args.test_datasets):
prefix = dataset.split('_')[0]
test_data = test_data_list[idx]
with torch.no_grad():
encoding = extract_encoding(model, test_data, use_cuda=args.cuda)
torch.save(encoding, os.path.join(output_dir, '{}_encoding.pt'.format(dataset)))
return
for epoch in range(0, args.epochs):
logger.warning('At epoch {}'.format(epoch))
start = datetime.now()
for i, (batch_meta, batch_data) in enumerate(multi_task_train_data):
batch_meta, batch_data = Collater.patch_data(args.cuda, batch_meta, batch_data)
task_id = batch_meta['task_id']
model.update(batch_meta, batch_data)
if (model.local_updates) % (args.log_per_updates * args.grad_accumulation_step) == 0 or model.local_updates == 1:
ramaining_time = str((datetime.now() - start) / (i + 1) * (len(multi_task_train_data) - i - 1)).split('.')[0]
logger.info('Task [{0:2}] updates[{1:6}] train loss[{2:.5f}] remaining[{3}]'.format(task_id,
model.updates,
model.train_loss.avg,
ramaining_time))
if args.tensorboard:
tensorboard.add_scalar('train/loss', model.train_loss.avg, global_step=model.updates)
if args.save_per_updates_on and ((model.local_updates) % (args.save_per_updates * args.grad_accumulation_step) == 0):
model_file = os.path.join(output_dir, 'model_{}_{}.pt'.format(epoch, model.updates))
logger.info('Saving mt-dnn model to {}'.format(model_file))
model.save(model_file)
for idx, dataset in enumerate(args.test_datasets):
prefix = dataset.split('_')[0]
task_def = task_defs.get_task_def(prefix)
label_dict = task_def.label_vocab
dev_data = dev_data_list[idx]
if dev_data is not None:
with torch.no_grad():
dev_metrics, dev_predictions, scores, golds, dev_ids= eval_model(model,
dev_data,
metric_meta=task_def.metric_meta,
use_cuda=args.cuda,
label_mapper=label_dict,
task_type=task_def.task_type)
for key, val in dev_metrics.items():
if args.tensorboard:
tensorboard.add_scalar('dev/{}/{}'.format(dataset, key), val, global_step=epoch)
if isinstance(val, str):
logger.warning('Task {0} -- epoch {1} -- Dev {2}:\n {3}'.format(dataset, epoch, key, val))
else:
logger.warning('Task {0} -- epoch {1} -- Dev {2}: {3:.3f}'.format(dataset, epoch, key, val))
score_file = os.path.join(output_dir, '{}_dev_scores_{}.json'.format(dataset, epoch))
results = {'metrics': dev_metrics, 'predictions': dev_predictions, 'uids': dev_ids, 'scores': scores}
dump(score_file, results)
if args.glue_format_on:
from experiments.glue.glue_utils import submit
official_score_file = os.path.join(output_dir, '{}_dev_scores_{}.tsv'.format(dataset, epoch))
submit(official_score_file, results, label_dict)
# test eval
test_data = test_data_list[idx]
if test_data is not None:
with torch.no_grad():
test_metrics, test_predictions, scores, golds, test_ids= eval_model(model, test_data,
metric_meta=task_def.metric_meta,
use_cuda=args.cuda, with_label=False,
label_mapper=label_dict,
task_type=task_def.task_type)
score_file = os.path.join(output_dir, '{}_test_scores_{}.json'.format(dataset, epoch))
results = {'metrics': test_metrics, 'predictions': test_predictions, 'uids': test_ids, 'scores': scores}
dump(score_file, results)
if args.glue_format_on:
from experiments.glue.glue_utils import submit
official_score_file = os.path.join(output_dir, '{}_test_scores_{}.tsv'.format(dataset, epoch))
submit(official_score_file, results, label_dict)
logger.info('[new test scores saved.]')
model_file = os.path.join(output_dir, 'model_{}.pt'.format(epoch))
model.save(model_file)
if args.tensorboard:
tensorboard.close()
if __name__ == '__main__':
main()