-
Notifications
You must be signed in to change notification settings - Fork 2
/
loss.py
99 lines (68 loc) · 2.66 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
from collections import Counter
import bottleneck as bn
class SIG_LOSS(nn.Module):
def __init__(self, device):
super(SIG_LOSS, self).__init__()
self.m_device = device
self.m_criterion = nn.BCEWithLogitsLoss(reduction="mean")
def forward(self, preds, targets):
loss = self.m_criterion(preds, targets)
return loss
class XE_LOSS(nn.Module):
def __init__(self, item_num, device):
super(XE_LOSS, self).__init__()
self.m_item_num = item_num
self.m_device = device
def forward(self, preds, targets):
# print("==="*10)
# print(targets.size())
targets = F.one_hot(targets, self.m_item_num)
# print("target", targets.size())
# print(targets.size())
# targets = torch.sum(targets, dim=1)
# targets[:, 0] = 0
preds = F.log_softmax(preds, 1)
xe_loss = torch.sum(preds*targets, dim=-1)
xe_loss = -torch.mean(xe_loss)
return xe_loss
class BPR_LOSS(nn.Module):
def __init__(self, device):
super(BPR_LOSS, self).__init__()
self.m_device = device
def forward(self, graph_batch, logits, labels):
batch_size = graph_batch.num_graphs
batch_snum = graph_batch.s_num
batch_cumsum_snum = torch.cumsum(batch_snum, dim=0)
last_cumsum_snum_i = 0
soft_label_num = 4
loss_list = []
for i in range(batch_size):
cumsum_snum_i = batch_cumsum_snum[i]
labels_i = labels[last_cumsum_snum_i:cumsum_snum_i]
logits_i = logits[last_cumsum_snum_i:cumsum_snum_i]
log_prob_i = []
for soft_label_idx in range(1, soft_label_num):
pos_mask_i = (labels_i == soft_label_idx)
neg_mask_i = (labels_i < soft_label_idx)
pos_logits_i = logits_i[pos_mask_i]
neg_logits_i = logits_i[neg_mask_i]
if pos_logits_i.size()[0] == 0:
continue
if neg_logits_i.size()[0] == 0:
continue
delta_logits_i = pos_logits_i.unsqueeze(1)-neg_logits_i
log_prob_i.append(F.logsigmoid(delta_logits_i).mean().unsqueeze(-1))
if labels_i.shape[0] == 1:
continue
else:
log_prob_i = torch.cat(log_prob_i, dim=-1)
loss_list.append(log_prob_i.mean().unsqueeze(-1))
last_cumsum_snum_i = cumsum_snum_i
loss = -torch.cat(loss_list, dim=-1)
loss = loss.mean()
return loss