diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index c3e77846..cb60e98b 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -4,7 +4,11 @@ from astropy.time import Time from jimgw.single_event.detector import GroundBased2G -from jimgw.transforms import BijectiveTransform, NtoNTransform +from jimgw.transforms import ( + BijectiveTransform, + NtoNTransform, + reverse_bijective_transform, +) from jimgw.single_event.utils import ( m1_m2_to_Mc_q, Mc_q_to_m1_m2, @@ -20,111 +24,56 @@ @jaxtyped(typechecker=typechecker) -class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform): - """ - Transform chirp mass and mass ratio to component masses - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. +class SpinToCartesianSpinTransform(NtoNTransform): """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - assert ( - "m_1" in name_mapping[0] - and "m_2" in name_mapping[0] - and "M_c" in name_mapping[1] - and "q" in name_mapping[1] - ) - - def named_transform(x): - Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) - return {"M_c": Mc, "q": q} - - self.transform_func = named_transform - - def named_inverse_transform(x): - m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) - return {"m_1": m1, "m_2": m2} - - self.inverse_transform_func = named_inverse_transform - - -@jaxtyped(typechecker=typechecker) -class ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform): + Spin to Cartesian spin transformation """ - Transform mass ratio to symmetric mass ratio - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ + freq_ref: Float def __init__( self, - name_mapping: tuple[list[str], list[str]], + freq_ref: Float, ): - super().__init__(name_mapping) - assert ( - "m_1" in name_mapping[0] - and "m_2" in name_mapping[0] - and "M_c" in name_mapping[1] - and "eta" in name_mapping[1] + name_mapping = ( + ["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], + ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"], ) + super().__init__(name_mapping) + + self.freq_ref = freq_ref def named_transform(x): - Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) - return {"M_c": Mc, "eta": eta} + iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( + x["theta_jn"], + x["phi_jl"], + x["theta_1"], + x["theta_2"], + x["phi_12"], + x["a_1"], + x["a_2"], + x["M_c"], + x["q"], + self.freq_ref, + x["phase_c"], + ) + return { + "iota": iota, + "s1_x": s1x, + "s1_y": s1y, + "s1_z": s1z, + "s2_x": s2x, + "s2_y": s2y, + "s2_z": s2z, + } self.transform_func = named_transform - def named_inverse_transform(x): - m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"]) - return {"m_1": m1, "m_2": m2} - - self.inverse_transform_func = named_inverse_transform - - -@jaxtyped(typechecker=typechecker) -class MassRatioToSymmetricMassRatioTransform(BijectiveTransform): - """ - Transform mass ratio to symmetric mass ratio - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - - """ - - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - ): - super().__init__(name_mapping) - assert "q" == name_mapping[0][0] and "eta" == name_mapping[1][0] - - self.transform_func = lambda x: {"eta": q_to_eta(x["q"])} - self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])} - @jaxtyped(typechecker=typechecker) class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): """ Transform sky frame to detector frame sky position - - Parameters - ---------- - name_mapping : tuple[list[str], list[str]] - The name mapping between the input and output dictionary. - """ gmst: Float @@ -133,10 +82,10 @@ class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform): def __init__( self, - name_mapping: tuple[list[str], list[str]], gps_time: Float, ifos: GroundBased2G, ): + name_mapping = (["ra", "dec"], ["zenith", "azimuth"]) super().__init__(name_mapping) self.gmst = ( @@ -146,13 +95,6 @@ def __init__( self.rotation = euler_rotation(delta_x) self.rotation_inv = jnp.linalg.inv(self.rotation) - assert ( - "ra" in name_mapping[0] - and "dec" in name_mapping[0] - and "zenith" in name_mapping[1] - and "azimuth" in name_mapping[1] - ) - def named_transform(x): zenith, azimuth = ra_dec_to_zenith_azimuth( x["ra"], x["dec"], self.gmst, self.rotation @@ -169,63 +111,45 @@ def named_inverse_transform(x): self.inverse_transform_func = named_inverse_transform +def named_m1_m2_to_Mc_q(x): + Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"]) + return {"M_c": Mc, "q": q} -@jaxtyped(typechecker=typechecker) -class SpinToCartesianSpinTransform(NtoNTransform): - """ - Spin to Cartesian spin transformation - """ +def named_Mc_q_to_m1_m2(x): + m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"]) + return {"m_1": m1, "m_2": m2} - freq_ref: Float +ComponentMassesToChirpMassMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "q"])) +ComponentMassesToChirpMassMassRatioTransform.transform_func = named_m1_m2_to_Mc_q +ComponentMassesToChirpMassMassRatioTransform.inverse_transform_func = named_Mc_q_to_m1_m2 - def __init__( - self, - name_mapping: tuple[list[str], list[str]], - freq_ref: Float, - ): - super().__init__(name_mapping) +def named_m1_m2_to_Mc_eta(x): + Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"]) + return {"M_c": Mc, "eta": eta} - self.freq_ref = freq_ref +def named_Mc_eta_to_m1_m2(x): + m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["eta"]) + return {"m_1": m1, "m_2": m2} - assert ( - "theta_jn" in name_mapping[0] - and "phi_jl" in name_mapping[0] - and "theta_1" in name_mapping[0] - and "theta_2" in name_mapping[0] - and "phi_12" in name_mapping[0] - and "a_1" in name_mapping[0] - and "a_2" in name_mapping[0] - and "iota" in name_mapping[1] - and "s1_x" in name_mapping[1] - and "s1_y" in name_mapping[1] - and "s1_z" in name_mapping[1] - and "s2_x" in name_mapping[1] - and "s2_y" in name_mapping[1] - and "s2_z" in name_mapping[1] - ) +ComponentMassesToChirpMassSymmetricMassRatioTransform = BijectiveTransform((["m_1", "m_2"], ["M_c", "eta"])) +ComponentMassesToChirpMassSymmetricMassRatioTransform.transform_func = named_m1_m2_to_Mc_eta +ComponentMassesToChirpMassSymmetricMassRatioTransform.inverse_transform_func = named_Mc_eta_to_m1_m2 - def named_transform(x): - iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin( - x["theta_jn"], - x["phi_jl"], - x["theta_1"], - x["theta_2"], - x["phi_12"], - x["a_1"], - x["a_2"], - x["M_c"], - x["q"], - self.freq_ref, - x["phase_c"], - ) - return { - "iota": iota, - "s1_x": s1x, - "s1_y": s1y, - "s1_z": s1z, - "s2_x": s2x, - "s2_y": s2y, - "s2_z": s2z, - } +def named_q_to_eta(x): + return {"eta": q_to_eta(x["q"])} +def named_eta_to_q(x): + return {"q": eta_to_q(x["eta"])} +MassRatioToSymmetricMassRatioTransform = BijectiveTransform((["q"], ["eta"])) +MassRatioToSymmetricMassRatioTransform.transform_func = named_q_to_eta +MassRatioToSymmetricMassRatioTransform.inverse_transform_func = named_eta_to_q - self.transform_func = named_transform + +ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform( + ComponentMassesToChirpMassMassRatioTransform +) +ChirpMassSymmetricMassRatioToComponentMassesTransform = reverse_bijective_transform( + ComponentMassesToChirpMassSymmetricMassRatioTransform +) +SymmetricMassRatioToMassRatioTransform = reverse_bijective_transform( + MassRatioToSymmetricMassRatioTransform +) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 715d49de..915802a0 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -445,3 +445,19 @@ def __init__( ) for i in range(len(name_mapping[1])) } + + +def reverse_bijective_transform( + original_transform: BijectiveTransform, +) -> BijectiveTransform: + + reversed_name_mapping = ( + original_transform.name_mapping[1], + original_transform.name_mapping[0], + ) + reversed_transform = BijectiveTransform(name_mapping=reversed_name_mapping) + reversed_transform.transform_func = original_transform.inverse_transform_func + reversed_transform.inverse_transform_func = original_transform.transform_func + reversed_transform.__repr__ = lambda: f"Reversed{repr(original_transform)}" + + return reversed_transform diff --git a/test/integration/.gitignore b/test/integration/.gitignore new file mode 100644 index 00000000..a7f7ef0e --- /dev/null +++ b/test/integration/.gitignore @@ -0,0 +1,2 @@ +outdir/ +figures/ diff --git a/test/integration/test_GW150914_D.py b/test/integration/test_GW150914_D.py index e1eee9ac..5103e5d8 100644 --- a/test/integration/test_GW150914_D.py +++ b/test/integration/test_GW150914_D.py @@ -1,3 +1,7 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" + import jax import jax.numpy as jnp @@ -10,6 +14,8 @@ from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam +from flowMC.utils.postprocessing import plot_summary +import optax jax.config.update("jax_enable_x64", True) @@ -62,7 +68,7 @@ ) sample_transforms = [ - ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + ComponentMassesToChirpMassMassRatioTransform, BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), @@ -72,13 +78,13 @@ BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ - ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), + ComponentMassesToChirpMassSymmetricMassRatioTransform, ] likelihood = TransientLikelihoodFD( @@ -89,7 +95,6 @@ post_trigger_duration=2, ) - mass_matrix = jnp.eye(11) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[5, 5].set(1e-3) @@ -101,7 +106,6 @@ n_loop_training = 1 learning_rate = 1e-4 - jim = Jim( likelihood, prior, @@ -127,4 +131,4 @@ jim.sample(jax.random.PRNGKey(42)) jim.get_samples() -jim.print_summary() +jim.print_summary() \ No newline at end of file diff --git a/test/integration/test_GW150914_D_heterodyne.py b/test/integration/test_GW150914_D_heterodyne.py index bf97efdb..cbac1788 100644 --- a/test/integration/test_GW150914_D_heterodyne.py +++ b/test/integration/test_GW150914_D_heterodyne.py @@ -62,7 +62,7 @@ ) sample_transforms = [ - ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + ComponentMassesToChirpMassMassRatioTransform, BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), @@ -72,13 +72,13 @@ BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ - ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), + ComponentMassesToChirpMassSymmetricMassRatioTransform, ] likelihood = HeterodynedTransientLikelihoodFD( @@ -132,4 +132,4 @@ jim.sample(jax.random.PRNGKey(42)) jim.get_samples() -jim.print_summary() +jim.print_summary() \ No newline at end of file diff --git a/test/integration/test_GW150914_Pv2.py b/test/integration/test_GW150914_Pv2.py index c9d83a5e..9892058d 100644 --- a/test/integration/test_GW150914_Pv2.py +++ b/test/integration/test_GW150914_Pv2.py @@ -89,8 +89,8 @@ ] likelihood_transforms = [ - SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), - MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), + SpinToCartesianSpinTransform(freq_ref=20.0), + MassRatioToSymmetricMassRatioTransform, ] likelihood = TransientLikelihoodFD( @@ -139,4 +139,4 @@ jim.sample(jax.random.PRNGKey(42)) jim.get_samples() -jim.print_summary() +jim.print_summary() \ No newline at end of file diff --git a/test/integration/test_mass_transforms.py b/test/integration/test_mass_transforms.py new file mode 100644 index 00000000..65b95244 --- /dev/null +++ b/test/integration/test_mass_transforms.py @@ -0,0 +1,106 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "3" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10" + +import numpy as np +import matplotlib.pyplot as plt +import corner +import jax +import jax.numpy as jnp +from jaxtyping import Float + +from jimgw.prior import UniformPrior, CombinePrior +from jimgw.single_event.transforms import ChirpMassMassRatioToComponentMassesTransform +from jimgw.base import LikelihoodBase +from jimgw.jim import Jim + +params = {"axes.grid": True, + "text.usetex" : True, + "font.family" : "serif", + "ytick.color" : "black", + "xtick.color" : "black", + "axes.labelcolor" : "black", + "axes.edgecolor" : "black", + "font.serif" : ["Computer Modern Serif"], + "xtick.labelsize": 16, + "ytick.labelsize": 16, + "axes.labelsize": 16, + "legend.fontsize": 16, + "legend.title_fontsize": 16, + "figure.titlesize": 16} + +plt.rcParams.update(params) + +# Improved corner kwargs +default_corner_kwargs = dict(bins=40, + smooth=1., + show_titles=False, + label_kwargs=dict(fontsize=16), + title_kwargs=dict(fontsize=16), + color="blue", + # quantiles=[], + # levels=[0.9], + plot_density=True, + plot_datapoints=False, + fill_contours=True, + max_n_ticks=4, + min_n_ticks=3, + truth_color = "red", + save=False) + +# Likelihood for this test: + +class MyLikelihood(LikelihoodBase): + """Simple toy likelihood: Gaussian centered on the true component masses""" + + true_m1: Float + true_m2: Float + + def __init__(self, + true_m1: Float, + true_m2: Float): + + self.true_m1 = true_m1 + self.true_m2 = true_m2 + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + m1, m2 = params['m_1'], params['m_2'] + m1_std = 0.1 + m2_std = 0.1 + return -0.5 * (((m1 - self.true_m1) / m1_std)**2 + ((m2 - self.true_m2) / m2_std)**2) + +# Setup +true_m1 = 1.6 +true_m2 = 1.4 +true_mc = (true_m1 * true_m2)**(3/5) / (true_m1 + true_m2)**(1/5) +true_q = true_m2 / true_m1 + +# Priors +eps = 0.5 # half of width of the chirp mass prior +mc_prior = UniformPrior(true_mc - eps, true_mc + eps, parameter_names=['M_c']) +q_prior = UniformPrior(0.125, 1.0, parameter_names=['q']) +combine_prior = CombinePrior([mc_prior, q_prior]) + +# Likelihood and transform +likelihood = MyLikelihood(true_m1, true_m2) +mass_transform = ChirpMassMassRatioToComponentMassesTransform + +print("Checking mass_transform repr") +print(repr(mass_transform)) + +# Other stuff we have to give to Jim to make it work +step = 5e-3 +local_sampler_arg = {"step_size": step * jnp.eye(2)} + +# Jim: +jim = Jim(likelihood, + combine_prior, + likelihood_transforms=[mass_transform], + n_chains = 10, + parameter_names=['M_c', 'q'], + n_loop_training=2, + n_loop_production=2, + local_sampler_arg=local_sampler_arg) + +jim.sample(jax.random.PRNGKey(0)) +jim.print_summary() \ No newline at end of file