Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added code to reverse transforms for more flexibility #132

222 changes: 73 additions & 149 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from astropy.time import Time

from jimgw.single_event.detector import GroundBased2G
from jimgw.transforms import BijectiveTransform, NtoNTransform
from jimgw.transforms import (
BijectiveTransform,
NtoNTransform,
reverse_bijective_transform,
)
from jimgw.single_event.utils import (
m1_m2_to_Mc_q,
Mc_q_to_m1_m2,
Expand All @@ -20,111 +24,56 @@


@jaxtyped(typechecker=typechecker)
class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform):
"""
Transform chirp mass and mass ratio to component masses

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
class SpinToCartesianSpinTransform(NtoNTransform):
kazewong marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
assert (
"m_1" in name_mapping[0]
and "m_2" in name_mapping[0]
and "M_c" in name_mapping[1]
and "q" in name_mapping[1]
)

def named_transform(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform):
Spin to Cartesian spin transformation
"""
Transform mass ratio to symmetric mass ratio

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""
freq_ref: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
super().__init__(name_mapping)
assert (
"m_1" in name_mapping[0]
and "m_2" in name_mapping[0]
and "M_c" in name_mapping[1]
and "eta" in name_mapping[1]
name_mapping = (
["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"],
["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"],
)
super().__init__(name_mapping)

self.freq_ref = freq_ref

def named_transform(x):
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class MassRatioToSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
assert "q" == name_mapping[0][0] and "eta" == name_mapping[1][0]

self.transform_func = lambda x: {"eta": q_to_eta(x["q"])}
self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])}


@jaxtyped(typechecker=typechecker)
class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform):
"""
Transform sky frame to detector frame sky position

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""

gmst: Float
Expand All @@ -133,10 +82,10 @@ class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform):

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
gps_time: Float,
ifos: GroundBased2G,
):
name_mapping = (["ra", "dec"], ["zenith", "azimuth"])
super().__init__(name_mapping)

self.gmst = (
Expand All @@ -146,13 +95,6 @@ def __init__(
self.rotation = euler_rotation(delta_x)
self.rotation_inv = jnp.linalg.inv(self.rotation)

assert (
"ra" in name_mapping[0]
and "dec" in name_mapping[0]
and "zenith" in name_mapping[1]
and "azimuth" in name_mapping[1]
)

def named_transform(x):
zenith, azimuth = ra_dec_to_zenith_azimuth(
x["ra"], x["dec"], self.gmst, self.rotation
Expand All @@ -169,63 +111,45 @@ def named_inverse_transform(x):

self.inverse_transform_func = named_inverse_transform

def named_m1_m2_to_Mc_q(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}

@jaxtyped(typechecker=typechecker)
class SpinToCartesianSpinTransform(NtoNTransform):
"""
Spin to Cartesian spin transformation
"""
def named_Mc_q_to_m1_m2(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

freq_ref: Float
ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"]))
ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q
ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
super().__init__(name_mapping)
def named_m1_m2_to_Mc_eta(x):
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}

self.freq_ref = freq_ref
def named_Mc_eta_to_m1_m2(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"])
return {"m_1": m1, "m_2": m2}

assert (
"theta_jn" in name_mapping[0]
and "phi_jl" in name_mapping[0]
and "theta_1" in name_mapping[0]
and "theta_2" in name_mapping[0]
and "phi_12" in name_mapping[0]
and "a_1" in name_mapping[0]
and "a_2" in name_mapping[0]
and "iota" in name_mapping[1]
and "s1_x" in name_mapping[1]
and "s1_y" in name_mapping[1]
and "s1_z" in name_mapping[1]
and "s2_x" in name_mapping[1]
and "s2_y" in name_mapping[1]
and "s2_z" in name_mapping[1]
)
ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"]))
ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta
ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2

def named_transform(x):
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}
def named_q_to_eta(x):
return {"eta": q_to_eta(x["q"])}
def named_eta_to_q(x):
return {"q": eta_to_q(x["eta"])}
MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"]))
MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta
MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q

self.transform_func = named_transform

ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassMassRatioTransform
)
ChirpMassSymmetricMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassSymmetricMassRatioTransform
)
SymmetricMassRatioToMassRatioTransform = reverse_bijective_transform(
MassRatioToSymmetricMassRatioTransform
)
16 changes: 16 additions & 0 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,19 @@ def __init__(
)
for i in range(len(name_mapping[1]))
}


def reverse_bijective_transform(
original_transform: BijectiveTransform,
) -> BijectiveTransform:

thomasckng marked this conversation as resolved.
Show resolved Hide resolved
reversed_name_mapping = (
original_transform.name_mapping[1],
original_transform.name_mapping[0],
)
reversed_transform = BijectiveTransform(name_mapping=reversed_name_mapping)
reversed_transform.transform_func = original_transform.inverse_transform_func
reversed_transform.inverse_transform_func = original_transform.transform_func
reversed_transform.__repr__ = lambda: f"Reversed{repr(original_transform)}"

return reversed_transform
2 changes: 2 additions & 0 deletions test/integration/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
outdir/
figures/
thomasckng marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 10 additions & 6 deletions test/integration/test_GW150914_D.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10"

import jax
import jax.numpy as jnp

Expand All @@ -10,6 +14,8 @@
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam
from flowMC.utils.postprocessing import plot_summary
import optax

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -62,7 +68,7 @@
)

sample_transforms = [
ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]),
ComponentMassesToChirpMassMassRatioTransform,
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
Expand All @@ -72,13 +78,13 @@
BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
]

likelihood_transforms = [
ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]),
ComponentMassesToChirpMassSymmetricMassRatioTransform,
]

likelihood = TransientLikelihoodFD(
Expand All @@ -89,7 +95,6 @@
post_trigger_duration=2,
)


mass_matrix = jnp.eye(11)
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
Expand All @@ -101,7 +106,6 @@
n_loop_training = 1
learning_rate = 1e-4


jim = Jim(
likelihood,
prior,
Expand All @@ -127,4 +131,4 @@

jim.sample(jax.random.PRNGKey(42))
jim.get_samples()
jim.print_summary()
jim.print_summary()
Loading
Loading