Skip to content

Commit

Permalink
Merge pull request #8 from xuyuon/fix-test
Browse files Browse the repository at this point in the history
Fix test
  • Loading branch information
xuyuon authored Aug 22, 2024
2 parents bc8e7dc + 2adcabf commit bb26fae
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]:
output_params = self.transform_func(transform_params)
jacobian = jax.jacfwd(self.transform_func)(transform_params)
jacobian = jnp.array(jax.tree.leaves(jacobian))
jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))
jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))))
jax.tree.map(
lambda key: x_copy.pop(key),
self.name_mapping[0],
Expand Down Expand Up @@ -126,7 +126,7 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]:
output_params = self.inverse_transform_func(transform_params)
jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params)
jacobian = jnp.array(jax.tree.leaves(jacobian))
jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))
jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))))
jax.tree.map(
lambda key: y_copy.pop(key),
self.name_mapping[1],
Expand Down
14 changes: 4 additions & 10 deletions test/unit/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@ def test_sine(self):
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))
# Check that the log_prob is correct in the support
x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None])
y = jax.vmap(p.base_prior.base_prior.transform)(x)
y = jax.vmap(p.base_prior.transform)(y)
y = jax.vmap(p.transform)(y)
assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.sin(y['x'])/2.0))
samples = samples['x']
assert jnp.allclose(log_prob, jnp.log(jnp.sin(samples)/2.0))

def test_cosine(self):
p = CosinePrior(["x"])
Expand All @@ -57,11 +54,8 @@ def test_cosine(self):
# Check that the log_prob is finite
log_prob = jax.vmap(p.log_prob)(samples)
assert jnp.all(jnp.isfinite(log_prob))
# Check that the log_prob is correct in the support
x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None])
y = jax.vmap(p.base_prior.transform)(x)
y = jax.vmap(p.transform)(y)
assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.cos(y['x'])/2.0))
samples = samples['x']
assert jnp.allclose(log_prob, jnp.log(jnp.cos(samples)/2.0))

def test_uniform_sphere(self):
p = UniformSpherePrior(["x"])
Expand Down

0 comments on commit bb26fae

Please sign in to comment.