-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
49 lines (37 loc) · 1.24 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 16 15:23:34 2020
"""
import torch
import torchvision.transforms as transforms
from networks import Generator_nu
from PIL import Image
import os
import numpy as np
import sys
# img_path = 'data/img/'
# save_path = 'data/pred/'
img_path = sys.argv[1]
save_path = sys.argv[2]
net_g = Generator_nu()
net_g.load_state_dict(torch.load('latest_G.pth'))
net_g.eval()
net_g.cuda()
transform_list = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img_list = os.listdir(img_path)
for i in range(len(img_list)):
img_name = img_list[i]
Img = Image.open(img_path+img_name).convert('RGB')
Input = transform_list(Img).unsqueeze(0).cuda()
with torch.no_grad():
pred = net_g(Input)
out_img = pred.detach().squeeze(0).cpu().float().numpy()
image_numpy = (np.transpose(out_img, (1, 2, 0)) + 1) / 2.0 * 255.0
image_numpy = image_numpy.clip(0, 255)
image_numpy = image_numpy.astype(np.uint8)
result = image_numpy.copy()
save_Img = Image.fromarray(result)
save_Img.save(save_path+img_name)
print('{}/{} done'.format(i+1,len(img_list)))