-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
executable file
·48 lines (38 loc) · 1.47 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import time, os
import numpy as np
import torch
from data.dataset import save_image
from functions import *
from loss import *
def test(data_loader, opt, model, epoch, path, log):
model.eval()
if not os.path.exists(path):
os.makedirs(path)
log.write("start on:{}\n".format( time.strftime("%Y-%m-%d::%H:%M") ))
timecost = 0
QNR = []
for index, batch in enumerate(data_loader):
input_pan = batch[0]
input_lr = batch[2]
input_lr_u = batch[3]
filename = batch[5]
input_pan_l = batch[6]
if opt.cuda:
input_pan = input_pan.cuda()
input_lr = input_lr.cuda()
input_lr_u = input_lr_u.cuda()
input_pan_l = input_pan_l.cuda()
start_time = time.time()
output = model(input_pan, input_lr_u, input_lr)
timecost += (time.time() - start_time)
output = trim_image(output)
D_lambda_val = D_lambda(output, input_lr)
D_s_val = D_s(output, input_lr, input_pan, input_pan_l)
QNR_val = (1 - D_lambda_val) * (1 - D_s_val)
QNR.append(QNR_val.cpu().detach().numpy())
n = filename.size()[0]
for i in range(n):
save_image('%s/%d_mul_hat.tif' % (path, filename[i]), output[i].cpu().detach().numpy(), 4)
log.write("Time Cost: {}s\n".format( timecost ))
log.write("QNR:{}\n".format( np.mean(QNR) ))
log.write("end on:{}\n".format( time.strftime("%Y-%m-%d::%H:%M") ))