-
Notifications
You must be signed in to change notification settings - Fork 184
/
main_test_bsrgan.py
123 lines (93 loc) · 4.18 KB
/
main_test_bsrgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os.path
import logging
import torch
from utils import utils_logger
from utils import utils_image as util
# from utils import utils_model
from models.network_rrdbnet import RRDBNet as net
"""
Spyder (Python 3.6-3.7)
PyTorch 1.4.0-1.8.1
Windows 10 or Linux
Kai Zhang ([email protected])
github: https://github.com/cszn/BSRGAN
https://github.com/cszn/KAIR
If you have any question, please feel free to contact with me.
Kai Zhang (e-mail: [email protected])
by Kai Zhang ( March/2020 --> March/2021 --> )
This work was previously submitted to CVPR2021.
# --------------------------------------------
@inproceedings{zhang2021designing,
title={Designing a Practical Degradation Model for Deep Blind Image Super-Resolution},
author={Zhang, Kai and Liang, Jingyun and Van Gool, Luc and Timofte, Radu},
booktitle={arxiv},
year={2021}
}
# --------------------------------------------
"""
def main():
utils_logger.logger_info('blind_sr_log', log_path='blind_sr_log.log')
logger = logging.getLogger('blind_sr_log')
# print(torch.__version__) # pytorch version
# print(torch.version.cuda) # cuda version
# print(torch.backends.cudnn.version()) # cudnn version
testsets = 'testsets' # fixed, set path of testsets
testset_Ls = ['RealSRSet'] # ['RealSRSet','DPED']
model_names = ['RRDB','ESRGAN','FSSR_DPED','FSSR_JPEG','RealSR_DPED','RealSR_JPEG']
model_names = ['BSRGAN'] # 'BSRGANx2' for scale factor 2
save_results = True
sf = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for model_name in model_names:
if model_name in ['BSRGANx2']:
sf = 2
model_path = os.path.join('model_zoo', model_name+'.pth') # set model path
logger.info('{:>16s} : {:s}'.format('Model Name', model_name))
# torch.cuda.set_device(0) # set GPU ID
logger.info('{:>16s} : {:<d}'.format('GPU ID', torch.cuda.current_device()))
torch.cuda.empty_cache()
# --------------------------------
# define network and load model
# --------------------------------
model = net(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=sf) # define network
# model_old = torch.load(model_path)
# state_dict = model.state_dict()
# for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
# state_dict[key2] = param
# model.load_state_dict(state_dict, strict=True)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
torch.cuda.empty_cache()
for testset_L in testset_Ls:
L_path = os.path.join(testsets, testset_L)
#E_path = os.path.join(testsets, testset_L+'_'+model_name)
E_path = os.path.join(testsets, testset_L+'_results_x'+str(sf))
util.mkdir(E_path)
logger.info('{:>16s} : {:s}'.format('Input Path', L_path))
logger.info('{:>16s} : {:s}'.format('Output Path', E_path))
idx = 0
for img in util.get_image_paths(L_path):
# --------------------------------
# (1) img_L
# --------------------------------
idx += 1
img_name, ext = os.path.splitext(os.path.basename(img))
logger.info('{:->4d} --> {:<s} --> x{:<d}--> {:<s}'.format(idx, model_name, sf, img_name+ext))
img_L = util.imread_uint(img, n_channels=3)
img_L = util.uint2tensor4(img_L)
img_L = img_L.to(device)
# --------------------------------
# (2) inference
# --------------------------------
img_E = model(img_L)
# --------------------------------
# (3) img_E
# --------------------------------
img_E = util.tensor2uint(img_E)
if save_results:
util.imsave(img_E, os.path.join(E_path, img_name+'_'+model_name+'.png'))
if __name__ == '__main__':
main()