Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Labels support for 0.4.19, points support for > 0.4.19 #239

Merged
merged 6 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ install_requires =
anndata
click
cycler
dask
dask<=2024.2.1
geopandas
loguru
matplotlib
napari>=0.4.19
napari>=0.4.19.post1
napari-matplotlib
numba
numpy
Expand Down
9 changes: 5 additions & 4 deletions src/napari_spatialdata/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@ def get_obs(
if name not in self.adata.obs.columns:
raise KeyError(f"Key `{name}` not found in `adata.obs`.")
if name != self.instance_key:
adata_obs = self.adata.obs[[self.instance_key, name]]
adata_obs.set_index(self.instance_key, inplace=True)
obs_column = self.adata.obs[[self.instance_key, name]]
obs_column = obs_column.set_index(self.instance_key)[name]
else:
adata_obs = self.adata.obs
return adata_obs[name], self._format_key(name)
obs_column = self.adata.obs[name].copy()
obs_column.index = self.adata.obs[self.instance_key]
return obs_column, self._format_key(name)

@_ensure_dense_vector
def get_columns_df(self, name: Union[str, int], **_: Any) -> Tuple[Optional[NDArrayA], str]:
Expand Down
19 changes: 15 additions & 4 deletions src/napari_spatialdata/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Any

import numpy as np
import packaging.version
from anndata import AnnData
from dask.dataframe.core import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
Expand All @@ -30,6 +31,7 @@
_get_transform,
_transform_coordinates,
get_duplicate_element_names,
get_napari_version,
)
from napari_spatialdata.utils._viewer_utils import _get_polygons_properties

Expand Down Expand Up @@ -372,29 +374,36 @@ def add_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
}

CIRCLES_AS_POINTS = True
version = get_napari_version()
kwargs: dict[str, Any] = (
{"edge_width": 0.0} if version <= packaging.version.parse("0.4.20") else {"border_width": 0.0}
)
if CIRCLES_AS_POINTS:
layer = self.viewer.add_points(
yx,
name=key,
affine=affine,
size=1, # the sise doesn't matter here since it will be adjusted in _adjust_radii_of_points_layer
edge_width=0.0,
metadata=metadata,
**kwargs,
)
assert affine is not None
self._adjust_radii_of_points_layer(layer=layer, affine=affine)
else:
if version <= packaging.version.parse("0.4.20"):
kwargs |= {"edge_color": "white"}
else:
kwargs |= {"border_color": "white"}
# useful code to have readily available to debug the correct radius of circles when represented as points
ellipses = _get_ellipses_from_circles(yx=yx, radii=radii)
self.viewer.add_shapes(
ellipses,
shape_type="ellipse",
name=key,
edge_color="white",
face_color="white",
edge_width=0.0,
affine=affine,
metadata=metadata,
**kwargs,
)

def add_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> None:
Expand Down Expand Up @@ -539,12 +548,13 @@ def add_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi
np.fliplr(xy)
# radii_size = _calc_default_radii(self.viewer, sdata, selected_cs)
radii_size = 3
version = get_napari_version()
kwargs = {"edge_width": 0.0} if version <= packaging.version.parse("0.4.20") else {"border_width": 0.0}
layer = self.viewer.add_points(
xy,
name=key,
size=radii_size * 2,
affine=affine,
edge_width=0.0,
metadata={
"sdata": sdata,
"adata": adata,
Expand All @@ -562,6 +572,7 @@ def add_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi
else None
),
},
**kwargs,
)
assert affine is not None
self._adjust_radii_of_points_layer(layer=layer, affine=affine)
Expand Down
24 changes: 13 additions & 11 deletions src/napari_spatialdata/_widgets.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from __future__ import annotations

from abc import abstractmethod
from collections import defaultdict
from functools import singledispatchmethod
from typing import TYPE_CHECKING, Any, Iterable, Sequence

import matplotlib.pyplot as plt
import napari
import numpy as np
import packaging.version
import pandas as pd
from anndata import AnnData
from loguru import logger
from napari.layers import Labels, Points, Shapes
from napari.utils import DirectLabelColormap
from napari.viewer import Viewer
from qtpy import QtCore, QtWidgets
from qtpy.QtCore import Qt, Signal
Expand All @@ -22,10 +25,7 @@
from vispy.scene.widgets import ColorBarWidget

from napari_spatialdata._model import DataModel
from napari_spatialdata.utils._utils import (
NDArrayA,
_min_max_norm,
)
from napari_spatialdata.utils._utils import NDArrayA, _min_max_norm, get_napari_version

__all__ = [
"AListWidget",
Expand Down Expand Up @@ -121,11 +121,7 @@ def _onChange(self) -> None:

def _onAction(self, items: Iterable[str]) -> None:
for item in sorted(set(items)):
try:
vec, name = self._getter(item, index=self.getIndex())
except Exception as e: # noqa: BLE001
logger.error(e)
continue
vec, name = self._getter(item, index=self.getIndex())

if self.model.layer is not None:
properties = self._get_points_properties(vec, key=item, layer=self.model.layer)
Expand All @@ -136,8 +132,14 @@ def _onAction(self, items: Iterable[str]) -> None:
self.model.layer.face_color = properties["face_color"]
self.model.layer.text = properties["text"]
elif isinstance(self.model.layer, Labels):
self.model.layer.color = properties["color"]
self.model.layer.properties = properties.get("properties", None)
version = get_napari_version()
if version < packaging.version.parse("0.4.20"):
self.model.layer.color = properties["color"]
self.model.layer.properties = properties.get("properties", None)
else:
ddict = defaultdict(lambda: np.zeros(4), properties["color"])
cmap = DirectLabelColormap(color_dict=ddict)
self.model.layer.colormap = cmap
else:
raise ValueError("TODO")
# TODO(michalk8): add contrasting fg/bg color once https://github.com/napari/napari/issues/2019 is done
Expand Down
8 changes: 7 additions & 1 deletion src/napari_spatialdata/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Sequence, Union

import numpy as np
import packaging.version
import pandas as pd
from anndata import AnnData
from dask.dataframe.core import DataFrame as DaskDataFrame
Expand All @@ -14,6 +15,7 @@
from loguru import logger
from matplotlib.colors import is_color_like, to_rgb
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
from napari import __version__
from napari.layers import Layer
from numba import njit, prange
from pandas.api.types import CategoricalDtype, infer_dtype
Expand Down Expand Up @@ -86,7 +88,7 @@ def decorator(self: Any, *args: Any, **kwargs: Any) -> Vector_name_t:
elif not isinstance(res, (np.ndarray, Sequence)):
raise TypeError(f"Unable to process result of type `{type(res).__name__}`.")

res = np.asarray(np.squeeze(res))
res = np.atleast_1d(np.squeeze(res))
if res.ndim != 1:
raise ValueError(f"Expected 1-dimensional array, found `{res.ndim}`.")

Expand Down Expand Up @@ -483,3 +485,7 @@ def _get_ellipses_from_circles(yx: NDArrayA, radii: NDArrayA) -> NDArrayA:
ellipses = np.stack([lower_left, lower_right, upper_right, upper_left], axis=1)
assert isinstance(ellipses, np.ndarray)
return ellipses


def get_napari_version() -> packaging.version.Version:
return packaging.version.parse(__version__)
1 change: 1 addition & 0 deletions tests/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_layer_transform(qtbot, make_napari_viewer: any):

assert np.array_equal(viewer.layers[0].affine.affine_matrix, affine_transform)
assert np.array_equal(viewer.layers[1].affine.affine_matrix, no_transform)
viewer.close()


def test_adata_metadata(qtbot, make_napari_viewer: any):
Expand Down
Loading