-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
128 lines (100 loc) · 4.88 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
import torch
import torch.optim as optim
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from distiller import Distiller
class DistillTrainer():
def __init__(
self,
run_name,
teacher,
student,
teacher_tokenizer,
student_tokenizer,
same_vocab=True,
vocab_prob_map=None,
train_dataloader=None,
val_dataloader=None,
test_dataloader=None,
lr=1e-4,
epochs=20
):
self.run_name = run_name
self.same_vocab = same_vocab
self.distiller = Distiller(
teacher=teacher,
student=student,
teacher_tokenizer=teacher_tokenizer,
student_tokenizer=student_tokenizer,
same_vocab=same_vocab,
vocab_prob_map=vocab_prob_map,
)
self.train_set = train_dataloader
self.val_set = val_dataloader
self.test_set = test_dataloader
self.optimizer = optim.AdamW(self.distiller.parameters(), lr=lr)
self.logger = logging.logger
self.lowest_val_loss = 1000000
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
self.writer = SummaryWriter(f'runs/{run_name}_{timestamp}')
def step(self, inputs, train=True):
if self.same_vocab:
teacher_inputs = inputs
student_inputs = inputs
else:
teacher_inputs, student_inputs = inputs
self.optimizer.zero_grad()
teacher_outputs, student_outputs = self.distiller(teacher_inputs, student_inputs)
teacher_logits = teacher_outputs["logits"]
teacher_preds = teacher_outputs["sentences"]
student_logits = student_outputs["logits"]
student_preds = student_outputs["sentences"]
loss, student_loss, distill_loss = self.distiller.loss_fn(teacher_logits, teacher_preds, student_logits, student_preds)
if train:
loss.backward()
self.optimizer.step()
return loss, student_loss, distill_loss
def epoch(self, epoch_num, log_every=1, val_every=0.20):
val_at = int(len(self.train_set) * 0.20)
for step_num, inputs in enumerate(self.train_set):
train_loss, train_student_loss, train_distill_loss = self.train_step(inputs)
if step_num % log_every == 0:
self.loss_logger(epoch_num, step_num, train_loss, train_student_loss, train_distill_loss)
if step_num % val_at == 0:
self.distiller.train(False)
running_loss = 0
running_student_loss = 0
running_distill_loss = 0
for inputs in self.val_set:
loss, student_loss, distill_loss = self.step(inputs, train=False)
running_loss += loss
running_student_loss += student_loss
running_distill_loss += distill_loss
val_loss = running_loss / len(self.val_set)
val_student_loss = running_student_loss / len(self.val_set)
val_distill_loss = running_distill_loss / len(self.val_set)
if val_loss < self.lowest_val_loss:
self.save(epoch_num, step_num, val_loss, val_student_loss, val_distill_loss)
self.lowest_val_loss = val_loss
self.loss_logger(epoch_num, step_num, val_loss, val_student_loss, val_distill_loss, train=False)
self.distiller.train(True)
self.writer.flush()
def loss_logger(self, epoch_num, step_num, loss, student_loss, distill_loss, train=True):
stage = "Train"
if not train:
stage = "Val"
self.logger.info(f"Epoch: {epoch_num} | Train Step Num {step_num} | Total {stage} Loss: {loss} | Student {stage} Loss: {student_loss} | Distill {stage} Loss: {distill_loss}")
global_step_num = (epoch_num * len(self.train_set)) + step_num + 1
self.writer.add_scalar(f'Total_Loss/{stage}', loss, global_step_num)
self.writer.add_scalar(f'Student_Loss/{stage}', student_loss, global_step_num)
self.writer.add_scalar(f'Distill_Loss/{stage}', distill_loss, global_step_num)
def train(self):
for epoch_num in range(self.epochs):
self.epoch(epoch_num)
def save(self, epoch_num, step_num, val_loss, val_student_loss, val_distill_loss):
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_path = f'{self.run_name}_valloss{val_loss}_stloss{val_student_loss}_dsloss{val_distill_loss}_epoch{epoch_num}_step{step_num}_{timestamp}'
torch.save(self.distiller.student.state_dict(), model_path)
def test(self):
pass