Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added kl_divergence for multivariate normals #1654

Merged
merged 3 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,16 @@ def _batch_solve_triangular(A, B):
return X


def _batch_trace_from_cholesky(L):
"""Computes the trace of matrix X given it's Cholesky decomposition matrix L.

:param jnp.ndarray(..., M, M) L: An array with lower triangular structure in the last two dimensions.

:return: Trace of X, where X = L L^T
"""
return jnp.square(L).sum((-1, -2))


class MatrixNormal(Distribution):
"""
Matrix variate normal distribution as described in [1] but with a lower_triangular parametrization,
Expand Down Expand Up @@ -1358,9 +1368,9 @@ def log_prob(self, values):
diff_col_solve = _batch_solve_triangular(
A=self.scale_tril_column, B=jnp.swapaxes(diff_row_solve, -2, -1)
)
batched_trace_term = jnp.square(
batched_trace_term = _batch_trace_from_cholesky(
diff_col_solve.reshape(diff_col_solve.shape[:-2] + (-1,))
).sum(-1)
)

log_prob = -0.5 * batched_trace_term - log_det_term

Expand Down
64 changes: 26 additions & 38 deletions numpyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,20 @@

from multipledispatch import dispatch

from jax import lax, vmap
from jax import lax
import jax.numpy as jnp
from jax.scipy.special import betaln, digamma, gammaln
from jax.scipy.linalg import solve_triangular

from numpyro.distributions.continuous import (
Beta,
Dirichlet,
Gamma,
Kumaraswamy,
Normal,
MultivariateNormal,
Normal,
Weibull,
_batch_solve_triangular,
_batch_trace_from_cholesky,
)
from numpyro.distributions.discrete import CategoricalProbs
from numpyro.distributions.distribution import (
Expand Down Expand Up @@ -142,58 +143,45 @@ def kl_divergence(p: MultivariateNormal, q: MultivariateNormal):

if p.event_shape != q.event_shape:
raise ValueError(
"Distributions must be have the same event shape, but are"
"Distributions must have the same event shape, but are"
f" {p.event_shape} and {q.event_shape} for p and q, respectively."
)

if p.batch_shape != q.batch_shape:
min_batch_ndim = min(len(p.batch_shape), len(q.batch_shape))
if p.batch_shape[-min_batch_ndim:] != q.batch_shape[-min_batch_ndim:]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about only assert that p.batch_shape and q.batch_shape can be broadcasted.

try:
    result_batch_shape = jnp.broadcast_shapes(p.batch_shape, q.batch_shape)
except ValueError:
    raise ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

raise ValueError(
"Distributions must be have the same batch shape, but are"
f" {p.batch_shape} and {q.batch_shape} for p and q, respectively."
"Distributions must have the same batch shape in common batch dimensions, "
f"but are {p.batch_shape} and {q.batch_shape} for p and q,"
"respectively."
)
result_batch_shape = (
p.batch_shape if len(p.batch_shape) >= len(q.batch_shape) else q.batch_shape
)

assert len(p.event_shape) == 1, "event_shape must be one-dimensional"
D = p.event_shape[0]

assert p.mean.shape == p.batch_shape + p.event_shape
assert q.mean.shape == p.mean.shape

def _single_mvn_kl(p_mean, p_scale_tril, q_mean, q_scale_tril):
assert jnp.ndim(p_mean) == 1
assert jnp.ndim(q_mean) == 1
assert jnp.ndim(p_scale_tril) == 2
assert jnp.ndim(q_scale_tril) == 2
assert q.mean.shape == q.batch_shape + q.event_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those assertions are unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


p_half_log_det = jnp.log(
jnp.diagonal(p_scale_tril)
).sum(-1)
q_half_log_det = jnp.log(
jnp.diagonal(q_scale_tril)
).sum(-1)
log_det_ratio = 2 * (p_half_log_det - q_half_log_det)
p_half_log_det = jnp.log(jnp.diagonal(p.scale_tril, axis1=-2, axis2=-1)).sum(-1)
assert p_half_log_det.shape == p.batch_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assertion might not be true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


Lq_inv = solve_triangular(q_scale_tril, jnp.eye(D), lower=True)
q_half_log_det = jnp.log(jnp.diagonal(q.scale_tril, axis1=-2, axis2=-1)).sum(-1)
assert q_half_log_det.shape == q.batch_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this assertion might not be true. In MultivariateNormal implementation, we avoid unnecessary broadcasting (e.g. we can have a batch of means with a single scale_tril).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I wasn't thinking of that. Removed the assertion and added tests for those cases.


tr = jnp.sum(jnp.diagonal(
Lq_inv.T @ (Lq_inv @ p_scale_tril) @ p_scale_tril.T
))
log_det_ratio = 2 * (p_half_log_det - q_half_log_det)
assert log_det_ratio.shape == result_batch_shape

z = jnp.matmul(Lq_inv, (p_mean - q_mean))
t1 = jnp.dot(z, z)
Lq_inv = _batch_solve_triangular(q.scale_tril, jnp.eye(D))

return .5 * (tr + t1 - D - log_det_ratio)
tr = _batch_trace_from_cholesky(Lq_inv @ p.scale_tril)
assert tr.shape == result_batch_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assertion might not be true.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


p_mean_flat = jnp.reshape(p.mean, (-1, D))
p_scale_tril_flat = jnp.reshape(p.scale_tril, (-1, D, D))
t1 = jnp.square(Lq_inv @ (p.loc - q.loc)[..., jnp.newaxis]).sum((-2, -1))
assert t1.shape == result_batch_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assertion might not be true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


q_mean_flat = jnp.reshape(q.mean, (-1, D))
q_scale_tril_flat = jnp.reshape(q.scale_tril, (-1, D, D))

kl_flat = vmap(_single_mvn_kl)(p_mean_flat, p_scale_tril_flat, q_mean_flat, q_scale_tril_flat)
assert jnp.ndim(kl_flat) == 1

kl = jnp.reshape(kl_flat, p.batch_shape)
return kl
return 0.5 * (tr + t1 - D - log_det_ratio)


@dispatch(Beta, Beta)
Expand Down
47 changes: 25 additions & 22 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2849,12 +2849,24 @@ def test_kl_expanded_normal(batch_shape, event_shape):
assert_allclose(actual, expected)


@pytest.mark.parametrize("batch_shape", [(), (1,), (2, 3)], ids=str)
def test_kl_multivariate_normal_consistency_with_independent_normals(batch_shape):
event_shape = (5, )
shape = batch_shape + event_shape
@pytest.mark.parametrize(
"batch_shape_p, batch_shape_q",
[
((), ()),
((1,), (1,)),
((2, 3), (2, 3)),
((5, 2, 3), (2, 3)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you change this to (5, 1, 3) and (2, 3)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

((2, 3), (5, 2, 3)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe ((1, 3), (5, 2, 3))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

],
ids=str,
)
def test_kl_multivariate_normal_consistency_with_independent_normals(
batch_shape_p, batch_shape_q
):
event_shape = (5,)

def make_dists():
def make_dists(batch_shape):
shape = batch_shape + event_shape
mus = np.random.normal(size=shape)
scales = np.exp(np.random.normal(size=shape))
scales = np.ones(shape)
Expand All @@ -2863,37 +2875,28 @@ def diagonalize(v, ignore_axes: int):
if ignore_axes == 0:
return jnp.diag(v)
return vmap(diagonalize, in_axes=(0, None))(v, ignore_axes - 1)

scale_tril = diagonalize(scales, len(batch_shape))
return (
dist.Normal(mus, scales).to_event(len(event_shape)),
dist.MultivariateNormal(mus, scale_tril=scale_tril)
dist.MultivariateNormal(mus, scale_tril=scale_tril),
)

p_uni, p_mvn = make_dists()
q_uni, q_mvn = make_dists()
p_uni, p_mvn = make_dists(batch_shape_p)
q_uni, q_mvn = make_dists(batch_shape_q)

actual = kl_divergence(
p_mvn, q_mvn
)
expected = kl_divergence(
p_uni, q_uni
)
actual = kl_divergence(p_mvn, q_mvn)
expected = kl_divergence(p_uni, q_uni)
assert_allclose(actual, expected, atol=1e-6)


def test_kl_multivariate_normal_nondiagonal_covariance():
p_mvn = dist.MultivariateNormal(np.zeros(2), covariance_matrix=np.eye(2))
q_mvn = dist.MultivariateNormal(
np.ones(2),
covariance_matrix=np.array([
[2, .8],
[.8, .5]
])
np.ones(2), covariance_matrix=np.array([[2, 0.8], [0.8, 0.5]])
)

actual = kl_divergence(
p_mvn, q_mvn
)
actual = kl_divergence(p_mvn, q_mvn)
expected = 3.21138
assert_allclose(actual, expected, atol=2e-5)

Expand Down
Loading