-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added kl_divergence for multivariate normals #1654
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,19 +27,20 @@ | |
|
||
from multipledispatch import dispatch | ||
|
||
from jax import lax, vmap | ||
from jax import lax | ||
import jax.numpy as jnp | ||
from jax.scipy.special import betaln, digamma, gammaln | ||
from jax.scipy.linalg import solve_triangular | ||
|
||
from numpyro.distributions.continuous import ( | ||
Beta, | ||
Dirichlet, | ||
Gamma, | ||
Kumaraswamy, | ||
Normal, | ||
MultivariateNormal, | ||
Normal, | ||
Weibull, | ||
_batch_solve_triangular, | ||
_batch_trace_from_cholesky, | ||
) | ||
from numpyro.distributions.discrete import CategoricalProbs | ||
from numpyro.distributions.distribution import ( | ||
|
@@ -142,58 +143,45 @@ def kl_divergence(p: MultivariateNormal, q: MultivariateNormal): | |
|
||
if p.event_shape != q.event_shape: | ||
raise ValueError( | ||
"Distributions must be have the same event shape, but are" | ||
"Distributions must have the same event shape, but are" | ||
f" {p.event_shape} and {q.event_shape} for p and q, respectively." | ||
) | ||
|
||
if p.batch_shape != q.batch_shape: | ||
min_batch_ndim = min(len(p.batch_shape), len(q.batch_shape)) | ||
if p.batch_shape[-min_batch_ndim:] != q.batch_shape[-min_batch_ndim:]: | ||
raise ValueError( | ||
"Distributions must be have the same batch shape, but are" | ||
f" {p.batch_shape} and {q.batch_shape} for p and q, respectively." | ||
"Distributions must have the same batch shape in common batch dimensions, " | ||
f"but are {p.batch_shape} and {q.batch_shape} for p and q," | ||
"respectively." | ||
) | ||
result_batch_shape = ( | ||
p.batch_shape if len(p.batch_shape) >= len(q.batch_shape) else q.batch_shape | ||
) | ||
|
||
assert len(p.event_shape) == 1, "event_shape must be one-dimensional" | ||
D = p.event_shape[0] | ||
|
||
assert p.mean.shape == p.batch_shape + p.event_shape | ||
assert q.mean.shape == p.mean.shape | ||
|
||
def _single_mvn_kl(p_mean, p_scale_tril, q_mean, q_scale_tril): | ||
assert jnp.ndim(p_mean) == 1 | ||
assert jnp.ndim(q_mean) == 1 | ||
assert jnp.ndim(p_scale_tril) == 2 | ||
assert jnp.ndim(q_scale_tril) == 2 | ||
assert q.mean.shape == q.batch_shape + q.event_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. those assertions are unnecessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed |
||
|
||
p_half_log_det = jnp.log( | ||
jnp.diagonal(p_scale_tril) | ||
).sum(-1) | ||
q_half_log_det = jnp.log( | ||
jnp.diagonal(q_scale_tril) | ||
).sum(-1) | ||
log_det_ratio = 2 * (p_half_log_det - q_half_log_det) | ||
p_half_log_det = jnp.log(jnp.diagonal(p.scale_tril, axis1=-2, axis2=-1)).sum(-1) | ||
assert p_half_log_det.shape == p.batch_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this assertion might not be true There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed |
||
|
||
Lq_inv = solve_triangular(q_scale_tril, jnp.eye(D), lower=True) | ||
q_half_log_det = jnp.log(jnp.diagonal(q.scale_tril, axis1=-2, axis2=-1)).sum(-1) | ||
assert q_half_log_det.shape == q.batch_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this assertion might not be true. In MultivariateNormal implementation, we avoid unnecessary broadcasting (e.g. we can have a batch of means with a single scale_tril). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, I wasn't thinking of that. Removed the assertion and added tests for those cases. |
||
|
||
tr = jnp.sum(jnp.diagonal( | ||
Lq_inv.T @ (Lq_inv @ p_scale_tril) @ p_scale_tril.T | ||
)) | ||
log_det_ratio = 2 * (p_half_log_det - q_half_log_det) | ||
assert log_det_ratio.shape == result_batch_shape | ||
|
||
z = jnp.matmul(Lq_inv, (p_mean - q_mean)) | ||
t1 = jnp.dot(z, z) | ||
Lq_inv = _batch_solve_triangular(q.scale_tril, jnp.eye(D)) | ||
|
||
return .5 * (tr + t1 - D - log_det_ratio) | ||
tr = _batch_trace_from_cholesky(Lq_inv @ p.scale_tril) | ||
assert tr.shape == result_batch_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this assertion might not be true. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
|
||
p_mean_flat = jnp.reshape(p.mean, (-1, D)) | ||
p_scale_tril_flat = jnp.reshape(p.scale_tril, (-1, D, D)) | ||
t1 = jnp.square(Lq_inv @ (p.loc - q.loc)[..., jnp.newaxis]).sum((-2, -1)) | ||
assert t1.shape == result_batch_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this assertion might not be true There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
|
||
q_mean_flat = jnp.reshape(q.mean, (-1, D)) | ||
q_scale_tril_flat = jnp.reshape(q.scale_tril, (-1, D, D)) | ||
|
||
kl_flat = vmap(_single_mvn_kl)(p_mean_flat, p_scale_tril_flat, q_mean_flat, q_scale_tril_flat) | ||
assert jnp.ndim(kl_flat) == 1 | ||
|
||
kl = jnp.reshape(kl_flat, p.batch_shape) | ||
return kl | ||
return 0.5 * (tr + t1 - D - log_det_ratio) | ||
|
||
|
||
@dispatch(Beta, Beta) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2849,12 +2849,24 @@ def test_kl_expanded_normal(batch_shape, event_shape): | |
assert_allclose(actual, expected) | ||
|
||
|
||
@pytest.mark.parametrize("batch_shape", [(), (1,), (2, 3)], ids=str) | ||
def test_kl_multivariate_normal_consistency_with_independent_normals(batch_shape): | ||
event_shape = (5, ) | ||
shape = batch_shape + event_shape | ||
@pytest.mark.parametrize( | ||
"batch_shape_p, batch_shape_q", | ||
[ | ||
((), ()), | ||
((1,), (1,)), | ||
((2, 3), (2, 3)), | ||
((5, 2, 3), (2, 3)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you change this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
((2, 3), (5, 2, 3)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
], | ||
ids=str, | ||
) | ||
def test_kl_multivariate_normal_consistency_with_independent_normals( | ||
batch_shape_p, batch_shape_q | ||
): | ||
event_shape = (5,) | ||
|
||
def make_dists(): | ||
def make_dists(batch_shape): | ||
shape = batch_shape + event_shape | ||
mus = np.random.normal(size=shape) | ||
scales = np.exp(np.random.normal(size=shape)) | ||
scales = np.ones(shape) | ||
|
@@ -2863,37 +2875,28 @@ def diagonalize(v, ignore_axes: int): | |
if ignore_axes == 0: | ||
return jnp.diag(v) | ||
return vmap(diagonalize, in_axes=(0, None))(v, ignore_axes - 1) | ||
|
||
scale_tril = diagonalize(scales, len(batch_shape)) | ||
return ( | ||
dist.Normal(mus, scales).to_event(len(event_shape)), | ||
dist.MultivariateNormal(mus, scale_tril=scale_tril) | ||
dist.MultivariateNormal(mus, scale_tril=scale_tril), | ||
) | ||
|
||
p_uni, p_mvn = make_dists() | ||
q_uni, q_mvn = make_dists() | ||
p_uni, p_mvn = make_dists(batch_shape_p) | ||
q_uni, q_mvn = make_dists(batch_shape_q) | ||
|
||
actual = kl_divergence( | ||
p_mvn, q_mvn | ||
) | ||
expected = kl_divergence( | ||
p_uni, q_uni | ||
) | ||
actual = kl_divergence(p_mvn, q_mvn) | ||
expected = kl_divergence(p_uni, q_uni) | ||
assert_allclose(actual, expected, atol=1e-6) | ||
|
||
|
||
def test_kl_multivariate_normal_nondiagonal_covariance(): | ||
p_mvn = dist.MultivariateNormal(np.zeros(2), covariance_matrix=np.eye(2)) | ||
q_mvn = dist.MultivariateNormal( | ||
np.ones(2), | ||
covariance_matrix=np.array([ | ||
[2, .8], | ||
[.8, .5] | ||
]) | ||
np.ones(2), covariance_matrix=np.array([[2, 0.8], [0.8, 0.5]]) | ||
) | ||
|
||
actual = kl_divergence( | ||
p_mvn, q_mvn | ||
) | ||
actual = kl_divergence(p_mvn, q_mvn) | ||
expected = 3.21138 | ||
assert_allclose(actual, expected, atol=2e-5) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about only assert that p.batch_shape and q.batch_shape can be broadcasted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done