forked from FLming/CRNN.tf2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
22 lines (19 loc) · 816 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import tensorflow as tf
from tensorflow import keras
class CTCLoss(keras.losses.Loss):
def __init__(self, logits_time_major=False, blank_index=-1,
reduction=keras.losses.Reduction.AUTO, name='ctc_loss'):
super().__init__(reduction=reduction, name=name)
self.logits_time_major = logits_time_major
self.blank_index = blank_index
def call(self, y_true, y_pred):
y_true = tf.cast(y_true, tf.int32)
logit_length = tf.fill([tf.shape(y_pred)[0]], tf.shape(y_pred)[1])
loss = tf.nn.ctc_loss(
labels=y_true,
logits=y_pred,
label_length=None,
logit_length=logit_length,
logits_time_major=self.logits_time_major,
blank_index=self.blank_index)
return tf.reduce_mean(loss)