Skip to content

Commit

Permalink
Correct event dimensions for ReshapeTransform. (#1895)
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann authored Oct 28, 2024
1 parent b611610 commit 6d5e508
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
12 changes: 8 additions & 4 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,10 +1209,6 @@ class ReshapeTransform(Transform):
:param inverse_shape: Shape of the sample for the inverse transform.
"""

domain = constraints.real
codomain = constraints.real
sign = 1

def __init__(self, forward_shape, inverse_shape) -> None:
forward_size = math.prod(forward_shape)
inverse_size = math.prod(inverse_shape)
Expand All @@ -1224,6 +1220,14 @@ def __init__(self, forward_shape, inverse_shape) -> None:
self._forward_shape = forward_shape
self._inverse_shape = inverse_shape

@property
def domain(self) -> constraints.Constraint:
return constraints.independent(constraints.real, len(self._inverse_shape))

@property
def codomain(self) -> constraints.Constraint:
return constraints.independent(constraints.real, len(self._forward_shape))

def forward_shape(self, shape):
return _get_target_shape(shape, self._forward_shape, self._inverse_shape)

Expand Down
18 changes: 18 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,27 @@ def test_biject_to(constraint, shape):
[
CorrCholeskyTransform(),
CorrCholeskyTransform().inv,
ReshapeTransform((3, 4), (12,)),
ReshapeTransform((12,), (3, 4)),
],
)
def test_compose_domain_codomain(transform):
composed = ComposeTransform([transform])
assert transform.domain.event_dim == composed.domain.event_dim
assert transform.codomain.event_dim == composed.codomain.event_dim


def test_compose_sequence_domain_codomain():
parts = [
CorrCholeskyTransform(), # 1 to 2
ReshapeTransform((3, 4, 12), (12, 12)), # 2 to 3
AffineTransform(0, 1), # 3 to 3
ReshapeTransform((12, 12), (144,)), # 3 to 1
]
x = jnp.zeros(66)
for i, event_dim in enumerate([2, 3, 3]):
composed = ComposeTransform(parts[: i + 1])
y = composed(x)
assert y.ndim == event_dim
assert composed.codomain.event_dim == event_dim
assert composed.domain.event_dim == 1

0 comments on commit 6d5e508

Please sign in to comment.