-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
82 lines (58 loc) · 2.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
''' Credit: https://github.com/fanyun-sun/InfoGraph '''
import torch as th
import torch.nn.functional as F
import math
def get_positive_expectation(p_samples, average=True):
"""Computes the positive part of a JS Divergence.
Args:
p_samples: Positive samples.
average: Average the result over samples.
Returns:
th.Tensor
"""
log_2 = math.log(2.)
Ep = log_2 - F.softplus(- p_samples)
if average:
return Ep.mean()
else:
return Ep
def get_negative_expectation(q_samples, average=True):
"""Computes the negative part of a JS Divergence.
Args:
q_samples: Negative samples.
average: Average the result over samples.
Returns:
th.Tensor
"""
log_2 = math.log(2.)
Eq = F.softplus(-q_samples) + q_samples - log_2
if average:
return Eq.mean()
else:
return Eq
def local_global_loss_(l_enc, g_enc, graph_id):
num_graphs = g_enc.shape[0]
num_nodes = l_enc.shape[0]
device = g_enc.device
pos_mask = th.zeros((num_nodes, num_graphs)).to(device)
neg_mask = th.ones((num_nodes, num_graphs)).to(device)
for nodeidx, graphidx in enumerate(graph_id):
pos_mask[nodeidx][graphidx] = 1.
neg_mask[nodeidx][graphidx] = 0.
res = th.mm(l_enc, g_enc.t())
E_pos = get_positive_expectation(res * pos_mask, average=False).sum()
E_pos = E_pos / num_nodes
E_neg = get_negative_expectation(res * neg_mask, average=False).sum()
E_neg = E_neg / (num_nodes * (num_graphs - 1))
return E_neg - E_pos
def global_global_loss_(sup_enc, unsup_enc):
num_graphs = sup_enc.shape[0]
device = sup_enc.device
pos_mask = th.eye(num_graphs).to(device)
neg_mask = 1 - pos_mask
res = th.mm(sup_enc, unsup_enc.t())
E_pos = get_positive_expectation(res * pos_mask, average=False)
E_pos = (E_pos * pos_mask).sum() / pos_mask.sum()
E_neg = get_negative_expectation(res * neg_mask, average=False)
E_neg = (E_neg * neg_mask).sum() / neg_mask.sum()
return E_neg - E_pos