From ae49da23db4f126e4ff6c3cd0f154509ded924bc Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Wed, 2 Oct 2024 20:26:20 +0500 Subject: [PATCH] Refactor log_prob method in _MixtureBase class to handle negative infinity values in sum_log_probs --- numpyro/distributions/mixtures.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/mixtures.py b/numpyro/distributions/mixtures.py index 21b00d4a8..19a821c89 100644 --- a/numpyro/distributions/mixtures.py +++ b/numpyro/distributions/mixtures.py @@ -149,7 +149,10 @@ def sample(self, key, sample_shape=()): def log_prob(self, value, intermediates=None): del intermediates sum_log_probs = self.component_log_probs(value) - return jax.nn.logsumexp(sum_log_probs, axis=-1) + safe_sum_log_probs = jnp.where( + jnp.isneginf(sum_log_probs), -jnp.inf, sum_log_probs + ) + return jax.nn.logsumexp(safe_sum_log_probs, axis=-1) class MixtureSameFamily(_MixtureBase):