diff --git a/glue_jupyter/bqplot/scatter/layer_artist.py b/glue_jupyter/bqplot/scatter/layer_artist.py index 6434f052..89d419e9 100644 --- a/glue_jupyter/bqplot/scatter/layer_artist.py +++ b/glue_jupyter/bqplot/scatter/layer_artist.py @@ -9,6 +9,7 @@ from glue.viewers.scatter.layer_artist import DensityMapLimits from glue.viewers.scatter.state import ScatterLayerState from glue_jupyter.bqplot.scatter.scatter_density_mark import GenericDensityMark +from glue.core.units import UnitConverter from ...utils import colormap_to_hexlist, float_or_none from ..compatibility import ScatterGL, LinesGL @@ -63,6 +64,8 @@ "linewidth", "markers_visible", "vector_scaling", + "x_display_unit", + "y_display_unit" } @@ -175,7 +178,6 @@ def _update_data(self): x = ensure_numerical(self.layer[self._viewer_state.x_att].ravel()) if x.dtype.kind == "M": x = datetime64_to_mpl(x) - except (IncompatibleAttribute, IndexError): # The following includes a call to self.clear() self.disable_invalid_attributes(self._viewer_state.x_att) @@ -195,6 +197,18 @@ def _update_data(self): else: self.enable() + converter = UnitConverter() + + x = converter.to_unit(self._viewer_state.x_att.parent, + self._viewer_state.x_att, + x, + self._viewer_state.x_display_unit) + + y = converter.to_unit(self._viewer_state.y_att.parent, + self._viewer_state.y_att, + y, + self._viewer_state.y_display_unit) + if self.state.markers_visible: if self.state.density_map: diff --git a/glue_jupyter/bqplot/scatter/tests/test_viewer.py b/glue_jupyter/bqplot/scatter/tests/test_viewer.py index dafb39da..f880ba9e 100644 --- a/glue_jupyter/bqplot/scatter/tests/test_viewer.py +++ b/glue_jupyter/bqplot/scatter/tests/test_viewer.py @@ -1,5 +1,11 @@ from itertools import permutations +from numpy.testing import assert_allclose, assert_equal + +from glue.core import Data +from glue.core.roi import RectangularROI +from glue.core.link_helpers import LinkSame, LinkSameWithUnits + def test_scatter2d_nd(app, data_4d): # Regression test for a bug that meant that arrays with more than one @@ -106,3 +112,147 @@ def test_incompatible_data(app): assert len(s.layers) == 1 assert s.layers[0].enabled + + +def test_unit_conversion(app): + + d1 = Data(a=[1, 2, 3], b=[2, 3, 4]) + d1.get_component('a').units = 'm' + d1.get_component('b').units = 's' + + d2 = Data(c=[2000, 1000, 3000], d=[0.001, 0.002, 0.004]) + d2.get_component('c').units = 'mm' + d2.get_component('d').units = 'ks' + + # d3 is the same as d2 but we will link it differently + d3 = Data(e=[2000, 1000, 3000], f=[0.001, 0.002, 0.004]) + d3.get_component('e').units = 'mm' + d3.get_component('f').units = 'ks' + + d4 = Data(g=[2, 2, 3], h=[1, 2, 1]) + d4.get_component('g').units = 'kg' + d4.get_component('h').units = 'm/s' + + session = app.session + + data_collection = session.data_collection + data_collection.append(d1) + data_collection.append(d2) + data_collection.append(d3) + data_collection.append(d4) + + data_collection.add_link(LinkSameWithUnits(d1.id['a'], d2.id['c'])) + data_collection.add_link(LinkSameWithUnits(d1.id['b'], d2.id['d'])) + data_collection.add_link(LinkSame(d1.id['a'], d3.id['e'])) + data_collection.add_link(LinkSame(d1.id['b'], d3.id['f'])) + data_collection.add_link(LinkSame(d1.id['a'], d4.id['g'])) + data_collection.add_link(LinkSame(d1.id['b'], d4.id['h'])) + + viewer = app.scatter2d(data=d1) + viewer.add_data(d2) + viewer.add_data(d3) + viewer.add_data(d4) + + assert viewer.layers[0].enabled + assert viewer.layers[1].enabled + assert viewer.layers[2].enabled + assert viewer.layers[3].enabled + + assert viewer.state.x_axislabel == 'a [m]' + assert viewer.state.y_axislabel == 'b [s]' + + assert_allclose(viewer.layers[0].scatter_mark.x, [1, 2, 3]) + assert_allclose(viewer.layers[0].scatter_mark.y, [2, 3, 4]) + assert_allclose(viewer.layers[1].scatter_mark.x, [2, 1, 3]) + assert_allclose(viewer.layers[1].scatter_mark.y, [1, 2, 4]) + assert_allclose(viewer.layers[2].scatter_mark.x, [2000, 1000, 3000]) + assert_allclose(viewer.layers[2].scatter_mark.y, [0.001, 0.002, 0.004]) + assert_allclose(viewer.layers[3].scatter_mark.x, [2, 2, 3]) + assert_allclose(viewer.layers[3].scatter_mark.y, [1, 2, 1]) + + assert viewer.state.x_min == 0.92 + assert viewer.state.x_max == 3.08 + assert viewer.state.y_min == 1.92 + assert viewer.state.y_max == 4.08 + + roi = RectangularROI(0.5, 2.5, 1.5, 4.5) + viewer.apply_roi(roi) + + assert len(d1.subsets) == 1 + assert_equal(d1.subsets[0].to_mask(), [1, 1, 0]) + + # Because of the LinkSameWithUnits, the points actually appear in the right + # place even before we set the display units. + assert len(d2.subsets) == 1 + assert_equal(d2.subsets[0].to_mask(), [0, 1, 0]) + + # d3 is only linked with LinkSame not LinkSameWithUnits so currently the + # points are outside the visible axes + assert len(d3.subsets) == 1 + assert_equal(d3.subsets[0].to_mask(), [0, 0, 0]) + + # As we haven't set display units yet, the values for this dataset are shown + # on the same scale as for d1 as if the units had never been set. + assert len(d4.subsets) == 1 + assert_equal(d4.subsets[0].to_mask(), [0, 1, 0]) + + # Now try setting the units explicitly + + viewer.state.x_display_unit = 'km' + viewer.state.y_display_unit = 'ms' + + assert viewer.state.x_axislabel == 'a [km]' + assert viewer.state.y_axislabel == 'b [ms]' + + assert_allclose(viewer.layers[0].scatter_mark.x, [1e-3, 2e-3, 3e-3]) + assert_allclose(viewer.layers[0].scatter_mark.y, [2e3, 3e3, 4e3]) + assert_allclose(viewer.layers[1].scatter_mark.x, [2e-3, 1e-3, 3e-3]) + assert_allclose(viewer.layers[1].scatter_mark.y, [1e3, 2e3, 4e3]) + assert_allclose(viewer.layers[2].scatter_mark.x, [2, 1, 3]) + assert_allclose(viewer.layers[2].scatter_mark.y, [1, 2, 4]) + assert_allclose(viewer.layers[3].scatter_mark.x, [2e-3, 2e-3, 3e-3]) + assert_allclose(viewer.layers[3].scatter_mark.y, [1e3, 2e3, 1e3]) + + assert_allclose(viewer.state.x_min, 0.92e-3) + assert_allclose(viewer.state.x_max, 3.08e-3) + assert_allclose(viewer.state.y_min, 1.92e3) + assert_allclose(viewer.state.y_max, 4.08e3) + + roi = RectangularROI(0.5e-3, 2.5e-3, 1.5e3, 4.5e3) + viewer.apply_roi(roi) + + # Results are as above - the display units do not result in any changes to + # the actual content of the axes and does not deal with automatic conversion + # of different units between different datasets - LinkSameWithUnits should + # deal with that already. + + assert_equal(d1.subsets[0].to_mask(), [1, 1, 0]) + assert_equal(d2.subsets[0].to_mask(), [0, 1, 0]) + assert_equal(d3.subsets[0].to_mask(), [0, 0, 0]) + assert_equal(d4.subsets[0].to_mask(), [0, 1, 0]) + + # Change the limits to make sure they are always converted + viewer.state.x_min = 0.0001 + viewer.state.x_max = 0.005 + viewer.state.y_min = 200 + viewer.state.y_max = 7000 + + viewer.state.x_display_unit = 'm' + viewer.state.y_display_unit = 's' +# + assert viewer.state.x_axislabel == 'a [m]' + assert viewer.state.y_axislabel == 'b [s]' + + assert_allclose(viewer.layers[0].scatter_mark.x, [1, 2, 3]) + assert_allclose(viewer.layers[0].scatter_mark.y, [2, 3, 4]) + assert_allclose(viewer.layers[1].scatter_mark.x, [2, 1, 3]) + assert_allclose(viewer.layers[1].scatter_mark.y, [1, 2, 4]) + assert_allclose(viewer.layers[2].scatter_mark.x, [2000, 1000, 3000]) + assert_allclose(viewer.layers[2].scatter_mark.y, [0.001, 0.002, 0.004]) + assert_allclose(viewer.layers[3].scatter_mark.x, [2, 2, 3]) + assert_allclose(viewer.layers[3].scatter_mark.y, [1, 2, 1]) + + assert viewer.state.x_min == 0.1 + assert viewer.state.x_max == 5 + assert viewer.state.y_min == 0.2 + assert viewer.state.y_max == 7 diff --git a/glue_jupyter/bqplot/scatter/viewer.py b/glue_jupyter/bqplot/scatter/viewer.py index 7381053e..0c9429d9 100644 --- a/glue_jupyter/bqplot/scatter/viewer.py +++ b/glue_jupyter/bqplot/scatter/viewer.py @@ -4,6 +4,9 @@ from .layer_artist import BqplotScatterLayerArtist +from glue.core.units import UnitConverter +from glue.core.subset import roi_to_subset_state + from glue_jupyter.common.state_widgets.layer_scatter import ScatterLayerStateWidget from glue_jupyter.common.state_widgets.viewer_scatter import ScatterViewerStateWidget @@ -31,12 +34,43 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.state.add_callback('x_att', self._update_axes) self.state.add_callback('y_att', self._update_axes) + self.state.add_callback('x_display_unit', self._update_axes) + self.state.add_callback('y_display_unit', self._update_axes) self._update_axes() def _update_axes(self, *args): if self.state.x_att is not None: - self.state.x_axislabel = str(self.state.x_att) + if self.state.x_display_unit: + self.state.x_axislabel = str(self.state.x_att) + f' [{self.state.x_display_unit}]' + else: + self.state.x_axislabel = str(self.state.x_att) if self.state.y_att is not None: - self.state.y_axislabel = str(self.state.y_att) + if self.state.y_display_unit: + self.state.y_axislabel = str(self.state.y_att) + f' [{self.state.y_display_unit}]' + else: + self.state.y_axislabel = str(self.state.y_att) + + def _roi_to_subset_state(self, roi): + + converter = UnitConverter() + + if self.state.x_display_unit: + xfunc = lambda x: converter.to_native(self.state.x_att.parent, + self.state.x_att, x, + self.state.x_display_unit) + else: + xfunc = None + + if self.state.y_display_unit: + yfunc = lambda y: converter.to_native(self.state.y_att.parent, + self.state.y_att, y, + self.state.y_display_unit) + else: + yfunc = None + + if xfunc or yfunc: + roi = roi.transformed(xfunc=xfunc, yfunc=yfunc) + + return roi_to_subset_state(roi, x_att=self.state.x_att, y_att=self.state.y_att) diff --git a/glue_jupyter/common/state_widgets/viewer_scatter.py b/glue_jupyter/common/state_widgets/viewer_scatter.py index 018fd049..6a7fa68a 100644 --- a/glue_jupyter/common/state_widgets/viewer_scatter.py +++ b/glue_jupyter/common/state_widgets/viewer_scatter.py @@ -19,6 +19,12 @@ class ScatterViewerStateWidget(v.VuetifyTemplate): y_att_items = traitlets.List().tag(sync=True) y_att_selected = traitlets.Int(allow_none=True).tag(sync=True) + x_display_unit_items = traitlets.List().tag(sync=True) + x_display_unit_selected = traitlets.Int(allow_none=True).tag(sync=True) + + y_display_unit_items = traitlets.List().tag(sync=True) + y_display_unit_selected = traitlets.Int(allow_none=True).tag(sync=True) + def __init__(self, viewer_state): super().__init__() @@ -28,3 +34,5 @@ def __init__(self, viewer_state): link_glue_choices(self, viewer_state, "x_att") link_glue_choices(self, viewer_state, "y_att") + link_glue_choices(self, viewer_state, 'x_display_unit') + link_glue_choices(self, viewer_state, 'y_display_unit') diff --git a/glue_jupyter/common/state_widgets/viewer_scatter.vue b/glue_jupyter/common/state_widgets/viewer_scatter.vue index 8c888e42..63f0b96e 100644 --- a/glue_jupyter/common/state_widgets/viewer_scatter.vue +++ b/glue_jupyter/common/state_widgets/viewer_scatter.vue @@ -6,6 +6,12 @@