Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the option to choose preferred domain #4916

Merged
merged 5 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pycbc/inference/models/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,11 @@ def create_waveform_generator(
except KeyError:
raise ValueError("no approximant provided in the static args")

generator_function = generator_class.select_rframe_generator(approximant)
dm = static_params.get('preferred_domain', None)
if isinstance(dm, str) and dm.lower() == 'none':
dm = None

gen_function = generator_class.select_rframe_generator(approximant, dm)
# get data parameters; we'll just use one of the data to get the
# values, then check that all the others are the same
delta_f = None
Expand All @@ -1216,7 +1220,7 @@ def create_waveform_generator(
raise ValueError("data must all have the same delta_t, "
"delta_f, and start_time")
waveform_generator = generator_class(
generator_function, epoch=start_time,
gen_function, epoch=start_time,
variable_args=variable_params, detectors=list(data.keys()),
delta_f=delta_f, delta_t=delta_t,
recalib=recalibration, gates=gates,
Expand Down
111 changes: 66 additions & 45 deletions pycbc/waveform/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,11 +702,11 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_generator(approximant)
return select_waveform_generator(approximant, domain)


class FDomainDetFrameTwoPolGenerator(BaseFDomainDetFrameGenerator):
Expand Down Expand Up @@ -845,11 +845,11 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_generator(approximant)
return select_waveform_generator(approximant, domain)

class FDomainDetFrameTwoPolNoRespGenerator(BaseFDomainDetFrameGenerator):
"""Generates frequency-domain waveform in a specific frame.
Expand Down Expand Up @@ -944,11 +944,12 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_generator(approximant)
return select_waveform_generator(approximant, domain)


class FDomainDetFrameModesGenerator(BaseFDomainDetFrameGenerator):
"""Generates frequency-domain waveform modes in a specific frame.
Expand Down Expand Up @@ -1096,11 +1097,11 @@ def generate(self, **kwargs):
return h

@staticmethod
def select_rframe_generator(approximant):
def select_rframe_generator(approximant, domain):
"""Returns a radiation frame generator class based on the approximant
string.
"""
return select_waveform_modes_generator(approximant)
return select_waveform_modes_generator(approximant, domain)


class FDomainDirectDetFrameGenerator(BaseCBCGenerator):
Expand Down Expand Up @@ -1198,14 +1199,38 @@ def generate(self, **kwargs):
#


def select_waveform_generator(approximant):
# Updated code to make the logic clearer and simplify decision-making based on the domain

def get_td_generator(approximant, modes=False):
"""Returns the time-domain generator for the given approximant."""
if approximant in waveform.td_approximants():
return TDomainCBCModesGenerator if modes else TDomainCBCGenerator
kkacanja marked this conversation as resolved.
Show resolved Hide resolved
if approximant in ringdown.ringdown_td_approximants:
return TDomainMassSpinRingdownGenerator if approximant == 'TdQNMfromFinalMassSpin' else TDomainFreqTauRingdownGenerator
kkacanja marked this conversation as resolved.
Show resolved Hide resolved
if approximant in supernovae.supernovae_td_approximants:
return TDomainSupernovaeGenerator
raise ValueError(f"No time-domain generator found for approximant: {approximant}")

def get_fd_generator(approximant, modes=False):
"""Returns the frequency-domain generator for the given approximant."""
if approximant in waveform.fd_approximants():
return FDomainCBCModesGenerator if modes else FDomainCBCGenerator
if approximant in ringdown.ringdown_fd_approximants:
return FDomainMassSpinRingdownGenerator if approximant == 'FdQNMfromFinalMassSpin' else FDomainFreqTauRingdownGenerator
raise ValueError(f"No frequency-domain generator found for approximant: {approximant}")

kkacanja marked this conversation as resolved.
Show resolved Hide resolved

def select_waveform_generator(approximant, domain=None):
"""Returns the single-IFO generator for the approximant.

Parameters
----------
approximant : str
Name of waveform approximant. Valid names can be found using
``pycbc.waveform`` methods.
domain : str or None
Name of the preferred waveform domain
('td' for time domain, 'fd' for frequency domain, None for default)

Returns
-------
Expand All @@ -1225,51 +1250,47 @@ def select_waveform_generator(approximant):
>>> from pycbc.waveform.generator import select_waveform_generator
>>> select_waveform_generator(waveform.fd_approximants()[0])
"""
# check if frequency-domain CBC waveform
if approximant in waveform.fd_approximants():
return FDomainCBCGenerator
# check if time-domain CBC waveform
elif approximant in waveform.td_approximants():
return TDomainCBCGenerator
# check if frequency-domain ringdown waveform
elif approximant in ringdown.ringdown_fd_approximants:
if approximant == 'FdQNMfromFinalMassSpin':
return FDomainMassSpinRingdownGenerator
elif approximant == 'FdQNMfromFreqTau':
return FDomainFreqTauRingdownGenerator
elif approximant in ringdown.ringdown_td_approximants:
if approximant == 'TdQNMfromFinalMassSpin':
return TDomainMassSpinRingdownGenerator
elif approximant == 'TdQNMfromFreqTau':
return TDomainFreqTauRingdownGenerator
# check if supernovae waveform:
elif approximant in supernovae.supernovae_td_approximants:
if approximant == 'CoreCollapseBounce':
return TDomainSupernovaeGenerator
# otherwise waveform approximant is not supported
else:
raise ValueError("%s is not a valid approximant." % approximant)


def select_waveform_modes_generator(approximant):

if domain not in {None, 'td', 'fd'}:
raise ValueError(f"Invalid domain '{domain}'. Must be one of: None, 'td', or 'fd'.")

if domain == 'td':
return get_td_generator(approximant)
elif domain == 'fd':
return get_fd_generator(approximant)
elif domain is None:
try:
return get_fd_generator(approximant)
except ValueError:
return get_td_generator(approximant)

def select_waveform_modes_generator(approximant, domain=None):
"""Returns the single-IFO modes generator for the approximant.

Parameters
----------
approximant : str
Name of waveform approximant. Valid names can be found using
``pycbc.waveform`` methods.
domain : str or None
Name of the preferred waveform domain
('td' for time domain, 'fd' for frequency domain, None for default)

Returns
-------
generator : (PyCBC generator instance)
A waveform generator object.
A waveform modes generator object.
"""
# check if frequency-domain CBC waveform
if approximant in waveform.fd_approximants():
return FDomainCBCModesGenerator
# check if time-domain CBC waveform
elif approximant in waveform.td_approximants():
return TDomainCBCModesGenerator
# otherwise waveform approximant is not supported
raise ValueError("%s is not a valid approximant." % approximant)

if domain not in {None, 'td', 'fd'}:
raise ValueError(f"Invalid domain '{domain}'. Must be one of: None, 'td', or 'fd'.")
kkacanja marked this conversation as resolved.
Show resolved Hide resolved

if domain == 'td':
return get_td_generator(approximant, modes=True)
elif domain == 'fd':
return get_fd_generator(approximant, modes=True)
elif domain is None:
try:
return get_fd_generator(approximant, modes=True)
except ValueError:
return get_td_generator(approximant, modes=True)
Loading