Skip to content

Commit

Permalink
minor bug fix on initial sample noon_finite_guess
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed Sep 16, 2024
1 parent 24235c2 commit aef6124
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion example/GW150914_IMRPhenomD.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30000,
n_flow_samples=100000,
n_flow_sample=100000,
momentum=0.9,
batch_size=30000,
use_global=True,
Expand Down
2 changes: 1 addition & 1 deletion src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])):
initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan

while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all():
non_finite_index = jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1)
non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0]

key, subkey = jax.random.split(key)
guess = self.prior.sample(subkey, self.sampler.n_chains)
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_GW150914_D.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30,
n_flow_samples=100,
n_flow_sample=100,
momentum=0.9,
batch_size=100,
use_global=True,
Expand Down

0 comments on commit aef6124

Please sign in to comment.