From 81581ff09d5467f8ac7e4de69a39a551c7ae185a Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Wed, 23 Oct 2024 11:54:37 -0400 Subject: [PATCH] Correct event dimensions for `ReshapeTransform`. --- numpyro/distributions/transforms.py | 12 ++++++++---- test/test_transforms.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 05ae19efd9..a1aee0015b 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1209,10 +1209,6 @@ class ReshapeTransform(Transform): :param inverse_shape: Shape of the sample for the inverse transform. """ - domain = constraints.real - codomain = constraints.real - sign = 1 - def __init__(self, forward_shape, inverse_shape) -> None: forward_size = math.prod(forward_shape) inverse_size = math.prod(inverse_shape) @@ -1224,6 +1220,14 @@ def __init__(self, forward_shape, inverse_shape) -> None: self._forward_shape = forward_shape self._inverse_shape = inverse_shape + @property + def domain(self) -> constraints.Constraint: + return constraints.independent(constraints.real, len(self._inverse_shape)) + + @property + def codomain(self) -> constraints.Constraint: + return constraints.independent(constraints.real, len(self._forward_shape)) + def forward_shape(self, shape): return _get_target_shape(shape, self._forward_shape, self._inverse_shape) diff --git a/test/test_transforms.py b/test/test_transforms.py index 9d8c659469..beff83b8c2 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -404,9 +404,27 @@ def test_biject_to(constraint, shape): [ CorrCholeskyTransform(), CorrCholeskyTransform().inv, + ReshapeTransform((3, 4), (12,)), + ReshapeTransform((12,), (3, 4)), ], ) def test_compose_domain_codomain(transform): composed = ComposeTransform([transform]) assert transform.domain.event_dim == composed.domain.event_dim assert transform.codomain.event_dim == composed.codomain.event_dim + + +def test_compose_sequence_domain_codomain(): + parts = [ + CorrCholeskyTransform(), # 1 to 2 + ReshapeTransform((3, 4, 12), (12, 12)), # 2 to 3 + AffineTransform(0, 1), # 3 to 3 + ReshapeTransform((12, 12), (144,)), # 3 to 1 + ] + x = jnp.zeros(66) + for i, event_dim in enumerate([2, 3, 3]): + composed = ComposeTransform(parts[: i + 1]) + y = composed(x) + assert y.ndim == event_dim + assert composed.codomain.event_dim == event_dim + assert composed.domain.event_dim == 1