diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 5df46dade..f5b7b81e6 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -1276,6 +1276,16 @@ def _batch_solve_triangular(A, B): return X +def _batch_trace_from_cholesky(L): + """Computes the trace of matrix X given it's Cholesky decomposition matrix L. + + :param jnp.ndarray(..., M, M) L: An array with lower triangular structure in the last two dimensions. + + :return: Trace of X, where X = L L^T + """ + return jnp.square(L).sum((-1, -2)) + + class MatrixNormal(Distribution): """ Matrix variate normal distribution as described in [1] but with a lower_triangular parametrization, @@ -1358,9 +1368,7 @@ def log_prob(self, values): diff_col_solve = _batch_solve_triangular( A=self.scale_tril_column, B=jnp.swapaxes(diff_row_solve, -2, -1) ) - batched_trace_term = jnp.square( - diff_col_solve.reshape(diff_col_solve.shape[:-2] + (-1,)) - ).sum(-1) + batched_trace_term = _batch_trace_from_cholesky(diff_col_solve) log_prob = -0.5 * batched_trace_term - log_det_term diff --git a/numpyro/distributions/kl.py b/numpyro/distributions/kl.py index 3d79dca01..1f43ba5ae 100644 --- a/numpyro/distributions/kl.py +++ b/numpyro/distributions/kl.py @@ -36,8 +36,11 @@ Dirichlet, Gamma, Kumaraswamy, + MultivariateNormal, Normal, Weibull, + _batch_solve_triangular, + _batch_trace_from_cholesky, ) from numpyro.distributions.discrete import CategoricalProbs from numpyro.distributions.distribution import ( @@ -134,6 +137,52 @@ def kl_divergence(p, q): return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio)) +@dispatch(MultivariateNormal, MultivariateNormal) +def kl_divergence(p: MultivariateNormal, q: MultivariateNormal): + # cf https://statproofbook.github.io/P/mvn-kl.html + + def _shapes_are_broadcastable(first_shape, second_shape): + try: + jnp.broadcast_shapes(first_shape, second_shape) + return True + except ValueError: + return False + + if p.event_shape != q.event_shape: + raise ValueError( + "Distributions must have the same event shape, but are" + f" {p.event_shape} and {q.event_shape} for p and q, respectively." + ) + + try: + result_batch_shape = jnp.broadcast_shapes(p.batch_shape, q.batch_shape) + except ValueError as ve: + raise ValueError( + "Distributions must have broadcastble batch shapes, " + f"but have {p.batch_shape} and {q.batch_shape} for p and q," + "respectively." + ) from ve + + assert len(p.event_shape) == 1, "event_shape must be one-dimensional" + D = p.event_shape[0] + + p_half_log_det = jnp.log(jnp.diagonal(p.scale_tril, axis1=-2, axis2=-1)).sum(-1) + q_half_log_det = jnp.log(jnp.diagonal(q.scale_tril, axis1=-2, axis2=-1)).sum(-1) + + log_det_ratio = 2 * (p_half_log_det - q_half_log_det) + assert _shapes_are_broadcastable(log_det_ratio.shape, result_batch_shape) + + Lq_inv = _batch_solve_triangular(q.scale_tril, jnp.eye(D)) + + tr = _batch_trace_from_cholesky(Lq_inv @ p.scale_tril) + assert _shapes_are_broadcastable(tr.shape, result_batch_shape) + + t1 = jnp.square(Lq_inv @ (p.loc - q.loc)[..., jnp.newaxis]).sum((-2, -1)) + assert _shapes_are_broadcastable(t1.shape, result_batch_shape) + + return 0.5 * (tr + t1 - D - log_det_ratio) + + @dispatch(Beta, Beta) def kl_divergence(p, q): # From https://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_(entropy) diff --git a/test/test_distributions.py b/test/test_distributions.py index 390880cac..4e278646b 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2849,6 +2849,68 @@ def test_kl_expanded_normal(batch_shape, event_shape): assert_allclose(actual, expected) +@pytest.mark.parametrize( + "batch_shape_p, batch_shape_q", + [ + ((1,), (1,)), + ((2, 3), (2, 3)), + ((5, 1, 3), (2, 3)), + ((1, 3), (5, 2, 3)), + ], + ids=str, +) +@pytest.mark.parametrize("single_scale_p", [False, True], ids=str) +@pytest.mark.parametrize("single_loc_p", [False, True], ids=str) +@pytest.mark.parametrize("single_scale_q", [False, True], ids=str) +@pytest.mark.parametrize("single_loc_q", [False, True], ids=str) +def test_kl_multivariate_normal_consistency_with_independent_normals( + batch_shape_p, + batch_shape_q, + single_scale_p, + single_loc_p, + single_scale_q, + single_loc_q, +): + event_shape = (5,) + + def make_dists(loc_batch_shape, scales_batch_shape): + mus = np.random.normal(size=loc_batch_shape + event_shape) + scales = np.exp(np.random.normal(size=scales_batch_shape + event_shape) * 0.1) + + def diagonalize(v, ignore_axes: int): + if ignore_axes == 0: + return jnp.diag(v) + return vmap(diagonalize, in_axes=(0, None))(v, ignore_axes - 1) + + scale_tril = diagonalize(scales, len(scales_batch_shape)) + return ( + dist.Normal(mus, scales).to_event(len(event_shape)), + dist.MultivariateNormal(mus, scale_tril=scale_tril), + ) + + p_uni, p_mvn = make_dists( + () if single_loc_p else batch_shape_p, () if single_scale_p else batch_shape_p + ) + q_uni, q_mvn = make_dists( + () if single_loc_q else batch_shape_q, () if single_scale_q else batch_shape_q + ) + + actual = kl_divergence(p_mvn, q_mvn) + expected = kl_divergence(p_uni, q_uni) + assert_allclose(actual, expected, atol=1e-5) + + +def test_kl_multivariate_normal_nondiagonal_covariance(): + p_mvn = dist.MultivariateNormal(np.zeros(2), covariance_matrix=np.eye(2)) + q_mvn = dist.MultivariateNormal( + np.ones(2), covariance_matrix=np.array([[2, 0.8], [0.8, 0.5]]) + ) + + actual = kl_divergence(p_mvn, q_mvn) + expected = 3.21138 + assert_allclose(actual, expected, atol=2e-5) + + @pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str) @pytest.mark.parametrize( "p_dist, q_dist",