-
Notifications
You must be signed in to change notification settings - Fork 61
/
utils.py
78 lines (62 loc) · 2.59 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import numpy as np
def calc_patch_size(func):
def wrapper(args):
if args.scale == 2:
args.patch_size = 10
elif args.scale == 3:
args.patch_size = 7
elif args.scale == 4:
args.patch_size = 6
else:
raise Exception('Scale Error', args.scale)
return func(args)
return wrapper
def convert_rgb_to_y(img, dim_order='hwc'):
if dim_order == 'hwc':
return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
else:
return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
def convert_rgb_to_ycbcr(img, dim_order='hwc'):
if dim_order == 'hwc':
y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.
cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.
else:
y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.
cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.
return np.array([y, cb, cr]).transpose([1, 2, 0])
def convert_ycbcr_to_rgb(img, dim_order='hwc'):
if dim_order == 'hwc':
r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921
g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576
b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836
else:
r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921
g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576
b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836
return np.array([r, g, b]).transpose([1, 2, 0])
def preprocess(img, device):
img = np.array(img).astype(np.float32)
ycbcr = convert_rgb_to_ycbcr(img)
x = ycbcr[..., 0]
x /= 255.
x = torch.from_numpy(x).to(device)
x = x.unsqueeze(0).unsqueeze(0)
return x, ycbcr
def calc_psnr(img1, img2):
return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count