diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index e3e5dd3ee..20b77ece9 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -8,7 +8,6 @@ import operator from jax import lax -from jax.lax import stop_gradient import jax.numpy as jnp import jax.random as random from jax.scipy import special @@ -309,11 +308,11 @@ class SineBivariateVonMises(Distribution): .. note:: Sample efficiency drops as .. math:: - \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1 + \frac{\rho}{\sqrt(\kappa_1\kappa_2)} \rightarrow 1 - because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation` - parameter with a skew away from one (e.g., - `TransformedDistribution(Beta(3,3), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation` + because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the + `weighted_correlation` parameter with a skew away from one (e.g., + `TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation` should be in [-1,1]. .. note:: The correlation and weighted_correlation params are mutually exclusive. @@ -407,10 +406,8 @@ def norm_const(self): jnp.log(jnp.clip(corr**2, jnp.finfo(jnp.result_type(float)).tiny)) - jnp.log(4 * jnp.prod(conc, axis=-1)) ) - num_I1terms = jnp.maximum( - 501, jnp.max(self.phi_concentration) + jnp.max(self.psi_concentration) - ).astype(int) - fs += log_I1(49, conc, terms=stop_gradient(num_I1terms)).sum(-1) + num_I1terms = 10_001 + fs += log_I1(49, conc, terms=num_I1terms).sum(-1) norm_const = 2 * jnp.log(jnp.array(2 * pi)) + logsumexp(fs, 0) return norm_const.reshape(jnp.shape(self.phi_loc)) diff --git a/test/test_distributions.py b/test/test_distributions.py index 1e2b3be05..54235cdaa 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3466,7 +3466,7 @@ def test_gaussian_random_walk_linear_recursive_equivalence(): assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2)) -@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0]) +@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0]) def test_sine_bivariate_von_mises_norm(conc): dist = SineBivariateVonMises(0, 0, conc, conc, 0.0) num_samples = 500