-
Notifications
You must be signed in to change notification settings - Fork 1
/
finetune_boolq.py
executable file
·112 lines (89 loc) · 5.87 KB
/
finetune_boolq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
from torch.utils.data import DataLoader
from transformers import GPT2ForSequenceClassification,GPT2Config, RobertaForSequenceClassification, RobertaConfig, \
BertForSequenceClassification, BertConfig, ElectraForSequenceClassification, ElectraConfig
from util.trainer import Trainer
from util.dataset import LoadDataset_boolq
import torch
from torch import distributed as dist
from torch.utils.data.distributed import DistributedSampler
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset")
parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluate train set")
parser.add_argument("-o", "--output_path", required=True, type=str, help="ex)output/bert.model")
parser.add_argument("--model", type=str, required=True, help="model (base,trinity)")
parser.add_argument("--ddp", type=bool, default=False, help="for distrbuted data parrerel")
parser.add_argument("--local_rank", type=int, help="for distrbuted data parrerel")
parser.add_argument("--input_seq_len", required=True, type=int, default=512, help="maximum sequence input len")
parser.add_argument("-b", "--batch_size", type=int, default=8, help="number of batch_size")
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
parser.add_argument("-w", "--num_workers", type=int, default=5, help="dataloader worker size")
parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
parser.add_argument("--log_freq", type=int, default=1, help="printing loss every n iter: setting n")
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate of adam")
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value")
parser.add_argument("--accumulate", type=int, default=1, help="accumulation step")
parser.add_argument("--seed", type=int, default=42, help="seed")
args = parser.parse_args()
if args.ddp:
dist.init_process_group(backend='nccl', init_method='env://')
torch.cuda.set_device(args.local_rank)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
print("Loading Train Dataset", args.train_dataset)
train_dataset = LoadDataset_boolq(args.train_dataset, seq_len=args.input_seq_len, model=args.model)
print("Loading Test Dataset", args.test_dataset)
test_dataset = LoadDataset_boolq(args.test_dataset, seq_len=args.input_seq_len, model=args.model) \
if args.test_dataset is not None else None
if args.ddp:
print("Creating Dataloader")
train_sampler = DistributedSampler(train_dataset)
train_data_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size, num_workers=args.num_workers)
else:
print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
if args.ddp:
print("Creating Dataloader")
test_sampler = DistributedSampler(test_dataset)
test_data_loader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.batch_size, num_workers=args.num_workers) \
if test_dataset is not None else None
else:
test_data_loader = DataLoader(test_dataset, batch_size=50, num_workers=args.num_workers) \
if test_dataset is not None else None
if args.model =='base':
model_config = GPT2Config.from_pretrained(pretrained_model_name_or_path= "skt/kogpt2-base-v2", num_labels=2)
model = GPT2ForSequenceClassification.from_pretrained("skt/kogpt2-base-v2", config=model_config)
elif args.model =='trinity':
model_config = GPT2Config.from_pretrained(pretrained_model_name_or_path="skt/ko-gpt-trinity-1.2B-v0.5", num_labels=2)
model = GPT2ForSequenceClassification.from_pretrained("skt/ko-gpt-trinity-1.2B-v0.5", config=model_config)
elif args.model == 'roberta':
model_config = RobertaConfig.from_pretrained(pretrained_model_name_or_path="klue/roberta-large",
num_labels=2)
model = RobertaForSequenceClassification.from_pretrained("klue/roberta-large", config=model_config)
elif args.model == 'electra':
model_config = ElectraConfig.from_pretrained(pretrained_model_name_or_path="monologg/koelectra-base-v3-discriminator",
num_labels=2)
model = ElectraForSequenceClassification.from_pretrained("monologg/koelectra-base-v3-discriminator", config=model_config)
elif args.model == 'electra_tunib':
model_config = ElectraConfig.from_pretrained(pretrained_model_name_or_path="tunib/electra-ko-base",
num_labels=2)
model = ElectraForSequenceClassification.from_pretrained("tunib/electra-ko-base", config=model_config)
print("Creating Trainer")
trainer = Trainer(task='boolq', model=model, train_dataloader=train_data_loader, test_dataloader=test_data_loader,
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq,
distributed = args.ddp, local_rank = args.local_rank, accum_iter= args.accumulate, seed= args.seed, model_name=args.model)
print("Training Start")
for epoch in range(args.epochs):
if args.ddp:
train_sampler.set_epoch(epoch)
trainer.train(epoch)
# if args.local_rank == 0:
# trainer.save(epoch, args.output_path)
else:
trainer.train(epoch)
# trainer.save(epoch, args.output_path)