diff --git a/src/jimgw/jim.py b/src/jimgw/jim.py index 2f0086ac..243d6a8f 100644 --- a/src/jimgw/jim.py +++ b/src/jimgw/jim.py @@ -104,19 +104,38 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])): if initial_position.size == 0: - initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan - - while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all(): - non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0] + initial_position = ( + jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan + ) + + while not jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ).all(): + non_finite_index = jnp.where( + jnp.any( + ~jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ), + axis=1, + ) + )[0] key, subkey = jax.random.split(key) guess = self.prior.sample(subkey, self.sampler.n_chains) for transform in self.sample_transforms: guess = jax.vmap(transform.forward)(guess) - guess = jnp.array(jax.tree.leaves({key: guess[key] for key in self.parameter_names})).T - finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0] + guess = jnp.array( + jax.tree.leaves({key: guess[key] for key in self.parameter_names}) + ).T + finite_guess = jnp.where( + jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) + )[0] common_length = min(len(finite_guess), len(non_finite_index)) - initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length]) + initial_position = initial_position.at[ + non_finite_index[:common_length] + ].set(guess[:common_length]) self.sampler.sample(initial_position, None) # type: ignore def maximize_likelihood( diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index 0cfae761..8ac4b126 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -111,34 +111,55 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform + def named_m1_m2_to_Mc_q(x): Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) return {"M_c": Mc, "q": q} + def named_Mc_q_to_m1_m2(x): m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) return {"m_1": m1, "m_2": m2} -ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"])) + +ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform( + (["m_1", "m_2"], ["M_c", "q"]) +) ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q -ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2 +ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = ( + named_Mc_q_to_m1_m2 +) + def named_m1_m2_to_Mc_eta(x): Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) return {"M_c": Mc, "eta": eta} + def named_Mc_eta_to_m1_m2(x): m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"]) return {"m_1": m1, "m_2": m2} -ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"])) -ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta -ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2 + +ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform( + (["m_1", "m_2"], ["M_c", "eta"]) +) +ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = ( + named_m1_m2_to_Mc_eta +) +ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = ( + named_Mc_eta_to_m1_m2 +) + def named_q_to_eta(x): return {"eta": q_to_eta(x["q"])} + + def named_eta_to_q(x): return {"q": eta_to_q(x["eta"])} + + MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"])) MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 3ad51e62..7dfa23df 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -89,7 +89,9 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -126,7 +128,9 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) + jacobian = jnp.log( + jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + ) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1],