diff --git a/src/jimgw/single_event/prior.py b/src/jimgw/single_event/prior.py index 194262f0..e9767dc2 100644 --- a/src/jimgw/single_event/prior.py +++ b/src/jimgw/single_event/prior.py @@ -1,25 +1,91 @@ +from dataclasses import field + +import jax + from beartype import beartype as typechecker -from jaxtyping import jaxtyped +from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped +from jimgw.transforms import BijectiveTransform from jimgw.prior import ( - PowerLawPrior, + Prior, + UniformPrior, + CombinePrior, +) +from jimgw.single_event.transforms import ( + UniformComponentMassSecondaryMassQuantileToSecondaryMassTransform, +) +from jimgw.single_event.utils import ( + Mc_q_to_m1_m2, ) @jaxtyped(typechecker=typechecker) -class UniformComponentChirpMassPrior(PowerLawPrior): - """ - A prior in the range [xmin, xmax) for chirp mass which assumes the - component masses to be uniformly distributed. - - p(M_c) ~ M_c - """ - - def __repr__(self): - return f"UniformInComponentsChirpMassPrior(xmin={self.xmin}, xmax={self.xmax}, naming={self.parameter_names})" - - def __init__(self, xmin: float, xmax: float): - super().__init__(xmin, xmax, 1.0, ["M_c"]) +class ChirpMassMassRatioBoundedUniformComponentPrior(CombinePrior): + + M_c_min: float = 5.0 + M_c_max: float = 15.0 + q_min: float = 0.125 + q_max: float = 1.0 + + m_1_min: float = 6.0 + m_1_max: float = 53.0 + m_2_min: float = 3.0 + m_2_max: float = 17.0 + + base_prior: list[Prior] = field(default_factory=list) + transform: BijectiveTransform = ( + UniformComponentMassSecondaryMassQuantileToSecondaryMassTransform( + q_min=q_min, + q_max=q_max, + M_c_min=M_c_min, + M_c_max=M_c_max, + m_1_min=m_1_min, + m_1_max=m_1_max, + ) + ) + + def __init__(self, q_min: Float, q_max: Float, M_c_min: Float, M_c_max: Float): + self.parameter_names = ["m_1", "m_2"] + # calculate the respective range of m1 and m2 given the Mc-q range + self.M_c_min = M_c_min + self.M_c_max = M_c_max + self.q_min = q_min + self.q_max = q_max + self.m_1_min = Mc_q_to_m1_m2(M_c_min, q_max)[0] + self.m_1_max = Mc_q_to_m1_m2(M_c_max, q_min)[0] + self.m_2_min = Mc_q_to_m1_m2(M_c_min, q_min)[1] + self.m_2_max = Mc_q_to_m1_m2(M_c_max, q_max)[1] + # define the prior on m1 and m2_quantile + m1_prior = UniformPrior(self.m_1_min, self.m_1_max, parameter_names=["m_1"]) + m2q_prior = UniformPrior(0.0, 1.0, parameter_names=["m_2_quantile"]) + self.base_prior = [m1_prior, m2q_prior] + self.transform = ( + UniformComponentMassSecondaryMassQuantileToSecondaryMassTransform( + q_min=self.q_min, + q_max=self.q_max, + M_c_min=self.M_c_min, + M_c_max=self.M_c_max, + m_1_min=self.m_1_min, + m_1_max=self.m_1_max, + ) + ) + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + output = {} + for prior in self.base_prior: + rng_key, subkey = jax.random.split(rng_key) + output.update(prior.sample(subkey, n_samples)) + output = jax.vmap(self.transform.forward)(output) + return output + + def log_prob(self, z: dict[str, Float]) -> Float: + z, jacobian = self.transform.inverse(z) + output = jacobian + for prior in self.base_prior: + output += prior.log_prob(z) + return output # ====================== Things below may need rework ====================== diff --git a/src/jimgw/single_event/transforms.py b/src/jimgw/single_event/transforms.py index e9983f06..c7a2dab2 100644 --- a/src/jimgw/single_event/transforms.py +++ b/src/jimgw/single_event/transforms.py @@ -1,6 +1,7 @@ import jax.numpy as jnp +from jax import lax from beartype import beartype as typechecker -from jaxtyping import Float, Array, jaxtyped +from jaxtyping import Float, Array, jaxtyped, Bool from astropy.time import Time from jimgw.single_event.detector import GroundBased2G @@ -11,6 +12,7 @@ reverse_bijective_transform, ) from jimgw.single_event.utils import ( + Mc_m1_to_m2, m1_m2_to_Mc_q, Mc_q_to_m1_m2, m1_m2_to_Mc_eta, @@ -24,6 +26,133 @@ ) +@jaxtyped(typechecker=typechecker) +class UniformComponentMassSecondaryMassQuantileToSecondaryMassTransform( + ConditionalBijectiveTransform +): + """ """ + + q_min: Float + q_max: Float + M_c_min: Float + M_c_max: Float + m_1_turning_point_lower: Float + m_1_turning_point_upper: Float + regime_2_mass_ratio: Bool + + def __init__( + self, + q_min: Float, + q_max: Float, + M_c_min: Float, + M_c_max: Float, + m_1_min: Float, + m_1_max: Float, + ): + name_mapping = ( + [ + "m_2_quantile", + ], + [ + "m_2", + ], + ) + conditional_names = [ + "m_1", + ] + super().__init__(name_mapping, conditional_names) + + self.q_min = q_min + self.q_max = q_max + self.M_c_min = M_c_min + self.M_c_max = M_c_max + + assert (M_c_min >= 0.0) & (M_c_max > 0.0), "Chirp mass range has to be positive" + assert ( + M_c_max > M_c_min + ), "Upper bound on chirp mass has to be higher than the lower bound" + assert (q_min >= 0.0) & (q_max > 0.0), "Mass ratio range has to be positive" + assert ( + q_max > q_min + ), "Upper bound on mass ratio has to be higher than the lower bound" + assert q_max <= 1.0, "The mass ratio is defined to be less than 1" + + m_1_lower_bound = Mc_q_to_m1_m2(self.M_c_min, self.q_max)[0] + m_1_upper_bound = Mc_q_to_m1_m2(self.M_c_max, self.q_min)[0] + + assert ( + m_1_min >= m_1_lower_bound + ), f"Please increase the lower bound on m_1 to {m_1_lower_bound}" + assert ( + m_1_max <= m_1_upper_bound + ), f"Please decrease the upper bound on m_1 to {m_1_upper_bound}" + + m_1_turning_point_1 = Mc_q_to_m1_m2(self.M_c_min, self.q_min)[0] + m_1_turning_point_2 = Mc_q_to_m1_m2(self.M_c_max, self.q_max)[0] + + def m2_range_regime_1(m_1: Float): + lower_bound = Mc_m1_to_m2(self.M_c_min, m_1) + upper_bound = self.q_max * m_1 + return [lower_bound, upper_bound] + + def m2_range_regime_3(m_1: Float): + lower_bound = self.q_min * m_1 + upper_bound = Mc_m1_to_m2(self.M_c_max, m_1) + return [lower_bound, upper_bound] + + if m_1_turning_point_2 >= m_1_turning_point_1: + self.m_1_turning_point_lower = m_1_turning_point_1 + self.m_1_turning_point_upper = m_1_turning_point_2 + self.regime_2_mass_ratio = True + else: + self.m_1_turning_point_lower = m_1_turning_point_2 + self.m_1_turning_point_upper = m_1_turning_point_1 + self.regime_2_mass_ratio = False + + def m2_range_regime_2(m_1: Float): + return lax.cond( + self.regime_2_mass_ratio, + lambda x: [self.q_min * x, self.q_max * x], + lambda x: [ + Mc_m1_to_m2(self.M_c_min, x), + Mc_m1_to_m2(self.M_c_max, x), + ], + m_1, + ) + + def m1_to_m2_range(m_1: Float): + m2_range = lax.cond( + m_1 < self.m_1_turning_point_lower, + m2_range_regime_1, + lambda x: lax.cond( + x <= self.m_1_turning_point_upper, + m2_range_regime_2, + m2_range_regime_3, + x, + ), + m_1, + ) + return m2_range + + def named_transform(x): + m2_range = m1_to_m2_range(x["m_1"]) + m_2 = (m2_range[1] - m2_range[0]) * x["m_2_quantile"] + m2_range[0] + return { + "m_2": m_2, + } + + self.transform_func = named_transform + + def named_inverse_transform(x): + m2_range = m1_to_m2_range(x["m_1"]) + m_2_quantile = (x["m_2"] - m2_range[0]) / (m2_range[1] - m2_range[0]) + return { + "m_2_quantile": m_2_quantile, + } + + self.inverse_transform_func = named_inverse_transform + + @jaxtyped(typechecker=typechecker) class PrecessingSpinToCartesianSpinTransform(NtoNTransform): """ diff --git a/src/jimgw/single_event/utils.py b/src/jimgw/single_event/utils.py index 517b2844..8adada46 100644 --- a/src/jimgw/single_event/utils.py +++ b/src/jimgw/single_event/utils.py @@ -38,6 +38,29 @@ def inner_product( return 4.0 * jnp.real(trapezoid(integrand, dx=df)) +def Mc_m1_to_m2(Mc: Float, m1: Float) -> Float: + + a = jnp.power(m1, 3.0) + b = 0.0 + c = -jnp.power(Mc, 5.0) + d = -jnp.power(Mc, 5.0) * m1 + + f = ((3.0 * c / a) - ((b**2) / (a**2))) / 3.0 + g = (((2.0 * (b**3)) / (a**3)) - ((9.0 * b * c) / (a**2)) + (27.0 * d / a)) / 27.0 + g_squared = g**2 + f_cubed = f**3 + h = g_squared / 4.0 + f_cubed / 27.0 + + R = -(g / 2.0) + jnp.sqrt(h) + S = jnp.cbrt(R) + T = -(g / 2.0) - jnp.sqrt(h) + U = jnp.cbrt(T) + + x1 = (S + U) - (b / (3.0 * a)) + + return x1.real + + def m1_m2_to_M_q(m1: Float, m2: Float): """ Transforming the primary mass m1 and secondary mass m2 to the Total mass M diff --git a/test/integration/test_GW150914_Pv2.py b/test/integration/test_GW150914_Pv2.py index 9892058d..2b54ed3c 100644 --- a/test/integration/test_GW150914_Pv2.py +++ b/test/integration/test_GW150914_Pv2.py @@ -4,12 +4,31 @@ import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.jim import Jim +from jimgw.prior import ( + CombinePrior, + UniformPrior, + CosinePrior, + SinePrior, + PowerLawPrior, + UniformSpherePrior, +) from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD -from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.single_event.waveform import RippleIMRPhenomPv2 from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform +from jimgw.single_event.transforms import ( + SkyFrameToDetectorFrameSkyPositionTransform, + SphereSpinToCartesianSpinTransform, + MassRatioToSymmetricMassRatioTransform, + DistanceToSNRWeightedDistanceTransform, + GeocentricArrivalTimeToDetectorArrivalTimeTransform, + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, + ComponentMassesToChirpMassMassRatioTransform, +) +from jimgw.single_event.prior import ( + ChirpMassMassRatioBoundedUniformComponentPrior +) from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -34,75 +53,91 @@ for ifo in ifos: ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior(0.125, 1., parameter_names=["q"]) -theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) -phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) -theta_1_prior = SinePrior(parameter_names=["theta_1"]) -theta_2_prior = SinePrior(parameter_names=["theta_2"]) -phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) -a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) -a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) -dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) +prior = [] + +# Mass prior +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m1m2_prior = ChirpMassMassRatioBoundedUniformComponentPrior( + q_min=q_min, + q_max=q_max, + M_c_min=M_c_min, + M_c_max=M_c_max, +) + +prior = prior + [m1m2_prior] + +# Spin prior +s1_prior = UniformSpherePrior(parameter_names=["s1"]) +s2_prior = UniformSpherePrior(parameter_names=["s2"]) +iota_prior = SinePrior(parameter_names=["iota"]) + +prior = prior + [ + s1_prior, + s2_prior, + iota_prior, +] + +# Extrinsic prior +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) dec_prior = CosinePrior(parameter_names=["dec"]) -prior = CombinePrior( - [ - Mc_prior, - q_prior, - theta_jn_prior, - phi_jl_prior, - theta_1_prior, - theta_2_prior, - phi_12_prior, - a_1_prior, - a_2_prior, - dL_prior, - t_c_prior, - phase_c_prior, - psi_prior, - ra_prior, - dec_prior, - ] -) +prior = prior + [ + dL_prior, + t_c_prior, + phase_c_prior, + psi_prior, + ra_prior, + dec_prior, +] + +prior = CombinePrior(prior) + +# Defining Transforms sample_transforms = [ - BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), - BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), - BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), - BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) + ComponentMassesToChirpMassMassRatioTransform, + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = (["s1_phi"], ["s1_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["s2_phi"], ["s2_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s1_theta"], ["s1_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s2_theta"], ["s2_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s1_mag"], ["s1_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99), + BoundToUnbound(name_mapping = (["s2_mag"], ["s2_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99), + BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ - SpinToCartesianSpinTransform(freq_ref=20.0), + ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform, + SphereSpinToCartesianSpinTransform("s1"), + SphereSpinToCartesianSpinTransform("s2"), ] likelihood = TransientLikelihoodFD( ifos, - waveform=RippleIMRPhenomD(), + waveform=RippleIMRPhenomPv2(), trigger_time=gps, duration=4, post_trigger_duration=2, ) -mass_matrix = jnp.eye(15) +n_dim = sum([ind_prior.n_dim for ind_prior in prior.base_prior]) +mass_matrix = jnp.eye(n_dim) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[9, 9].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 3e-3} @@ -139,4 +174,4 @@ jim.sample(jax.random.PRNGKey(42)) jim.get_samples() -jim.print_summary() \ No newline at end of file +jim.print_summary()