-
Notifications
You must be signed in to change notification settings - Fork 43
/
metrics.py
61 lines (52 loc) · 2.11 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
'''
@Author: Zhou Kai
@GitHub: https://github.com/athon2
@Date: 2018-11-03 09:47:14
'''
import torch
from torch.nn.modules.loss import _Loss
import torch.nn as nn
class SoftDiceLoss(_Loss):
'''
Soft_Dice = 2*|dot(A, B)| / (|dot(A, A)| + |dot(B, B)| + eps)
eps is a small constant to avoid zero division,
'''
def __init__(self, *args, **kwargs):
super(SoftDiceLoss, self).__init__()
def forward(self, y_pred, y_true, eps=1e-8):
intersection = torch.sum(torch.mul(y_pred, y_true))
union = torch.sum(torch.mul(y_pred, y_pred)) + torch.sum(torch.mul(y_true, y_true)) + eps
dice = 2 * intersection / union
dice_loss = 1 - dice
return dice_loss
class CustomKLLoss(_Loss):
'''
KL_Loss = (|dot(mean , mean)| + |dot(std, std)| - |log(dot(std, std))| - 1) / N
N is the total number of image voxels
'''
def __init__(self, *args, **kwargs):
super(CustomKLLoss, self).__init__()
def forward(self, mean, std):
return torch.mean(torch.mul(mean, mean)) + torch.mean(torch.mul(std, std)) - torch.mean(torch.log(torch.mul(std, std))) - 1
class CombinedLoss(_Loss):
'''
Combined_loss = Dice_loss + k1 * L2_loss + k2 * KL_loss
As default: k1=0.1, k2=0.1
'''
def __init__(self, k1=0.1, k2=0.1):
super(CombinedLoss, self).__init__()
self.k1 = k1
self.k2 = k2
self.dice_loss = SoftDiceLoss()
self.l2_loss = nn.MSELoss()
self.kl_loss = CustomKLLoss()
def forward(self, y_pred, y_true, y_mid):
est_mean, est_std = (y_mid[:, :128], y_mid[:, 128:])
seg_pred, seg_truth = (y_pred[:,0,:,:,:], y_true[:,0,:,:,:])
vae_pred, vae_truth = (y_pred[:,1:,:,:,:], y_true[:,1:,:,:,:])
dice_loss = self.dice_loss(seg_pred, seg_truth)
l2_loss = self.l2_loss(vae_pred, vae_truth)
kl_div = self.kl_loss(est_mean, est_std)
combined_loss = dice_loss + self.k1 * l2_loss + self.k2 * kl_div
#print("dice_loss:%.4f, L2_loss:%.4f, KL_div:%.4f, combined_loss:%.4f"%(dice_loss,l2_loss,kl_div,combined_loss))
return combined_loss