forked from jaywonchung/BERT4Rec-VAE-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loggers.py
93 lines (69 loc) · 3.1 KB
/
loggers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
from abc import ABCMeta, abstractmethod
import torch
def save_state_dict(state_dict, path, filename):
torch.save(state_dict, os.path.join(path, filename))
class LoggerService(object):
def __init__(self, train_loggers=None, val_loggers=None):
self.train_loggers = train_loggers if train_loggers else []
self.val_loggers = val_loggers if val_loggers else []
def complete(self, log_data):
for logger in self.train_loggers:
logger.complete(**log_data)
for logger in self.val_loggers:
logger.complete(**log_data)
def log_train(self, log_data):
for logger in self.train_loggers:
logger.log(**log_data)
def log_val(self, log_data):
for logger in self.val_loggers:
logger.log(**log_data)
class AbstractBaseLogger(metaclass=ABCMeta):
@abstractmethod
def log(self, *args, **kwargs):
raise NotImplementedError
def complete(self, *args, **kwargs):
pass
class RecentModelLogger(AbstractBaseLogger):
def __init__(self, checkpoint_path, filename='checkpoint-recent.pth'):
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
os.mkdir(self.checkpoint_path)
self.recent_epoch = None
self.filename = filename
def log(self, *args, **kwargs):
epoch = kwargs['epoch']
if self.recent_epoch != epoch:
self.recent_epoch = epoch
state_dict = kwargs['state_dict']
state_dict['epoch'] = kwargs['epoch']
save_state_dict(state_dict, self.checkpoint_path, self.filename)
def complete(self, *args, **kwargs):
save_state_dict(kwargs['state_dict'], self.checkpoint_path, self.filename + '.final')
class BestModelLogger(AbstractBaseLogger):
def __init__(self, checkpoint_path, metric_key='mean_iou', filename='best_acc_model.pth'):
self.checkpoint_path = checkpoint_path
if not os.path.exists(self.checkpoint_path):
os.mkdir(self.checkpoint_path)
self.best_metric = 0.
self.metric_key = metric_key
self.filename = filename
def log(self, *args, **kwargs):
current_metric = kwargs[self.metric_key]
if self.best_metric < current_metric:
print("Update Best {} Model at {}".format(self.metric_key, kwargs['epoch']))
self.best_metric = current_metric
save_state_dict(kwargs['state_dict'], self.checkpoint_path, self.filename)
class MetricGraphPrinter(AbstractBaseLogger):
def __init__(self, writer, key='train_loss', graph_name='Train Loss', group_name='metric'):
self.key = key
self.graph_label = graph_name
self.group_name = group_name
self.writer = writer
def log(self, *args, **kwargs):
if self.key in kwargs:
self.writer.add_scalar(self.group_name + '/' + self.graph_label, kwargs[self.key], kwargs['accum_iter'])
else:
self.writer.add_scalar(self.group_name + '/' + self.graph_label, 0, kwargs['accum_iter'])
def complete(self, *args, **kwargs):
self.writer.close()