-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
36 lines (25 loc) · 848 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
def show_images(images):
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
for index, image in enumerate(images):
plt.subplot(sqrtn, sqrtn, index+1)
plt.imshow(image.reshape(28, 28))
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)
# Model
G = torch.load('Generator_epoch_100.pth')
G.eval()
# Generator
noise = (torch.rand(16, 128)-0.5) / 0.5
noise = noise.to(device)
fake_image = G(noise)
imgs_numpy = (fake_image.data.cpu().numpy()+1.0)/2.0
show_images(imgs_numpy)
plt.show()