-
Notifications
You must be signed in to change notification settings - Fork 2
/
netvlad.py
109 lines (90 loc) · 3.87 KB
/
netvlad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class NetVLADLoupe(nn.Module):
def __init__(self, feature_size, max_samples, cluster_size, output_dim,
gating=True, add_batch_norm=True, is_training=True):
super(NetVLADLoupe, self).__init__()
self.feature_size = feature_size
self.max_samples = max_samples
self.output_dim = output_dim
self.is_training = is_training
self.gating = gating
self.add_batch_norm = add_batch_norm
self.cluster_size = cluster_size
self.softmax = nn.Softmax(dim=-1)
self.cluster_weights = nn.Parameter(torch.randn(
feature_size, cluster_size) * 1 / math.sqrt(feature_size))
self.cluster_weights2 = nn.Parameter(torch.randn(
1, feature_size, cluster_size) * 1 / math.sqrt(feature_size))
self.hidden1_weights = nn.Parameter(torch.randn(
cluster_size * feature_size, output_dim) * 1 / math.sqrt(feature_size))
if add_batch_norm:
self.cluster_biases = None
self.bn1 = nn.BatchNorm1d(cluster_size)
else:
self.cluster_biases = nn.Parameter(torch.randn(
cluster_size) * 1 / math.sqrt(feature_size))
self.bn1 = None
self.bn2 = nn.BatchNorm1d(output_dim)
if gating:
self.context_gating = GatingContext(
output_dim, add_batch_norm=add_batch_norm)
def forward(self, x):
x = x.transpose(1, 3).contiguous()
x = x.view((-1, self.max_samples, self.feature_size))
activation = torch.matmul(x, self.cluster_weights)
if self.add_batch_norm:
activation = activation.view(-1, self.cluster_size)
activation = self.bn1(activation)
activation = activation.view(-1, self.max_samples, self.cluster_size)
else:
activation = activation + self.cluster_biases
activation = self.softmax(activation)
activation = activation.view((-1, self.max_samples, self.cluster_size))
a_sum = activation.sum(-2, keepdim=True)
a = a_sum * self.cluster_weights2
activation = torch.transpose(activation, 2, 1)
x = x.view((-1, self.max_samples, self.feature_size))
vlad = torch.matmul(activation, x)
vlad = torch.transpose(vlad, 2, 1)
vlad = vlad - a
vlad = F.normalize(vlad, dim=1, p=2)
vlad = vlad.reshape((-1, self.cluster_size * self.feature_size))
vlad = F.normalize(vlad, dim=1, p=2)
vlad = torch.matmul(vlad, self.hidden1_weights)
if self.gating:
vlad = self.context_gating(vlad)
return vlad
class GatingContext(nn.Module):
def __init__(self, dim, add_batch_norm=True):
super(GatingContext, self).__init__()
self.dim = dim
self.add_batch_norm = add_batch_norm
self.gating_weights = nn.Parameter(
torch.randn(dim, dim) * 1 / math.sqrt(dim))
self.sigmoid = nn.Sigmoid()
if add_batch_norm:
self.gating_biases = None
self.bn1 = nn.BatchNorm1d(dim)
else:
self.gating_biases = nn.Parameter(
torch.randn(dim) * 1 / math.sqrt(dim))
self.bn1 = None
def forward(self, x):
gates = torch.matmul(x, self.gating_weights)
if self.add_batch_norm:
gates = self.bn1(gates)
else:
gates = gates + self.gating_biases
gates = self.sigmoid(gates)
activation = x * gates
return activation
if __name__ == '__main__':
net_vlad = NetVLADLoupe(feature_size=512, max_samples=224, cluster_size=64,
output_dim=256, gating=True, add_batch_norm=False,
is_training=True)
inputs = torch.rand((1, 512, 224, 1))
outputs_tea = net_vlad(inputs)
print(outputs_tea.shape)