-
Notifications
You must be signed in to change notification settings - Fork 11
/
layers.py
191 lines (157 loc) · 8.03 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import torch
import torch.nn as nn
import torch.nn.functional as F
from sparse_softmax import Sparsemax
from torch.nn import Parameter
from torch_geometric.data import Data
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.pool.topk_pool import topk, filter_adj
from torch_geometric.utils import softmax, dense_to_sparse, add_remaining_self_loops
from torch_scatter import scatter_add
from torch_sparse import spspmm, coalesce
class TwoHopNeighborhood(object):
def __call__(self, data):
edge_index, edge_attr = data.edge_index, data.edge_attr
n = data.num_nodes
fill = 1e16
value = edge_index.new_full((edge_index.size(1),), fill, dtype=torch.float)
index, value = spspmm(edge_index, value, edge_index, value, n, n, n, True)
edge_index = torch.cat([edge_index, index], dim=1)
if edge_attr is None:
data.edge_index, _ = coalesce(edge_index, None, n, n)
else:
value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)])
value = value.expand(-1, *list(edge_attr.size())[1:])
edge_attr = torch.cat([edge_attr, value], dim=0)
data.edge_index, edge_attr = coalesce(edge_index, edge_attr, n, n, op='min')
edge_attr[edge_attr >= fill] = 0
data.edge_attr = edge_attr
return data
def __repr__(self):
return '{}()'.format(self.__class__.__name__)
class NodeInformationScore(MessagePassing):
def __init__(self, improved=False, cached=False, **kwargs):
super(NodeInformationScore, self).__init__(aggr='add', **kwargs)
self.improved = improved
self.cached = cached
self.cached_result = None
self.cached_num_edges = None
@staticmethod
def norm(edge_index, num_nodes, edge_weight, dtype=None):
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, 0, num_nodes)
row, col = edge_index
expand_deg = torch.zeros((edge_weight.size(0),), dtype=dtype, device=edge_index.device)
expand_deg[-num_nodes:] = torch.ones((num_nodes,), dtype=dtype, device=edge_index.device)
return edge_index, expand_deg - deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_weight):
if self.cached and self.cached_result is not None:
if edge_index.size(1) != self.cached_num_edges:
raise RuntimeError(
'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1)))
if not self.cached or self.cached_result is None:
self.cached_num_edges = edge_index.size(1)
edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
self.cached_result = edge_index, norm
edge_index, norm = self.cached_result
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
return aggr_out
class HGPSLPool(torch.nn.Module):
def __init__(self, in_channels, ratio=0.8, sample=False, sparse=False, sl=True, lamb=1.0):
super(HGPSLPool, self).__init__()
self.in_channels = in_channels
self.ratio = ratio
self.sample = sample
self.sparse = sparse
self.sl = sl
self.lamb = lamb
self.sim = nn.CosineSimilarity(dim=1)
self.sparse_attention = Sparsemax()
self.neighbor_augment = TwoHopNeighborhood()
self.calc_information_score = NodeInformationScore()
def forward(self, x, edge_index, edge_attr, batch=None):
if batch is None:
batch = edge_index.new_zeros(x.size(0))
x_information_score = self.calc_information_score(x, edge_index, edge_attr)
score = torch.sum(torch.abs(x_information_score), dim=1)
# Graph Pooling
original_x = x
perm = topk(score, self.ratio, batch)
x = x[perm]
batch = batch[perm]
induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0))
# Discard structure learning layer, directly return
if self.sl is False:
return x, induced_edge_index, induced_edge_attr, batch
# Structure Learning
if self.sample:
# A fast mode for large graphs.
# In large graphs, learning the possible edge weights between each pair of nodes is time consuming.
# To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the
# edge weights between them.
k_hop = 3
if edge_attr is None:
edge_attr = torch.ones((edge_index.size(1),), dtype=torch.float, device=edge_index.device)
hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr)
for _ in range(k_hop - 1):
hop_data = self.neighbor_augment(hop_data)
hop_edge_index = hop_data.edge_index
hop_edge_attr = hop_data.edge_attr
new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=score.size(0))
new_edge_index, new_edge_attr = add_remaining_self_loops(new_edge_index, new_edge_attr, 0, x.size(0))
row, col = new_edge_index
weights = torch.abs(self.sim(x[row], x[col]))
weights = weights + self.lamb * new_edge_attr
adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
adj[row, col] = weights
new_edge_index, weights = dense_to_sparse(adj)
row, col = new_edge_index
if self.sparse:
new_edge_attr = self.sparse_attention(weights, row)
else:
new_edge_attr = softmax(weights, row, x.size(0))
# filter out zero weight edges
adj[row, col] = new_edge_attr
new_edge_index, new_edge_attr = dense_to_sparse(adj)
# release gpu memory
del adj
torch.cuda.empty_cache()
else:
# Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower.
if edge_attr is None:
induced_edge_attr = torch.ones((induced_edge_index.size(1),), dtype=x.dtype,
device=induced_edge_index.device)
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
shift_cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
cum_num_nodes = num_nodes.cumsum(dim=0)
adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device)
# Construct batch fully connected graph in block diagonal matirx format
for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes):
adj[idx_i:idx_j, idx_i:idx_j] = 1.0
new_edge_index, _ = dense_to_sparse(adj)
row, col = new_edge_index
weights = torch.abs(self.sim(x[row], x[col]))
weights = weights + self.lamb
adj[row, col] = weights
induced_row, induced_col = induced_edge_index
adj[induced_row, induced_col] += induced_edge_attr * self.lamb
weights = adj[row, col]
if self.sparse:
new_edge_attr = self.sparse_attention(weights, row)
else:
new_edge_attr = softmax(weights, row, x.size(0))
# filter out zero weight edges
adj[row, col] = new_edge_attr
new_edge_index, new_edge_attr = dense_to_sparse(adj)
# release gpu memory
del adj
torch.cuda.empty_cache()
return x, new_edge_index, new_edge_attr, batch