From e032436eee6901f458d409db8c65b152300a2e9a Mon Sep 17 00:00:00 2001 From: Max Isi Date: Fri, 18 Oct 2024 16:14:32 -0400 Subject: [PATCH] likelihood api --- example/GW150914_IMRPhenomPV2.py | 21 +++++- src/jimgw/single_event/data.py | 2 +- src/jimgw/single_event/likelihood.py | 101 +++++++++++++++++---------- 3 files changed, 82 insertions(+), 42 deletions(-) diff --git a/example/GW150914_IMRPhenomPV2.py b/example/GW150914_IMRPhenomPV2.py index 11936abc..7cf79424 100644 --- a/example/GW150914_IMRPhenomPV2.py +++ b/example/GW150914_IMRPhenomPV2.py @@ -26,6 +26,7 @@ ) from jimgw.single_event.utils import Mc_q_to_m1_m2 from flowMC.strategy.optimization import optimization_Adam +from jimgw.single_event import data as jd jax.config.update("jax_enable_x64", True) @@ -36,16 +37,30 @@ total_time_start = time.time() # first, fetch a 4s segment centered on GW150914 +# for the analysis gps = 1126259462.4 start = gps - 2 end = gps + 2 + +# fetch 4096s of data to estimate the PSD (to be +# careful we should avoid the on-source segment, +# but we don't do this in this example) +psd_start = gps - 2048 +psd_end = gps + 2048 + +# define frequency integration bounds for the likelihood fmin = 20.0 -fmax = 1024.0 +fmax = 1000.0 ifos = [H1, L1] -H1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) -L1.load_data(gps, 2, 2, fmin, fmax, psd_pad=16, tukey_alpha=0.2) +for ifo in ifos: + data = jd.Data.from_gwosc(ifo.name, start, end) + ifo.set_data(data) + + psd_data = jd.Data.from_gwosc(ifo.name, psd_start, psd_end) + psd_fftlength = data.duration * data.sampling_frequency + ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength)) waveform = RippleIMRPhenomPv2(f_ref=20) diff --git a/src/jimgw/single_event/data.py b/src/jimgw/single_event/data.py index af9deda8..bfe6b4d9 100644 --- a/src/jimgw/single_event/data.py +++ b/src/jimgw/single_event/data.py @@ -98,6 +98,7 @@ def __init__(self, td: Float[Array, " n_time"] = jnp.array([]), window: array, optional Window function to apply to the data before FFT (default: None) """ + self.name = name or '' self.td = td self.fd = jnp.zeros(self.n_freq) self.delta_t = delta_t @@ -106,7 +107,6 @@ def __init__(self, td: Float[Array, " n_time"] = jnp.array([]), self.set_tukey_window() else: self.window = window - self.name = name or '' def __repr__(self): return f"{self.__class__.__name__}(name='{self.name}', " + \ diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index fe280f1b..47bd591a 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -41,28 +41,42 @@ def __init__( self, detectors: list[Detector], waveform: Waveform, - f_min: float = 0, - f_max: float = float("inf"), - trigger_time: float = 0, - post_trigger_duration: float = 2, + f_min: Float = 0, + f_max: Float = float("inf"), + trigger_time: Float = 0, + post_trigger_duration: Float = 2, **kwargs, ) -> None: self.detectors = detectors - # TODO: we can probably make this a bit more elegant + # make sure data has a Fourier representation for det in detectors: if not det.data.has_fd: logging.info("Computing FFT with default window") det.data.fft() - det.set_frequency_bounds(f_min, f_max) - - freqs = [d.data.frequency_slice(f_min, f_max)[1] for d in detectors] - assert all([ - (freqs[0] - == freq).all() # noqa: W503 - for freq in freqs] - ), "The detectors must have the same frequency grid" + + # collect the data, psd and frequencies for the requested band + freqs = [] + datas = [] + psds = [] + for detector in detectors: + data, freq_0 = detector.data.frequency_slice(f_min, f_max) + psd, freq_1 = detector.psd.frequency_slice(f_min, f_max) + freqs.append(freq_0) + datas.append(data) + psds.append(psd) + # make sure the psd and data are consistent + assert (freq_0 == freq_1).all(), \ + f"The {detector.name} data and PSD must have same frequencies" + + # make sure all detectors are consistent + assert all([(freqs[0] == freq).all() for freq in freqs]), \ + "The detectors must have the same frequency grid" + self.frequencies = freqs[0] # type: ignore + self.datas = [d.data.frequency_slice(f_min, f_max)[0] for d in detectors] + self.psds = [d.psd.frequency_slice(f_min, f_max)[0] for d in detectors] + self.waveform = waveform self.trigger_time = trigger_time self.gmst = ( @@ -85,15 +99,15 @@ def __init__( if self.marginalization == "phase-time": self.param_func = lambda x: {**x, "phase_c": 0.0, "t_c": 0.0} self.likelihood_function = phase_time_marginalized_likelihood - print("Marginalizing over phase and time") + logging.info("Marginalizing over phase and time") elif self.marginalization == "time": self.param_func = lambda x: {**x, "t_c": 0.0} self.likelihood_function = time_marginalized_likelihood - print("Marginalizing over time") + logging.info("Marginalizing over time") elif self.marginalization == "phase": self.param_func = lambda x: {**x, "phase_c": 0.0} self.likelihood_function = phase_marginalized_likelihood - print("Marginalizing over phase") + logging.info("Marginalizing over phase") if "time" in self.marginalization: fs = kwargs["sampling_rate"] @@ -136,22 +150,19 @@ def __init__( @property def epoch(self): - """ - The epoch of the data. + """The epoch of the data. """ return self.duration - self.post_trigger_duration @property def ifos(self): - """ - The interferometers for the likelihood. + """The interferometers for the likelihood. """ return [detector.name for detector in self.detectors] def evaluate(self, params: dict[str, Float], data: dict) -> Float: # TODO: Test whether we need to pass data in or with class changes is fine. - """ - Evaluate the likelihood for a given set of parameters. + """Evaluate the likelihood for a given set of parameters. """ frequencies = self.frequencies params["gmst"] = self.gmst @@ -169,6 +180,8 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float: waveform_sky, self.detectors, frequencies, + self.datas, + self.psds, align_time, **self.kwargs, ) @@ -203,9 +216,10 @@ def __init__( self, detectors: list[Detector], waveform: Waveform, + f_min: Float = 0, + f_max: Float = float("inf"), n_bins: int = 100, trigger_time: float = 0, - duration: float = 4, post_trigger_duration: float = 2, popsize: int = 100, n_steps: int = 2000, @@ -217,10 +231,10 @@ def __init__( **kwargs, ) -> None: super().__init__( - detectors, waveform, trigger_time, duration, post_trigger_duration + detectors, waveform, f_min, f_max, trigger_time, post_trigger_duration ) - print("Initializing heterodyned likelihood..") + logging.info("Initializing heterodyned likelihood..") # Can use another waveform to use as reference waveform, but if not provided, use the same waveform if reference_waveform is None: @@ -299,7 +313,7 @@ def __init__( print("The eta of the reference parameter is close to 0.25") print(f"The eta is adjusted to {self.ref_params['eta']}") - print("Constructing reference waveforms..") + logging.info("Constructing reference waveforms..") self.ref_params["gmst"] = self.gmst # adjust the params due to different marginalzation scheme @@ -647,15 +661,16 @@ def original_likelihood( h_sky: dict[str, Float[Array, " n_dim"]], detectors: list[Detector], freqs: Float[Array, " n_dim"], + datas: list[Float[Array, " n_dim"]], + psds: list[Float[Array, " n_dim"]], align_time: Float, **kwargs, ) -> Float: log_likelihood = 0.0 df = freqs[1] - freqs[0] - for detector in detectors: + for detector, data, psd in zip(detectors, datas, psds): h_dec = detector.fd_response(freqs, h_sky, params) * align_time - data = detector.fd_data_slice - psd = detector.psd_slice + # NOTE: do we want to take the slide outside the likelihood? match_filter_SNR = ( 4 * jnp.sum((jnp.conj(h_dec) * data) / psd * df).real ) @@ -670,18 +685,22 @@ def phase_marginalized_likelihood( h_sky: dict[str, Float[Array, " n_dim"]], detectors: list[Detector], freqs: Float[Array, " n_dim"], + datas: list[Float[Array, " n_dim"]], + psds: list[Float[Array, " n_dim"]], align_time: Float, **kwargs, ) -> Float: log_likelihood = 0.0 complex_d_inner_h = 0.0 df = freqs[1] - freqs[0] - for detector in detectors: + f_min = freqs[0] + f_max = freqs[-1] + for detector, data, psd in zip(detectors, datas, psds): h_dec = detector.fd_response(freqs, h_sky, params) * align_time complex_d_inner_h += 4 * jnp.sum( - (jnp.conj(h_dec) * detector.data) / detector.psd * df + (jnp.conj(h_dec) * data) / psd * df ) - optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / psd * df).real log_likelihood += -optimal_SNR / 2 log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) @@ -694,17 +713,21 @@ def time_marginalized_likelihood( h_sky: dict[str, Float[Array, " n_dim"]], detectors: list[Detector], freqs: Float[Array, " n_dim"], + datas: list[Float[Array, " n_dim"]], + psds: list[Float[Array, " n_dim"]], align_time: Float, **kwargs, ) -> Float: log_likelihood = 0.0 df = freqs[1] - freqs[0] + f_min = freqs[0] + f_max = freqs[-1] # using instead of complex_h_inner_d = jnp.zeros_like(freqs) - for detector in detectors: + for detector, data, psd in zip(detectors, datas, psds): h_dec = detector.fd_response(freqs, h_sky, params) * align_time - complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df - optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real + complex_h_inner_d += 4 * h_dec * jnp.conj(data) / detector.psd * df + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / psd * df).real log_likelihood += -optimal_SNR / 2 # fetch the tc range tc_array, lower padding and higher padding @@ -743,6 +766,8 @@ def phase_time_marginalized_likelihood( h_sky: dict[str, Float[Array, " n_dim"]], detectors: list[Detector], freqs: Float[Array, " n_dim"], + datas: list[Float[Array, " n_dim"]], + psds: list[Float[Array, " n_dim"]], align_time: Float, **kwargs, ) -> Float: @@ -750,10 +775,10 @@ def phase_time_marginalized_likelihood( df = freqs[1] - freqs[0] # using instead of complex_h_inner_d = jnp.zeros_like(freqs) - for detector in detectors: + for detector, data, psd in zip(detectors, datas, psds): h_dec = detector.fd_response(freqs, h_sky, params) * align_time - complex_h_inner_d += 4 * h_dec * jnp.conj(detector.data) / detector.psd * df - optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real + complex_h_inner_d += 4 * h_dec * jnp.conj(data) / psd * df + optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / psd * df).real log_likelihood += -optimal_SNR / 2 # fetch the tc range tc_array, lower padding and higher padding