Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added code to reverse transforms for more flexibility #132

202 changes: 67 additions & 135 deletions src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from astropy.time import Time

from jimgw.single_event.detector import GroundBased2G
from jimgw.transforms import BijectiveTransform, NtoNTransform
from jimgw.transforms import (
BijectiveTransform,
NtoNTransform,
reverse_bijective_transform,
create_bijective_transform,
)
from jimgw.single_event.utils import (
m1_m2_to_Mc_q,
Mc_q_to_m1_m2,
Expand All @@ -20,100 +25,65 @@


@jaxtyped(typechecker=typechecker)
class ComponentMassesToChirpMassMassRatioTransform(BijectiveTransform):
class SpinToCartesianSpinTransform(NtoNTransform):
kazewong marked this conversation as resolved.
Show resolved Hide resolved
"""
Transform chirp mass and mass ratio to component masses

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
Spin to Cartesian spin transformation
"""

freq_ref: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
super().__init__(name_mapping)
assert (
"m_1" in name_mapping[0]
and "m_2" in name_mapping[0]
and "M_c" in name_mapping[1]
and "q" in name_mapping[1]
)

def named_transform(x):
Mc, q = m1_m2_to_Mc_q(x["m_1"], x["m_2"])
return {"M_c": Mc, "q": q}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_q_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class ComponentMassesToChirpMassSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""
self.freq_ref = freq_ref

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
assert (
"m_1" in name_mapping[0]
and "m_2" in name_mapping[0]
and "M_c" in name_mapping[1]
and "eta" in name_mapping[1]
"theta_jn" in name_mapping[0]
and "phi_jl" in name_mapping[0]
and "theta_1" in name_mapping[0]
and "theta_2" in name_mapping[0]
and "phi_12" in name_mapping[0]
and "a_1" in name_mapping[0]
and "a_2" in name_mapping[0]
and "iota" in name_mapping[1]
and "s1_x" in name_mapping[1]
and "s1_y" in name_mapping[1]
and "s1_z" in name_mapping[1]
and "s2_x" in name_mapping[1]
and "s2_y" in name_mapping[1]
and "s2_z" in name_mapping[1]
)

def named_transform(x):
Mc, eta = m1_m2_to_Mc_eta(x["m_1"], x["m_2"])
return {"M_c": Mc, "eta": eta}
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}

self.transform_func = named_transform

def named_inverse_transform(x):
m1, m2 = Mc_eta_to_m1_m2(x["M_c"], x["q"])
return {"m_1": m1, "m_2": m2}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class MassRatioToSymmetricMassRatioTransform(BijectiveTransform):
"""
Transform mass ratio to symmetric mass ratio

Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.

"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
assert "q" == name_mapping[0][0] and "eta" == name_mapping[1][0]

self.transform_func = lambda x: {"eta": q_to_eta(x["q"])}
self.inverse_transform_func = lambda x: {"q": eta_to_q(x["eta"])}


@jaxtyped(typechecker=typechecker)
class SkyFrameToDetectorFrameSkyPositionTransform(BijectiveTransform):
Expand Down Expand Up @@ -170,62 +140,24 @@ def named_inverse_transform(x):
self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class SpinToCartesianSpinTransform(NtoNTransform):
"""
Spin to Cartesian spin transformation
"""

freq_ref: Float

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
freq_ref: Float,
):
super().__init__(name_mapping)

self.freq_ref = freq_ref

assert (
"theta_jn" in name_mapping[0]
and "phi_jl" in name_mapping[0]
and "theta_1" in name_mapping[0]
and "theta_2" in name_mapping[0]
and "phi_12" in name_mapping[0]
and "a_1" in name_mapping[0]
and "a_2" in name_mapping[0]
and "iota" in name_mapping[1]
and "s1_x" in name_mapping[1]
and "s1_y" in name_mapping[1]
and "s1_z" in name_mapping[1]
and "s2_x" in name_mapping[1]
and "s2_y" in name_mapping[1]
and "s2_z" in name_mapping[1]
)
# Pre-made bijective transforms:
ComponentMassesToChirpMassMassRatioTransform = create_bijective_transform(
(["m_1", "m_2"], ["M_c", "q"]), m1_m2_to_Mc_q, Mc_q_to_m1_m2
)
ChirpMassMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassMassRatioTransform
)

def named_transform(x):
iota, s1x, s1y, s1z, s2x, s2y, s2z = spin_to_cartesian_spin(
x["theta_jn"],
x["phi_jl"],
x["theta_1"],
x["theta_2"],
x["phi_12"],
x["a_1"],
x["a_2"],
x["M_c"],
x["q"],
self.freq_ref,
x["phase_c"],
)
return {
"iota": iota,
"s1_x": s1x,
"s1_y": s1y,
"s1_z": s1z,
"s2_x": s2x,
"s2_y": s2y,
"s2_z": s2z,
}
ComponentMassesToChirpMassSymmetricMassRatioTransform = create_bijective_transform(
(["m_1", "m_2"], ["M_c", "eta"]), m1_m2_to_Mc_eta, Mc_eta_to_m1_m2
)
ChirpMassSymmetricMassRatioToComponentMassesTransform = reverse_bijective_transform(
ComponentMassesToChirpMassSymmetricMassRatioTransform
)

self.transform_func = named_transform
MassRatioToSymmetricMassRatioTransform = create_bijective_transform(
(["q"], ["eta"]), q_to_eta, eta_to_q
)
SymmetricMassRatioToMassRatioTransform = reverse_bijective_transform(
MassRatioToSymmetricMassRatioTransform
)
51 changes: 51 additions & 0 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,54 @@ def __init__(
)
for i in range(len(name_mapping[1]))
}


def create_bijective_transform(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not a good practice in general. Here transform_func_array and inverse_transform_func_array can be potentially harmful functions, and in the long run, this flexibility will create trouble in reproducibility. Any bijective transform should be explicitly declared with the class interface.

name_mapping: tuple[list[str], list[str]],
transform_func_array: Callable[[Float], Float],
inverse_transform_func_array: Callable[[Float], Float],
) -> BijectiveTransform:
"""
Utility function to create a BijectiveTransform object given a name_mapping and the forward and backward transform functions which take arrays as input, e.g. coming from the utils module.

Args:
name_mapping (tuple[list[str], list[str]]): The name_mapping to be used in the named transforms.
transform_func_array (Callable[[Float], Float]): The forward function method taking an array as input.
inverse_transform_func_array (Callable[[Float], Float]): The inverse function method taking an array as input.

Returns:
BijectiveTransform: The BijectiveTransform object.
"""

def named_transform_func(x_named: dict[str, Float]) -> dict[str, Float]:
x_array = jnp.array([x_named[key] for key in name_mapping[0]])
y_array = transform_func_array(*x_array)
y_named = dict(zip(name_mapping[1], jnp.atleast_1d(y_array)))
return y_named

def named_inverse_transform_func(y_named: dict[str, Float]) -> dict[str, Float]:
y_array = jnp.array([y_named[key] for key in name_mapping[1]])
x_array = inverse_transform_func_array(*y_array)
x_named = dict(zip(name_mapping[0], jnp.atleast_1d(x_array)))
return x_named

new_transform = BijectiveTransform(name_mapping)
new_transform.transform_func = named_transform_func
new_transform.inverse_transform_func = named_inverse_transform_func

return new_transform


def reverse_bijective_transform(
original_transform: BijectiveTransform,
) -> BijectiveTransform:

thomasckng marked this conversation as resolved.
Show resolved Hide resolved
reversed_name_mapping = (
original_transform.name_mapping[1],
original_transform.name_mapping[0],
)
reversed_transform = BijectiveTransform(name_mapping=reversed_name_mapping)
reversed_transform.transform_func = original_transform.inverse_transform_func
reversed_transform.inverse_transform_func = original_transform.transform_func

return reversed_transform
2 changes: 2 additions & 0 deletions test/integration/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
outdir/
figures/
thomasckng marked this conversation as resolved.
Show resolved Hide resolved
37 changes: 24 additions & 13 deletions test/integration/test_GW150914_D.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.10"

import jax
import jax.numpy as jnp

Expand All @@ -10,6 +14,8 @@
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam
from flowMC.utils.postprocessing import plot_summary
import optax

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -62,7 +68,7 @@
)

sample_transforms = [
ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]),
ComponentMassesToChirpMassMassRatioTransform,
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0),
Expand All @@ -78,7 +84,7 @@
]

likelihood_transforms = [
ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]),
ComponentMassesToChirpMassSymmetricMassRatioTransform,
]

likelihood = TransientLikelihoodFD(
Expand All @@ -97,34 +103,39 @@

Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1)

n_epochs = 2
n_loop_training = 1
learning_rate = 1e-4

n_epochs = 20
n_loop_training = 10
total_epochs = n_epochs * n_loop_training
start = total_epochs//10
learning_rate = optax.polynomial_schedule(
1e-3, 5e-4, 4.0, total_epochs - start, transition_begin=start
)

jim = Jim(
likelihood,
prior,
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
n_loop_training=n_loop_training,
n_loop_production=1,
n_local_steps=5,
n_global_steps=5,
n_chains=4,
n_loop_production=4,
n_local_steps=10,
n_global_steps=1000,
n_chains=500,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30,
n_flow_samples=100,
n_max_examples=30000,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test are supposed to be light weight. Reminder to revert these numbers

n_flow_samples=100000,
momentum=0.9,
batch_size=100,
batch_size=30000,
use_global=True,
train_thinning=1,
output_thinning=1,
output_thinning=10,
local_sampler_arg=local_sampler_arg,
strategies=[Adam_optimizer, "default"],
)

jim.sample(jax.random.PRNGKey(42))
jim.get_samples()
jim.print_summary()
plot_summary(jim.sampler)
Loading
Loading