forked from GunhoChoi/FusionNet-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
79 lines (58 loc) · 2.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from FusionNet import *
import numpy as np
import matplotlib.pyplot as plt
# hyperparameters
batch_size = 16
img_size = 512
lr = 0.0002
epoch = 1000
# input pipeline
img_dir = "./merged/"
img_data = dset.ImageFolder(root=img_dir, transform = transforms.Compose([
transforms.Scale(size=img_size),
transforms.CenterCrop(size=(img_size,img_size*2)),
transforms.ToTensor(),
]))
img_batch = data.DataLoader(img_data, batch_size=batch_size,
shuffle=True, num_workers=2)
# initiate FusionNet
fusion = nn.DataParallel(FusionGenerator(3,3,16)).cuda()
try:
fusion = torch.load('./model/fusion.pkl')
print("\n--------model restored--------\n")
except:
print("\n--------model not restored--------\n")
pass
# loss function & optimizer
loss_func = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(fusion.parameters(),lr=lr)
# training
for i in range(epoch):
for _,(image,label) in enumerate(img_batch):
satel_image, map_image = torch.chunk(image, chunks=2, dim=3)
# quarter size image
satel_left,satel_right = torch.chunk(satel_image,chunks=2, dim=3)
satel_0, satel_1 = torch.chunk(satel_left,chunks=2,dim=2)
satel_2, satel_3 = torch.chunk(satel_right,chunks=2,dim=2)
map_left,map_right = torch.chunk(map_image,chunks=2, dim=3)
map_0, map_1 = torch.chunk(map_left,chunks=2,dim=2)
map_2, map_3 = torch.chunk(map_right,chunks=2,dim=2)
h,w = satel_0.size()[2:]
satel_list = [satel_0, satel_1, satel_2, satel_3]
map_list = [map_0, map_1,map_2, map_3]
for idx in range(4):
optimizer.zero_grad()
x = Variable(satel_list[idx]).cuda()
y_ = Variable(map_list[idx]).cuda()
y = fusion.forward(x)
loss = loss_func(y,y_)
loss.backward()
optimizer.step()
if _ % 500 ==0:
print(i)
print(loss)
#print(y.size())
v_utils.save_image(x[0].cpu().data,"./result/original_image_{}_{}.png".format(i,_))
v_utils.save_image(y_[0].cpu().data,"./result/label_image_{}_{}.png".format(i,_))
v_utils.save_image(y[0].cpu().data,"./result/gen_image_{}_{}.png".format(i,_))
torch.save(fusion,"./model/fusion.pkl")