Skip to content

Commit

Permalink
Further consolidation as requested
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibeauWouters committed Sep 2, 2024
1 parent 98e9c0b commit ad90d09
Showing 1 changed file with 31 additions and 70 deletions.
101 changes: 31 additions & 70 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,77 +111,38 @@ def named_inverse_transform(x):

self.inverse_transform_func = named_inverse_transform

def named_m1_m2_to_Mc_q(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}

def named_Mc_q_to_m1_m2(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"]))
ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q
ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2

def named_m1_m2_to_Mc_eta(x):
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}

def named_Mc_eta_to_m1_m2(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"])
return {"m_1": m1, "m_2": m2}

ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"]))
ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta
ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2

def named_q_to_eta(x):
return {"eta": q_to_eta(x["q"])}
def named_eta_to_q(x):
return {"q": eta_to_q(x["eta"])}
MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"]))
MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta
MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q

@jaxtyped(typechecker=typechecker)
class _ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform):
"""
Transform chirp mass and mass ratio to component masses. Instantiated to object below.
"""

def __init__(self):
name_mapping = (["m_1", "m_2"], ["M_c", "q"])
super().__init__(name_mapping)

def named_transform(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class _ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio. Instantiated to object below.
"""

def __init__(self):
name_mapping = (["m_1", "m_2"], ["M_c", "eta"])
super().__init__(name_mapping)

def named_transform(x):
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class _MassRatioToSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio. Instantiated to object below.
"""

def __init__(self):
name_mapping = (["q"], ["eta"])
super().__init__(name_mapping)

self.transform_func = lambda x: {"eta": q_to_eta(x["q"])}
self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])}

def __repr__(self):
return f"{self.__class__.__name__[1:]}()"


ComponentMassesToChirpMassMassRatioTransform = (
_ComponentMassesToChirpMassMassRatioTransform()
)
ComponentMassesToChirpMassSymmetricMassRatioTransform = (
_ComponentMassesToChirpMassSymmetricMassRatioTransform()
)
MassRatioToSymmetricMassRatioTransform = _MassRatioToSymmetricMassRatioTransform()

ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassMassRatioTransform
Expand Down

0 comments on commit ad90d09

Please sign in to comment.