Skip to content

Commit

Permalink
Merge pull request #135 from lincc-frameworks/issue/113/passband_dir_…
Browse files Browse the repository at this point in the history
…unit_tests

Use LSST transmission test data for passband unit tests
  • Loading branch information
OliviaLynn authored Oct 1, 2024
2 parents ffad7c4 + 294781b commit 4a0ba5f
Show file tree
Hide file tree
Showing 10 changed files with 8,774 additions and 1,946 deletions.
62 changes: 39 additions & 23 deletions src/tdastro/astro_utils/passbands.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@ class PassbandGroup:
----------
passbands : dict of Passband
A dictionary of Passband objects, where the keys are the full_names of the passbands (eg, "LSST_u").
table_dir : str
The path to the directory containing the passband tables.
waves : np.ndarray
The union of all wavelengths in the passbands.
"""

def __init__(
self,
preset: str = None,
passband_parameters: list = None,
passband_parameters: Optional[list] = None,
table_dir: Optional[str] = None,
given_passbands: list = None,
**kwargs,
):
Expand All @@ -35,7 +38,7 @@ def __init__(
Parameters
----------
preset : str, optional
A pre-defined set of passbands to load.
A pre-defined set of passbands to load. If using a preset, passband_parameters will be ignored.
passband_parameters : list of dict, optional
A list of dictionaries of passband parameters used to create Passband objects.
Each dictionary must contain the following:
Expand All @@ -49,6 +52,10 @@ def __init__(
- units : str (either 'nm' or 'A')
If survey is not LSST (or other survey with defined defaults), either a table_path or table_url
must be provided.
table_dir : str, optional
The path to the directory containing the passband tables. If a table_path has not been specified
in the passband_parameters dictionary, table paths will be set to
{table_dir}/{survey}/{filter_name}.dat. If None, the table path will be set to a default path.
given_passbands : list, optional
A list of Passband objects from which to create the PassbandGroup. These
overwrite any passbands with the same full name provided by either
Expand All @@ -65,13 +72,21 @@ def __init__(
)

if preset is not None:
self._load_preset(preset, **kwargs)
self._load_preset(preset, table_dir=table_dir, **kwargs)

elif passband_parameters is not None:
for parameters in passband_parameters:
# Add any missing parameters from kwargs
for key, value in kwargs.items():
if key not in parameters:
parameters[key] = value

# Set the table path if it is not already set and a table_dir is provided
if "table_path" not in parameters and table_dir is not None:
parameters["table_path"] = os.path.join(
table_dir, parameters["survey"], f"{parameters['filter_name']}.dat"
)

passband = Passband(**parameters)
self.passbands[passband.full_name] = passband

Expand All @@ -94,25 +109,32 @@ def __str__(self) -> str:
def __len__(self) -> int:
return len(self.passbands)

def _load_preset(self, preset: str, **kwargs) -> None:
def _load_preset(self, preset: str, table_dir: Optional[str], **kwargs) -> None:
"""Load a pre-defined set of passbands.
Parameters
----------
preset : str
The name of the pre-defined set of passbands to load.
table_dir : str, optional
The path to the directory containing the passband tables. If no table_path has been specified in
the PassbandGroup's passband_parameters and table_dir is not None, table paths will be set to
table_dir/{survey}/{filter_name}.dat.
**kwargs
Additional keyword arguments to pass to the Passband constructor.
"""
if preset == "LSST":
self.passbands = {
"LSST_u": Passband("LSST", "u", **kwargs),
"LSST_g": Passband("LSST", "g", **kwargs),
"LSST_r": Passband("LSST", "r", **kwargs),
"LSST_i": Passband("LSST", "i", **kwargs),
"LSST_z": Passband("LSST", "z", **kwargs),
"LSST_y": Passband("LSST", "y", **kwargs),
}
for filter_name in ["u", "g", "r", "i", "z", "y"]:
if table_dir is None:
self.passbands[f"LSST_{filter_name}"] = Passband("LSST", filter_name, **kwargs)
else:
table_path = os.path.join(table_dir, "LSST", f"{filter_name}.dat")
self.passbands[f"LSST_{filter_name}"] = Passband(
"LSST",
filter_name,
table_path=table_path,
**kwargs,
)
else:
raise ValueError(f"Unknown passband preset: {preset}")

Expand Down Expand Up @@ -157,7 +179,7 @@ def _calculate_in_band_wave_indices(self) -> None:
# do not happen to be on the same phase of the grid; eg, even if the step is 10, if the first
# passband starts at 100 and the second at 105, the passbands won't share the same grid)
if np.array_equal(self.waves[lower_index : upper_index + 1], passband.waves):
indices = (lower_index, upper_index + 1)
indices = slice(lower_index, upper_index + 1)
else:
indices = np.searchsorted(self.waves, passband.waves)
passband._in_band_wave_indices = indices
Expand Down Expand Up @@ -213,10 +235,7 @@ def fluxes_to_bandfluxes(self, flux_density_matrix: np.ndarray) -> np.ndarray:
"This should have been calculated in PassbandGroup._calculate_in_band_wave_indices."
)

if isinstance(passband._in_band_wave_indices, tuple):
in_band_fluxes = flux_density_matrix[:, indices[0] : indices[1]]
else:
in_band_fluxes = flux_density_matrix[:, indices]
in_band_fluxes = flux_density_matrix[:, indices]

bandfluxes[full_name] = passband.fluxes_to_bandflux(in_band_fluxes)
return bandfluxes
Expand Down Expand Up @@ -348,9 +367,10 @@ def _load_transmission_table(self, force_download: bool = False) -> None:
self.table_path = os.path.join(
os.path.dirname(__file__), f"passbands/{self.survey}/{self.filter_name}.dat"
)
os.makedirs(os.path.dirname(self.table_path), exist_ok=True)
os.makedirs(os.path.dirname(self.table_path), exist_ok=True)
if force_download or not os.path.exists(self.table_path):
self._download_transmission_table()

# Load the table file
try:
loaded_table = np.loadtxt(self.table_path)
Expand Down Expand Up @@ -380,11 +400,7 @@ def _download_transmission_table(self) -> bool:
"""
if self.table_url is None:
if self.survey == "LSST":
# TODO switch to files at: https://github.com/lsst/throughputs/blob/main/baseline/total_g.dat
self.table_url = (
f"http://svo2.cab.inta-csic.es/svo/theory/fps3/getdata.php"
f"?format=ascii&id=LSST/LSST.{self.filter_name}"
)
self.table_url = f"https://github.com/lsst/throughputs/blob/main/baseline/total_{self.filter_name}.dat?raw=true"
else:
raise NotImplementedError(
f"Transmission table download is not yet implemented for survey: {self.survey}."
Expand Down
Loading

0 comments on commit 4a0ba5f

Please sign in to comment.