Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LSST transmission test data for passband unit tests #135

Merged
merged 14 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 39 additions & 23 deletions src/tdastro/astro_utils/passbands.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@ class PassbandGroup:
----------
passbands : dict of Passband
A dictionary of Passband objects, where the keys are the full_names of the passbands (eg, "LSST_u").
table_dir : str
The path to the directory containing the passband tables.
waves : np.ndarray
The union of all wavelengths in the passbands.
"""

def __init__(
self,
preset: str = None,
passband_parameters: list = None,
passband_parameters: Optional[list] = None,
table_dir: Optional[str] = None,
given_passbands: list = None,
**kwargs,
):
Expand All @@ -35,7 +38,7 @@ def __init__(
Parameters
----------
preset : str, optional
A pre-defined set of passbands to load.
A pre-defined set of passbands to load. If using a preset, passband_parameters will be ignored.
passband_parameters : list of dict, optional
A list of dictionaries of passband parameters used to create Passband objects.
Each dictionary must contain the following:
Expand All @@ -49,6 +52,10 @@ def __init__(
- units : str (either 'nm' or 'A')
If survey is not LSST (or other survey with defined defaults), either a table_path or table_url
must be provided.
table_dir : str, optional
The path to the directory containing the passband tables. If a table_path has not been specified
in the passband_parameters dictionary, table paths will be set to
{table_dir}/{survey}/{filter_name}.dat. If None, the table path will be set to a default path.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No change needed right now, but I had been thinking of this as "{table_dir}/{filter_name}.dat" so the user could effectively provide both the base directory and the survey name with the table_dir parameter.

This current approach better matches the presets below though.

given_passbands : list, optional
A list of Passband objects from which to create the PassbandGroup. These
overwrite any passbands with the same full name provided by either
Expand All @@ -65,13 +72,21 @@ def __init__(
)

if preset is not None:
self._load_preset(preset, **kwargs)
self._load_preset(preset, table_dir=table_dir, **kwargs)

elif passband_parameters is not None:
for parameters in passband_parameters:
# Add any missing parameters from kwargs
for key, value in kwargs.items():
if key not in parameters:
parameters[key] = value

# Set the table path if it is not already set and a table_dir is provided
if "table_path" not in parameters and table_dir is not None:
parameters["table_path"] = os.path.join(
table_dir, parameters["survey"], f"{parameters['filter_name']}.dat"
)

passband = Passband(**parameters)
self.passbands[passband.full_name] = passband

Expand All @@ -94,25 +109,32 @@ def __str__(self) -> str:
def __len__(self) -> int:
return len(self.passbands)

def _load_preset(self, preset: str, **kwargs) -> None:
def _load_preset(self, preset: str, table_dir: Optional[str], **kwargs) -> None:
"""Load a pre-defined set of passbands.

Parameters
----------
preset : str
The name of the pre-defined set of passbands to load.
table_dir : str, optional
The path to the directory containing the passband tables. If no table_path has been specified in
the PassbandGroup's passband_parameters and table_dir is not None, table paths will be set to
table_dir/{survey}/{filter_name}.dat.
**kwargs
Additional keyword arguments to pass to the Passband constructor.
"""
if preset == "LSST":
self.passbands = {
"LSST_u": Passband("LSST", "u", **kwargs),
"LSST_g": Passband("LSST", "g", **kwargs),
"LSST_r": Passband("LSST", "r", **kwargs),
"LSST_i": Passband("LSST", "i", **kwargs),
"LSST_z": Passband("LSST", "z", **kwargs),
"LSST_y": Passband("LSST", "y", **kwargs),
}
for filter_name in ["u", "g", "r", "i", "z", "y"]:
if table_dir is None:
self.passbands[f"LSST_{filter_name}"] = Passband("LSST", filter_name, **kwargs)
else:
table_path = os.path.join(table_dir, "LSST", f"{filter_name}.dat")
self.passbands[f"LSST_{filter_name}"] = Passband(
"LSST",
filter_name,
table_path=table_path,
**kwargs,
)
else:
raise ValueError(f"Unknown passband preset: {preset}")

Expand Down Expand Up @@ -157,7 +179,7 @@ def _calculate_in_band_wave_indices(self) -> None:
# do not happen to be on the same phase of the grid; eg, even if the step is 10, if the first
# passband starts at 100 and the second at 105, the passbands won't share the same grid)
if np.array_equal(self.waves[lower_index : upper_index + 1], passband.waves):
indices = (lower_index, upper_index + 1)
indices = slice(lower_index, upper_index + 1)
else:
indices = np.searchsorted(self.waves, passband.waves)
passband._in_band_wave_indices = indices
Expand Down Expand Up @@ -213,10 +235,7 @@ def fluxes_to_bandfluxes(self, flux_density_matrix: np.ndarray) -> np.ndarray:
"This should have been calculated in PassbandGroup._calculate_in_band_wave_indices."
)

if isinstance(passband._in_band_wave_indices, tuple):
in_band_fluxes = flux_density_matrix[:, indices[0] : indices[1]]
else:
in_band_fluxes = flux_density_matrix[:, indices]
in_band_fluxes = flux_density_matrix[:, indices]

bandfluxes[full_name] = passband.fluxes_to_bandflux(in_band_fluxes)
return bandfluxes
Expand Down Expand Up @@ -348,9 +367,10 @@ def _load_transmission_table(self, force_download: bool = False) -> None:
self.table_path = os.path.join(
os.path.dirname(__file__), f"passbands/{self.survey}/{self.filter_name}.dat"
)
os.makedirs(os.path.dirname(self.table_path), exist_ok=True)
os.makedirs(os.path.dirname(self.table_path), exist_ok=True)
if force_download or not os.path.exists(self.table_path):
self._download_transmission_table()

# Load the table file
try:
loaded_table = np.loadtxt(self.table_path)
Expand Down Expand Up @@ -380,11 +400,7 @@ def _download_transmission_table(self) -> bool:
"""
if self.table_url is None:
if self.survey == "LSST":
# TODO switch to files at: https://github.com/lsst/throughputs/blob/main/baseline/total_g.dat
self.table_url = (
f"http://svo2.cab.inta-csic.es/svo/theory/fps3/getdata.php"
f"?format=ascii&id=LSST/LSST.{self.filter_name}"
)
self.table_url = f"https://github.com/lsst/throughputs/blob/main/baseline/total_{self.filter_name}.dat?raw=true"
else:
raise NotImplementedError(
f"Transmission table download is not yet implemented for survey: {self.survey}."
Expand Down
Loading
Loading