Skip to content

Commit

Permalink
style fixes
Browse files Browse the repository at this point in the history
making the linter tests happy
  • Loading branch information
lumip committed Oct 20, 2023
1 parent 4e3beae commit 1249cad
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 62 deletions.
14 changes: 12 additions & 2 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,16 @@ def _batch_solve_triangular(A, B):
return X


def _batch_trace_from_cholesky(L):
"""Computes the trace of matrix X given it's Cholesky decomposition matrix L.
:param jnp.ndarray(..., M, M) L: An array with lower triangular structure in the last two dimensions.
:return: Trace of X, where X = L L^T
"""
return jnp.square(L).sum((-1, -2))


class MatrixNormal(Distribution):
"""
Matrix variate normal distribution as described in [1] but with a lower_triangular parametrization,
Expand Down Expand Up @@ -1358,9 +1368,9 @@ def log_prob(self, values):
diff_col_solve = _batch_solve_triangular(
A=self.scale_tril_column, B=jnp.swapaxes(diff_row_solve, -2, -1)
)
batched_trace_term = jnp.square(
batched_trace_term = _batch_trace_from_cholesky(
diff_col_solve.reshape(diff_col_solve.shape[:-2] + (-1,))
).sum(-1)
)

log_prob = -0.5 * batched_trace_term - log_det_term

Expand Down
64 changes: 26 additions & 38 deletions numpyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,20 @@

from multipledispatch import dispatch

from jax import lax, vmap
from jax import lax
import jax.numpy as jnp
from jax.scipy.special import betaln, digamma, gammaln
from jax.scipy.linalg import solve_triangular

from numpyro.distributions.continuous import (
Beta,
Dirichlet,
Gamma,
Kumaraswamy,
Normal,
MultivariateNormal,
Normal,
Weibull,
_batch_solve_triangular,
_batch_trace_from_cholesky,
)
from numpyro.distributions.discrete import CategoricalProbs
from numpyro.distributions.distribution import (
Expand Down Expand Up @@ -142,58 +143,45 @@ def kl_divergence(p: MultivariateNormal, q: MultivariateNormal):

if p.event_shape != q.event_shape:
raise ValueError(
"Distributions must be have the same event shape, but are"
"Distributions must have the same event shape, but are"
f" {p.event_shape} and {q.event_shape} for p and q, respectively."
)

if p.batch_shape != q.batch_shape:
min_batch_ndim = min(len(p.batch_shape), len(q.batch_shape))
if p.batch_shape[-min_batch_ndim:] != q.batch_shape[-min_batch_ndim:]:
raise ValueError(
"Distributions must be have the same batch shape, but are"
f" {p.batch_shape} and {q.batch_shape} for p and q, respectively."
"Distributions must have the same batch shape in common batch dimensions, "
f"but are {p.batch_shape} and {q.batch_shape} for p and q,"
"respectively."
)
result_batch_shape = (
p.batch_shape if len(p.batch_shape) >= len(q.batch_shape) else q.batch_shape
)

assert len(p.event_shape) == 1, "event_shape must be one-dimensional"
D = p.event_shape[0]

assert p.mean.shape == p.batch_shape + p.event_shape
assert q.mean.shape == p.mean.shape

def _single_mvn_kl(p_mean, p_scale_tril, q_mean, q_scale_tril):
assert jnp.ndim(p_mean) == 1
assert jnp.ndim(q_mean) == 1
assert jnp.ndim(p_scale_tril) == 2
assert jnp.ndim(q_scale_tril) == 2
assert q.mean.shape == q.batch_shape + q.event_shape

p_half_log_det = jnp.log(
jnp.diagonal(p_scale_tril)
).sum(-1)
q_half_log_det = jnp.log(
jnp.diagonal(q_scale_tril)
).sum(-1)
log_det_ratio = 2 * (p_half_log_det - q_half_log_det)
p_half_log_det = jnp.log(jnp.diagonal(p.scale_tril, axis1=-2, axis2=-1)).sum(-1)
assert p_half_log_det.shape == p.batch_shape

Lq_inv = solve_triangular(q_scale_tril, jnp.eye(D), lower=True)
q_half_log_det = jnp.log(jnp.diagonal(q.scale_tril, axis1=-2, axis2=-1)).sum(-1)
assert q_half_log_det.shape == q.batch_shape

tr = jnp.sum(jnp.diagonal(
Lq_inv.T @ (Lq_inv @ p_scale_tril) @ p_scale_tril.T
))
log_det_ratio = 2 * (p_half_log_det - q_half_log_det)
assert log_det_ratio.shape == result_batch_shape

z = jnp.matmul(Lq_inv, (p_mean - q_mean))
t1 = jnp.dot(z, z)
Lq_inv = _batch_solve_triangular(q.scale_tril, jnp.eye(D))

return .5 * (tr + t1 - D - log_det_ratio)
tr = _batch_trace_from_cholesky(Lq_inv @ p.scale_tril)
assert tr.shape == result_batch_shape

p_mean_flat = jnp.reshape(p.mean, (-1, D))
p_scale_tril_flat = jnp.reshape(p.scale_tril, (-1, D, D))
t1 = jnp.square(Lq_inv @ (p.loc - q.loc)[..., jnp.newaxis]).sum((-2, -1))
assert t1.shape == result_batch_shape

q_mean_flat = jnp.reshape(q.mean, (-1, D))
q_scale_tril_flat = jnp.reshape(q.scale_tril, (-1, D, D))

kl_flat = vmap(_single_mvn_kl)(p_mean_flat, p_scale_tril_flat, q_mean_flat, q_scale_tril_flat)
assert jnp.ndim(kl_flat) == 1

kl = jnp.reshape(kl_flat, p.batch_shape)
return kl
return 0.5 * (tr + t1 - D - log_det_ratio)


@dispatch(Beta, Beta)
Expand Down
47 changes: 25 additions & 22 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2849,12 +2849,24 @@ def test_kl_expanded_normal(batch_shape, event_shape):
assert_allclose(actual, expected)


@pytest.mark.parametrize("batch_shape", [(), (1,), (2, 3)], ids=str)
def test_kl_multivariate_normal_consistency_with_independent_normals(batch_shape):
event_shape = (5, )
shape = batch_shape + event_shape
@pytest.mark.parametrize(
"batch_shape_p, batch_shape_q",
[
((), ()),
((1,), (1,)),
((2, 3), (2, 3)),
((5, 2, 3), (2, 3)),
((2, 3), (5, 2, 3)),
],
ids=str,
)
def test_kl_multivariate_normal_consistency_with_independent_normals(
batch_shape_p, batch_shape_q
):
event_shape = (5,)

def make_dists():
def make_dists(batch_shape):
shape = batch_shape + event_shape
mus = np.random.normal(size=shape)
scales = np.exp(np.random.normal(size=shape))
scales = np.ones(shape)
Expand All @@ -2863,37 +2875,28 @@ def diagonalize(v, ignore_axes: int):
if ignore_axes == 0:
return jnp.diag(v)
return vmap(diagonalize, in_axes=(0, None))(v, ignore_axes - 1)

scale_tril = diagonalize(scales, len(batch_shape))
return (
dist.Normal(mus, scales).to_event(len(event_shape)),
dist.MultivariateNormal(mus, scale_tril=scale_tril)
dist.MultivariateNormal(mus, scale_tril=scale_tril),
)

p_uni, p_mvn = make_dists()
q_uni, q_mvn = make_dists()
p_uni, p_mvn = make_dists(batch_shape_p)
q_uni, q_mvn = make_dists(batch_shape_q)

actual = kl_divergence(
p_mvn, q_mvn
)
expected = kl_divergence(
p_uni, q_uni
)
actual = kl_divergence(p_mvn, q_mvn)
expected = kl_divergence(p_uni, q_uni)
assert_allclose(actual, expected, atol=1e-6)


def test_kl_multivariate_normal_nondiagonal_covariance():
p_mvn = dist.MultivariateNormal(np.zeros(2), covariance_matrix=np.eye(2))
q_mvn = dist.MultivariateNormal(
np.ones(2),
covariance_matrix=np.array([
[2, .8],
[.8, .5]
])
np.ones(2), covariance_matrix=np.array([[2, 0.8], [0.8, 0.5]])
)

actual = kl_divergence(
p_mvn, q_mvn
)
actual = kl_divergence(p_mvn, q_mvn)
expected = 3.21138
assert_allclose(actual, expected, atol=2e-5)

Expand Down

0 comments on commit 1249cad

Please sign in to comment.