diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 8156855b1..e3e5dd3ee 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -8,6 +8,7 @@ import operator from jax import lax +from jax.lax import stop_gradient import jax.numpy as jnp import jax.random as random from jax.scipy import special @@ -311,7 +312,9 @@ class SineBivariateVonMises(Distribution): \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1 because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation` - parameter with a skew away from one (e.g., Beta(1,3)). The `weighted_correlation` should be in [0,1]. + parameter with a skew away from one (e.g., + `TransformedDistribution(Beta(3,3), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation` + should be in [-1,1]. .. note:: The correlation and weighted_correlation params are mutually exclusive. @@ -404,7 +407,10 @@ def norm_const(self): jnp.log(jnp.clip(corr**2, jnp.finfo(jnp.result_type(float)).tiny)) - jnp.log(4 * jnp.prod(conc, axis=-1)) ) - fs += log_I1(49, conc, terms=51).sum(-1) + num_I1terms = jnp.maximum( + 501, jnp.max(self.phi_concentration) + jnp.max(self.psi_concentration) + ).astype(int) + fs += log_I1(49, conc, terms=stop_gradient(num_I1terms)).sum(-1) norm_const = 2 * jnp.log(jnp.array(2 * pi)) + logsumexp(fs, 0) return norm_const.reshape(jnp.shape(self.phi_loc)) diff --git a/test/test_distributions.py b/test/test_distributions.py index 20da4165d..1e2b3be05 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3464,3 +3464,16 @@ def test_gaussian_random_walk_linear_recursive_equivalence(): x2 = dist2.sample(random.PRNGKey(7)) assert jnp.allclose(x1, x2.squeeze()) assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2)) + + +@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0]) +def test_sine_bivariate_von_mises_norm(conc): + dist = SineBivariateVonMises(0, 0, conc, conc, 0.0) + num_samples = 500 + x = jnp.linspace(-jnp.pi, jnp.pi, num_samples) + y = jnp.linspace(-jnp.pi, jnp.pi, num_samples) + mesh = jnp.stack(jnp.meshgrid(x, y), axis=-1) + integral_torus = ( + jnp.exp(dist.log_prob(mesh)) * (2 * jnp.pi) ** 2 / num_samples**2 + ).sum() + assert jnp.allclose(integral_torus, 1.0, rtol=1e-2)