Skip to content

Commit

Permalink
Raise error when using CircularReparam at observed site (#1856)
Browse files Browse the repository at this point in the history
* raise error when using circular reparam at observed site

* clean up
  • Loading branch information
fehiepsi authored Aug 26, 2024
1 parent d52209c commit f478772
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
1 change: 1 addition & 0 deletions numpyro/infer/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __call__(self, name, fn, obs):
if isinstance(support, constraints.independent):
support = fn.support.base_constraint
assert support is constraints.circular
assert obs is None, "CircularReparam does not support observe statements"

# Draw parameter-free noise.
new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)
Expand Down
10 changes: 10 additions & 0 deletions test/infer/test_reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,16 @@ def get_actual_probe(loc, concentration):
assert_allclose(actual_probe, expected_probe, atol=0.1)


def test_circular_reparam_no_observe():
def model():
numpyro.sample("x", dist.VonMises(0, 1), obs=0.5)

with numpyro.handlers.seed(rng_seed=0):
with numpyro.handlers.reparam(config={"x": CircularReparam()}):
with pytest.raises(AssertionError, match="not support observe"):
model()


_unconstrain_reparam = numpyro.infer.util._unconstrain_reparam


Expand Down

0 comments on commit f478772

Please sign in to comment.