diff --git a/numpyro/distributions/kl.py b/numpyro/distributions/kl.py index 3d79dca01c..cffa779c2e 100644 --- a/numpyro/distributions/kl.py +++ b/numpyro/distributions/kl.py @@ -27,9 +27,10 @@ from multipledispatch import dispatch -from jax import lax +from jax import lax, vmap import jax.numpy as jnp from jax.scipy.special import betaln, digamma, gammaln +from jax.scipy.linalg import solve_triangular from numpyro.distributions.continuous import ( Beta, @@ -37,6 +38,7 @@ Gamma, Kumaraswamy, Normal, + MultivariateNormal, Weibull, ) from numpyro.distributions.discrete import CategoricalProbs @@ -134,6 +136,66 @@ def kl_divergence(p, q): return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio)) +@dispatch(MultivariateNormal, MultivariateNormal) +def kl_divergence(p: MultivariateNormal, q: MultivariateNormal): + # cf https://statproofbook.github.io/P/mvn-kl.html + + if p.event_shape != q.event_shape: + raise ValueError( + "Distributions must be have the same event shape, but are" + f" {p.event_shape} and {q.event_shape} for p and q, respectively." + ) + + if p.batch_shape != q.batch_shape: + raise ValueError( + "Distributions must be have the same batch shape, but are" + f" {p.batch_shape} and {q.batch_shape} for p and q, respectively." + ) + + assert len(p.event_shape) == 1, "event_shape must be one-dimensional" + D = p.event_shape[0] + + assert p.mean.shape == p.batch_shape + p.event_shape + assert q.mean.shape == p.mean.shape + + def _single_mvn_kl(p_mean, p_scale_tril, q_mean, q_scale_tril): + assert jnp.ndim(p_mean) == 1 + assert jnp.ndim(q_mean) == 1 + assert jnp.ndim(p_scale_tril) == 2 + assert jnp.ndim(q_scale_tril) == 2 + + p_half_log_det = jnp.log( + jnp.diagonal(p_scale_tril) + ).sum(-1) + q_half_log_det = jnp.log( + jnp.diagonal(q_scale_tril) + ).sum(-1) + log_det_ratio = 2 * (p_half_log_det - q_half_log_det) + + Lq_inv = solve_triangular(q_scale_tril, jnp.eye(D), lower=True) + + tr = jnp.sum(jnp.diagonal( + Lq_inv.T @ (Lq_inv @ p_scale_tril) @ p_scale_tril.T + )) + + z = jnp.matmul(Lq_inv, (p_mean - q_mean)) + t1 = jnp.dot(z, z) + + return .5 * (tr + t1 - D - log_det_ratio) + + p_mean_flat = jnp.reshape(p.mean, (-1, D)) + p_scale_tril_flat = jnp.reshape(p.scale_tril, (-1, D, D)) + + q_mean_flat = jnp.reshape(q.mean, (-1, D)) + q_scale_tril_flat = jnp.reshape(q.scale_tril, (-1, D, D)) + + kl_flat = vmap(_single_mvn_kl)(p_mean_flat, p_scale_tril_flat, q_mean_flat, q_scale_tril_flat) + assert jnp.ndim(kl_flat) == 1 + + kl = jnp.reshape(kl_flat, p.batch_shape) + return kl + + @dispatch(Beta, Beta) def kl_divergence(p, q): # From https://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_(entropy) diff --git a/test/test_distributions.py b/test/test_distributions.py index 390880cac9..01e95b41fa 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2849,6 +2849,55 @@ def test_kl_expanded_normal(batch_shape, event_shape): assert_allclose(actual, expected) +@pytest.mark.parametrize("batch_shape", [(), (1,), (2, 3)], ids=str) +def test_kl_multivariate_normal_consistency_with_independent_normals(batch_shape): + event_shape = (5, ) + shape = batch_shape + event_shape + + def make_dists(): + mus = np.random.normal(size=shape) + scales = np.exp(np.random.normal(size=shape)) + scales = np.ones(shape) + + def diagonalize(v, ignore_axes: int): + if ignore_axes == 0: + return jnp.diag(v) + return vmap(diagonalize, in_axes=(0, None))(v, ignore_axes - 1) + scale_tril = diagonalize(scales, len(batch_shape)) + return ( + dist.Normal(mus, scales).to_event(len(event_shape)), + dist.MultivariateNormal(mus, scale_tril=scale_tril) + ) + + p_uni, p_mvn = make_dists() + q_uni, q_mvn = make_dists() + + actual = kl_divergence( + p_mvn, q_mvn + ) + expected = kl_divergence( + p_uni, q_uni + ) + assert_allclose(actual, expected, atol=1e-6) + + +def test_kl_multivariate_normal_nondiagonal_covariance(): + p_mvn = dist.MultivariateNormal(np.zeros(2), covariance_matrix=np.eye(2)) + q_mvn = dist.MultivariateNormal( + np.ones(2), + covariance_matrix=np.array([ + [2, .8], + [.8, .5] + ]) + ) + + actual = kl_divergence( + p_mvn, q_mvn + ) + expected = 3.21138 + assert_allclose(actual, expected, atol=2e-5) + + @pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str) @pytest.mark.parametrize( "p_dist, q_dist",