diff --git a/glue_astronomy/translators/spectrum1d.py b/glue_astronomy/translators/spectrum1d.py index 42c22a9..e2fa102 100644 --- a/glue_astronomy/translators/spectrum1d.py +++ b/glue_astronomy/translators/spectrum1d.py @@ -7,6 +7,7 @@ from astropy import units as u from astropy.wcs import WCSSUB_SPECTRAL from astropy.nddata import StdDevUncertainty, InverseVariance, VarianceUncertainty +from gwcs import WCS as GWCS from glue_astronomy.spectral_coordinates import SpectralCoordinates @@ -21,26 +22,44 @@ class Specutils1DHandler: def to_data(self, obj): - coords = SpectralCoordinates(obj.spectral_axis) - data = Data(coords=coords) - data['flux'] = obj.flux - data.get_component('flux').units = str(obj.unit) + + # Glue expects spectral axis first for cubes (opposite of specutils). + # Swap the spectral axis to first here. to_object doesn't need this because + # Spectrum1D does it automatically on initialization. + if len(obj.flux.shape) == 3: + data = Data(coords=obj.wcs.swapaxes(-1, 0)) + data['flux'] = np.swapaxes(obj.flux, -1, 0) + data.get_component('flux').units = str(obj.unit) + else: + # Don't use the dummy GWCS created by Spectrum1D initialized with spectral_axis + if isinstance(obj.wcs, GWCS): + data = Data(coords=SpectralCoordinates(obj.spectral_axis)) + else: + data = Data(coords=obj.wcs) + data['flux'] = obj.flux + data.get_component('flux').units = str(obj.unit) # Include uncertainties if they exist if obj.uncertainty is not None: - data['uncertainty'] = obj.uncertainty.quantity + if len(obj.flux.shape) == 3: + data['uncertainty'] = np.swapaxes(obj.uncertainty.quantity, -1, 0) + else: + data['uncertainty'] = obj.uncertainty.quantity data.get_component('uncertainty').units = str(obj.uncertainty.unit) data.meta.update({'uncertainty_type': obj.uncertainty.uncertainty_type}) # Include mask if it exists if obj.mask is not None: - data['mask'] = obj.mask + if len(obj.flux.shape) == 3: + data['mask'] = np.swapaxes(obj.mask, -1, 0) + else: + data['mask'] = obj.mask data.meta.update(obj.meta) return data - def to_object(self, data_or_subset, attribute=None, statistic='mean'): + def to_object(self, data_or_subset, attribute=None, statistic=None): """ Convert a glue Data object to a Spectrum1D object. @@ -69,14 +88,16 @@ def to_object(self, data_or_subset, attribute=None, statistic='mean'): # Find non-spectral axes axes = tuple(i for i in range(data.ndim) if i != spec_axis) - kwargs = {'wcs': data.coords.sub([WCSSUB_SPECTRAL])} + if statistic is not None: + kwargs = {'wcs': data.coords.sub([WCSSUB_SPECTRAL])} + else: + kwargs = {'wcs': data.coords} elif isinstance(data.coords, SpectralCoordinates): kwargs = {'spectral_axis': data.coords.spectral_axis} else: - raise TypeError('data.coords should be an instance of WCS ' 'or SpectralCoordinates') @@ -112,7 +133,7 @@ def parse_attributes(attributes): mask = ~mask # Collapse values and mask to profile - if data.ndim > 1: + if data.ndim > 1 and statistic is not None: # Get units and attach to value values = data.compute_statistic(statistic, attribute, axis=axes, subset_state=subset_state) diff --git a/glue_astronomy/translators/tests/test_spectrum1d.py b/glue_astronomy/translators/tests/test_spectrum1d.py index 4ccfc61..357f4f5 100644 --- a/glue_astronomy/translators/tests/test_spectrum1d.py +++ b/glue_astronomy/translators/tests/test_spectrum1d.py @@ -123,23 +123,36 @@ def test_to_spectrum1d_default_attribute(): 'keyword argument.') -@pytest.mark.parametrize('mode', ('wcs', 'lookup')) +@pytest.mark.parametrize('mode', ('wcs1d', 'wcs3d', 'lookup')) def test_from_spectrum1d(mode): - if mode == 'wcs': - wcs = WCS(naxis=1) - wcs.wcs.ctype = ['FREQ'] + if mode == 'wcs3d': + # This test is intended to be run with the version of Spectrum1D based + # on NDCube 2.0 + pytest.importorskip("ndcube", minversion="1.99") + + # Set up simple spatial+spectral WCS + wcs = WCS(naxis=3) + wcs.wcs.ctype = ['RA---TAN', 'DEC--TAN', 'FREQ'] wcs.wcs.set() - kwargs = {'wcs': wcs} + flux = np.ones((4, 4, 5))*u.Unit('Jy') + uncertainty = VarianceUncertainty(np.square(flux*0.1)) + mask = np.zeros((4, 4, 5)) + kwargs = {'wcs': wcs, 'uncertainty': uncertainty, 'mask': mask} else: - - kwargs = {'spectral_axis': [1, 2, 3, 4] * u.Hz} - - spec = Spectrum1D([2, 3, 4, 5] * u.Jy, - uncertainty=VarianceUncertainty( - [0.1, 0.1, 0.1, 0.1] * u.Jy**2), - mask=[False, False, False, False], - **kwargs) + flux = [2, 3, 4, 5] * u.Jy + uncertainty = VarianceUncertainty([0.1, 0.1, 0.1, 0.1] * u.Jy**2) + mask = [False, False, False, False] + if mode == 'wcs1d': + wcs = WCS(naxis=1) + wcs.wcs.ctype = ['FREQ'] + wcs.wcs.set() + kwargs = {'wcs': wcs, 'uncertainty': uncertainty, 'mask': mask} + else: + kwargs = {'spectral_axis': [1, 2, 3, 4] * u.Hz, + 'uncertainty': uncertainty, 'mask': mask} + + spec = Spectrum1D(flux, **kwargs) data_collection = DataCollection() @@ -150,13 +163,13 @@ def test_from_spectrum1d(mode): assert isinstance(data, Data) assert len(data.main_components) == 3 assert data.main_components[0].label == 'flux' - assert_allclose(data['flux'], [2, 3, 4, 5]) + assert_allclose(data['flux'], flux.value) component = data.get_component('flux') assert component.units == 'Jy' # Check uncertainty parsing within glue data object assert data.main_components[1].label == 'uncertainty' - assert_allclose(data['uncertainty'], [0.1, 0.1, 0.1, 0.1]) + assert_allclose(data['uncertainty'], uncertainty.array) component = data.get_component('uncertainty') assert component.units == 'Jy2' @@ -164,13 +177,24 @@ def test_from_spectrum1d(mode): spec_new = data.get_object(attribute='flux') assert isinstance(spec_new, Spectrum1D) assert_quantity_allclose(spec_new.spectral_axis, [1, 2, 3, 4] * u.Hz) - assert_quantity_allclose(spec_new.flux, [2, 3, 4, 5] * u.Jy) + if mode == 'wcs3d': + assert_quantity_allclose(spec_new.flux, np.ones((5, 4, 4))*u.Unit('Jy')) + else: + assert_quantity_allclose(spec_new.flux, [2, 3, 4, 5] * u.Jy) assert spec_new.uncertainty is None # Check complete round-tripping, including uncertainties spec_new = data.get_object() assert isinstance(spec_new, Spectrum1D) assert_quantity_allclose(spec_new.spectral_axis, [1, 2, 3, 4] * u.Hz) - assert_quantity_allclose(spec_new.flux, [2, 3, 4, 5] * u.Jy) - assert spec_new.uncertainty is not None - assert_quantity_allclose(spec_new.uncertainty.quantity, [0.1, 0.1, 0.1, 0.1] * u.Jy**2) + if mode == 'wcs3d': + assert_quantity_allclose(spec_new.flux, np.ones((5, 4, 4))*u.Unit('Jy')) + assert spec_new.uncertainty is not None + print(spec_new.uncertainty) + print(uncertainty) + assert_quantity_allclose(spec_new.uncertainty.quantity, + np.ones((5, 4, 4))*0.01*u.Jy**2) + else: + assert_quantity_allclose(spec_new.flux, [2, 3, 4, 5] * u.Jy) + assert spec_new.uncertainty is not None + assert_quantity_allclose(spec_new.uncertainty.quantity, [0.1, 0.1, 0.1, 0.1] * u.Jy**2)