-
Notifications
You must be signed in to change notification settings - Fork 30
/
model.py
138 lines (111 loc) · 5.21 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
'''
Source code for an attention based image caption generation system described
in:
Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
International Conference for Machine Learning (2015)
http://arxiv.org/abs/1502.03044
'''
import torch
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable
MAX_LENGTH = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def cuda_variable(tensor):
if torch.cuda.is_available():
return Variable(tensor.cuda())
else:
return Variable(tensor)
class EncoderCNN(nn.Module):
def __init__(self, encoded_image_size=14):
super(EncoderCNN, self).__init__()
resnet = models.resnet101(pretrained=True)
modules = list(resnet.children())[:-2]
self.resnet = nn.Sequential(*modules)
self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
self.fine_tune()
def forward(self, images):
out = self.resnet(images)
out = self.adaptive_pool(out)
out = out.permute(0, 2, 3, 1)
return out
def fine_tune(self, fine_tune=True):
for p in self.resnet.parameters():
p.requires_grad = False
for c in list(self.resnet.children())[5:]:
for p in c.parameters():
p.requires_grad = fine_tune
class Attention(nn.Module):
def __init__(self, encoder_dim, decoder_dim, attention_dim):
super(Attention, self).__init__()
self.encoder_att = nn.Linear(encoder_dim, attention_dim)
self.decoder_att = nn.Linear(decoder_dim, attention_dim)
self.full_att = nn.Linear(attention_dim, 1)
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
def forward(self, encoder_out, decoder_hidden):
att1 = self.encoder_att(encoder_out)
att2 = self.decoder_att(decoder_hidden)
att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)
alpha = self.softmax(att)
attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
return attention_weighted_encoding, alpha
class AttnDecoderRNN(nn.Module):
def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
super(AttnDecoderRNN, self).__init__()
self.encoder_dim = encoder_dim
self.attention_dim = attention_dim
self.embed_dim = embed_dim
self.decoder_dim = decoder_dim
self.vocab_size = vocab_size
self.dropout = dropout
self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.dropout = nn.Dropout(p=self.dropout)
self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)
self.init_h = nn.Linear(encoder_dim, decoder_dim)
self.init_c = nn.Linear(encoder_dim, decoder_dim)
self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate
self.sigmoid = nn.Sigmoid()
self.fc = nn.Linear(decoder_dim, vocab_size)
self.init_weights()
def init_weights(self):
self.embedding.weight.data.uniform_(-0.1, 0.1)
self.fc.bias.data.fill_(0)
self.fc.weight.data.uniform_(-0.1, 0.1)
def load_pretrained_embeddings(self, embeddings):
self.embedding.weight = nn.Parameter(embeddings)
def fine_tune_embeddings(self, fine_tune=True):
for p in self.embedding.parameters():
p.requires_grad = fine_tune
def init_hidden_state(self, encoder_out):
mean_encoder_out = encoder_out.mean(dim=1)
h = self.init_h(mean_encoder_out)
c = self.init_c(mean_encoder_out)
return h, c
def forward(self, encoder_out, encoded_captions, caption_lengths):
"""
:return: scores for vocabulary, sorted encoded captions, decode lengths, weights
"""
batch_size = encoder_out.size(0)
encoder_dim = encoder_out.size(-1)
vocab_size = self.vocab_size
encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
num_pixels = encoder_out.size(1)
embeddings = self.embedding(encoded_captions)
h, c = self.init_hidden_state(encoder_out)
decode_lengths = [c-1 for c in caption_lengths]
predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)
for t in range(max(decode_lengths)):
batch_size_t = sum([l > t for l in decode_lengths ])
attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
attention_weighted_encoding = gate * attention_weighted_encoding
h, c = self.decode_step(
torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
(h[:batch_size_t], c[:batch_size_t]))
preds = self.fc(self.dropout(h))
predictions[:batch_size_t, t, :] = preds
alphas[:batch_size_t, t, :] = alpha
return predictions, encoded_captions, decode_lengths, alphas