forked from awesome-davian/Text2Colors
-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
171 lines (139 loc) · 5.79 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import numpy as np
import torch
import torch.nn as nn
import warnings
from skimage.color import lab2rgb, rgb2lab
# ======================== For text embeddings ======================== #
SOS_token = 0
EOS_token = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Dictionary:
def __init__(self):
self.word2index = {}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS"}
self.n_words = 2
self.max_len = 0
def index_elements(self, data):
for element in data:
self.max_len = len(data) if self.max_len < len(data) else self.max_len
self.index_element(element)
def index_element(self, element):
if element not in self.word2index:
self.word2index[element] = self.n_words
self.word2count[element] = 1
self.index2word[self.n_words] = element
self.n_words += 1
else:
self.word2count[element] += 1
def load_pretrained_embedding(dictionary, embed_file, embed_dim):
if embed_file is None: return None
pretrained_embed = {}
with open(embed_file, 'r', encoding='utf-8') as f:
for line in f:
tokens = line.split(' ')
word = tokens[0]
entries = tokens[1:]
if word == '<unk>':
continue
pretrained_embed[word] = entries
f.close()
vocab_size = len(dictionary) + 2
W_emb = np.random.randn(vocab_size, embed_dim).astype('float32')
n = 0
for word, index in dictionary.items():
if word in pretrained_embed:
W_emb[index, :] = pretrained_embed[word]
n += 1
print ("%d/%d vocabs are initialized with GloVe embeddings." % (n, vocab_size))
return W_emb
class Embed(nn.Module):
def __init__(self, vocab_size, embed_dim, W_emb, train_emb):
super(Embed, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_dim)
if W_emb is not None:
print ("Using pre-trained word embeddings...")
self.embed.weight = nn.Parameter(W_emb)
if train_emb == False:
print ("Not training word embeddings...")
self.embed.requires_grad = False
def forward(self, doc):
doc = self.embed(doc)
return doc
# ======================== For processing data ======================== #
def process_image(image_data, batch_size, imsize):
input = torch.zeros(batch_size, 1, imsize, imsize)
labels = torch.zeros(batch_size, 2, imsize, imsize)
images_np = image_data.numpy().transpose((0, 2, 3, 1))
for k in range(batch_size):
img_lab = rgb2lab(images_np[k], illuminant='D50')
img_l = img_lab[:, :, 0] / 100
input[k] = torch.from_numpy(np.expand_dims(img_l, 0))
img_a_scale = (img_lab[:, :, 1:2] + 88) / 185
img_b_scale = (img_lab[:, :, 2:3] + 127) / 212
img_ab_scale = np.concatenate((img_a_scale, img_b_scale), axis=2)
labels[k] = torch.from_numpy(img_ab_scale.transpose((2, 0, 1)))
return input, labels
def process_palette_ab(pal_data, batch_size):
img_a_scale = (pal_data[:, :, 1:2] + 88) / 185
img_b_scale = (pal_data[:, :, 2:3] + 127) / 212
img_ab_scale = np.concatenate((img_a_scale, img_b_scale), axis=2)
ab_for_global = torch.from_numpy(img_ab_scale).float()
ab_for_global = ab_for_global.view(batch_size, 10).unsqueeze(2).unsqueeze(2)
return ab_for_global
def process_palette_lab(pal_data, batch_size):
img_l = pal_data[:, :, 0:1] / 100
img_a_scale = (pal_data[:, :, 1:2] + 88) / 185
img_b_scale = (pal_data[:, :, 2:3] + 127) / 212
img_lab_scale = np.concatenate((img_l, img_a_scale, img_b_scale), axis=2)
lab_for_global = torch.from_numpy(img_lab_scale).float()
lab_for_global = lab_for_global.view(batch_size, 15).unsqueeze(2).unsqueeze(2)
return lab_for_global
def process_global_ab(input_ab, batch_size, always_give_global_hint):
X_hist = input_ab
if always_give_global_hint:
B_hist = torch.ones(batch_size, 1, 1, 1)
else:
B_hist = torch.round(torch.rand(batch_size, 1, 1, 1))
for l in range(batch_size):
if B_hist[l].numpy() == 0:
X_hist[l] = torch.rand(10)
global_input = torch.cat([X_hist, B_hist], 1)
return global_input
def process_global_lab(input_lab, batch_size, always_give_global_hint):
X_hist = input_lab
if always_give_global_hint:
B_hist = torch.ones(batch_size, 1, 1, 1)
else:
B_hist = torch.round(torch.rand(batch_size, 1, 1, 1))
for l in range(batch_size):
if B_hist[l].numpy() == 0:
X_hist[l] = torch.rand(15)
global_input = torch.cat([X_hist, B_hist], 1)
return global_input
def process_global_sampling_ab(palette, batch_size, imsize, hist_mean, hist_std):
X_hist = palette.to(device)
B_hist = torch.ones(batch_size, 1, 1, 1).to(device)
global_input = torch.cat([X_hist, B_hist], 1)
return global_input
def process_global_sampling_lab(palette, batch_size, imsize, hist_mean, hist_std):
X_hist = palette.to(device)
B_hist = torch.ones(batch_size, 1, 1, 1).to(device)
global_input = torch.cat([X_hist, B_hist], 1)
return global_input
# ============================= Etc. ============================= #
def KL_loss(mu, logvar):
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.mean(KLD_element).mul_(-0.5)
return KLD
def lab2rgb_1d(in_lab, clip=True):
warnings.filterwarnings("ignore")
tmp_rgb = lab2rgb(in_lab[np.newaxis, np.newaxis, :], illuminant='D50').flatten()
if clip:
tmp_rgb = np.clip(tmp_rgb, 0, 1)
return tmp_rgb
def init_weights_normal(m):
if type(m) == nn.Conv1d:
m.weight.data.normal_(0.0, 0.05)
if type(m) == nn.Linear:
m.weight.data.normal_(0.0, 0.05)