-
Notifications
You must be signed in to change notification settings - Fork 4
/
loss.py
43 lines (26 loc) · 1.05 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from __future__ import division
import torch
import torch.nn as nn
class EuclideanLoss(nn.Module):
def __init__(self):
super(EuclideanLoss, self).__init__()
self.pdist = nn.PairwiseDistance(p=2)
def forward(self, pred, target, mask):
n, c, h, w = pred.size()
pred = pred.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
pred = pred[mask.view(n * h * w, 1).repeat(1, c) == 1.]
pred = pred.view(-1, c)
target = target.transpose(1,2).transpose(2,3).contiguous().view(-1, c)
target = target[mask.view(n * h * w, 1).repeat(1, c) == 1.]
target = target.view(-1,c)
loss = self.pdist(pred, target)
loss = torch.sum(loss, 0)
loss /= mask.sum()
return loss
class CELoss(nn.Module):
def __init__(self):
super(CELoss, self).__init__()
self.celoss = nn.CrossEntropyLoss(reduce=False)
def forward(self, pred, target, mask):
loss = self.celoss(pred, target)
return (loss*mask).sum() / mask.sum()