-
Notifications
You must be signed in to change notification settings - Fork 26
/
losses.py
27 lines (19 loc) · 885 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import tensorflow as tf
def sigmoid_cross_entropy_balanced(logits, label, name='cross_entropy_loss'):
"""
Implements Equation [2] in https://arxiv.org/pdf/1504.06375.pdf
Compute edge pixels for each training sample and set as pos_weights to
tf.nn.weighted_cross_entropy_with_logits
"""
y = tf.cast(label, tf.float32)
count_neg = tf.reduce_sum(1. - y)
count_pos = tf.reduce_sum(y)
# Equation [2]
beta = count_neg / (count_neg + count_pos)
# Equation [2] divide by 1 - beta
pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight)
# Multiply by 1 - beta
cost = tf.reduce_mean(cost * (1 - beta))
# check if image has no edge pixels return 0 else return complete error function
return tf.where(tf.equal(count_pos, 0.0), 0.0, cost, name=name)