diff --git a/example/GW150914_IMRPhenomD.py b/example/GW150914_IMRPhenomD.py index 7a7a37b3..66619ddc 100644 --- a/example/GW150914_IMRPhenomD.py +++ b/example/GW150914_IMRPhenomD.py @@ -117,7 +117,7 @@ n_epochs=n_epochs, learning_rate=learning_rate, n_max_examples=30000, - n_flow_samples=100000, + n_flow_sample=100000, momentum=0.9, batch_size=30000, use_global=True, diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 8e6bb0bc..2f0086ac 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -107,7 +107,7 @@ def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all(): - non_finite_index = jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1) + non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0] key, subkey = jax.random.split(key) guess = self.prior.sample(subkey, self.sampler.n_chains) diff --git a/test/integration/test_GW150914_D.py b/test/integration/test_GW150914_D.py index 5103e5d8..946b0735 100644 --- a/test/integration/test_GW150914_D.py +++ b/test/integration/test_GW150914_D.py @@ -119,7 +119,7 @@ n_epochs=n_epochs, learning_rate=learning_rate, n_max_examples=30, - n_flow_samples=100, + n_flow_sample=100, momentum=0.9, batch_size=100, use_global=True,