diff --git a/example/GW150914_IMRPhenomPV2.py b/example/GW150914_IMRPhenomPV2.py index 7cf79424..08f48dac 100644 --- a/example/GW150914_IMRPhenomPV2.py +++ b/example/GW150914_IMRPhenomPV2.py @@ -1,3 +1,4 @@ +import optax import time import jax @@ -57,7 +58,7 @@ for ifo in ifos: data = jd.Data.from_gwosc(ifo.name, start, end) ifo.set_data(data) - + psd_data = jd.Data.from_gwosc(ifo.name, psd_start, psd_end) psd_fftlength = data.duration * data.sampling_frequency ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength)) @@ -111,23 +112,39 @@ # Defining Transforms sample_transforms = [ - DistanceToSNRWeightedDistanceTransform(gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), - GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + DistanceToSNRWeightedDistanceTransform( + gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax), + GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( + gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform( + tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), - BoundToUnbound(name_mapping = (["M_c"], ["M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), - BoundToUnbound(name_mapping = (["q"], ["q_unbounded"]), original_lower_bound=q_min, original_upper_bound=q_max), - BoundToUnbound(name_mapping = (["s1_phi"], ["s1_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = (["s2_phi"], ["s2_phi_unbounded"]) , original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = (["iota"], ["iota_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["s1_theta"], ["s1_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["s2_theta"], ["s2_theta_unbounded"]) , original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["s1_mag"], ["s1_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99), - BoundToUnbound(name_mapping = (["s2_mag"], ["s2_mag_unbounded"]) , original_lower_bound=0.0, original_upper_bound=0.99), - BoundToUnbound(name_mapping = (["phase_det"], ["phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), - BoundToUnbound(name_mapping = (["psi"], ["psi_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["zenith"], ["zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), - BoundToUnbound(name_mapping = (["azimuth"], ["azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping=(["M_c"], [ + "M_c_unbounded"]), original_lower_bound=M_c_min, original_upper_bound=M_c_max), + BoundToUnbound(name_mapping=(["q"], ["q_unbounded"]), + original_lower_bound=q_min, original_upper_bound=q_max), + BoundToUnbound(name_mapping=(["s1_phi"], [ + "s1_phi_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping=(["s2_phi"], [ + "s2_phi_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping=(["iota"], ["iota_unbounded"]), + original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping=(["s1_theta"], [ + "s1_theta_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping=(["s2_theta"], [ + "s2_theta_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping=(["s1_mag"], [ + "s1_mag_unbounded"]), original_lower_bound=0.0, original_upper_bound=0.99), + BoundToUnbound(name_mapping=(["s2_mag"], [ + "s2_mag_unbounded"]), original_lower_bound=0.0, original_upper_bound=0.99), + BoundToUnbound(name_mapping=(["phase_det"], [ + "phase_det_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), + BoundToUnbound(name_mapping=(["psi"], ["psi_unbounded"]), + original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping=(["zenith"], [ + "zenith_unbounded"]), original_lower_bound=0.0, original_upper_bound=jnp.pi), + BoundToUnbound(name_mapping=(["azimuth"], [ + "azimuth_unbounded"]), original_lower_bound=0.0, original_upper_bound=2 * jnp.pi), ] likelihood_transforms = [ @@ -147,9 +164,9 @@ # mass_matrix = mass_matrix.at[9, 9].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 1e-3} -Adam_optimizer = optimization_Adam(n_steps=3000, learning_rate=0.01, noise_level=1) +Adam_optimizer = optimization_Adam( + n_steps=3000, learning_rate=0.01, noise_level=1) -import optax n_epochs = 20 n_loop_training = 100