forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DGLDigitCapsule.py
39 lines (34 loc) · 1.88 KB
/
DGLDigitCapsule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import dgl
import torch
from torch import nn
from torch.nn import functional as F
import dgl.function as fn
from DGLRoutingLayer import DGLRoutingLayer
class DGLDigitCapsuleLayer(nn.Module):
def __init__(self, in_nodes_dim=8, in_nodes=1152, out_nodes=10, out_nodes_dim=16, device='cpu'):
super(DGLDigitCapsuleLayer, self).__init__()
self.device = device
self.in_nodes_dim, self.out_nodes_dim = in_nodes_dim, out_nodes_dim
self.in_nodes, self.out_nodes = in_nodes, out_nodes
self.weight = nn.Parameter(torch.randn(in_nodes, out_nodes, out_nodes_dim, in_nodes_dim))
def forward(self, x):
self.batch_size = x.size(0)
u_hat = self.compute_uhat(x)
routing = DGLRoutingLayer(self.in_nodes, self.out_nodes, self.out_nodes_dim, batch_size=self.batch_size,
device=self.device)
routing(u_hat, routing_num=3)
out_nodes_feature = routing.g.nodes[routing.out_indx].data['v']
# shape transformation is for further classification
return out_nodes_feature.transpose(0, 1).unsqueeze(1).unsqueeze(4).squeeze(1)
def compute_uhat(self, x):
# x is the input vextor with shape [batch_size, in_nodes_dim, in_nodes]
# Transpose x to [batch_size, in_nodes, in_nodes_dim]
x = x.transpose(1, 2)
# Expand x to [batch_size, in_nodes, out_nodes, in_nodes_dim, 1]
x = torch.stack([x] * self.out_nodes, dim=2).unsqueeze(4)
# Expand W from [in_nodes, out_nodes, in_nodes_dim, out_nodes_dim]
# to [batch_size, in_nodes, out_nodes, out_nodes_dim, in_nodes_dim]
W = self.weight.expand(self.batch_size, *self.weight.size())
# u_hat's shape is [in_nodes, out_nodes, batch_size, out_nodes_dim]
u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous()
return u_hat.view(-1, self.batch_size, self.out_nodes_dim)