Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the option to choose preferred domain #4916

Merged
merged 5 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pycbc/inference/models/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,11 @@ def create_waveform_generator(
except KeyError:
raise ValueError("no approximant provided in the static args")

generator_function = generator_class.select_rframe_generator(approximant)
dm = static_params.get('preferred_domain', None)
if isinstance(dm, str) and dm.lower() == 'none':
dm = None

gen_function = generator_class.select_rframe_generator(approximant, dm)
# get data parameters; we'll just use one of the data to get the
# values, then check that all the others are the same
delta_f = None
Expand All @@ -1216,7 +1220,7 @@ def create_waveform_generator(
raise ValueError("data must all have the same delta_t, "
"delta_f, and start_time")
waveform_generator = generator_class(
generator_function, epoch=start_time,
gen_function, epoch=start_time,
variable_args=variable_params, detectors=list(data.keys()),
delta_f=delta_f, delta_t=delta_t,
recalib=recalibration, gates=gates,
Expand Down
124 changes: 79 additions & 45 deletions pycbc/waveform/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,11 +702,11 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_generator(approximant)
return select_waveform_generator(approximant, domain)


class FDomainDetFrameTwoPolGenerator(BaseFDomainDetFrameGenerator):
Expand Down Expand Up @@ -845,11 +845,11 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_generator(approximant)
return select_waveform_generator(approximant, domain)

class FDomainDetFrameTwoPolNoRespGenerator(BaseFDomainDetFrameGenerator):
"""Generates frequency-domain waveform in a specific frame.
Expand Down Expand Up @@ -944,11 +944,12 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_generator(approximant)
return select_waveform_generator(approximant, domain)


class FDomainDetFrameModesGenerator(BaseFDomainDetFrameGenerator):
"""Generates frequency-domain waveform modes in a specific frame.
Expand Down Expand Up @@ -1096,11 +1097,11 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_modes_generator(approximant)
return select_waveform_modes_generator(approximant, domain)


class FDomainDirectDetFrameGenerator(BaseCBCGenerator):
Expand Down Expand Up @@ -1197,15 +1198,50 @@ def generate(self, **kwargs):
# =============================================================================
#

def get_td_generator(approximant, modes=False):
"""Returns the time-domain generator for the given approximant."""
if approximant in waveform.td_approximants():
if modes:
return TDomainCBCModesGenerator
return TDomainCBCGenerator

if approximant in ringdown.ringdown_td_approximants:
if approximant == 'TdQNMfromFinalMassSpin':
return TDomainMassSpinRingdownGenerator
return TDomainFreqTauRingdownGenerator

if approximant in supernovae.supernovae_td_approximants:
return TDomainSupernovaeGenerator

raise ValueError(f"No time-domain generator found for "
"approximant: {approximant}")

def get_fd_generator(approximant, modes=False):
"""Returns the frequency-domain generator for the given approximant."""
if approximant in waveform.fd_approximants():
if modes:
return FDomainCBCModesGenerator
return FDomainCBCGenerator

if approximant in ringdown.ringdown_fd_approximants:
if approximant == 'FdQNMfromFinalMassSpin':
return FDomainMassSpinRingdownGenerator
return FDomainFreqTauRingdownGenerator

kkacanja marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"No frequency-domain generator found for "
"approximant: {approximant}")

def select_waveform_generator(approximant):
def select_waveform_generator(approximant, domain=None):
"""Returns the single-IFO generator for the approximant.

Parameters
----------
approximant : str
Name of waveform approximant. Valid names can be found using
``pycbc.waveform`` methods.
domain : str or None
Name of the preferred waveform domain
('td' for time domain, 'fd' for frequency domain, None for default)

Returns
-------
Expand All @@ -1225,51 +1261,49 @@ def select_waveform_generator(approximant):
>>> from pycbc.waveform.generator import select_waveform_generator
>>> select_waveform_generator(waveform.fd_approximants()[0])
"""
# check if frequency-domain CBC waveform
if approximant in waveform.fd_approximants():
return FDomainCBCGenerator
# check if time-domain CBC waveform
elif approximant in waveform.td_approximants():
return TDomainCBCGenerator
# check if frequency-domain ringdown waveform
elif approximant in ringdown.ringdown_fd_approximants:
if approximant == 'FdQNMfromFinalMassSpin':
return FDomainMassSpinRingdownGenerator
elif approximant == 'FdQNMfromFreqTau':
return FDomainFreqTauRingdownGenerator
elif approximant in ringdown.ringdown_td_approximants:
if approximant == 'TdQNMfromFinalMassSpin':
return TDomainMassSpinRingdownGenerator
elif approximant == 'TdQNMfromFreqTau':
return TDomainFreqTauRingdownGenerator
# check if supernovae waveform:
elif approximant in supernovae.supernovae_td_approximants:
if approximant == 'CoreCollapseBounce':
return TDomainSupernovaeGenerator
# otherwise waveform approximant is not supported
else:
raise ValueError("%s is not a valid approximant." % approximant)


def select_waveform_modes_generator(approximant):

if domain not in {None, 'td', 'fd'}:
raise ValueError(f"Invalid domain '{domain}'. "
"Must be one of: None, 'td', or 'fd'.")

if domain == 'td':
return get_td_generator(approximant)
elif domain == 'fd':
return get_fd_generator(approximant)
elif domain is None:
try:
return get_fd_generator(approximant)
except ValueError:
return get_td_generator(approximant)

def select_waveform_modes_generator(approximant, domain=None):
"""Returns the single-IFO modes generator for the approximant.

Parameters
----------
approximant : str
Name of waveform approximant. Valid names can be found using
``pycbc.waveform`` methods.
domain : str or None
Name of the preferred waveform domain
('td' for time domain, 'fd' for frequency domain, None for default)

Returns
-------
generator : (PyCBC generator instance)
A waveform generator object.
A waveform modes generator object.
"""
# check if frequency-domain CBC waveform
if approximant in waveform.fd_approximants():
return FDomainCBCModesGenerator
# check if time-domain CBC waveform
elif approximant in waveform.td_approximants():
return TDomainCBCModesGenerator
# otherwise waveform approximant is not supported
raise ValueError("%s is not a valid approximant." % approximant)

if domain not in {None, 'td', 'fd'}:
raise ValueError(f"Invalid domain '{domain}'. "
"Must be one of: None, 'td', or 'fd'.")

if domain == 'td':
return get_td_generator(approximant, modes=True)
elif domain == 'fd':
return get_fd_generator(approximant, modes=True)
elif domain is None:
try:
return get_fd_generator(approximant, modes=True)
except ValueError:
return get_td_generator(approximant, modes=True)
Loading