-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_util.py
152 lines (136 loc) · 6.3 KB
/
data_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import OptTensor
import numpy as np
def to_adj_nodes_with_times(data):
num_nodes = data.num_nodes
timestamps = torch.zeros((data.edge_index.shape[1], 1)) if data.timestamps is None else data.timestamps.reshape((-1,1))
edges = torch.cat((data.edge_index.T, timestamps), dim=1) if not isinstance(data, HeteroData) else torch.cat((data['node', 'to', 'node'].edge_index.T, timestamps), dim=1)
adj_list_out = dict([(i, []) for i in range(num_nodes)])
adj_list_in = dict([(i, []) for i in range(num_nodes)])
for u,v,t in edges:
u,v,t = int(u), int(v), int(t)
adj_list_out[u] += [(v, t)]
adj_list_in[v] += [(u, t)]
return adj_list_in, adj_list_out
def to_adj_edges_with_times(data):
num_nodes = data.num_nodes
timestamps = torch.zeros((data.edge_index.shape[1], 1)) if data.timestamps is None else data.timestamps.reshape((-1,1))
edges = torch.cat((data.edge_index.T, timestamps), dim=1)
# calculate adjacent edges with times per node
adj_edges_out = dict([(i, []) for i in range(num_nodes)])
adj_edges_in = dict([(i, []) for i in range(num_nodes)])
for i, (u,v,t) in enumerate(edges):
u,v,t = int(u), int(v), int(t)
adj_edges_out[u] += [(i, v, t)]
adj_edges_in[v] += [(i, u, t)]
return adj_edges_in, adj_edges_out
def ports(edge_index, adj_list):
ports = torch.zeros(edge_index.shape[1], 1)
ports_dict = {}
for v, nbs in adj_list.items():
if len(nbs) < 1: continue
a = np.array(nbs)
a = a[a[:, -1].argsort()]
_, idx = np.unique(a[:,[0]],return_index=True,axis=0)
nbs_unique = a[np.sort(idx)][:,0]
for i, u in enumerate(nbs_unique):
ports_dict[(u,v)] = i
for i, e in enumerate(edge_index.T):
ports[i] = ports_dict[tuple(e.numpy())]
return ports
def time_deltas(data, adj_edges_list):
time_deltas = torch.zeros(data.edge_index.shape[1], 1)
if data.timestamps is None:
return time_deltas
for v, edges in adj_edges_list.items():
if len(edges) < 1: continue
a = np.array(edges)
a = a[a[:, -1].argsort()]
a_tds = [0] + [a[i+1,-1] - a[i,-1] for i in range(a.shape[0]-1)]
tds = np.hstack((a[:,0].reshape(-1,1), np.array(a_tds).reshape(-1,1)))
for i,td in tds:
time_deltas[i] = td
return time_deltas
class GraphData(Data):
'''This is the homogenous graph object we use for GNN training if reverse MP is not enabled'''
def __init__(
self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None,
readout: str = 'edge',
num_nodes: int = None,
timestamps: OptTensor = None,
node_timestamps: OptTensor = None,
**kwargs
):
super().__init__(x, edge_index, edge_attr, y, pos, **kwargs)
self.readout = readout
self.loss_fn = 'ce'
self.num_nodes = int(self.x.shape[0])
self.node_timestamps = node_timestamps
if timestamps is not None:
self.timestamps = timestamps
elif edge_attr is not None:
self.timestamps = edge_attr[:,0].clone()
else:
self.timestamps = None
def add_ports(self):
reverse_ports = True
adj_list_in, adj_list_out = to_adj_nodes_with_times(self)
in_ports = ports(self.edge_index, adj_list_in)
out_ports = [ports(self.edge_index.flipud(), adj_list_out)] if reverse_ports else []
self.edge_attr = torch.cat([self.edge_attr, in_ports] + out_ports, dim=1)
return self
def add_time_deltas(self):
reverse_tds = True
adj_list_in, adj_list_out = to_adj_edges_with_times(self)
in_tds = time_deltas(self, adj_list_in)
out_tds = [time_deltas(self, adj_list_out)] if reverse_tds else []
self.edge_attr = torch.cat([self.edge_attr, in_tds] + out_tds, dim=1)
return self
class HeteroGraphData(HeteroData):
'''This is the heterogenous graph object we use for GNN training if reverse MP is enabled'''
def __init__(
self,
readout: str = 'edge',
**kwargs
):
super().__init__(**kwargs)
self.readout = readout
@property
def num_nodes(self):
return self['node'].x.shape[0]
@property
def timestamps(self):
return self['node', 'to', 'node'].timestamps
def add_ports(self):
adj_list_in, adj_list_out = to_adj_nodes_with_times(self)
in_ports = ports(self['node', 'to', 'node'].edge_index, adj_list_in)
out_ports = ports(self['node', 'rev_to', 'node'].edge_index, adj_list_out)
self['node', 'to', 'node'].edge_attr = torch.cat([self['node', 'to', 'node'].edge_attr, in_ports], dim=1)
self['node', 'rev_to', 'node'].edge_attr = torch.cat([self['node', 'rev_to', 'node'].edge_attr, out_ports], dim=1)
return self
def add_time_deltas(self):
adj_list_in, adj_list_out = to_adj_edges_with_times(self)
in_tds = time_deltas(self, adj_list_in)
out_tds = time_deltas(self, adj_list_out)
self['node', 'to', 'node'].edge_attr = torch.cat([self['node', 'to', 'node'].edge_attr, in_tds], dim=1)
self['node', 'rev_to', 'node'].edge_attr = torch.cat([self['node', 'rev_to', 'node'].edge_attr, out_tds], dim=1)
return self
def z_norm(data):
std = data.std(0).unsqueeze(0)
std = torch.where(std == 0, torch.tensor(1, dtype=torch.float32).cpu(), std)
return (data - data.mean(0).unsqueeze(0)) / std
def create_hetero_obj(x, y, edge_index, edge_attr, timestamps, args):
'''This function creates a heterogenous graph object for reverse message passing'''
data = HeteroGraphData()
data['node'].x = x
data['node', 'to', 'node'].edge_index = edge_index
data['node', 'rev_to', 'node'].edge_index = edge_index.flipud()
data['node', 'to', 'node'].edge_attr = edge_attr
data['node', 'rev_to', 'node'].edge_attr = edge_attr
if args.ports:
#swap the in- and outgoing port numberings for the reverse edges
data['node', 'rev_to', 'node'].edge_attr[:, [-1, -2]] = data['node', 'rev_to', 'node'].edge_attr[:, [-2, -1]]
data['node', 'to', 'node'].y = y
data['node', 'to', 'node'].timestamps = timestamps
return data