Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prior uniform in m1-m2 space with a bound in chirp mass and mass ratio #164

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 128 additions & 1 deletion src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax.numpy as jnp
from jax import lax
from beartype import beartype as typechecker
from jaxtyping import Float, Array, jaxtyped
from jaxtyping import Float, Array, jaxtyped, Bool
from astropy.time import Time

from jimgw.single_event.detector import GroundBased2G
Expand All @@ -11,6 +12,7 @@
reverse_bijective_transform,
)
from jimgw.single_event.utils import (
Mc_m1_to_m2,
m1_m2_to_Mc_q,
Mc_q_to_m1_m2,
m1_m2_to_Mc_eta,
Expand All @@ -24,6 +26,131 @@
)


@jaxtyped(typechecker=typechecker)
class UniformInComponentMassSecondaryMassTransform(ConditionalBijectiveTransform):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a transform or a prior? Also, is this used in the testing script?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a transform

""" """

q_min: Float
q_max: Float
M_c_min: Float
M_c_max: Float
m_1_turning_point_lower: Float
m_1_turning_point_upper: Float
regime_2_mass_ratio: Bool

def __init__(
self,
q_min: Float,
q_max: Float,
M_c_min: Float,
M_c_max: Float,
m_1_min: Float,
m_1_max: Float,
):
name_mapping = (
[
"m_2_quantile",
],
[
"m_2",
],
)
conditional_names = [
"m_1",
]
super().__init__(name_mapping, conditional_names)

self.q_min = q_min
self.q_max = q_max
self.M_c_min = M_c_min
self.M_c_max = M_c_max

assert (M_c_min >= 0.0) & (M_c_max > 0.0), "Chirp mass range has to be positive"
assert (
M_c_max > M_c_min
), "Upper bound on chirp mass has to be higher than the lower bound"
assert (q_min >= 0.0) & (q_max > 0.0), "Mass ratio range has to be positive"
assert (
q_max > q_min
), "Upper bound on mass ratio has to be higher than the lower bound"
assert q_max <= 1.0, "The mass ratio is defined to be less than 1"

m_1_lower_bound = Mc_q_to_m1_m2(self.M_c_min, self.q_max)[0]
m_1_upper_bound = Mc_q_to_m1_m2(self.M_c_max, self.q_min)[0]

assert (
m_1_min >= m_1_lower_bound
), f"Please increase the lower bound on m_1 to {m_1_lower_bound}"
assert (
m_1_max <= m_1_upper_bound
), f"Please decrease the upper bound on m_1 to {m_1_upper_bound}"

m_1_turning_point_1 = Mc_q_to_m1_m2(self.M_c_min, self.q_min)[0]
m_1_turning_point_2 = Mc_q_to_m1_m2(self.M_c_max, self.q_max)[0]

def m2_range_regime_1(m_1: Float):
lower_bound = Mc_m1_to_m2(self.M_c_min, m_1)[0].real
upper_bound = self.q_max * m_1
return [lower_bound, upper_bound]

def m2_range_regime_3(m_1: Float):
lower_bound = self.q_min * m_1
upper_bound = Mc_m1_to_m2(self.M_c_max, m_1)[0].real
return [lower_bound, upper_bound]

if m_1_turning_point_2 >= m_1_turning_point_1:
self.m_1_turning_point_lower = m_1_turning_point_1
self.m_1_turning_point_upper = m_1_turning_point_2
self.regime_2_mass_ratio = True
else:
self.m_1_turning_point_lower = m_1_turning_point_2
self.m_1_turning_point_upper = m_1_turning_point_1
self.regime_2_mass_ratio = False

def m2_range_regime_2(m_1: Float):
return lax.cond(
self.regime_2_mass_ratio,
lambda x: [self.q_min * x, self.q_max * x],
lambda x: [
Mc_m1_to_m2(self.M_c_min, x)[0].real,
Mc_m1_to_m2(self.M_c_max, x)[0].real,
],
m_1,
)

def m1_to_m2_range(m_1: Float):
m2_range = lax.cond(
m_1 < self.m_1_turning_point_lower,
m2_range_regime_1,
lambda x: lax.cond(
x <= self.m_1_turning_point_upper,
m2_range_regime_2,
m2_range_regime_3,
x,
),
m_1,
)
return m2_range

def named_transform(x):
m2_range = m1_to_m2_range(x["m_1"])
m_2 = (m2_range[1] - m2_range[0]) * x["m_2_quantile"] + m2_range[0]
return {
"m_2": m_2,
}

self.transform_func = named_transform

def named_inverse_transform(x):
m2_range = m1_to_m2_range(x["m_1"])
m_2_quantile = (x["m_2"] - m2_range[0]) / (m2_range[1] - m2_range[0])
return {
"m_2_quantile": m_2_quantile,
}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class PrecessingSpinToCartesianSpinTransform(NtoNTransform):
"""
Expand Down
25 changes: 25 additions & 0 deletions src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,31 @@ def inner_product(
return 4.0 * jnp.real(trapezoid(integrand, dx=df))


def Mc_m1_to_m2(Mc: Float, m1: Float):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return typing hint is missing here


a = jnp.power(m1, 3.0)
b = 0.0
c = -jnp.power(Mc, 5.0)
d = -jnp.power(Mc, 5.0) * m1

f = ((3.0 * c / a) - ((b**2) / (a**2))) / 3.0
g = (((2.0 * (b**3)) / (a**3)) - ((9.0 * b * c) / (a**2)) + (27.0 * d / a)) / 27.0
g_squared = g**2
f_cubed = f**3
h = g_squared / 4.0 + f_cubed / 27.0

R = -(g / 2.0) + jnp.sqrt(h)
S = jnp.cbrt(R)
T = -(g / 2.0) - jnp.sqrt(h)
U = jnp.cbrt(T)

x1 = (S + U) - (b / (3.0 * a))
x2 = -(S + U) / 2 - (b / (3.0 * a)) + (S - U) * jnp.sqrt(3.0) * 0.5j
x3 = -(S + U) / 2 - (b / (3.0 * a)) - (S - U) * jnp.sqrt(3.0) * 0.5j

return jnp.array([x1, x2, x3])


def m1_m2_to_M_q(m1: Float, m2: Float):
"""
Transforming the primary mass m1 and secondary mass m2 to the Total mass M
Expand Down
125 changes: 75 additions & 50 deletions test/integration/test_GW150914_Pv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,28 @@
import jax.numpy as jnp

from jimgw.jim import Jim
from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior
from jimgw.jim import Jim
from jimgw.prior import (
CombinePrior,
UniformPrior,
CosinePrior,
SinePrior,
PowerLawPrior,
UniformSpherePrior,
)
from jimgw.single_event.detector import H1, L1
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.single_event.waveform import RippleIMRPhenomPv2
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform
from jimgw.single_event.transforms import (
SkyFrameToDetectorFrameSkyPositionTransform,
SphereSpinToCartesianSpinTransform,
MassRatioToSymmetricMassRatioTransform,
DistanceToSNRWeightedDistanceTransform,
GeocentricArrivalTimeToDetectorArrivalTimeTransform,
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform,
)
from jimgw.single_event.utils import Mc_q_to_m1_m2
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)
Expand All @@ -34,68 +50,77 @@
for ifo in ifos:
ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2)

Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"])
q_prior = UniformPrior(0.125, 1., parameter_names=["q"])
theta_jn_prior = SinePrior(parameter_names=["theta_jn"])
phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"])
theta_1_prior = SinePrior(parameter_names=["theta_1"])
theta_2_prior = SinePrior(parameter_names=["theta_2"])
phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"])
a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"])
a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"])
dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"])
prior = []

# Mass prior
M_c_min, M_c_max = 10.0, 80.0
q_min, q_max = 0.125, 1.0
Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"])
q_prior = UniformPrior(q_min, q_max, parameter_names=["q"])

prior = prior + [Mc_prior, q_prior]

# Spin prior
s1_prior = UniformSpherePrior(parameter_names=["s1"])
s2_prior = UniformSpherePrior(parameter_names=["s2"])
iota_prior = SinePrior(parameter_names=["iota"])

prior = prior + [
s1_prior,
s2_prior,
iota_prior,
]

# Extrinsic prior
dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"])
t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"])
phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"])
psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"])
ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"])
dec_prior = CosinePrior(parameter_names=["dec"])

prior = CombinePrior(
[
Mc_prior,
q_prior,
theta_jn_prior,
phi_jl_prior,
theta_1_prior,
theta_2_prior,
phi_12_prior,
a_1_prior,
a_2_prior,
dL_prior,
t_c_prior,
phase_c_prior,
psi_prior,
ra_prior,
dec_prior,
]
)
prior = prior + [
dL_prior,
t_c_prior,
phase_c_prior,
psi_prior,
ra_prior,
dec_prior,
]

prior = CombinePrior(prior)

# Defining Transforms

sample_transforms = [
BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0),
BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.),
BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0),
BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0),
BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05),
BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2)
DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max),
BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max),
BoundToUnbound(name_mapping = (["s1_phi"], ["s1_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["s2_phi"], ["s2_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s1_theta"], ["s1_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s2_theta"], ["s2_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["s1_mag"], ["s1_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping = (["s2_mag"], ["s2_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99),
BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi),
BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi),
]

likelihood_transforms = [
SpinToCartesianSpinTransform(freq_ref=20.0),
MassRatioToSymmetricMassRatioTransform,
SphereSpinToCartesianSpinTransform("s1"),
SphereSpinToCartesianSpinTransform("s2"),
]

likelihood = TransientLikelihoodFD(
ifos,
waveform=RippleIMRPhenomD(),
waveform=RippleIMRPhenomPv2(),
trigger_time=gps,
duration=4,
post_trigger_duration=2,
Expand Down Expand Up @@ -139,4 +164,4 @@

jim.sample(jax.random.PRNGKey(42))
jim.get_samples()
jim.print_summary()
jim.print_summary()
Loading