Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasckng committed Sep 17, 2024
1 parent 3d3b20a commit b8a0244
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 14 deletions.
33 changes: 26 additions & 7 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,38 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict):

def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])):
if initial_position.size == 0:
initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan

while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all():
non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0]
initial_position = (
jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan
)

while not jax.tree.reduce(
jnp.logical_and,
jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
).all():
non_finite_index = jnp.where(
jnp.any(
~jax.tree.reduce(
jnp.logical_and,
jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
),
axis=1,
)
)[0]

key, subkey = jax.random.split(key)
guess = self.prior.sample(subkey, self.sampler.n_chains)
for transform in self.sample_transforms:
guess = jax.vmap(transform.forward)(guess)
guess = jnp.array(jax.tree.leaves({key: guess[key] for key in self.parameter_names})).T
finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0]
guess = jnp.array(
jax.tree.leaves({key: guess[key] for key in self.parameter_names})
).T
finite_guess = jnp.where(
jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1)
)[0]
common_length = min(len(finite_guess), len(non_finite_index))
initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length])
initial_position = initial_position.at[
non_finite_index[:common_length]
].set(guess[:common_length])
self.sampler.sample(initial_position, None) # type: ignore

def maximize_likelihood(
Expand Down
31 changes: 26 additions & 5 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,34 +111,55 @@ def named_inverse_transform(x):

self.inverse_transform_func = named_inverse_transform


def named_m1_m2_to_Mc_q(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}


def named_Mc_q_to_m1_m2(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"]))

ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform(
(["m_1", "m_2"], ["M_c", "q"])
)
ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q
ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2
ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = (
named_Mc_q_to_m1_m2
)


def named_m1_m2_to_Mc_eta(x):
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}


def named_Mc_eta_to_m1_m2(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"])
return {"m_1": m1, "m_2": m2}

ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"]))
ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta
ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2

ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform(
(["m_1", "m_2"], ["M_c", "eta"])
)
ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = (
named_m1_m2_to_Mc_eta
)
ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = (
named_Mc_eta_to_m1_m2
)


def named_q_to_eta(x):
return {"eta": q_to_eta(x["q"])}


def named_eta_to_q(x):
return {"q": eta_to_q(x["eta"])}


MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"]))
MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta
MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q
Expand Down
8 changes: 6 additions & 2 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]:
output_params = self.transform_func(transform_params)
jacobian = jax.jacfwd(self.transform_func)(transform_params)
jacobian = jnp.array(jax.tree.leaves(jacobian))
jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))))
jacobian = jnp.log(
jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))
)
jax.tree.map(
lambda key: x_copy.pop(key),
self.name_mapping[0],
Expand Down Expand Up @@ -126,7 +128,9 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]:
output_params = self.inverse_transform_func(transform_params)
jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params)
jacobian = jnp.array(jax.tree.leaves(jacobian))
jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))))
jacobian = jnp.log(
jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))
)
jax.tree.map(
lambda key: y_copy.pop(key),
self.name_mapping[1],
Expand Down

0 comments on commit b8a0244

Please sign in to comment.