forked from FLming/CRNN.tf2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
38 lines (33 loc) · 1.58 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf
from tensorflow import keras
class WordAccuracy(keras.metrics.Metric):
"""
Calculate the word accuracy between y_true and y_pred.
"""
def __init__(self, name='word_accuracy', **kwargs):
super().__init__(name=name, **kwargs)
self.total = self.add_weight(name='total', initializer='zeros')
self.count = self.add_weight(name='count', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
batch_size = tf.shape(y_true)[0]
max_width = tf.maximum(tf.shape(y_true)[1], tf.shape(y_pred)[1])
logit_length = tf.fill([tf.shape(y_pred)[0]], tf.shape(y_pred)[1])
decoded, _ = tf.nn.ctc_greedy_decoder(
inputs=tf.transpose(y_pred, perm=[1, 0, 2]),
sequence_length=logit_length)
y_true = tf.sparse.reset_shape(y_true, [batch_size, max_width])
y_pred = tf.sparse.reset_shape(decoded[0], [batch_size, max_width])
y_true = tf.sparse.to_dense(y_true, default_value=-1)
y_pred = tf.sparse.to_dense(y_pred, default_value=-1)
y_pred = tf.cast(y_pred, tf.float32)
values = tf.math.reduce_any(tf.math.not_equal(y_true, y_pred), axis=1)
values = tf.cast(values, tf.float32)
values = tf.reduce_sum(values)
batch_size = tf.cast(batch_size, tf.float32)
self.total.assign_add(batch_size)
self.count.assign_add(batch_size - values)
def result(self):
return self.count / self.total
def reset_states(self):
self.count.assign(0)
self.total.assign(0)