Skip to content

Commit

Permalink
Adding the option to choose preferred domain (gwastro#4916)
Browse files Browse the repository at this point in the history
* Adding the option to choose preferred domain

* Updated logic

* Syntax fix

* Fixed imports

* Stylistic changes

---------

Co-authored-by: kkacanja <[email protected]>
  • Loading branch information
2 people authored and prayush committed Nov 21, 2024
1 parent 7f44c34 commit 17c3dd8
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 47 deletions.
8 changes: 6 additions & 2 deletions pycbc/inference/models/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,11 @@ def create_waveform_generator(
except KeyError:
raise ValueError("no approximant provided in the static args")

generator_function = generator_class.select_rframe_generator(approximant)
dm = static_params.get('preferred_domain', None)
if isinstance(dm, str) and dm.lower() == 'none':
dm = None

gen_function = generator_class.select_rframe_generator(approximant, dm)
# get data parameters; we'll just use one of the data to get the
# values, then check that all the others are the same
delta_f = None
Expand All @@ -1216,7 +1220,7 @@ def create_waveform_generator(
raise ValueError("data must all have the same delta_t, "
"delta_f, and start_time")
waveform_generator = generator_class(
generator_function, epoch=start_time,
gen_function, epoch=start_time,
variable_args=variable_params, detectors=list(data.keys()),
delta_f=delta_f, delta_t=delta_t,
recalib=recalibration, gates=gates,
Expand Down
124 changes: 79 additions & 45 deletions pycbc/waveform/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,11 +702,11 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_generator(approximant)
return select_waveform_generator(approximant, domain)


class FDomainDetFrameTwoPolGenerator(BaseFDomainDetFrameGenerator):
Expand Down Expand Up @@ -845,11 +845,11 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_generator(approximant)
return select_waveform_generator(approximant, domain)

class FDomainDetFrameTwoPolNoRespGenerator(BaseFDomainDetFrameGenerator):
"""Generates frequency-domain waveform in a specific frame.
Expand Down Expand Up @@ -944,11 +944,12 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_generator(approximant)
return select_waveform_generator(approximant, domain)


class FDomainDetFrameModesGenerator(BaseFDomainDetFrameGenerator):
"""Generates frequency-domain waveform modes in a specific frame.
Expand Down Expand Up @@ -1096,11 +1097,11 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_modes_generator(approximant)
return select_waveform_modes_generator(approximant, domain)


class FDomainDirectDetFrameGenerator(BaseCBCGenerator):
Expand Down Expand Up @@ -1197,15 +1198,50 @@ def generate(self, **kwargs):
# =============================================================================
#

def get_td_generator(approximant, modes=False):
"""Returns the time-domain generator for the given approximant."""
if approximant in waveform.td_approximants():
if modes:
return TDomainCBCModesGenerator
return TDomainCBCGenerator

if approximant in ringdown.ringdown_td_approximants:
if approximant == 'TdQNMfromFinalMassSpin':
return TDomainMassSpinRingdownGenerator
return TDomainFreqTauRingdownGenerator

if approximant in supernovae.supernovae_td_approximants:
return TDomainSupernovaeGenerator

raise ValueError(f"No time-domain generator found for "
"approximant: {approximant}")

def get_fd_generator(approximant, modes=False):
"""Returns the frequency-domain generator for the given approximant."""
if approximant in waveform.fd_approximants():
if modes:
return FDomainCBCModesGenerator
return FDomainCBCGenerator

if approximant in ringdown.ringdown_fd_approximants:
if approximant == 'FdQNMfromFinalMassSpin':
return FDomainMassSpinRingdownGenerator
return FDomainFreqTauRingdownGenerator

raise ValueError(f"No frequency-domain generator found for "
"approximant: {approximant}")

def select_waveform_generator(approximant):
def select_waveform_generator(approximant, domain=None):
"""Returns the single-IFO generator for the approximant.
Parameters
----------
approximant : str
Name of waveform approximant. Valid names can be found using
``pycbc.waveform`` methods.
domain : str or None
Name of the preferred waveform domain
('td' for time domain, 'fd' for frequency domain, None for default)
Returns
-------
Expand All @@ -1225,51 +1261,49 @@ def select_waveform_generator(approximant):
>>> from pycbc.waveform.generator import select_waveform_generator
>>> select_waveform_generator(waveform.fd_approximants()[0])
"""
# check if frequency-domain CBC waveform
if approximant in waveform.fd_approximants():
return FDomainCBCGenerator
# check if time-domain CBC waveform
elif approximant in waveform.td_approximants():
return TDomainCBCGenerator
# check if frequency-domain ringdown waveform
elif approximant in ringdown.ringdown_fd_approximants:
if approximant == 'FdQNMfromFinalMassSpin':
return FDomainMassSpinRingdownGenerator
elif approximant == 'FdQNMfromFreqTau':
return FDomainFreqTauRingdownGenerator
elif approximant in ringdown.ringdown_td_approximants:
if approximant == 'TdQNMfromFinalMassSpin':
return TDomainMassSpinRingdownGenerator
elif approximant == 'TdQNMfromFreqTau':
return TDomainFreqTauRingdownGenerator
# check if supernovae waveform:
elif approximant in supernovae.supernovae_td_approximants:
if approximant == 'CoreCollapseBounce':
return TDomainSupernovaeGenerator
# otherwise waveform approximant is not supported
else:
raise ValueError("%s is not a valid approximant." % approximant)


def select_waveform_modes_generator(approximant):

if domain not in {None, 'td', 'fd'}:
raise ValueError(f"Invalid domain '{domain}'. "
"Must be one of: None, 'td', or 'fd'.")

if domain == 'td':
return get_td_generator(approximant)
elif domain == 'fd':
return get_fd_generator(approximant)
elif domain is None:
try:
return get_fd_generator(approximant)
except ValueError:
return get_td_generator(approximant)

def select_waveform_modes_generator(approximant, domain=None):
"""Returns the single-IFO modes generator for the approximant.
Parameters
----------
approximant : str
Name of waveform approximant. Valid names can be found using
``pycbc.waveform`` methods.
domain : str or None
Name of the preferred waveform domain
('td' for time domain, 'fd' for frequency domain, None for default)
Returns
-------
generator : (PyCBC generator instance)
A waveform generator object.
A waveform modes generator object.
"""
# check if frequency-domain CBC waveform
if approximant in waveform.fd_approximants():
return FDomainCBCModesGenerator
# check if time-domain CBC waveform
elif approximant in waveform.td_approximants():
return TDomainCBCModesGenerator
# otherwise waveform approximant is not supported
raise ValueError("%s is not a valid approximant." % approximant)

if domain not in {None, 'td', 'fd'}:
raise ValueError(f"Invalid domain '{domain}'. "
"Must be one of: None, 'td', or 'fd'.")

if domain == 'td':
return get_td_generator(approximant, modes=True)
elif domain == 'fd':
return get_fd_generator(approximant, modes=True)
elif domain is None:
try:
return get_fd_generator(approximant, modes=True)
except ValueError:
return get_td_generator(approximant, modes=True)

0 comments on commit 17c3dd8

Please sign in to comment.