Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpectralCube to specutils transition in cubeviz #547

Merged
merged 14 commits into from
Oct 28, 2021
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Specviz
Other Changes and Additions
---------------------------

- Cubeviz now loads data cube as ``Spectrum1D``. [#547]

2.0 (2021-09-17)
================
Expand Down
22 changes: 10 additions & 12 deletions jdaviz/configs/cubeviz/plugins/moment_maps/moment_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from glue.core.message import (DataCollectionAddMessage,
DataCollectionDeleteMessage)
from traitlets import List, Unicode, Any, Bool, observe
from spectral_cube import SpectralCube
from specutils import SpectralRegion
from specutils import Spectrum1D, manipulation, SpectralRegion, analysis
from regions import RectanglePixelRegion

from jdaviz.core.events import SnackbarMessage
Expand Down Expand Up @@ -63,12 +62,12 @@ def _on_data_updated(self, msg):
# Also set the spectral min and max to default to the full range
try:
self.selected_data = self.dc_items[i]
cube = self._selected_data.get_object(cls=SpectralCube)
cube = self._selected_data.get_object(cls=Spectrum1D, statistic=None)
self.spectral_min = cube.spectral_axis[0].value
self.spectral_max = cube.spectral_axis[-1].value
self.spectral_unit = str(cube.spectral_axis.unit)
break
# Skip data that can't be returned as a SpectralCube
# Skip data that can't be returned as a Spectrum1D
except (ValueError, TypeError):
continue

Expand All @@ -80,7 +79,7 @@ def _on_subset_created(self, msg):
def _on_data_selected(self, event):
self._selected_data = next((x for x in self.data_collection
if x.label == event['new']))
cube = self._selected_data.get_object(cls=SpectralCube)
cube = self._selected_data.get_object(cls=Spectrum1D, statistic=None)
# Update spectral bounds and unit if we've switched to another unit
if str(cube.spectral_axis.unit) != self.spectral_unit:
self.spectral_min = cube.spectral_axis[0].value
Expand All @@ -92,7 +91,7 @@ def _on_subset_selected(self, event):
# If "None" selected, reset based on bounds of selected data
self._selected_subset = self.selected_subset
if self._selected_subset == "None":
cube = self._selected_data.get_object(cls=SpectralCube)
cube = self._selected_data.get_object(cls=Spectrum1D, statistic=None)
self.spectral_min = cube.spectral_axis[0].value
self.spectral_max = cube.spectral_axis[-1].value
else:
Expand Down Expand Up @@ -120,12 +119,12 @@ def vue_list_subsets(self, event):
self._spectral_subsets = temp_dict
self.spectral_subset_items = temp_list

def vue_calculate_moment(self, event):
def vue_calculate_moment(self, *args):
# Retrieve the data cube and slice out desired region, if specified
cube = self._selected_data.get_object(cls=SpectralCube)
cube = self._selected_data.get_object(cls=Spectrum1D, statistic=None)
spec_min = float(self.spectral_min) * u.Unit(self.spectral_unit)
spec_max = float(self.spectral_max) * u.Unit(self.spectral_unit)
slab = cube.spectral_slab(spec_min, spec_max)
slab = manipulation.spectral_slab(cube, spec_min, spec_max)

# Calculate the moment and convert to CCDData to add to the viewers
try:
Expand All @@ -134,10 +133,9 @@ def vue_calculate_moment(self, event):
raise ValueError("Moment must be a positive integer")
except ValueError:
raise ValueError("Moment must be a positive integer")
self.moment = slab.moment(n_moment)
self.moment = analysis.moment(slab, order=n_moment)

moment_ccd = CCDData(self.moment.array, wcs=self.moment.wcs,
unit=self.moment.unit)
moment_ccd = CCDData(self.moment, unit=self.moment.unit)

label = "Moment {}: {}".format(n_moment, self._selected_data.label)
fname_label = self._selected_data.label.replace("[", "_").replace("]", "_")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import numpy as np

from glue.core import Data
from specutils import Spectrum1D

from jdaviz import Application
from jdaviz.configs.cubeviz.plugins.moment_maps.moment_maps import MomentMap


def test_moment_calculation(spectral_cube_wcs):
def todo_fix_test_moment_calculation(spectrum1d_cube):

app = Application()
dc = app.data_collection
dc.append(Data(x=np.ones((3, 4, 5)), label='test', coords=spectral_cube_wcs))
app.add_data(spectrum1d_cube, 'test')

mm = MomentMap(app=app)
mm._subset_selected = 'None'
# mm.spectral_min = 1.0 * u.m
# mm.spectral_max = 2.0 * u.m
mm._on_data_updated(None)

mm.selected_data = 'test'
mm.n_moment = 0
mm.vue_calculate_moment(None)
mm._on_data_selected({'new': 'test'})
mm._on_subset_selected({'new': None})

print(dc[1].get_object())
mm.n_moment = 0
mm.vue_calculate_moment()

assert mm.moment_available
assert dc[1].label == 'Moment 0: test'
assert dc[1].get_object().shape == (4, 5)
assert dc[1].get_object(cls=Spectrum1D, statistic=None).shape == (4, 2, 2)
129 changes: 83 additions & 46 deletions jdaviz/configs/cubeviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import os

import numpy as np
from astropy import units as u
from astropy.io import fits
from spectral_cube import SpectralCube
from spectral_cube.io.fits import FITSReadError
from astropy.wcs import WCS
from specutils import Spectrum1D

from jdaviz.core.registries import data_parser_registry
Expand Down Expand Up @@ -49,15 +49,26 @@ def parse_data(app, file_obj, data_type=None, data_label=None):
file_name = os.path.basename(file_obj)

with fits.open(file_obj) as hdulist:
_parse_hdu(app, hdulist, file_name=data_label or file_name)
prihdr = hdulist[0].header
telescop = prihdr.get('TELESCOP', '').lower()
filetype = prihdr.get('FILETYPE', '').lower()
if telescop == 'jwst' and filetype == '3d ifu cube':
# TODO: What about ERR, DQ, and WMAP?
data_label = f'{file_name}[SCI]'
_parse_jwst_s3d(app, hdulist, data_label)
else:
_parse_hdu(app, hdulist, file_name=data_label or file_name)

# If the data types are custom data objects, use explicit parsers. Note
# that this relies on the glue-astronomy machinery to turn the data object
# into something glue can understand.
elif isinstance(file_obj, SpectralCube):
_parse_spectral_cube(app, file_obj, data_type or 'flux', data_label)
elif isinstance(file_obj, Spectrum1D):
_parse_spectrum1d(app, file_obj)
if file_obj.flux.ndim == 3:
pllim marked this conversation as resolved.
Show resolved Hide resolved
_parse_spectrum1d_3d(app, file_obj)
else:
_parse_spectrum1d(app, file_obj)
else:
raise NotImplementedError(f'Unsupported data format: {file_obj}')


def _parse_hdu(app, hdulist, file_name=None):
Expand All @@ -67,44 +78,36 @@ def _parse_hdu(app, hdulist, file_name=None):

file_name = file_name or "Unknown HDU object"

# WCS may only exist in a single extension (in this case, just the flux
# flux extension), so first find and store then wcs information.
wcs = None

for hdu in hdulist:
if hdu.data is None or not hdu.is_image:
continue

try:
sc = SpectralCube.read(hdu, format='fits')
except (ValueError, FITSReadError):
continue
else:
wcs = sc.wcs

# Now loop through and attempt to parse the fits extensions as spectral
# cube object. If the wcs fails to parse in any case, use the wcs
# information we scraped above.
for hdu in hdulist:
data_label = f"{file_name}[{hdu.name}]"

if hdu.data is None or not hdu.is_image:
if hdu.data is None or not hdu.is_image or hdu.data.ndim != 3:
continue

# This is supposed to fail on attempting to load anything that
# isn't cube-shaped. But it's not terribly reliable
try:
pllim marked this conversation as resolved.
Show resolved Hide resolved
sc = SpectralCube.read(hdu, format='fits')
except (ValueError, OSError):
# This will fail if the parsing of the wcs does not provide
# proper celestial axes
wcs = WCS(hdu.header, hdulist)
except Exception as e: # TODO: Do we just want to fail here?
logging.warning(f"Invalid WCS: {repr(e)}")
wcs = None

if 'BUNIT' in hdu.header:
try:
hdu.header.update(wcs.to_header())
sc = SpectralCube.read(hdu)
except (ValueError, AttributeError) as e:
logging.warning(e)
continue
except FITSReadError as e:
flux_unit = u.Unit(hdu.header['BUNIT'])
except Exception:
logging.warning("Invalid BUNIT, using count as data unit")
flux_unit = u.count
else:
logging.warning("Missing BUNIT, using count as data unit")
flux_unit = u.count

flux = hdu.data << flux_unit

try:
sc = Spectrum1D(flux=flux, wcs=wcs)
except Exception as e:
logging.warning(e)
continue

Expand All @@ -124,21 +127,55 @@ def _parse_hdu(app, hdulist, file_name=None):
app.add_data_to_viewer('spectrum-viewer', data_label)


def _parse_spectral_cube(app, file_obj, data_type='flux', data_label=None):
data_label = data_label or f"Unknown spectral cube[{data_type.upper()}]"
def _parse_jwst_s3d(app, hdulist, data_label):
from specutils import Spectrum1D

app.add_data(file_obj, data_label)
unit = u.Unit(hdulist[1].header.get('BUNIT', 'count'))
flux = hdulist[1].data << unit
wcs = WCS(hdulist[1].header, hdulist)
data = Spectrum1D(flux, wcs=wcs)

if data_type == 'flux':
app.add_data_to_viewer('flux-viewer', f"{data_label}")
app.add_data_to_viewer('spectrum-viewer', f"{data_label}")
elif data_type == 'mask':
app.add_data_to_viewer('mask-viewer', f"{data_label}")
elif data_type == 'uncert':
app.add_data_to_viewer('uncert-viewer', f"{data_label}")
# NOTE: Tried to only pass in sliced WCS but got error in Glue.
# sliced_wcs = wcs[:, 0, 0] # Only want wavelengths
# data = Spectrum1D(flux, wcs=sliced_wcs)

# TODO: SpectralCube does not store mask information
# TODO: SpectralCube does not store data quality information
app.add_data(data, data_label)
app.add_data_to_viewer('flux-viewer', data_label)
app.add_data_to_viewer('spectrum-viewer', data_label)


def _parse_spectrum1d_3d(app, file_obj):
# Load spectrum1d as a cube

for attr in ["flux", "mask", "uncertainty"]:
val = getattr(file_obj, attr)
if val is None:
continue

if attr == "mask":
flux = val << file_obj.flux.unit
elif attr == "uncertainty":
if hasattr(val, "array"):
flux = u.Quantity(val.array, file_obj.flux.unit)
else:
continue
else:
flux = val

flux = np.moveaxis(flux, 1, 0)

s1d = Spectrum1D(flux=flux, wcs=file_obj.wcs)

data_label = f"Unknown spectrum object[{attr.upper()}]"
app.add_data(s1d, data_label)

if attr == 'flux':
app.add_data_to_viewer('flux-viewer', data_label)
app.add_data_to_viewer('spectrum-viewer', data_label)
elif attr == 'mask':
app.add_data_to_viewer('mask-viewer', data_label)
else: # 'uncertainty'
app.add_data_to_viewer('uncert-viewer', data_label)


def _parse_spectrum1d(app, file_obj):
Expand Down
63 changes: 30 additions & 33 deletions jdaviz/configs/cubeviz/plugins/tests/test_parsers.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
import os

import astropy.units as u
import numpy as np
import pytest
from astropy import units as u
from astropy.io import fits
from astropy.nddata import StdDevUncertainty
from astropy.wcs import WCS
from spectral_cube import SpectralCube
from specutils import Spectrum1D

from jdaviz.app import Application


@pytest.fixture
def cubeviz_app():
return Application(configuration='cubeviz')


@pytest.fixture
def image_hdu_obj():
flux_hdu = fits.ImageHDU(np.random.sample((10, 10, 10)))
flux_hdu = fits.ImageHDU(np.ones((10, 10, 10)))
flux_hdu.name = 'FLUX'

mask_hdu = fits.ImageHDU(np.zeros((10, 10, 10)))
mask_hdu.name = 'MASK'

uncert_hdu = fits.ImageHDU(np.random.sample((10, 10, 10)))
uncert_hdu = fits.ImageHDU(np.ones((10, 10, 10)))
uncert_hdu.name = 'ERR'

wcs = WCS(header={
Expand All @@ -41,43 +30,51 @@ def image_hdu_obj():
})

flux_hdu.header.update(wcs.to_header())
flux_hdu.header['BUNIT'] = '1E-17 erg/s/cm^2/Angstrom/spaxel'
flux_hdu.header['BUNIT'] = '1E-17 erg*s^-1*cm^-2*Angstrom^-1*pix^-1'

mask_hdu.header.update(wcs.to_header())
uncert_hdu.header.update(wcs.to_header())

return fits.HDUList([fits.PrimaryHDU(), flux_hdu, mask_hdu, uncert_hdu])


@pytest.mark.filterwarnings('ignore:.* contains multiple slashes')
@pytest.mark.filterwarnings('ignore')
def test_fits_image_hdu_parse(image_hdu_obj, cubeviz_app):
cubeviz_app.load_data(image_hdu_obj)

assert len(cubeviz_app.data_collection) == 3
assert cubeviz_app.data_collection[0].label.endswith('[FLUX]')
assert len(cubeviz_app.app.data_collection) == 3
assert cubeviz_app.app.data_collection[0].label.endswith('[FLUX]')


@pytest.mark.filterwarnings('ignore:.* contains multiple slashes')
def test_spectral_cube_parse(tmpdir, image_hdu_obj, cubeviz_app):
@pytest.mark.filterwarnings('ignore')
def test_fits_image_hdu_parse_from_file(tmpdir, image_hdu_obj, cubeviz_app):
f = tmpdir.join("test_fits_image.fits")
path = os.path.join(f.dirname, f.basename)
image_hdu_obj.writeto(path)
path = f.strpath
image_hdu_obj.writeto(path, overwrite=True)
cubeviz_app.load_data(path)

assert len(cubeviz_app.app.data_collection) == 3
assert cubeviz_app.app.data_collection[0].label.endswith('[FLUX]')

sc = SpectralCube.read(path, hdu=1)

@pytest.mark.filterwarnings('ignore')
def test_spectrum3d_parse(image_hdu_obj, cubeviz_app):
flux = image_hdu_obj[1].data << u.Unit(image_hdu_obj[1].header['BUNIT'])
wcs = WCS(image_hdu_obj[1].header, image_hdu_obj)
sc = Spectrum1D(flux=flux, wcs=wcs)
cubeviz_app.load_data(sc)

assert len(cubeviz_app.data_collection) == 1
assert cubeviz_app.data_collection[0].label.endswith('[FLUX]')
assert len(cubeviz_app.app.data_collection) == 1
assert cubeviz_app.app.data_collection[0].label.endswith('[FLUX]')


def test_spectrum1d_parse(spectrum1d, cubeviz_app):
cubeviz_app.load_data(spectrum1d)

def test_spectrum1d_parse(image_hdu_obj, cubeviz_app):
spec = Spectrum1D(flux=np.random.sample(10) * u.Jy,
spectral_axis=np.arange(10) * u.nm,
uncertainty=StdDevUncertainty(
np.random.sample(10) * u.Jy))
assert len(cubeviz_app.app.data_collection) == 1
assert cubeviz_app.app.data_collection[0].label.endswith('[FLUX]')

cubeviz_app.load_data(spec)

assert len(cubeviz_app.data_collection) == 1
assert cubeviz_app.data_collection[0].label.endswith('[FLUX]')
def test_numpy_cube(cubeviz_app):
with pytest.raises(NotImplementedError, match='Unsupported data format'):
cubeviz_app.load_data(np.ones(27).reshape((3, 3, 3)))
Loading