-
Notifications
You must be signed in to change notification settings - Fork 6
/
boq.py
65 lines (49 loc) · 2.35 KB
/
boq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
class BoQBlock(torch.nn.Module):
def __init__(self, in_dim, num_queries, nheads=8):
super(BoQBlock, self).__init__()
self.encoder = torch.nn.TransformerEncoderLayer(d_model=in_dim, nhead=nheads, dim_feedforward=4*in_dim, batch_first=True, dropout=0.)
self.queries = torch.nn.Parameter(torch.randn(1, num_queries, in_dim))
# the following two lines are used during training only, you can cache their output in eval.
self.self_attn = torch.nn.MultiheadAttention(in_dim, num_heads=nheads, batch_first=True)
self.norm_q = torch.nn.LayerNorm(in_dim)
#####
self.cross_attn = torch.nn.MultiheadAttention(in_dim, num_heads=nheads, batch_first=True)
self.norm_out = torch.nn.LayerNorm(in_dim)
def forward(self, x):
B = x.size(0)
x = self.encoder(x)
q = self.queries.repeat(B, 1, 1)
# the following two lines are used during training.
# for stability purposes
q = q + self.self_attn(q, q, q)[0]
q = self.norm_q(q)
#######
out, attn = self.cross_attn(q, x, x)
out = self.norm_out(out)
return x, out, attn.detach()
class BoQ(torch.nn.Module):
def __init__(self, in_channels=1024, proj_channels=512, num_queries=32, num_layers=2, row_dim=32):
super().__init__()
self.proj_c = torch.nn.Conv2d(in_channels, proj_channels, kernel_size=3, padding=1)
self.norm_input = torch.nn.LayerNorm(proj_channels)
in_dim = proj_channels
self.boqs = torch.nn.ModuleList([
BoQBlock(in_dim, num_queries, nheads=in_dim//64) for _ in range(num_layers)])
self.fc = torch.nn.Linear(num_layers*num_queries, row_dim)
def forward(self, x):
# reduce input dimension using 3x3 conv when using ResNet
x = self.proj_c(x)
x = x.flatten(2).permute(0, 2, 1)
x = self.norm_input(x)
outs = []
attns = []
for i in range(len(self.boqs)):
x, out, attn = self.boqs[i](x)
outs.append(out)
attns.append(attn)
out = torch.cat(outs, dim=1)
out = self.fc(out.permute(0, 2, 1))
out = out.flatten(1)
out = torch.nn.functional.normalize(out, p=2, dim=-1)
return out, attns