From acecd23fb69347f420ec75d1f93072552ab7ad6c Mon Sep 17 00:00:00 2001 From: kkacanja <123669569+kkacanja@users.noreply.github.com> Date: Mon, 18 Nov 2024 10:40:41 -0500 Subject: [PATCH] Adding the option to choose preferred domain (#4916) * Adding the option to choose preferred domain * Updated logic * Syntax fix * Fixed imports * Stylistic changes --------- Co-authored-by: kkacanja --- pycbc/inference/models/gaussian_noise.py | 8 +- pycbc/waveform/generator.py | 124 +++++++++++++++-------- 2 files changed, 85 insertions(+), 47 deletions(-) diff --git a/pycbc/inference/models/gaussian_noise.py b/pycbc/inference/models/gaussian_noise.py index 418fdb66cc5..5844f58b956 100644 --- a/pycbc/inference/models/gaussian_noise.py +++ b/pycbc/inference/models/gaussian_noise.py @@ -1201,7 +1201,11 @@ def create_waveform_generator( except KeyError: raise ValueError("no approximant provided in the static args") - generator_function = generator_class.select_rframe_generator(approximant) + dm = static_params.get('preferred_domain', None) + if isinstance(dm, str) and dm.lower() == 'none': + dm = None + + gen_function = generator_class.select_rframe_generator(approximant, dm) # get data parameters; we'll just use one of the data to get the # values, then check that all the others are the same delta_f = None @@ -1216,7 +1220,7 @@ def create_waveform_generator( raise ValueError("data must all have the same delta_t, " "delta_f, and start_time") waveform_generator = generator_class( - generator_function, epoch=start_time, + gen_function, epoch=start_time, variable_args=variable_params, detectors=list(data.keys()), delta_f=delta_f, delta_t=delta_t, recalib=recalibration, gates=gates, diff --git a/pycbc/waveform/generator.py b/pycbc/waveform/generator.py index ff669aede72..ec421a81f57 100644 --- a/pycbc/waveform/generator.py +++ b/pycbc/waveform/generator.py @@ -702,11 +702,11 @@ def generate(self, **kwargs): return h @staticmethod - def select_rframe_generator(approximant): + def select_rframe_generator(approximant, domain): """Returns a radiation frame generator class based on the approximant string. """ - return select_waveform_generator(approximant) + return select_waveform_generator(approximant, domain) class FDomainDetFrameTwoPolGenerator(BaseFDomainDetFrameGenerator): @@ -845,11 +845,11 @@ def generate(self, **kwargs): return h @staticmethod - def select_rframe_generator(approximant): + def select_rframe_generator(approximant, domain): """Returns a radiation frame generator class based on the approximant string. """ - return select_waveform_generator(approximant) + return select_waveform_generator(approximant, domain) class FDomainDetFrameTwoPolNoRespGenerator(BaseFDomainDetFrameGenerator): """Generates frequency-domain waveform in a specific frame. @@ -944,11 +944,12 @@ def generate(self, **kwargs): return h @staticmethod - def select_rframe_generator(approximant): + def select_rframe_generator(approximant, domain): """Returns a radiation frame generator class based on the approximant string. """ - return select_waveform_generator(approximant) + return select_waveform_generator(approximant, domain) + class FDomainDetFrameModesGenerator(BaseFDomainDetFrameGenerator): """Generates frequency-domain waveform modes in a specific frame. @@ -1096,11 +1097,11 @@ def generate(self, **kwargs): return h @staticmethod - def select_rframe_generator(approximant): + def select_rframe_generator(approximant, domain): """Returns a radiation frame generator class based on the approximant string. """ - return select_waveform_modes_generator(approximant) + return select_waveform_modes_generator(approximant, domain) class FDomainDirectDetFrameGenerator(BaseCBCGenerator): @@ -1197,8 +1198,40 @@ def generate(self, **kwargs): # ============================================================================= # +def get_td_generator(approximant, modes=False): + """Returns the time-domain generator for the given approximant.""" + if approximant in waveform.td_approximants(): + if modes: + return TDomainCBCModesGenerator + return TDomainCBCGenerator + + if approximant in ringdown.ringdown_td_approximants: + if approximant == 'TdQNMfromFinalMassSpin': + return TDomainMassSpinRingdownGenerator + return TDomainFreqTauRingdownGenerator + + if approximant in supernovae.supernovae_td_approximants: + return TDomainSupernovaeGenerator + + raise ValueError(f"No time-domain generator found for " + "approximant: {approximant}") + +def get_fd_generator(approximant, modes=False): + """Returns the frequency-domain generator for the given approximant.""" + if approximant in waveform.fd_approximants(): + if modes: + return FDomainCBCModesGenerator + return FDomainCBCGenerator + + if approximant in ringdown.ringdown_fd_approximants: + if approximant == 'FdQNMfromFinalMassSpin': + return FDomainMassSpinRingdownGenerator + return FDomainFreqTauRingdownGenerator + + raise ValueError(f"No frequency-domain generator found for " + "approximant: {approximant}") -def select_waveform_generator(approximant): +def select_waveform_generator(approximant, domain=None): """Returns the single-IFO generator for the approximant. Parameters @@ -1206,6 +1239,9 @@ def select_waveform_generator(approximant): approximant : str Name of waveform approximant. Valid names can be found using ``pycbc.waveform`` methods. + domain : str or None + Name of the preferred waveform domain + ('td' for time domain, 'fd' for frequency domain, None for default) Returns ------- @@ -1225,33 +1261,22 @@ def select_waveform_generator(approximant): >>> from pycbc.waveform.generator import select_waveform_generator >>> select_waveform_generator(waveform.fd_approximants()[0]) """ - # check if frequency-domain CBC waveform - if approximant in waveform.fd_approximants(): - return FDomainCBCGenerator - # check if time-domain CBC waveform - elif approximant in waveform.td_approximants(): - return TDomainCBCGenerator - # check if frequency-domain ringdown waveform - elif approximant in ringdown.ringdown_fd_approximants: - if approximant == 'FdQNMfromFinalMassSpin': - return FDomainMassSpinRingdownGenerator - elif approximant == 'FdQNMfromFreqTau': - return FDomainFreqTauRingdownGenerator - elif approximant in ringdown.ringdown_td_approximants: - if approximant == 'TdQNMfromFinalMassSpin': - return TDomainMassSpinRingdownGenerator - elif approximant == 'TdQNMfromFreqTau': - return TDomainFreqTauRingdownGenerator - # check if supernovae waveform: - elif approximant in supernovae.supernovae_td_approximants: - if approximant == 'CoreCollapseBounce': - return TDomainSupernovaeGenerator - # otherwise waveform approximant is not supported - else: - raise ValueError("%s is not a valid approximant." % approximant) - - -def select_waveform_modes_generator(approximant): + + if domain not in {None, 'td', 'fd'}: + raise ValueError(f"Invalid domain '{domain}'. " + "Must be one of: None, 'td', or 'fd'.") + + if domain == 'td': + return get_td_generator(approximant) + elif domain == 'fd': + return get_fd_generator(approximant) + elif domain is None: + try: + return get_fd_generator(approximant) + except ValueError: + return get_td_generator(approximant) + +def select_waveform_modes_generator(approximant, domain=None): """Returns the single-IFO modes generator for the approximant. Parameters @@ -1259,17 +1284,26 @@ def select_waveform_modes_generator(approximant): approximant : str Name of waveform approximant. Valid names can be found using ``pycbc.waveform`` methods. + domain : str or None + Name of the preferred waveform domain + ('td' for time domain, 'fd' for frequency domain, None for default) Returns ------- generator : (PyCBC generator instance) - A waveform generator object. + A waveform modes generator object. """ - # check if frequency-domain CBC waveform - if approximant in waveform.fd_approximants(): - return FDomainCBCModesGenerator - # check if time-domain CBC waveform - elif approximant in waveform.td_approximants(): - return TDomainCBCModesGenerator - # otherwise waveform approximant is not supported - raise ValueError("%s is not a valid approximant." % approximant) + + if domain not in {None, 'td', 'fd'}: + raise ValueError(f"Invalid domain '{domain}'. " + "Must be one of: None, 'td', or 'fd'.") + + if domain == 'td': + return get_td_generator(approximant, modes=True) + elif domain == 'fd': + return get_fd_generator(approximant, modes=True) + elif domain is None: + try: + return get_fd_generator(approximant, modes=True) + except ValueError: + return get_td_generator(approximant, modes=True)