Skip to content

Commit

Permalink
fix continous poisson log prob
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed May 10, 2024
1 parent b950146 commit 23d4d72
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/evermore/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from jax.scipy.special import gammaln, xlogy
from jaxtyping import Array, PRNGKeyArray

__all__ = [
Expand Down Expand Up @@ -46,8 +47,11 @@ class Poisson(PDF):
lamb: Array = eqx.field(converter=jnp.atleast_1d)

def log_prob(self, x: Array) -> Array:
logpdf_max = jax.scipy.stats.poisson.logpmf(self.lamb, mu=self.lamb)
unnormalized = jax.scipy.stats.poisson.logpmf((x + 1) * self.lamb, mu=self.lamb)
def _continous_poisson_log_prob(x, lamb):
return xlogy(x, lamb) - lamb - gammaln(x + 1)

logpdf_max = _continous_poisson_log_prob(self.lamb, self.lamb)
unnormalized = _continous_poisson_log_prob((x + 1) * self.lamb, self.lamb)
return unnormalized - logpdf_max

def sample(self, key: PRNGKeyArray) -> Array:
Expand Down

0 comments on commit 23d4d72

Please sign in to comment.