-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
28 lines (24 loc) · 910 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
import math
import numpy as np
class LDAMHingeLoss(nn.Module):
# not working
# inspired by CAO, Kaidi, WEI, Colin, GAIDON, Adrien, et al. Learning imbalanced datasets with label-distribution-aware margin loss.
# Advances in neural information processing systems, 2019, vol. 32.
def __init__(self, C, weights, n_data):
super(LDAMHingeLoss, self).__init__()
self.C = C
self.n0 = weights[0]*n_data
self.n1 = weights[1]*n_data
def forward(self, output, target):
loss = 0
for i, y in enumerate(target):
if y == 0 :
delta = self.C/(self.n0)**1/4
elif y == 1 :
delta = self.C/(self.n1)**1/4
exp = torch.exp(output[i, y] - delta)
print(exp)
loss -= torch.log(exp/(exp+torch.exp(output[i, 1-y])))
return loss