-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_gmvae.py
57 lines (53 loc) · 2.23 KB
/
run_gmvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import numpy as np
import torch
import tqdm
from codebase import utils as ut
from codebase.models.gmvae import GMVAE
from codebase.train import train
from pprint import pprint
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--z', type=int, default=10, help="Number of latent dimensions")
parser.add_argument('--k', type=int, default=500, help="Number mixture components in MoG prior")
parser.add_argument('--iter_max', type=int, default=20000, help="Number of training iterations")
parser.add_argument('--iter_save', type=int, default=10000, help="Save model every n iterations")
parser.add_argument('--run', type=int, default=0, help="Run ID. In case you want to run replicates")
parser.add_argument('--train', type=int, default=1, help="Flag for training")
args = parser.parse_args()
layout = [
('model={:s}', 'gmvae'),
('z={:02d}', args.z),
('k={:03d}', args.k),
('run={:04d}', args.run)
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device, use_test_subset=True)
gmvae = GMVAE(z_dim=args.z, k=args.k, name=model_name).to(device)
if args.train:
writer = ut.prepare_writer(model_name, overwrite_existing=True)
train(model=gmvae,
train_loader=train_loader,
labeled_subset=labeled_subset,
device=device,
tqdm=tqdm.tqdm,
writer=writer,
iter_max=args.iter_max,
iter_save=args.iter_save)
ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=args.train == 2)
else:
ut.load_model_by_name(gmvae, global_step=args.iter_max)
ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=True)
x = gmvae.sample_x(200)
x = x.view(20, 10, 28, 28).cpu().detach().numpy()
fig, axes = plt.subplots(20, 10)
for i in range(10):
for j in range(10):
axes[i, j].imshow(x[i][j])
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
plt.show()