From 5e2792ca62feb25dfff4b57bed097b1330b5f479 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 10:38:45 +0800 Subject: [PATCH 01/20] Updated Pv2 scripts --- example/GW150914_PV2.py | 202 +++++++++++++------------- test/integration/test_GW150914_Pv2.py | 4 +- 2 files changed, 105 insertions(+), 101 deletions(-) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index 06209ba6..f0eb9ffd 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -4,13 +4,14 @@ import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import Composite, Sphere, Unconstrained_Uniform +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomPv2 +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform from flowMC.strategy.optimization import optimization_Adam - jax.config.update("jax_enable_x64", True) ########################################### @@ -21,125 +22,103 @@ # first, fetch a 4s segment centered on GW150914 gps = 1126259462.4 -start = gps - 2 -end = gps + 2 +duration = 4 +post_trigger_duration = 2 +start_pad = duration - post_trigger_duration +end_pad = post_trigger_duration fmin = 20.0 fmax = 1024.0 -ifos = ["H1", "L1"] - -H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) - -waveform = RippleIMRPhenomPv2(f_ref=20) - -########################################### -########## Set up priors ################## -########################################### - -Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) -q_prior = Unconstrained_Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, -) -s1_prior = Sphere(naming="s1") -s2_prior = Sphere(naming="s2") -dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) -t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) -phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) -cos_iota_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos( - jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) -psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) -ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) -sin_dec_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin( - jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) - -prior = Composite( +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) +q_prior = UniformPrior(0.125, 1., parameter_names=["q"]) +theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) +phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) +theta_1_prior = SinePrior(parameter_names=["theta_1"]) +theta_2_prior = SinePrior(parameter_names=["theta_2"]) +phi_12_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_12"]) +a_1_prior = UniformPrior(0.0, 1.0, parameter_names=["a_1"]) +a_2_prior = UniformPrior(0.0, 1.0, parameter_names=["a_2"]) +dL_prior = PowerLawPrior(10.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( [ Mc_prior, q_prior, - s1_prior, - s2_prior, + theta_jn_prior, + phi_jl_prior, + theta_1_prior, + theta_2_prior, + phi_12_prior, + a_1_prior, + a_2_prior, dL_prior, t_c_prior, phase_c_prior, - cos_iota_prior, psi_prior, ra_prior, - sin_dec_prior, - ], + dec_prior, + ] ) -epsilon = 1e-3 -bounds = jnp.array( - [ - [10.0, 80.0], - [0.125, 1.0], - [0, jnp.pi], - [0, 2 * jnp.pi], - [0.0, 1.0], - [0, jnp.pi], - [0, 2 * jnp.pi], - [0.0, 1.0], - [0.0, 2000], - [-0.05, 0.05], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - [0.0, jnp.pi], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - ] -) + jnp.array([[epsilon, -epsilon]]) +sample_transforms = [ + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), + BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_jl"], ["phi_jl_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["theta_1"], ["theta_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["theta_2"], ["theta_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) +] + +likelihood_transforms = [ + SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), + MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), +] likelihood = TransientLikelihoodFD( - [H1, L1], waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2 + ifos, + waveform=RippleIMRPhenomPv2(), + trigger_time=gps, + duration=4, + post_trigger_duration=2, ) -# likelihood = HeterodynedTransientLikelihoodFD([H1, L1], prior=prior, bounds=bounds, waveform=waveform, trigger_time=gps, duration=4, post_trigger_duration=2) -mass_matrix = jnp.eye(prior.n_dim) +mass_matrix = jnp.eye(15) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[9, 9].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 1e-3} +local_sampler_arg = {"step_size": mass_matrix * 3e-3} -Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1, bounds=bounds) +Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) -import optax -n_epochs = 20 +n_epochs = 30 n_loop_training = 100 -total_epochs = n_epochs * n_loop_training -start = total_epochs//10 -learning_rate = optax.polynomial_schedule( - 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start -) +learning_rate = 1e-4 + jim = Jim( likelihood, prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=20, n_local_steps=10, @@ -148,18 +127,43 @@ n_epochs=n_epochs, learning_rate=learning_rate, n_max_examples=30000, - n_flow_sample=100000, + n_flow_samples=100000, momentum=0.9, batch_size=30000, use_global=True, - keep_quantile=0.0, train_thinning=1, output_thinning=10, local_sampler_arg=local_sampler_arg, - # strategies=[Adam_optimizer,"default"], + strategies=[Adam_optimizer, "default"], ) +jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() + +########################################### +########## Visualize the Data ############# +########################################### +import corner +import matplotlib.pyplot as plt import numpy as np -# chains = np.load('./GW150914_init.npz')['chain'] -jim.sample(jax.random.PRNGKey(42))#,initial_guess=chains) +production_summary = jim.sampler.get_sampler_state(training=False) +production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T +if jim.sample_transforms: + transformed_chain = jim.add_name(production_chain) + for transform in jim.sample_transforms: + transformed_chain = transform.backward(transformed_chain) +result = transformed_chain +labels = list(transformed_chain.keys()) + +samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array +transposed_array = samples.T # transpose the array +figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) +plt.savefig("GW1500914_pv2.jpeg") + +########################################### +############# Save the Run ################ +########################################### +import pickle +pickle.dump(result, open("GW150914_pv2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file diff --git a/test/integration/test_GW150914_Pv2.py b/test/integration/test_GW150914_Pv2.py index c9d83a5e..e8b939ca 100644 --- a/test/integration/test_GW150914_Pv2.py +++ b/test/integration/test_GW150914_Pv2.py @@ -7,7 +7,7 @@ from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD -from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.single_event.waveform import RippleIMRPhenomPv2 from jimgw.transforms import BoundToUnbound from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform from flowMC.strategy.optimization import optimization_Adam @@ -95,7 +95,7 @@ likelihood = TransientLikelihoodFD( ifos, - waveform=RippleIMRPhenomD(), + waveform=RippleIMRPhenomPv2(), trigger_time=gps, duration=4, post_trigger_duration=2, From 420f663040c7735161e1ef60691b18197365dcf4 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 10:56:31 +0800 Subject: [PATCH 02/20] Updated Pv2 scripts --- example/GW150914_PV2.py | 6 ++++-- test/integration/test_GW150914_Pv2.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index f0eb9ffd..6a67843b 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -31,6 +31,8 @@ ifos = [H1, L1] +f_ref = 20.0 + for ifo in ifos: ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) @@ -89,13 +91,13 @@ ] likelihood_transforms = [ - SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), + SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=f_ref), MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), ] likelihood = TransientLikelihoodFD( ifos, - waveform=RippleIMRPhenomPv2(), + waveform=RippleIMRPhenomPv2(f_ref=f_ref), trigger_time=gps, duration=4, post_trigger_duration=2, diff --git a/test/integration/test_GW150914_Pv2.py b/test/integration/test_GW150914_Pv2.py index e8b939ca..d1ee0c37 100644 --- a/test/integration/test_GW150914_Pv2.py +++ b/test/integration/test_GW150914_Pv2.py @@ -29,6 +29,8 @@ fmin = 20.0 fmax = 1024.0 +f_ref = 20.0 + ifos = [H1, L1] for ifo in ifos: @@ -89,13 +91,13 @@ ] likelihood_transforms = [ - SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=20.0), + SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=f_ref), MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), ] likelihood = TransientLikelihoodFD( ifos, - waveform=RippleIMRPhenomPv2(), + waveform=RippleIMRPhenomPv2(f_ref=f_ref), trigger_time=gps, duration=4, post_trigger_duration=2, From aaded8eeba460c70d3fc07dbde67b0347cf7f686 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 11:03:46 +0800 Subject: [PATCH 03/20] Updated GW150914.py --- example/GW150914.py | 144 ++++++++++++++++++++++++-------------------- 1 file changed, 79 insertions(+), 65 deletions(-) diff --git a/example/GW150914.py b/example/GW150914.py index 559b5b7c..62343b2a 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -1,13 +1,14 @@ -import time - import jax import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import Composite, Unconstrained_Uniform +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1 from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -16,8 +17,6 @@ ########## First we grab data ############# ########################################### -total_time_start = time.time() - # first, fetch a 4s segment centered on GW150914 gps = 1126259462.4 duration = 4 @@ -27,69 +26,63 @@ fmin = 20.0 fmax = 1024.0 -ifos = ["H1", "L1"] - -H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) - -Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) -q_prior = Unconstrained_Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, -) -s1z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s1_z"]) -s2z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s2_z"]) -dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) -t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) -phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) -cos_iota_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos( - jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) -psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) -ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) -sin_dec_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin( - jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) - -prior = Composite( +ifos = [H1, L1] + +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) + +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) + +prior = CombinePrior( [ - Mc_prior, - q_prior, + m_1_prior, + m_2_prior, s1z_prior, s2z_prior, dL_prior, t_c_prior, phase_c_prior, - cos_iota_prior, + iota_prior, psi_prior, ra_prior, - sin_dec_prior, + dec_prior, ] ) + +sample_transforms = [ + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), +] + likelihood = TransientLikelihoodFD( - [H1, L1], + ifos, waveform=RippleIMRPhenomD(), trigger_time=gps, duration=4, @@ -104,19 +97,16 @@ Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) -import optax -n_epochs = 20 +n_epochs = 30 n_loop_training = 100 -total_epochs = n_epochs * n_loop_training -start = total_epochs//10 -learning_rate = optax.polynomial_schedule( - 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start -) +learning_rate = 1e-4 jim = Jim( likelihood, prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=20, n_local_steps=10, @@ -132,7 +122,31 @@ train_thinning=1, output_thinning=10, local_sampler_arg=local_sampler_arg, - strategies=[Adam_optimizer,"default"], + strategies=[Adam_optimizer, "default"], ) jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() + + +########################################### +########## Visualize the Data ############# +########################################### +import corner +import matplotlib.pyplot as plt +import numpy as np + +production_summary = jim.sampler.get_sampler_state(training=False) +production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T +if jim.sample_transforms: + transformed_chain = jim.add_name(production_chain) + for transform in jim.sample_transforms: + transformed_chain = transform.backward(transformed_chain) +result = transformed_chain +labels = list(transformed_chain.keys()) + +samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array +transposed_array = samples.T # transpose the array +figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) +plt.savefig("GW1500914_D.jpeg") From 316e7c89c5ca91f7e44bd8276dd297e8c862273b Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 11:07:16 +0800 Subject: [PATCH 04/20] Updated GW150914_heterodyne.py --- example/GW150914_heterodyne.py | 177 ++++++++++++++++----------------- 1 file changed, 88 insertions(+), 89 deletions(-) diff --git a/example/GW150914_heterodyne.py b/example/GW150914_heterodyne.py index c1faed03..bf92f9fd 100644 --- a/example/GW150914_heterodyne.py +++ b/example/GW150914_heterodyne.py @@ -1,16 +1,14 @@ -import time - import jax import jax.numpy as jnp from jimgw.jim import Jim -from jimgw.prior import Composite, Unconstrained_Uniform +from jimgw.prior import CombinePrior, UniformPrior, CosinePrior, SinePrior, PowerLawPrior from jimgw.single_event.detector import H1, L1 -from jimgw.single_event.likelihood import ( - HeterodynedTransientLikelihoodFD, - TransientLikelihoodFD, -) +from jimgw.single_event.likelihood import HeterodynedTransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomD +from jimgw.transforms import BoundToUnbound +from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -19,8 +17,6 @@ ########## First we grab data ############# ########################################### -total_time_start = time.time() - # first, fetch a 4s segment centered on GW150914 gps = 1126259462.4 duration = 4 @@ -30,113 +26,92 @@ fmin = 20.0 fmax = 1024.0 -ifos = ["H1", "L1"] +ifos = [H1, L1] -H1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +for ifo in ifos: + ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -Mc_prior = Unconstrained_Uniform(10.0, 80.0, naming=["M_c"]) -q_prior = Unconstrained_Uniform( - 0.125, - 1.0, - naming=["q"], - transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)}, -) -s1z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s1_z"]) -s2z_prior = Unconstrained_Uniform(-1.0, 1.0, naming=["s2_z"]) -dL_prior = Unconstrained_Uniform(0.0, 2000.0, naming=["d_L"]) -t_c_prior = Unconstrained_Uniform(-0.05, 0.05, naming=["t_c"]) -phase_c_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["phase_c"]) -cos_iota_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["cos_iota"], - transforms={ - "cos_iota": ( - "iota", - lambda params: jnp.arccos( - jnp.arcsin(jnp.sin(params["cos_iota"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) -psi_prior = Unconstrained_Uniform(0.0, jnp.pi, naming=["psi"]) -ra_prior = Unconstrained_Uniform(0.0, 2 * jnp.pi, naming=["ra"]) -sin_dec_prior = Unconstrained_Uniform( - -1.0, - 1.0, - naming=["sin_dec"], - transforms={ - "sin_dec": ( - "dec", - lambda params: jnp.arcsin( - jnp.arcsin(jnp.sin(params["sin_dec"] / 2 * jnp.pi)) * 2 / jnp.pi - ), - ) - }, -) +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) +s1z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s1_z"]) +s2z_prior = UniformPrior(-1.0, 1.0, parameter_names=["s2_z"]) +dL_prior = PowerLawPrior(1.0, 2000.0, 2.0, parameter_names=["d_L"]) +t_c_prior = UniformPrior(-0.05, 0.05, parameter_names=["t_c"]) +phase_c_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phase_c"]) +iota_prior = SinePrior(parameter_names=["iota"]) +psi_prior = UniformPrior(0.0, jnp.pi, parameter_names=["psi"]) +ra_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["ra"]) +dec_prior = CosinePrior(parameter_names=["dec"]) -prior = Composite( +prior = CombinePrior( [ - Mc_prior, - q_prior, + m_1_prior, + m_2_prior, s1z_prior, s2z_prior, dL_prior, t_c_prior, phase_c_prior, - cos_iota_prior, + iota_prior, psi_prior, ra_prior, - sin_dec_prior, + dec_prior, ] ) -bounds = jnp.array( - [ - [10.0, 80.0], - [0.125, 1.0], - [-1.0, 1.0], - [-1.0, 1.0], - [0.0, 2000.0], - [-0.05, 0.05], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - [0.0, jnp.pi], - [0.0, 2 * jnp.pi], - [-1.0, 1.0], - ] -) +sample_transforms = [ + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), + BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), + BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), +] + +likelihood_transforms = [ + ComponentMassesToChirpMassSymmetricMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "eta"]]), +] likelihood = HeterodynedTransientLikelihoodFD( - [H1, L1], + ifos, prior=prior, - bounds=bounds, waveform=RippleIMRPhenomD(), trigger_time=gps, - duration=duration, - post_trigger_duration=post_trigger_duration, - n_steps=3000, + duration=4, + post_trigger_duration=2, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, + n_steps=5, + popsize=10, ) + mass_matrix = jnp.eye(11) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[5, 5].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 3e-3} Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) -import optax -n_epochs = 20 + +n_epochs = 30 n_loop_training = 100 -total_epochs = n_epochs * n_loop_training -start = total_epochs//10 -learning_rate = optax.polynomial_schedule( - 1e-3, 1e-4, 4.0, total_epochs - start, transition_begin=start -) +learning_rate = 1e-4 + jim = Jim( likelihood, prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, n_loop_training=n_loop_training, n_loop_production=20, n_local_steps=10, @@ -145,14 +120,38 @@ n_epochs=n_epochs, learning_rate=learning_rate, n_max_examples=30000, - n_flow_sample=100000, - momentum=0.9, - batch_size=30000, + n_flow_samples=100000, + momentum=30000, + batch_size=100, use_global=True, - keep_quantile=0.0, train_thinning=1, output_thinning=10, local_sampler_arg=local_sampler_arg, - # strategies=[Adam_optimizer,"default"], + strategies=[Adam_optimizer, "default"], ) + jim.sample(jax.random.PRNGKey(42)) +jim.get_samples() +jim.print_summary() + + +########################################### +########## Visualize the Data ############# +########################################### +import corner +import matplotlib.pyplot as plt +import numpy as np + +production_summary = jim.sampler.get_sampler_state(training=False) +production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T +if jim.sample_transforms: + transformed_chain = jim.add_name(production_chain) + for transform in jim.sample_transforms: + transformed_chain = transform.backward(transformed_chain) +result = transformed_chain +labels = list(transformed_chain.keys()) + +samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array +transposed_array = samples.T # transpose the array +figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) +plt.savefig("GW1500914_D_heterodyne.jpeg") From d06de0df091d4a9d690e43c6a28e12e2e091a458 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 12:35:28 +0800 Subject: [PATCH 05/20] Updated test_GW150914_Pv2.py --- test/integration/test_GW150914_Pv2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/test_GW150914_Pv2.py b/test/integration/test_GW150914_Pv2.py index d1ee0c37..392ef131 100644 --- a/test/integration/test_GW150914_Pv2.py +++ b/test/integration/test_GW150914_Pv2.py @@ -82,7 +82,7 @@ BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=10.0, original_upper_bound=2000.0), BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), From b09ed0c41c6ed9c7737f648e69bd299fb342546c Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 12:40:17 +0800 Subject: [PATCH 06/20] Updated GW150914_Pv2.py --- example/GW150914_PV2.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index 6a67843b..e30e9030 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -9,7 +9,8 @@ from jimgw.single_event.likelihood import TransientLikelihoodFD from jimgw.single_event.waveform import RippleIMRPhenomPv2 from jimgw.transforms import BoundToUnbound -from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform +from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform, ComponentMassesToChirpMassMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform +from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam jax.config.update("jax_enable_x64", True) @@ -36,8 +37,10 @@ for ifo in ifos: ifo.load_data(gps, start_pad, end_pad, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -Mc_prior = UniformPrior(10.0, 80.0, parameter_names=["M_c"]) -q_prior = UniformPrior(0.125, 1., parameter_names=["q"]) +M_c_min, M_c_max = 10.0, 80.0 +q_min, q_max = 0.125, 1.0 +m_1_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_max)[0], Mc_q_to_m1_m2(M_c_max, q_min)[0], parameter_names=["m_1"]) +m_2_prior = UniformPrior(Mc_q_to_m1_m2(M_c_min, q_min)[1], Mc_q_to_m1_m2(M_c_max, q_max)[1], parameter_names=["m_2"]) theta_jn_prior = SinePrior(parameter_names=["theta_jn"]) phi_jl_prior = UniformPrior(0.0, 2 * jnp.pi, parameter_names=["phi_jl"]) theta_1_prior = SinePrior(parameter_names=["theta_1"]) @@ -54,8 +57,8 @@ prior = CombinePrior( [ - Mc_prior, - q_prior, + m_1_prior, + m_2_prior, theta_jn_prior, phi_jl_prior, theta_1_prior, @@ -73,6 +76,7 @@ ) sample_transforms = [ + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), BoundToUnbound(name_mapping = [["M_c"], ["M_c_unbounded"]], original_lower_bound=10.0, original_upper_bound=80.0), BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=0.125, original_upper_bound=1.), BoundToUnbound(name_mapping = [["theta_jn"], ["theta_jn_unbounded"]] , original_lower_bound=0.0, original_upper_bound=jnp.pi), @@ -82,15 +86,17 @@ BoundToUnbound(name_mapping = [["phi_12"], ["phi_12_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["a_1"], ["a_1_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["a_2"], ["a_2_unbounded"]] , original_lower_bound=0.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=10.0, original_upper_bound=2000.0), BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) ] likelihood_transforms = [ + ComponentMassesToChirpMassMassRatioTransform(name_mapping=[["m_1", "m_2"], ["M_c", "q"]]), SpinToCartesianSpinTransform(name_mapping=[["theta_jn", "phi_jl", "theta_1", "theta_2", "phi_12", "a_1", "a_2"], ["iota", "s1_x", "s1_y", "s1_z", "s2_x", "s2_y", "s2_z"]], freq_ref=f_ref), MassRatioToSymmetricMassRatioTransform(name_mapping=[["q"], ["eta"]]), ] @@ -107,7 +113,7 @@ mass_matrix = jnp.eye(15) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[9, 9].set(1e-3) -local_sampler_arg = {"step_size": mass_matrix * 3e-3} +local_sampler_arg = {"step_size": mass_matrix * 1e-3} Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) From f85787c3411a956d0c98e3036bbab3c247aa47d7 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 12:40:48 +0800 Subject: [PATCH 07/20] Fixed PowerLaw Lowerbound error --- example/GW150914.py | 2 +- example/GW150914_heterodyne.py | 2 +- test/integration/test_GW150914_D.py | 2 +- test/integration/test_GW150914_D_heterodyne.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/example/GW150914.py b/example/GW150914.py index 62343b2a..13ef0c8c 100644 --- a/example/GW150914.py +++ b/example/GW150914.py @@ -67,7 +67,7 @@ BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=1.0, original_upper_bound=2000.0), BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), diff --git a/example/GW150914_heterodyne.py b/example/GW150914_heterodyne.py index bf92f9fd..8a3e9e66 100644 --- a/example/GW150914_heterodyne.py +++ b/example/GW150914_heterodyne.py @@ -67,7 +67,7 @@ BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=1.0, original_upper_bound=2000.0), BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), diff --git a/test/integration/test_GW150914_D.py b/test/integration/test_GW150914_D.py index e1eee9ac..24dac718 100644 --- a/test/integration/test_GW150914_D.py +++ b/test/integration/test_GW150914_D.py @@ -67,7 +67,7 @@ BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=1.0, original_upper_bound=2000.0), BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), diff --git a/test/integration/test_GW150914_D_heterodyne.py b/test/integration/test_GW150914_D_heterodyne.py index bf97efdb..ff5cc57b 100644 --- a/test/integration/test_GW150914_D_heterodyne.py +++ b/test/integration/test_GW150914_D_heterodyne.py @@ -67,7 +67,7 @@ BoundToUnbound(name_mapping = [["q"], ["q_unbounded"]], original_lower_bound=q_min, original_upper_bound=q_max), BoundToUnbound(name_mapping = [["s1_z"], ["s1_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), BoundToUnbound(name_mapping = [["s2_z"], ["s2_z_unbounded"]] , original_lower_bound=-1.0, original_upper_bound=1.0), - BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2000.0), + BoundToUnbound(name_mapping = [["d_L"], ["d_L_unbounded"]] , original_lower_bound=1.0, original_upper_bound=2000.0), BoundToUnbound(name_mapping = [["t_c"], ["t_c_unbounded"]] , original_lower_bound=-0.05, original_upper_bound=0.05), BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["iota"], ["iota_unbounded"]], original_lower_bound=0., original_upper_bound=jnp.pi), From d34fd529dbb6b58fe50331b6a23732dea211b507 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:00:57 +0800 Subject: [PATCH 08/20] Updated GW150914_Pv2.py --- example/GW150914_PV2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/GW150914_PV2.py b/example/GW150914_PV2.py index e30e9030..63dc0754 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_PV2.py @@ -91,8 +91,8 @@ BoundToUnbound(name_mapping = [["phase_c"], ["phase_c_unbounded"]] , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), BoundToUnbound(name_mapping = [["psi"], ["psi_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), SkyFrameToDetectorFrameSkyPositionTransform(name_mapping = [["ra", "dec"], ["zenith", "azimuth"]], gps_time=gps, ifos=ifos), - BoundToUnbound(name_mapping = [["ra"], ["ra_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = [["dec"], ["dec_unbounded"]],original_lower_bound=-jnp.pi / 2, original_upper_bound=jnp.pi / 2) + BoundToUnbound(name_mapping = [["zenith"], ["zenith_unbounded"]], original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping = [["azimuth"], ["azimuth_unbounded"]], original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ From 250a344b81509a7503b54c26c4076ccff7c6fd13 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:02:06 +0800 Subject: [PATCH 09/20] Rename GW150914.py to GW150914_D.py --- example/{GW150914.py => GW150914_D.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename example/{GW150914.py => GW150914_D.py} (100%) diff --git a/example/GW150914.py b/example/GW150914_D.py similarity index 100% rename from example/GW150914.py rename to example/GW150914_D.py From f3f462da692e2da9bd0a139e97c550853ab41f24 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:02:32 +0800 Subject: [PATCH 10/20] Rename GW150914_PV2.py to GW150914_Pv2.py --- example/{GW150914_PV2.py => GW150914_Pv2.py} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename example/{GW150914_PV2.py => GW150914_Pv2.py} (99%) diff --git a/example/GW150914_PV2.py b/example/GW150914_Pv2.py similarity index 99% rename from example/GW150914_PV2.py rename to example/GW150914_Pv2.py index 63dc0754..2b4c9415 100644 --- a/example/GW150914_PV2.py +++ b/example/GW150914_Pv2.py @@ -174,4 +174,4 @@ ############# Save the Run ################ ########################################### import pickle -pickle.dump(result, open("GW150914_pv2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file +pickle.dump(result, open("GW150914_pv2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) From 40339d95f1c3cea0d38b418f9e94a4c6932ecff7 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:02:51 +0800 Subject: [PATCH 11/20] Rename GW150914_heterodyne.py to GW150914_D_heterodyne.py --- example/{GW150914_heterodyne.py => GW150914_D_heterodyne.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename example/{GW150914_heterodyne.py => GW150914_D_heterodyne.py} (100%) diff --git a/example/GW150914_heterodyne.py b/example/GW150914_D_heterodyne.py similarity index 100% rename from example/GW150914_heterodyne.py rename to example/GW150914_D_heterodyne.py From db768ea1c14e393a6e36054477f53ae6fb9daa82 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:04:28 +0800 Subject: [PATCH 12/20] Updated GW150914_Pv2.py --- example/GW150914_Pv2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/GW150914_Pv2.py b/example/GW150914_Pv2.py index 2b4c9415..11d5f5a4 100644 --- a/example/GW150914_Pv2.py +++ b/example/GW150914_Pv2.py @@ -173,5 +173,5 @@ ########################################### ############# Save the Run ################ ########################################### -import pickle -pickle.dump(result, open("GW150914_pv2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) +# import pickle +# pickle.dump(result, open("GW150914_pv2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) From b452856d108308c11c3c3604d533f26c22d9b178 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 13:09:32 +0800 Subject: [PATCH 13/20] Updated GW150914_Pv2.py --- example/GW150914_Pv2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/GW150914_Pv2.py b/example/GW150914_Pv2.py index 11d5f5a4..ea990b9a 100644 --- a/example/GW150914_Pv2.py +++ b/example/GW150914_Pv2.py @@ -168,7 +168,7 @@ samples = np.array(list(result.values())).reshape(int(len(labels)), -1) # flatten the array transposed_array = samples.T # transpose the array figure = corner.corner(transposed_array, labels=labels, plot_datapoints=False, title_quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt='g', use_math_text=True) -plt.savefig("GW1500914_pv2.jpeg") +plt.savefig("GW1500914_Pv2.jpeg") ########################################### ############# Save the Run ################ From 814c204b93649d6c9dfd8ab1d4f46caa17730037 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:04:16 +0800 Subject: [PATCH 14/20] Fixed chain processing error --- example/GW150914_D.py | 2 +- example/GW150914_D_heterodyne.py | 2 +- example/GW150914_Pv2.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/example/GW150914_D.py b/example/GW150914_D.py index 13ef0c8c..efa58b04 100644 --- a/example/GW150914_D.py +++ b/example/GW150914_D.py @@ -141,7 +141,7 @@ production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T if jim.sample_transforms: transformed_chain = jim.add_name(production_chain) - for transform in jim.sample_transforms: + for transform in reversed(jim.sample_transforms): transformed_chain = transform.backward(transformed_chain) result = transformed_chain labels = list(transformed_chain.keys()) diff --git a/example/GW150914_D_heterodyne.py b/example/GW150914_D_heterodyne.py index 8a3e9e66..cb89b7e5 100644 --- a/example/GW150914_D_heterodyne.py +++ b/example/GW150914_D_heterodyne.py @@ -146,7 +146,7 @@ production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T if jim.sample_transforms: transformed_chain = jim.add_name(production_chain) - for transform in jim.sample_transforms: + for transform in reversed(jim.sample_transforms): transformed_chain = transform.backward(transformed_chain) result = transformed_chain labels = list(transformed_chain.keys()) diff --git a/example/GW150914_Pv2.py b/example/GW150914_Pv2.py index ea990b9a..c922c822 100644 --- a/example/GW150914_Pv2.py +++ b/example/GW150914_Pv2.py @@ -160,7 +160,7 @@ production_chain = production_summary["chains"].reshape(-1, len(jim.parameter_names)).T if jim.sample_transforms: transformed_chain = jim.add_name(production_chain) - for transform in jim.sample_transforms: + for transform in reversed(jim.sample_transforms): transformed_chain = transform.backward(transformed_chain) result = transformed_chain labels = list(transformed_chain.keys()) From 645dea1d69f5eba84232b57a41755250e6b86a86 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:15:33 +0800 Subject: [PATCH 15/20] Updated GW150914_D_heterodyne.py --- example/GW150914_D_heterodyne.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/GW150914_D_heterodyne.py b/example/GW150914_D_heterodyne.py index cb89b7e5..91e04ba2 100644 --- a/example/GW150914_D_heterodyne.py +++ b/example/GW150914_D_heterodyne.py @@ -121,8 +121,8 @@ learning_rate=learning_rate, n_max_examples=30000, n_flow_samples=100000, - momentum=30000, - batch_size=100, + momentum=0.9, + batch_size=30000, use_global=True, train_thinning=1, output_thinning=10, From 8e014eec785fdf412a03e0a797ed824236c231c3 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Wed, 21 Aug 2024 15:44:26 +0800 Subject: [PATCH 16/20] Updated prior.py --- src/jimgw/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index 715d49de..ac56a1e1 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -89,7 +89,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -126,7 +126,7 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) + jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], From bc8e7dca5e8d5bd37d6555f3db2c08de7532138f Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 02:20:08 +0800 Subject: [PATCH 17/20] Update transforms.py --- src/jimgw/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/jimgw/transforms.py b/src/jimgw/transforms.py index ac56a1e1..715d49de 100644 --- a/src/jimgw/transforms.py +++ b/src/jimgw/transforms.py @@ -89,7 +89,7 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], @@ -126,7 +126,7 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]: output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) - jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))) + jacobian = jnp.log(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], From b1133d36f4e6ff20667b30be9d7a91d58419d562 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:07:06 +0800 Subject: [PATCH 18/20] Update runManager.py --- src/jimgw/single_event/runManager.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index aa8d0dc7..0a4b502d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,9 +71,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] = field( - default_factory=lambda: {} - ) + injection_parameters: dict[str, float] injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} @@ -125,9 +123,6 @@ def __init__(self, **kwargs): print("Neither run instance nor path provided.") raise ValueError - if self.run.injection and not self.run.injection_parameters: - raise ValueError("Injection mode requires injection parameters.") - local_prior = self.initialize_prior() local_likelihood = self.initialize_likelihood(local_prior) self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters) From 2a9d696a20d28fdd1693698a1840fef9ab276578 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:07:49 +0800 Subject: [PATCH 19/20] Update runManager.py --- src/jimgw/single_event/runManager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 0a4b502d..3f65166d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -71,7 +71,7 @@ class SingleEventRun: str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] - injection_parameters: dict[str, float] + injection_parameters: dict[str, float] injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} From 2adcabf0bbdbee0ee42b163c8ba51a45a4f1b9a6 Mon Sep 17 00:00:00 2001 From: xuyuon <116078673+xuyuon@users.noreply.github.com> Date: Thu, 22 Aug 2024 09:11:25 +0800 Subject: [PATCH 20/20] Updated test_prior.py --- test/unit/test_prior.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/test/unit/test_prior.py b/test/unit/test_prior.py index 852ded16..5fbcf3c3 100644 --- a/test/unit/test_prior.py +++ b/test/unit/test_prior.py @@ -43,11 +43,8 @@ def test_sine(self): log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) # Check that the log_prob is correct in the support - x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) - y = jax.vmap(p.base_prior.base_prior.transform)(x) - y = jax.vmap(p.base_prior.transform)(y) - y = jax.vmap(p.transform)(y) - assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.sin(y['x'])/2.0)) + samples = samples['x'] + assert jnp.allclose(log_prob, jnp.log(jnp.sin(samples)/2.0)) def test_cosine(self): p = CosinePrior(["x"]) @@ -57,11 +54,8 @@ def test_cosine(self): # Check that the log_prob is finite log_prob = jax.vmap(p.log_prob)(samples) assert jnp.all(jnp.isfinite(log_prob)) - # Check that the log_prob is correct in the support - x = trace_prior_parent(p, [])[0].add_name(jnp.linspace(-10.0, 10.0, 1000)[None]) - y = jax.vmap(p.base_prior.transform)(x) - y = jax.vmap(p.transform)(y) - assert jnp.allclose(jax.vmap(p.log_prob)(y), jnp.log(jnp.cos(y['x'])/2.0)) + samples = samples['x'] + assert jnp.allclose(log_prob, jnp.log(jnp.cos(samples)/2.0)) def test_uniform_sphere(self): p = UniformSpherePrior(["x"])