diff --git a/src/tdastro/sources/sncomso_models.py b/src/tdastro/sources/sncomso_models.py index cadd68d1..19971ae4 100644 --- a/src/tdastro/sources/sncomso_models.py +++ b/src/tdastro/sources/sncomso_models.py @@ -4,6 +4,7 @@ https://sncosmo.readthedocs.io/en/stable/models.html """ +import numpy as np from astropy import units as u from sncosmo.models import get_source @@ -193,7 +194,15 @@ def compute_flux(self, times, wavelengths, graph_state=None, **kwargs): params = self.get_local_params(graph_state) self._update_sncosmo_model_parameters(graph_state) - flux_flam = self.source.flux(times - params["t0"], wavelengths) + # sncosmo gives an error if the wavelengths are out of bounds, so we truncate + # and fill the rest of the predictions with zero. + min_wave_idx = np.searchsorted(wavelengths, self.source.minwave(), side="left") + max_wave_idx = np.searchsorted(wavelengths, self.source.maxwave(), side="right") + + model_flam = self.source.flux(times - params["t0"], wavelengths[min_wave_idx:max_wave_idx]) + flux_flam = np.zeros((len(times), len(wavelengths))) + flux_flam[:, min_wave_idx:max_wave_idx] = model_flam + flux_fnu = flam_to_fnu( flux_flam, wavelengths, diff --git a/tests/tdastro/sources/test_sncosmo_models.py b/tests/tdastro/sources/test_sncosmo_models.py index 839b8258..444a12d7 100644 --- a/tests/tdastro/sources/test_sncosmo_models.py +++ b/tests/tdastro/sources/test_sncosmo_models.py @@ -72,6 +72,28 @@ def test_sncomso_models_hsiao_t0() -> None: assert np.array_equal(mask, expected_mask) +def test_sncomso_models_bounds() -> None: + """Test that we do not crash if we give wavelengths outside the model bounds.""" + model = SncosmoWrapperModel("nugent-sn1a", amplitude=2.0e10, t0=0.0) + min_w = model.source.minwave() + max_w = model.source.maxwave() + + wavelengths = [ + min_w - 100.0, # Out of bounds + min_w, # edge of bounds (included) + 0.5 * min_w + 0.5 * max_w, # included + max_w, # edge of bounds (included) + max_w + 0.1, # Out of bounds + max_w + 100.0, # Out of bounds + ] + + # Check that columns 0, 4, and 5 are all zeros and the other columns are not. + fluxes_fnu = model.evaluate([54990.0, 54990.5], wavelengths) + assert np.all(fluxes_fnu[:, 0] == 0.0) + assert not np.any(fluxes_fnu[:, 1:4] == 0.0) + assert np.all(fluxes_fnu[:, 4:6] == 0.0) + + def test_sncomso_models_set() -> None: """Test that we can create and evalue a 'hsiao' model and set parameter.""" model = SncosmoWrapperModel("hsiao", t0=0.0, redshift=0.5)