-
Notifications
You must be signed in to change notification settings - Fork 1
/
finetune_copa.py
executable file
·118 lines (94 loc) · 6.21 KB
/
finetune_copa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
from torch.utils.data import DataLoader
from transformers import GPT2ForSequenceClassification,GPT2Config, RobertaForSequenceClassification, RobertaConfig, \
BertForSequenceClassification, BertConfig, ElectraForSequenceClassification, ElectraConfig, RobertaForMultipleChoice, ElectraForMultipleChoice
from util.trainer import Trainer
from util.dataset import LoadDataset_copa
import torch
from torch import distributed as dist
from torch.utils.data.distributed import DistributedSampler
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset")
parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluate train set")
parser.add_argument("-o", "--output_path", required=True, type=str, help="ex)output/bert.model")
parser.add_argument("--model", type=str, required=True, help="model (base,trinity)")
parser.add_argument("--ddp", type=bool, default=False, help="for distrbuted data parrerel")
parser.add_argument("--local_rank", type=int, help="for distrbuted data parrerel")
parser.add_argument("--input_seq_len", required=True, type=int, default=512, help="maximum sequence input len")
parser.add_argument("-b", "--batch_size", type=int, default=8, help="number of batch_size")
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
parser.add_argument("-w", "--num_workers", type=int, default=5, help="dataloader worker size")
parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
parser.add_argument("--log_freq", type=int, default=1, help="printing loss every n iter: setting n")
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate of adam")
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value")
parser.add_argument("--accumulate", type=int, default=1, help="accumulation step")
parser.add_argument("--seed", type=int, default=42, help="seed")
args = parser.parse_args()
if args.ddp:
dist.init_process_group(backend='nccl', init_method='env://')
torch.cuda.set_device(args.local_rank)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
print("Loading Train Dataset", args.train_dataset)
train_dataset = LoadDataset_copa(args.train_dataset, seq_len=args.input_seq_len, model=args.model)
print("Loading Test Dataset", args.test_dataset)
test_dataset = LoadDataset_copa(args.test_dataset, seq_len=args.input_seq_len, model=args.model) \
if args.test_dataset is not None else None
if args.ddp:
print("Creating Dataloader")
train_sampler = DistributedSampler(train_dataset)
train_data_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size, num_workers=args.num_workers)
else:
print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
if args.ddp:
print("Creating Dataloader")
test_sampler = DistributedSampler(test_dataset)
test_data_loader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.batch_size, num_workers=args.num_workers) \
if test_dataset is not None else None
else:
test_data_loader = DataLoader(test_dataset, batch_size=200, num_workers=args.num_workers) \
if test_dataset is not None else None
if args.model =='base':
NotImplementedError
# model_config = GPT2Config.from_pretrained(pretrained_model_name_or_path= "skt/kogpt2-base-v2", num_labels=2)
# model = GPT2ForSequenceClassification.from_pretrained("skt/kogpt2-base-v2", config=model_config)
elif args.model =='trinity':
NotImplementedError
# model_config = GPT2Config.from_pretrained(pretrained_model_name_or_path="skt/ko-gpt-trinity-1.2B-v0.5", num_labels=2)
# model = GPT2ForSequenceClassification.from_pretrained("skt/ko-gpt-trinity-1.2B-v0.5", config=model_config)
elif args.model == 'roberta':
model_config = RobertaConfig.from_pretrained(pretrained_model_name_or_path="klue/roberta-large")#, num_labels=2)
model = RobertaForMultipleChoice.from_pretrained("klue/roberta-large", config=model_config)
elif args.model == 'electra':
model_config = ElectraConfig.from_pretrained(pretrained_model_name_or_path="monologg/koelectra-base-v3-discriminator",
num_labels=2)
model = ElectraForMultipleChoice.from_pretrained("monologg/koelectra-base-v3-discriminator", config=model_config)
elif args.model == 'electra_tunib':
model_config = ElectraConfig.from_pretrained(pretrained_model_name_or_path="tunib/electra-ko-base",
num_labels=2)
model = ElectraForMultipleChoice.from_pretrained("tunib/electra-ko-base", config=model_config)
elif args.model == 'electra_kor':
model_config = ElectraConfig.from_pretrained(pretrained_model_name_or_path="kykim/electra-kor-base",
num_labels=2)
model = ElectraForMultipleChoice.from_pretrained("kykim/electra-kor-base", config=model_config)
print("Creating Trainer")
trainer = Trainer(task='copa', model=model, train_dataloader=train_data_loader, test_dataloader=test_data_loader,
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq,
distributed = args.ddp, local_rank = args.local_rank, accum_iter= args.accumulate, seed= args.seed, model_name=args.model)
print("Training Start")
for epoch in range(args.epochs):
if args.ddp:
train_sampler.set_epoch(epoch)
trainer.train(epoch)
# if args.local_rank == 0:
# trainer.save(epoch, args.output_path)
else:
trainer.train(epoch)
# trainer.save(epoch, args.output_path)