forked from WasifurRahman/BERT_multimodal_transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
modeling.py
51 lines (37 loc) · 1.69 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import torch.nn.functional as F
from global_configs import *
class MAG(nn.Module):
def __init__(self, hidden_size, beta_shift, dropout_prob):
super(MAG, self).__init__()
print(
"Initializing MAG with beta_shift:{} hidden_prob:{}".format(
beta_shift, dropout_prob
)
)
self.W_hv = nn.Linear(VISUAL_DIM + TEXT_DIM, TEXT_DIM)
self.W_ha = nn.Linear(ACOUSTIC_DIM + TEXT_DIM, TEXT_DIM)
self.W_v = nn.Linear(VISUAL_DIM, TEXT_DIM)
self.W_a = nn.Linear(ACOUSTIC_DIM, TEXT_DIM)
self.beta_shift = beta_shift
self.LayerNorm = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout_prob)
def forward(self, text_embedding, visual, acoustic):
eps = 1e-6
weight_v = F.relu(self.W_hv(torch.cat((visual, text_embedding), dim=-1)))
weight_a = F.relu(self.W_ha(torch.cat((acoustic, text_embedding), dim=-1)))
h_m = weight_v * self.W_v(visual) + weight_a * self.W_a(acoustic)
em_norm = text_embedding.norm(2, dim=-1)
hm_norm = h_m.norm(2, dim=-1)
hm_norm_ones = torch.ones(hm_norm.shape, requires_grad=True).to(DEVICE)
hm_norm = torch.where(hm_norm == 0, hm_norm_ones, hm_norm)
thresh_hold = (em_norm / (hm_norm + eps)) * self.beta_shift
ones = torch.ones(thresh_hold.shape, requires_grad=True).to(DEVICE)
alpha = torch.min(thresh_hold, ones)
alpha = alpha.unsqueeze(dim=-1)
acoustic_vis_embedding = alpha * h_m
embedding_output = self.dropout(
self.LayerNorm(acoustic_vis_embedding + text_embedding)
)
return embedding_output