Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prior uniform in m1-m2 space with a bound in chirp mass and mass ratio #164

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 81 additions & 15 deletions src/jimgw/single_event/prior.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,91 @@
from dataclasses import field

import jax

from beartype import beartype as typechecker
from jaxtyping import jaxtyped
from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped

from jimgw.transforms import BijectiveTransform
from jimgw.prior import (
PowerLawPrior,
Prior,
UniformPrior,
CombinePrior,
)
from jimgw.single_event.transforms import (
UniformComponentMassSecondaryMassQuantileToSecondaryMassTransform,
)
from jimgw.single_event.utils import (
Mc_q_to_m1_m2,
)


@jaxtyped(typechecker=typechecker)
class UniformComponentChirpMassPrior(PowerLawPrior):
"""
A prior in the range [xmin, xmax) for chirp mass which assumes the
component masses to be uniformly distributed.

p(M_c) ~ M_c
"""

def __repr__(self):
return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})"

def __init__(self, xmin: float, xmax: float):
super().__init__(xmin, xmax, 1.0, ["M_c"])
class ChirpMassMassRatioBoundedUniformComponentPrior(CombinePrior):

M_c_min: float = 5.0
M_c_max: float = 15.0
q_min: float = 0.125
q_max: float = 1.0

m_1_min: float = 6.0
m_1_max: float = 53.0
m_2_min: float = 3.0
m_2_max: float = 17.0

base_prior: list[Prior] = field(default_factory=list)
transform: BijectiveTransform = (
UniformComponentMassSecondaryMassQuantileToSecondaryMassTransform(
q_min=q_min,
q_max=q_max,
M_c_min=M_c_min,
M_c_max=M_c_max,
m_1_min=m_1_min,
m_1_max=m_1_max,
)
)

def __init__(self, q_min: Float, q_max: Float, M_c_min: Float, M_c_max: Float):
self.parameter_names = ["m_1", "m_2"]
# calculate the respective range of m1 and m2 given the Mc-q range
self.M_c_min = M_c_min
self.M_c_max = M_c_max
self.q_min = q_min
self.q_max = q_max
self.m_1_min = Mc_q_to_m1_m2(M_c_min, q_max)[0]
self.m_1_max = Mc_q_to_m1_m2(M_c_max, q_min)[0]
self.m_2_min = Mc_q_to_m1_m2(M_c_min, q_min)[1]
self.m_2_max = Mc_q_to_m1_m2(M_c_max, q_max)[1]
# define the prior on m1 and m2_quantile
m1_prior = UniformPrior(self.m_1_min, self.m_1_max, parameter_names=["m_1"])
m2q_prior = UniformPrior(0.0, 1.0, parameter_names=["m_2_quantile"])
self.base_prior = [m1_prior, m2q_prior]
self.transform = (
UniformComponentMassSecondaryMassQuantileToSecondaryMassTransform(
q_min=self.q_min,
q_max=self.q_max,
M_c_min=self.M_c_min,
M_c_max=self.M_c_max,
m_1_min=self.m_1_min,
m_1_max=self.m_1_max,
)
)

def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
output = {}
for prior in self.base_prior:
rng_key, subkey = jax.random.split(rng_key)
output.update(prior.sample(subkey, n_samples))
output = jax.vmap(self.transform.forward)(output)
return output

def log_prob(self, z: dict[str, Float]) -> Float:
z, jacobian = self.transform.inverse(z)
output = jacobian
for prior in self.base_prior:
output += prior.log_prob(z)
return output


# ====================== Things below may need rework ======================
Expand Down
131 changes: 130 additions & 1 deletion src/jimgw/single_event/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax.numpy as jnp
from jax import lax
from beartype import beartype as typechecker
from jaxtyping import Float, Array, jaxtyped
from jaxtyping import Float, Array, jaxtyped, Bool
from astropy.time import Time

from jimgw.single_event.detector import GroundBased2G
Expand All @@ -11,6 +12,7 @@
reverse_bijective_transform,
)
from jimgw.single_event.utils import (
Mc_m1_to_m2,
m1_m2_to_Mc_q,
Mc_q_to_m1_m2,
m1_m2_to_Mc_eta,
Expand All @@ -24,6 +26,133 @@
)


@jaxtyped(typechecker=typechecker)
class UniformComponentMassSecondaryMassQuantileToSecondaryMassTransform(
ConditionalBijectiveTransform
):
""" """

q_min: Float
q_max: Float
M_c_min: Float
M_c_max: Float
m_1_turning_point_lower: Float
m_1_turning_point_upper: Float
regime_2_mass_ratio: Bool

def __init__(
self,
q_min: Float,
q_max: Float,
M_c_min: Float,
M_c_max: Float,
m_1_min: Float,
m_1_max: Float,
):
name_mapping = (
[
"m_2_quantile",
],
[
"m_2",
],
)
conditional_names = [
"m_1",
]
super().__init__(name_mapping, conditional_names)

self.q_min = q_min
self.q_max = q_max
self.M_c_min = M_c_min
self.M_c_max = M_c_max

assert (M_c_min >= 0.0) & (M_c_max > 0.0), "Chirp mass range has to be positive"
assert (
M_c_max > M_c_min
), "Upper bound on chirp mass has to be higher than the lower bound"
assert (q_min >= 0.0) & (q_max > 0.0), "Mass ratio range has to be positive"
assert (
q_max > q_min
), "Upper bound on mass ratio has to be higher than the lower bound"
assert q_max <= 1.0, "The mass ratio is defined to be less than 1"

m_1_lower_bound = Mc_q_to_m1_m2(self.M_c_min, self.q_max)[0]
m_1_upper_bound = Mc_q_to_m1_m2(self.M_c_max, self.q_min)[0]

assert (
m_1_min >= m_1_lower_bound
), f"Please increase the lower bound on m_1 to {m_1_lower_bound}"
assert (
m_1_max <= m_1_upper_bound
), f"Please decrease the upper bound on m_1 to {m_1_upper_bound}"

m_1_turning_point_1 = Mc_q_to_m1_m2(self.M_c_min, self.q_min)[0]
m_1_turning_point_2 = Mc_q_to_m1_m2(self.M_c_max, self.q_max)[0]

def m2_range_regime_1(m_1: Float):
lower_bound = Mc_m1_to_m2(self.M_c_min, m_1)
upper_bound = self.q_max * m_1
return [lower_bound, upper_bound]

def m2_range_regime_3(m_1: Float):
lower_bound = self.q_min * m_1
upper_bound = Mc_m1_to_m2(self.M_c_max, m_1)
return [lower_bound, upper_bound]

if m_1_turning_point_2 >= m_1_turning_point_1:
self.m_1_turning_point_lower = m_1_turning_point_1
self.m_1_turning_point_upper = m_1_turning_point_2
self.regime_2_mass_ratio = True
else:
self.m_1_turning_point_lower = m_1_turning_point_2
self.m_1_turning_point_upper = m_1_turning_point_1
self.regime_2_mass_ratio = False

def m2_range_regime_2(m_1: Float):
return lax.cond(
self.regime_2_mass_ratio,
lambda x: [self.q_min * x, self.q_max * x],
lambda x: [
Mc_m1_to_m2(self.M_c_min, x),
Mc_m1_to_m2(self.M_c_max, x),
],
m_1,
)

def m1_to_m2_range(m_1: Float):
m2_range = lax.cond(
m_1 < self.m_1_turning_point_lower,
m2_range_regime_1,
lambda x: lax.cond(
x <= self.m_1_turning_point_upper,
m2_range_regime_2,
m2_range_regime_3,
x,
),
m_1,
)
return m2_range

def named_transform(x):
m2_range = m1_to_m2_range(x["m_1"])
m_2 = (m2_range[1] - m2_range[0]) * x["m_2_quantile"] + m2_range[0]
return {
"m_2": m_2,
}

self.transform_func = named_transform

def named_inverse_transform(x):
m2_range = m1_to_m2_range(x["m_1"])
m_2_quantile = (x["m_2"] - m2_range[0]) / (m2_range[1] - m2_range[0])
return {
"m_2_quantile": m_2_quantile,
}

self.inverse_transform_func = named_inverse_transform


@jaxtyped(typechecker=typechecker)
class PrecessingSpinToCartesianSpinTransform(NtoNTransform):
"""
Expand Down
23 changes: 23 additions & 0 deletions src/jimgw/single_event/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@ def inner_product(
return 4.0 * jnp.real(trapezoid(integrand, dx=df))


def Mc_m1_to_m2(Mc: Float, m1: Float) -> Float:

a = jnp.power(m1, 3.0)
b = 0.0
c = -jnp.power(Mc, 5.0)
d = -jnp.power(Mc, 5.0) * m1

f = ((3.0 * c / a) - ((b**2) / (a**2))) / 3.0
g = (((2.0 * (b**3)) / (a**3)) - ((9.0 * b * c) / (a**2)) + (27.0 * d / a)) / 27.0
g_squared = g**2
f_cubed = f**3
h = g_squared / 4.0 + f_cubed / 27.0

R = -(g / 2.0) + jnp.sqrt(h)
S = jnp.cbrt(R)
T = -(g / 2.0) - jnp.sqrt(h)
U = jnp.cbrt(T)

x1 = (S + U) - (b / (3.0 * a))

return x1.real


def m1_m2_to_M_q(m1: Float, m2: Float):
"""
Transforming the primary mass m1 and secondary mass m2 to the Total mass M
Expand Down
Loading
Loading