Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added code to reverse transforms for more flexibility #132

251 changes: 107 additions & 144 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from astropy.time import Time

from jimgw.single_event.detector import GroundBased2G
from jimgw.transforms import BijectiveTransform, NtoNTransform
from jimgw.transforms import (
BijectiveTransform,
NtoNTransform,
reverse_bijective_transform,
)
from jimgw.single_event.utils import (
m1_m2_to_Mc_q,
Mc_q_to_m1_m2,
Expand All @@ -20,111 +24,56 @@


@jaxtyped(typechecker=typechecker)
class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform):
"""
Transform chirp mass and mass ratio to component masses

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
class SpinToCartesianSpinTransform(NtoNTransform):
kazewong marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
assert (
"m_1" in name_mapping[0]
and "m_2" in name_mapping[0]
and "M_c" in name_mapping[1]
and "q" in name_mapping[1]
)

def named_transform(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform):
Spin to Cartesian spin transformation
"""
Transform mass ratio to symmetric mass ratio

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""
freq_ref: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
super().__init__(name_mapping)
assert (
"m_1" in name_mapping[0]
and "m_2" in name_mapping[0]
and "M_c" in name_mapping[1]
and "eta" in name_mapping[1]
name_mapping = (
["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"],
["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"],
)
super().__init__(name_mapping)

self.freq_ref = freq_ref

def named_transform(x):
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class MassRatioToSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
assert "q" == name_mapping[0][0] and "eta" == name_mapping[1][0]

self.transform_func = lambda x: {"eta": q_to_eta(x["q"])}
self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])}


@jaxtyped(typechecker=typechecker)
class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform):
"""
Transform sky frame to detector frame sky position

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""

gmst: Float
Expand All @@ -133,10 +82,10 @@ class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform):

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
gps_time: Float,
ifos: GroundBased2G,
):
name_mapping = (["ra", "dec"], ["zenith", "azimuth"])
super().__init__(name_mapping)

self.gmst = (
Expand All @@ -146,13 +95,6 @@ def __init__(
self.rotation = euler_rotation(delta_x)
self.rotation_inv = jnp.linalg.inv(self.rotation)

assert (
"ra" in name_mapping[0]
and "dec" in name_mapping[0]
and "zenith" in name_mapping[1]
and "azimuth" in name_mapping[1]
)

def named_transform(x):
zenith, azimuth = ra_dec_to_zenith_azimuth(
x["ra"], x["dec"], self.gmst, self.rotation
Expand All @@ -171,61 +113,82 @@ def named_inverse_transform(x):


@jaxtyped(typechecker=typechecker)
class SpinToCartesianSpinTransform(NtoNTransform):
class _ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform):
"""
Spin to Cartesian spin transformation
Transform chirp mass and mass ratio to component masses. Instantiated to object below.
"""

freq_ref: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
def __init__(self):
name_mapping = (["m_1", "m_2"], ["M_c", "q"])
super().__init__(name_mapping)

self.freq_ref = freq_ref
def named_transform(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}

assert (
"theta_jn" in name_mapping[0]
and "phi_jl" in name_mapping[0]
and "theta_1" in name_mapping[0]
and "theta_2" in name_mapping[0]
and "phi_12" in name_mapping[0]
and "a_1" in name_mapping[0]
and "a_2" in name_mapping[0]
and "iota" in name_mapping[1]
and "s1_x" in name_mapping[1]
and "s1_y" in name_mapping[1]
and "s1_z" in name_mapping[1]
and "s2_x" in name_mapping[1]
and "s2_y" in name_mapping[1]
and "s2_z" in name_mapping[1]
)
self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class _ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio. Instantiated to object below.
"""

def __init__(self):
name_mapping = (["m_1", "m_2"], ["M_c", "eta"])
super().__init__(name_mapping)

def named_transform(x):
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class _MassRatioToSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio. Instantiated to object below.
"""

def __init__(self):
name_mapping = (["q"], ["eta"])
super().__init__(name_mapping)

self.transform_func = lambda x: {"eta": q_to_eta(x["q"])}
self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])}

def __repr__(self):
return f"{self.__class__.__name__[1:]}()"


ComponentMassesToChirpMassMassRatioTransform = (
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to further consolidate it? From what I see, this can be just

def named_transform_func...
def named_transform_inverse_func...
ComponentMassesToChirpMassMassRatioTransform =  BijectiveTransform(name_mapping = (["m_1", "m_2"], ["M_c", "q"]),  transform_func =  named_transform, self.inverse_transform_func = named_inverse_transform)

_ComponentMassesToChirpMassMassRatioTransform()
)
ComponentMassesToChirpMassSymmetricMassRatioTransform = (
_ComponentMassesToChirpMassSymmetricMassRatioTransform()
)
MassRatioToSymmetricMassRatioTransform = _MassRatioToSymmetricMassRatioTransform()

ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassMassRatioTransform
)
ChirpMassSymmetricMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassSymmetricMassRatioTransform
)
SymmetricMassRatioToMassRatioTransform = reverse_bijective_transform(
MassRatioToSymmetricMassRatioTransform
)
16 changes: 16 additions & 0 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,19 @@ def __init__(
)
for i in range(len(name_mapping[1]))
}


def reverse_bijective_transform(
original_transform: BijectiveTransform,
) -> BijectiveTransform:

thomasckng marked this conversation as resolved.
Show resolved Hide resolved
reversed_name_mapping = (
original_transform.name_mapping[1],
original_transform.name_mapping[0],
)
reversed_transform = BijectiveTransform(name_mapping=reversed_name_mapping)
reversed_transform.transform_func = original_transform.inverse_transform_func
reversed_transform.inverse_transform_func = original_transform.transform_func
reversed_transform.__repr__ = lambda: f"Reversed{repr(original_transform)}"

return reversed_transform
2 changes: 2 additions & 0 deletions test/integration/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
outdir/
figures/
thomasckng marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading