-
Notifications
You must be signed in to change notification settings - Fork 13
/
loss.py
56 lines (44 loc) · 2.21 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
import tf.keras.backend as K
def Tanimoto_loss(label, pred):
"""
Implementation of Tanimoto loss in tensorflow 2.x
-------------------------------------------------------------------------
Tanimoto coefficient with dual from: Diakogiannis et al 2019 (https://arxiv.org/abs/1904.00592)
"""
smooth = 1e-5
Vli = tf.reduce_mean(tf.reduce_sum(label, axis=[1,2]), axis=0)
# wli = 1.0/Vli**2 # weighting scheme
wli = tf.math.reciprocal(Vli**2) # weighting scheme
# ---------------------This line is taken from niftyNet package --------------
# ref: https://github.com/NifTK/NiftyNet/blob/dev/niftynet/layer/loss_segmentation.py, lines:170 -- 172
# First turn inf elements to zero, then replace that with the maximum weight value
new_weights = tf.where(tf.math.is_inf(wli), tf.zeros_like(wli), wli)
wli = tf.where(tf.math.is_inf(wli), tf.ones_like(wli) * tf.reduce_max(new_weights), wli)
# --------------------------------------------------------------------
square_pred = tf.square(pred)
square_label = tf.square(label)
add_squared_label_pred = tf.add(square_pred, square_label)
sum_square = tf.reduce_sum(add_squared_label_pred, axis=[1, 2])
product = tf.multiply(pred, label)
sum_product = tf.reduce_sum(product, axis=[1, 2])
sum_product_labels = tf.reduce_sum(tf.multiply(wli, sum_product), axis=-1)
denomintor = tf.subtract(sum_square, sum_product)
denomintor_sum_labels = tf.reduce_sum(tf.multiply(wli, denomintor), axis=-1)
loss = tf.divide(sum_product_labels + smooth, denomintor_sum_labels + smooth)
return loss
def Tanimoto_dual_loss():
'''
Implementation of Tanimoto dual loss in tensorflow 2.x
------------------------------------------------------------------------
Note: to use it in deep learning training use: return 1. - 0.5*(loss1+loss2)
OBS: Do use note's advice. Otherwise tanimoto doesn't work
'''
def loss(label, pred):
loss1 = Tanimoto_loss(pred, label)
pred = tf.subtract(1.0, pred)
label = tf.subtract(1.0, label)
loss2 = Tanimoto_loss(label, pred)
loss = (loss1+loss2)*0.5
return 1.0 - loss
return loss