Skip to content

Commit

Permalink
fixed SBVM norm_const
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning committed Nov 5, 2024
1 parent 6d5e508 commit 9f41e00
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
10 changes: 8 additions & 2 deletions numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import operator

from jax import lax
from jax.lax import stop_gradient
import jax.numpy as jnp
import jax.random as random
from jax.scipy import special
Expand Down Expand Up @@ -311,7 +312,9 @@ class SineBivariateVonMises(Distribution):
\frac{\rho}{\kappa_1\kappa_2} \rightarrow 1
because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation`
parameter with a skew away from one (e.g., Beta(1,3)). The `weighted_correlation` should be in [0,1].
parameter with a skew away from one (e.g.,
`TransformedDistribution(Beta(3,3), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation`
should be in [-1,1].
.. note:: The correlation and weighted_correlation params are mutually exclusive.
Expand Down Expand Up @@ -404,7 +407,10 @@ def norm_const(self):
jnp.log(jnp.clip(corr**2, jnp.finfo(jnp.result_type(float)).tiny))
- jnp.log(4 * jnp.prod(conc, axis=-1))
)
fs += log_I1(49, conc, terms=51).sum(-1)
num_I1terms = jnp.maximum(
501, jnp.max(self.phi_concentration) + jnp.max(self.psi_concentration)
).astype(int)
fs += log_I1(49, conc, terms=stop_gradient(num_I1terms)).sum(-1)
norm_const = 2 * jnp.log(jnp.array(2 * pi)) + logsumexp(fs, 0)
return norm_const.reshape(jnp.shape(self.phi_loc))

Expand Down
13 changes: 13 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3464,3 +3464,16 @@ def test_gaussian_random_walk_linear_recursive_equivalence():
x2 = dist2.sample(random.PRNGKey(7))
assert jnp.allclose(x1, x2.squeeze())
assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2))


@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0])
def test_sine_bivariate_von_mises_norm(conc):
dist = SineBivariateVonMises(0, 0, conc, conc, 0.0)
num_samples = 500
x = jnp.linspace(-jnp.pi, jnp.pi, num_samples)
y = jnp.linspace(-jnp.pi, jnp.pi, num_samples)
mesh = jnp.stack(jnp.meshgrid(x, y), axis=-1)
integral_torus = (
jnp.exp(dist.log_prob(mesh)) * (2 * jnp.pi) ** 2 / num_samples**2
).sum()
assert jnp.allclose(integral_torus, 1.0, rtol=1e-2)

0 comments on commit 9f41e00

Please sign in to comment.