From de8aeaf0a34e78523b14a3745f22fd32e6f29a61 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sat, 16 Mar 2024 16:07:23 -0700 Subject: [PATCH 1/2] avoid tensorflow reduce mean --- src/baskerville/metrics.py | 45 +++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/src/baskerville/metrics.py b/src/baskerville/metrics.py index 29c0d99..dac22bb 100644 --- a/src/baskerville/metrics.py +++ b/src/baskerville/metrics.py @@ -29,7 +29,11 @@ # Losses ################################################################################ def mean_squared_error_udot(y_true, y_pred, udot_weight: float = 1): - """Mean squared error with mean-normalized specificity term.""" + """Mean squared error with mean-normalized specificity term. + + Args: + udot_weight: Weight of the mean-normalized specificity term. + """ mse_term = tf.keras.losses.mean_squared_error(y_true, y_pred) yn_true = y_true - tf.math.reduce_mean(y_true, axis=-1, keepdims=True) @@ -43,7 +47,7 @@ class MeanSquaredErrorUDot(LossFunctionWrapper): """Mean squared error with mean-normalized specificity term. Args: - udot_weight: Weight of the mean-normalized specificity term. + udot_weight: Weight of the mean-normalized specificity term. """ def __init__( @@ -59,7 +63,13 @@ def __init__( ) -def poisson_kl(y_true, y_pred, kl_weight=1, epsilon=1e-3): +def poisson_kl(y_true, y_pred, kl_weight=1, epsilon=1e-7): + """Poisson decomposition with KL specificity term. + + Args: + kl_weight (float): Weight of the KL specificity term. + epsilon (float): Added small value to avoid log(0). + """ # poisson loss poisson_term = tf.keras.losses.poisson(y_true, y_pred) @@ -96,36 +106,42 @@ def __init__( super(PoissonKL, self).__init__(pois_kl, name=name, reduction=reduction) +def poisson(yt, yp, epsilon: float = 1e-7): + """Poisson loss, without mean reduction.""" + return yp - yt * tf.math.log(yp + epsilon) + + def poisson_multinomial( y_true, y_pred, total_weight: float = 1, - epsilon: float = 1e-6, + epsilon: float = 1e-7, rescale: bool = False, ): """Possion decomposition with multinomial specificity term. Args: - total_weight (float): Weight of the Poisson total term. - epsilon (float): Added small value to avoid log(0). + total_weight (float): Weight of the Poisson total term. + epsilon (float): Added small value to avoid log(0). + rescale (bool): Rescale loss after re-weighting. """ seq_len = y_true.shape[1] - # add epsilon to protect against tiny values - y_true += epsilon - y_pred += epsilon - # sum across lengths s_true = tf.math.reduce_sum(y_true, axis=-2, keepdims=True) s_pred = tf.math.reduce_sum(y_pred, axis=-2, keepdims=True) + # total count poisson loss, mean across targets + poisson_term = poisson(s_true, s_pred) # B x T + poisson_term /= seq_len + + # add epsilon to protect against tiny values + y_true += epsilon + y_pred += epsilon + # normalize to sum to one p_pred = y_pred / s_pred - # total count poisson loss - poisson_term = tf.keras.losses.poisson(s_true, s_pred) # B x T - poisson_term /= seq_len - # multinomial loss pl_pred = tf.math.log(p_pred) # B x L x T multinomial_dot = -tf.math.multiply(y_true, pl_pred) # B x L x T @@ -147,7 +163,6 @@ class PoissonMultinomial(LossFunctionWrapper): Args: total_weight (float): Weight of the Poisson total term. - epsilon (float): Added small value to avoid log(0). """ def __init__( From 74c7f782b5fd6bd32a665db945030dea51576064 Mon Sep 17 00:00:00 2001 From: lruizcalico Date: Sun, 17 Mar 2024 11:43:43 -0700 Subject: [PATCH 2/2] black format --- src/baskerville/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/baskerville/metrics.py b/src/baskerville/metrics.py index dac22bb..8aa549f 100644 --- a/src/baskerville/metrics.py +++ b/src/baskerville/metrics.py @@ -30,7 +30,7 @@ ################################################################################ def mean_squared_error_udot(y_true, y_pred, udot_weight: float = 1): """Mean squared error with mean-normalized specificity term. - + Args: udot_weight: Weight of the mean-normalized specificity term. """