From 87e0824aa8ada3423f65f71afa70fc57323fe83a Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Wed, 16 Oct 2024 17:03:16 -0400 Subject: [PATCH 1/4] Fixing GW150914 Pv2 test --- test/integration/test_GW150914_Pv2.py | 137 ++++++++++++++++---------- 1 file changed, 86 insertions(+), 51 deletions(-) diff --git a/test/integration/test_GW150914_Pv2.py b/test/integration/test_GW150914_Pv2.py index 9892058d..2b54ed3c 100644 --- a/test/integration/test_GW150914_Pv2.py +++ b/test/integration/test_GW150914_Pv2.py @@ -4,12 +4,31 @@ import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior +from jimgw.jim import Jim +from jimgw.prior import ( + CombinePrior, + UniformPrior, + CosinePrior, + SinePrior, + PowerLawPrior, + UniformSpherePrior, +) from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD -from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.single_event.waveform import RippleIMRPhenomPv2 from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform +from jimgw.single_event.transforms import ( + SkyFrameToDetectorFrameSkyPositionTransform, + SphereSpinToCartesianSpinTransform, + MassRatioToSymmetricMassRatioTransform, + DistanceToSNRWeightedDistanceTransform, + GeocentricArrivalTimeToDetectorArrivalTimeTransform, + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, + ComponentMassesToChirpMassMassRatioTransform, +) +from jimgw.single_event.prior import ( + ChirpMassMassRatioBoundedUniformComponentPrior +) from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -34,75 +53,91 @@ for ifo in ifos: ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior(0.125, 1., parameter_names=["q"]) -theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) -phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) -theta_1_prior = SinePrior(parameter_names=["theta_1"]) -theta_2_prior = SinePrior(parameter_names=["theta_2"]) -phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) -a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) -a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) -dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) +prior = [] + +# Mass prior +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m1m2_prior = ChirpMassMassRatioBoundedUniformComponentPrior( + q_min=q_min, + q_max=q_max, + M_c_min=M_c_min, + M_c_max=M_c_max, +) + +prior = prior + [m1m2_prior] + +# Spin prior +s1_prior = UniformSpherePrior(parameter_names=["s1"]) +s2_prior = UniformSpherePrior(parameter_names=["s2"]) +iota_prior = SinePrior(parameter_names=["iota"]) + +prior = prior + [ + s1_prior, + s2_prior, + iota_prior, +] + +# Extrinsic prior +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) dec_prior = CosinePrior(parameter_names=["dec"]) -prior = CombinePrior( - [ - Mc_prior, - q_prior, - theta_jn_prior, - phi_jl_prior, - theta_1_prior, - theta_2_prior, - phi_12_prior, - a_1_prior, - a_2_prior, - dL_prior, - t_c_prior, - phase_c_prior, - psi_prior, - ra_prior, - dec_prior, - ] -) +prior = prior + [ + dL_prior, + t_c_prior, + phase_c_prior, + psi_prior, + ra_prior, + dec_prior, +] + +prior = CombinePrior(prior) + +# Defining Transforms sample_transforms = [ - BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), - BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), - BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), - BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), - BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) + ComponentMassesToChirpMassMassRatioTransform, + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = (["s1_phi"], ["s1_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["s2_phi"], ["s2_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s1_theta"], ["s1_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s2_theta"], ["s2_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["s1_mag"], ["s1_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99), + BoundToUnbound(name_mapping = (["s2_mag"], ["s2_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99), + BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ - SpinToCartesianSpinTransform(freq_ref=20.0), + ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform, + SphereSpinToCartesianSpinTransform("s1"), + SphereSpinToCartesianSpinTransform("s2"), ] likelihood = TransientLikelihoodFD( ifos, - waveform=RippleIMRPhenomD(), + waveform=RippleIMRPhenomPv2(), trigger_time=gps, duration=4, post_trigger_duration=2, ) -mass_matrix = jnp.eye(15) +n_dim = sum([ind_prior.n_dim for ind_prior in prior.base_prior]) +mass_matrix = jnp.eye(n_dim) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[9, 9].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 3e-3} @@ -139,4 +174,4 @@ jim.sample(jax.random.PRNGKey(42)) jim.get_samples() -jim.print_summary() \ No newline at end of file +jim.print_summary() From 54cb2e12770aa17e0eb03c70848df3e17dc8493d Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Wed, 16 Oct 2024 17:06:42 -0400 Subject: [PATCH 2/4] Update GW150914 PhenomD example --- example/GW150914_IMRPhenomD.py | 186 +++++++++++++++------------------ 1 file changed, 83 insertions(+), 103 deletions(-) diff --git a/example/GW150914_IMRPhenomD.py b/example/GW150914_IMRPhenomD.py index 23d08c7b..9a9b9e09 100644 --- a/example/GW150914_IMRPhenomD.py +++ b/example/GW150914_IMRPhenomD.py @@ -1,6 +1,9 @@ +import time + import jax import jax.numpy as jnp +from jimgw.jim import Jim from jimgw.jim import Jim from jimgw.prior import ( CombinePrior, @@ -8,15 +11,19 @@ CosinePrior, SinePrior, PowerLawPrior, + UniformSpherePrior, ) from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD from jimgw.transforms import BoundToUnbound from jimgw.single_event.transforms import ( - ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, - ComponentMassesToChirpMassMassRatioTransform, + SphereSpinToCartesianSpinTransform, + MassRatioToSymmetricMassRatioTransform, + DistanceToSNRWeightedDistanceTransform, + GeocentricArrivalTimeToDetectorArrivalTimeTransform, + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, ) from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam @@ -27,136 +34,110 @@ ########## First we grab data ############# ########################################### +total_time_start = time.time() + # first, fetch a 4s segment centered on GW150914 gps = 1126259462.4 -duration = 4 -post_trigger_duration = 2 -start_pad = duration - post_trigger_duration -end_pad = post_trigger_duration +start = gps - 2 +end = gps + 2 fmin = 20.0 fmax = 1024.0 ifos = [H1, L1] -for ifo in ifos: - ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +waveform = RippleIMRPhenomD(f_ref=20) + +########################################### +########## Set up priors ################## +########################################### + +prior = [] + +# Mass prior M_c_min, M_c_max = 10.0, 80.0 -eta_min, eta_max = 0.2, 0.25 -# m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) -# m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +q_min, q_max = 0.125, 1.0 Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"]) -eta_prior = UniformPrior(eta_min, eta_max, parameter_names=["eta"]) -s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) -s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +q_prior = UniformPrior(q_min, q_max, parameter_names=["q"]) + +prior = prior + [Mc_prior, q_prior] + +# Spin prior +s1_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +iota_prior = SinePrior(parameter_names=["iota"]) + +prior = prior + [ + s1_prior, + s2_prior, + iota_prior, +] + +# Extrinsic prior dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) -iota_prior = SinePrior(parameter_names=["iota"]) psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) dec_prior = CosinePrior(parameter_names=["dec"]) -prior = CombinePrior( - [ - Mc_prior, - eta_prior, - s1z_prior, - s2z_prior, - dL_prior, - t_c_prior, - phase_c_prior, - iota_prior, - psi_prior, - ra_prior, - dec_prior, - ] -) +prior = prior + [ + dL_prior, + t_c_prior, + phase_c_prior, + psi_prior, + ra_prior, + dec_prior, +] + +prior = CombinePrior(prior) + +# Defining Transforms sample_transforms = [ - # ComponentMassesToChirpMassMassRatioTransform, - BoundToUnbound( - name_mapping=(["M_c"], ["M_c_unbounded"]), - original_lower_bound=M_c_min, - original_upper_bound=M_c_max, - ), - BoundToUnbound( - name_mapping=(["eta"], ["eta_unbounded"]), - original_lower_bound=eta_min, - original_upper_bound=eta_max, - ), - BoundToUnbound( - name_mapping=(["s1_z"], ["s1_z_unbounded"]), - original_lower_bound=-1.0, - original_upper_bound=1.0, - ), - BoundToUnbound( - name_mapping=(["s2_z"], ["s2_z_unbounded"]), - original_lower_bound=-1.0, - original_upper_bound=1.0, - ), - BoundToUnbound( - name_mapping=(["d_L"], ["d_L_unbounded"]), - original_lower_bound=1.0, - original_upper_bound=2000.0, - ), - BoundToUnbound( - name_mapping=(["t_c"], ["t_c_unbounded"]), - original_lower_bound=-0.05, - original_upper_bound=0.05, - ), - BoundToUnbound( - name_mapping=(["phase_c"], ["phase_c_unbounded"]), - original_lower_bound=0.0, - original_upper_bound=2 * jnp.pi, - ), - BoundToUnbound( - name_mapping=(["iota"], ["iota_unbounded"]), - original_lower_bound=0.0, - original_upper_bound=jnp.pi, - ), - BoundToUnbound( - name_mapping=(["psi"], ["psi_unbounded"]), - original_lower_bound=0.0, - original_upper_bound=jnp.pi, - ), + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), - BoundToUnbound( - name_mapping=(["zenith"], ["zenith_unbounded"]), - original_lower_bound=0.0, - original_upper_bound=jnp.pi, - ), - BoundToUnbound( - name_mapping=(["azimuth"], ["azimuth_unbounded"]), - original_lower_bound=0.0, - original_upper_bound=2 * jnp.pi, - ), + BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = (["s1_z"], ["s1_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["s2_z"], ["s2_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ - # ComponentMassesToChirpMassSymmetricMassRatioTransform, + MassRatioToSymmetricMassRatioTransform, ] + likelihood = TransientLikelihoodFD( - ifos, - waveform=RippleIMRPhenomD(), - trigger_time=gps, - duration=4, - post_trigger_duration=2, + [H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2 ) -mass_matrix = jnp.eye(11) -mass_matrix = mass_matrix.at[1, 1].set(1e-3) -mass_matrix = mass_matrix.at[5, 5].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 3e-3} +mass_matrix = jnp.eye(prior.n_dim) +# mass_matrix = mass_matrix.at[1, 1].set(1e-3) +# mass_matrix = mass_matrix.at[9, 9].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 1e-3} Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) -n_epochs = 30 -n_loop_training = 20 -learning_rate = 1e-4 +import optax +n_epochs = 20 +n_loop_training = 100 +total_epochs = n_epochs * n_loop_training +start = total_epochs // 10 +learning_rate = optax.polynomial_schedule( + 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start +) jim = Jim( likelihood, @@ -175,13 +156,12 @@ momentum=0.9, batch_size=30000, use_global=True, + keep_quantile=0.0, train_thinning=1, output_thinning=10, local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer, "default"], - verbose=True, + # strategies=[Adam_optimizer,"default"], ) + jim.sample(jax.random.PRNGKey(42)) -# jim.get_samples() -# jim.print_summary() From ae9204abbf8c246b4d9f8dacf2ea38163400ca12 Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Wed, 16 Oct 2024 17:14:05 -0400 Subject: [PATCH 3/4] Update GW170817 PhenomD --- example/GW170817_IMRPhenomD.py | 184 +++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 example/GW170817_IMRPhenomD.py diff --git a/example/GW170817_IMRPhenomD.py b/example/GW170817_IMRPhenomD.py new file mode 100644 index 00000000..d3aca76d --- /dev/null +++ b/example/GW170817_IMRPhenomD.py @@ -0,0 +1,184 @@ +import time + +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from jimgw.jim import Jim +from jimgw.prior import ( + CombinePrior, + UniformPrior, + CosinePrior, + SinePrior, + PowerLawPrior, + UniformSpherePrior, +) +from jimgw.single_event.detector import H1, L1, V1 +from jimgw.single_event.likelihood import TransientLikelihoodFD, HeterodynedTransientLikelihoodFD +from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ( + SkyFrameToDetectorFrameSkyPositionTransform, + SphereSpinToCartesianSpinTransform, + MassRatioToSymmetricMassRatioTransform, + DistanceToSNRWeightedDistanceTransform, + GeocentricArrivalTimeToDetectorArrivalTimeTransform, + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, +) +from jimgw.single_event.utils import Mc_q_to_m1_m2 +from flowMC.strategy.optimization import optimization_Adam + +jax.config.update("jax_enable_x64", True) + +########################################### +########## First we grab data ############# +########################################### + +total_time_start = time.time() + +# first, fetch a 4s segment centered on GW150914 + +gps = 1187008882.43 +trigger_time = gps +fmin = 20 +fmax = 2048 +minimum_frequency = fmin +maximum_frequency = fmax +duration = 128 +post_trigger_duration = 2 +epoch = duration - post_trigger_duration +f_ref = fmin + +ifos = [H1, L1, V1] + + +tukey_alpha = 2 / (duration / 2) +H1.load_data( + gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha +) +L1.load_data( + gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha +) +V1.load_data( + gps, duration, 2, fmin, fmax, psd_pad=duration + 16, tukey_alpha=tukey_alpha +) + + +waveform = RippleIMRPhenomD(f_ref=f_ref) + +########################################### +########## Set up priors ################## +########################################### + +prior = [] + +# Mass prior +M_c_min, M_c_max = 1.18, 1.21 +q_min, q_max = 0.125, 1.0 +Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=["M_c"]) +q_prior = UniformPrior(q_min, q_max, parameter_names=["q"]) + +prior = prior + [Mc_prior, q_prior] + +# Spin prior +s1_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +iota_prior = SinePrior(parameter_names=["iota"]) + +prior = prior + [ + s1_prior, + s2_prior, + iota_prior, +] + +# Extrinsic prior +dL_prior = PowerLawPrior(1.0, 75.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.1, 0.1, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = prior + [ + dL_prior, + t_c_prior, + phase_c_prior, + psi_prior, + ra_prior, + dec_prior, +] + +prior = CombinePrior(prior) + +# Defining Transforms + +sample_transforms = [ + DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = (["s1_z"], ["s1_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["s2_z"], ["s2_z_unbounded"]) , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + MassRatioToSymmetricMassRatioTransform, +] + + +#likelihood = TransientLikelihoodFD( +# [H1, L1, V1], waveform=waveform, trigger_time=trigger_time, duration=duration, post_trigger_duration=post_trigger_duration +#) + +likelihood = HeterodynedTransientLikelihoodFD(ifos, waveform=waveform, n_bins = 1000, trigger_time=trigger_time, duration=duration, post_trigger_duration=post_trigger_duration, prior = prior, sample_transforms = sample_transforms, likelihood_transforms = likelihood_transforms, popsize = 10, n_steps = 50) + +mass_matrix = jnp.eye(prior.n_dim) +# mass_matrix = mass_matrix.at[1, 1].set(1e-3) +# mass_matrix = mass_matrix.at[9, 9].set(1e-3) +local_sampler_arg = {"step_size": mass_matrix * 1e-3} + +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) + +import optax + +n_epochs = 20 +n_loop_training = 100 +total_epochs = n_epochs * n_loop_training +start = total_epochs // 10 +learning_rate = optax.polynomial_schedule( + 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start +) + +jim = Jim( + likelihood, + prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_loop_training=n_loop_training, + n_loop_production=20, + n_local_steps=10, + n_global_steps=1000, + n_chains=500, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30000, + n_flow_sample=100000, + momentum=0.9, + batch_size=30000, + use_global=True, + keep_quantile=0.0, + train_thinning=1, + output_thinning=10, + local_sampler_arg=local_sampler_arg, + # strategies=[Adam_optimizer,"default"], +) + + +jim.sample(jax.random.PRNGKey(42)) From 38de2ac4c694ecd6f7c8fd362587552e2f9f027d Mon Sep 17 00:00:00 2001 From: "Peter T. H. Pang" Date: Wed, 16 Oct 2024 17:25:23 -0400 Subject: [PATCH 4/4] Fixing the Pv2 test again --- test/integration/test_GW150914_Pv2.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/test/integration/test_GW150914_Pv2.py b/test/integration/test_GW150914_Pv2.py index 2b54ed3c..7f118703 100644 --- a/test/integration/test_GW150914_Pv2.py +++ b/test/integration/test_GW150914_Pv2.py @@ -24,11 +24,7 @@ DistanceToSNRWeightedDistanceTransform, GeocentricArrivalTimeToDetectorArrivalTimeTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, - ComponentMassesToChirpMassMassRatioTransform, ) -from jimgw.single_event.prior import ( - ChirpMassMassRatioBoundedUniformComponentPrior -) from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -58,14 +54,10 @@ # Mass prior M_c_min, M_c_max = 10.0, 80.0 q_min, q_max = 0.125, 1.0 -m1m2_prior = ChirpMassMassRatioBoundedUniformComponentPrior( - q_min=q_min, - q_max=q_max, - M_c_min=M_c_min, - M_c_max=M_c_max, -) +Mc_prior = UniformPrior(M_c_min, M_c_max, parameter_names=['M_c']) +q_prior = UniformPrior(q_min, q_max, parameter_names=['q']) -prior = prior + [m1m2_prior] +prior = prior + [Mc_prior, q_prior] # Spin prior s1_prior = UniformSpherePrior(parameter_names=["s1"]) @@ -100,7 +92,6 @@ # Defining Transforms sample_transforms = [ - ComponentMassesToChirpMassMassRatioTransform, DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), @@ -121,7 +112,6 @@ ] likelihood_transforms = [ - ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform, SphereSpinToCartesianSpinTransform("s1"), SphereSpinToCartesianSpinTransform("s2"),