-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
54 lines (43 loc) · 1.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Copyright (c) 2020 Tongzhou Wang
import torch
class AverageMeter(object):
r"""
Computes and stores the average and current value.
Adapted from
https://github.com/pytorch/examples/blob/ec10eee2d55379f0b9c87f4b36fcf8d0723f45fc/imagenet/main.py#L359-L380
"""
def __init__(self, name=None, fmt='.6f'):
fmtstr = f'{{val:{fmt}}} ({{avg:{fmt}}})'
if name is not None:
fmtstr = name + ' ' + fmtstr
self.fmtstr = fmtstr
self.reset()
def reset(self):
self.val = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
@property
def avg(self):
return float(self.sum / self.count)
def __str__(self):
avg = self.avg
val = float(self.val) # assuming float, we are `AverageMeter`
return self.fmtstr.format(val=val, avg=avg)
class ProgressMeter(object):
BR = '\n'
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'