Skip to content

Commit

Permalink
Merge pull request #21 from calico/multinomial
Browse files Browse the repository at this point in the history
avoid tensorflow reduce mean
  • Loading branch information
davek44 authored Mar 28, 2024
2 parents 2e58034 + 74c7f78 commit 9625573
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions src/baskerville/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
# Losses
################################################################################
def mean_squared_error_udot(y_true, y_pred, udot_weight: float = 1):
"""Mean squared error with mean-normalized specificity term."""
"""Mean squared error with mean-normalized specificity term.
Args:
udot_weight: Weight of the mean-normalized specificity term.
"""
mse_term = tf.keras.losses.mean_squared_error(y_true, y_pred)

yn_true = y_true - tf.math.reduce_mean(y_true, axis=-1, keepdims=True)
Expand All @@ -43,7 +47,7 @@ class MeanSquaredErrorUDot(LossFunctionWrapper):
"""Mean squared error with mean-normalized specificity term.
Args:
udot_weight: Weight of the mean-normalized specificity term.
udot_weight: Weight of the mean-normalized specificity term.
"""

def __init__(
Expand All @@ -59,7 +63,13 @@ def __init__(
)


def poisson_kl(y_true, y_pred, kl_weight=1, epsilon=1e-3):
def poisson_kl(y_true, y_pred, kl_weight=1, epsilon=1e-7):
"""Poisson decomposition with KL specificity term.
Args:
kl_weight (float): Weight of the KL specificity term.
epsilon (float): Added small value to avoid log(0).
"""
# poisson loss
poisson_term = tf.keras.losses.poisson(y_true, y_pred)

Expand Down Expand Up @@ -96,36 +106,42 @@ def __init__(
super(PoissonKL, self).__init__(pois_kl, name=name, reduction=reduction)


def poisson(yt, yp, epsilon: float = 1e-7):
"""Poisson loss, without mean reduction."""
return yp - yt * tf.math.log(yp + epsilon)


def poisson_multinomial(
y_true,
y_pred,
total_weight: float = 1,
epsilon: float = 1e-6,
epsilon: float = 1e-7,
rescale: bool = False,
):
"""Possion decomposition with multinomial specificity term.
Args:
total_weight (float): Weight of the Poisson total term.
epsilon (float): Added small value to avoid log(0).
total_weight (float): Weight of the Poisson total term.
epsilon (float): Added small value to avoid log(0).
rescale (bool): Rescale loss after re-weighting.
"""
seq_len = y_true.shape[1]

# add epsilon to protect against tiny values
y_true += epsilon
y_pred += epsilon

# sum across lengths
s_true = tf.math.reduce_sum(y_true, axis=-2, keepdims=True)
s_pred = tf.math.reduce_sum(y_pred, axis=-2, keepdims=True)

# total count poisson loss, mean across targets
poisson_term = poisson(s_true, s_pred) # B x T
poisson_term /= seq_len

# add epsilon to protect against tiny values
y_true += epsilon
y_pred += epsilon

# normalize to sum to one
p_pred = y_pred / s_pred

# total count poisson loss
poisson_term = tf.keras.losses.poisson(s_true, s_pred) # B x T
poisson_term /= seq_len

# multinomial loss
pl_pred = tf.math.log(p_pred) # B x L x T
multinomial_dot = -tf.math.multiply(y_true, pl_pred) # B x L x T
Expand All @@ -147,7 +163,6 @@ class PoissonMultinomial(LossFunctionWrapper):
Args:
total_weight (float): Weight of the Poisson total term.
epsilon (float): Added small value to avoid log(0).
"""

def __init__(
Expand Down

0 comments on commit 9625573

Please sign in to comment.