Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change channel names #786

Merged
merged 15 commits into from
Nov 22, 2024
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ and this project adheres to [Semantic Versioning][].

## [0.2.6] - TBD

### Added

- Added `set_channel_names` method to `SpatialData` to change the channel names of an
image element in `SpatialData`
- Added `write_channel_names` method to `SpatialData` to overwrite channel metadata on disk
without overwriting the image array itself.

### Changed

- `get_channels` is marked for deprecation in `SpatialData` v0.3.0. Function is replaced
by `get_channel_names`

### Fixed

- Updated deprecated default stages of `pre-commit` #771
Expand Down
2 changes: 2 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ The elements (building-blocks) that constitute `SpatialData`.
points_geopandas_to_dask_dataframe
points_dask_dataframe_to_geopandas
get_channels
get_channel_names
set_channel_names
force_2d
```

Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_core/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dask.array.overlap import coerce_depth
from xarray import DataArray, DataTree

from spatialdata.models._utils import get_axes_names, get_channels, get_raster_model_from_data_dims
from spatialdata.models._utils import get_axes_names, get_channel_names, get_raster_model_from_data_dims
from spatialdata.transformations import get_transformation

__all__ = ["map_raster"]
Expand Down Expand Up @@ -121,7 +121,7 @@ def map_raster(

if "c" in dims:
if c_coords is None:
c_coords = range(arr.shape[0]) if arr.shape[0] != len(get_channels(data)) else get_channels(data)
c_coords = range(arr.shape[0]) if arr.shape[0] != len(get_channel_names(data)) else get_channel_names(data)
else:
c_coords = None
if transformations is None:
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata.models import SpatialElement, get_axes_names, get_model
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM, get_channels
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM, get_channel_names
from spatialdata.transformations._utils import _get_scale, compute_coordinates, scale_radii

if TYPE_CHECKING:
Expand Down Expand Up @@ -367,7 +367,7 @@ def _(
channel_names = None
elif schema in (Image2DModel, Image3DModel):
kwargs = {}
channel_names = get_channels(data)
channel_names = get_channel_names(data)
else:
raise ValueError(f"DataTree with schema {schema} not supported")

Expand Down
74 changes: 71 additions & 3 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables
from spatialdata._logging import logger
from spatialdata._types import ArrayLike, Raster_T
from spatialdata._utils import _deprecation_alias, _error_message_add_element
from spatialdata._utils import (
_deprecation_alias,
_error_message_add_element,
)
from spatialdata.models import (
Image2DModel,
Image3DModel,
Expand All @@ -36,7 +39,12 @@
get_model,
get_table_keys,
)
from spatialdata.models._utils import SpatialElement, convert_region_column_to_categorical, get_axes_names
from spatialdata.models._utils import (
SpatialElement,
convert_region_column_to_categorical,
get_axes_names,
set_channel_names,
)

if TYPE_CHECKING:
from spatialdata._core.query.spatial_query import BaseSpatialRequest
Expand Down Expand Up @@ -315,6 +323,26 @@ def get_instance_key_column(table: AnnData) -> pd.Series:
return table.obs[instance_key]
raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.")

def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None:
"""Set the channel names for a image `SpatialElement` in the `SpatialData` object.

This method assumes that the `SpatialData` object and the element are already stored on disk as it will
also overwrite the channel names metadata on disk. In case either the `SpatialData` object or the
element are not stored on disk, please use `SpatialData.set_image_channel_names` instead.

Parameters
----------
element_name
Name of the image `SpatialElement`.
channel_names
The channel names to be assigned to the c dimension of the image `SpatialElement`.
write
Whether to overwrite the channel metadata on disk.
"""
self.images[element_name] = set_channel_names(self.images[element_name], channel_names)
if write:
self.write_channel_names(element_name)

@staticmethod
def _set_table_annotation_target(
table: AnnData,
Expand Down Expand Up @@ -1441,6 +1469,45 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st
)
return element_type, element

def write_channel_names(self, element_name: str | None = None) -> None:
"""
Write channel names to disk for a single image element, or for all image elements, without rewriting the data.

Parameters
----------
element_name
The name of the element to write the channel names of. If None, write the channel names of all image
elements.
"""
from spatialdata._core._elements import Elements

if element_name is not None:
Elements._check_valid_name(element_name)

# recursively write the transformation for all the SpatialElement
if element_name is None:
for element_name in list(self.images.keys()):
self.write_channel_names(element_name)
return

validation_result = self._validate_can_write_metadata_on_element(element_name)
if validation_result is None:
return

element_type, element = validation_result

# Mypy does not understand that path is not None so we have the check in the conditional
if element_type == "images" and self.path is not None:
_, _, element_group = self._get_groups_for_element(
zarr_path=Path(self.path), element_type=element_type, element_name=element_name
)

from spatialdata._io._utils import overwrite_channel_names

overwrite_channel_names(element_group, element)
else:
raise ValueError(f"Can't set channel names for element of type '{element_type}'.")

def write_transformations(self, element_name: str | None = None) -> None:
"""
Write transformations to disk for a single element, or for all elements, without rewriting the data.
Expand Down Expand Up @@ -1471,6 +1538,7 @@ def write_transformations(self, element_name: str | None = None) -> None:
transformations = get_transformation(element, get_all=True)
assert isinstance(transformations, dict)

# Mypy does not understand that path is not None so we have a conditional
assert self.path is not None
_, _, element_group = self._get_groups_for_element(
zarr_path=Path(self.path), element_type=element_type, element_name=element_name
Expand Down Expand Up @@ -1546,9 +1614,9 @@ def write_metadata(self, element_name: str | None = None, consolidate_metadata:
Elements._check_valid_name(element_name)

self.write_transformations(element_name)
self.write_channel_names(element_name)
# TODO: write .uns['spatialdata_attrs'] metadata for AnnData.
# TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame.
# TODO: write omero metadata for the channel name of images.

if consolidate_metadata is None and self.has_consolidated_metadata():
consolidate_metadata = True
Expand Down
21 changes: 21 additions & 0 deletions src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,27 @@ def overwrite_coordinate_transformations_raster(
group.attrs["multiscales"] = multiscales


def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> None:
"""Write channel metadata to a group."""
if isinstance(element, DataArray):
channel_names = element.coords["c"].data.tolist()
else:
channel_names = element["scale0"]["image"].coords["c"].data.tolist()

channel_metadata = [{"label": name} for name in channel_names]
omero_meta = group.attrs["omero"]
omero_meta["channels"] = channel_metadata
group.attrs["omero"] = omero_meta
multiscales_meta = group.attrs["multiscales"]
if len(multiscales_meta) != 1:
raise ValueError(
f"Multiscale metadata must be of length one but got length {len(multiscales_meta)}. Data might"
f"be corrupted."
)
multiscales_meta[0]["metadata"]["omero"]["channels"] = channel_metadata
group.attrs["multiscales"] = multiscales_meta


def _write_metadata(
group: zarr.Group,
group_type: str,
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_parse_version,
)
from spatialdata._utils import get_pyramid_levels
from spatialdata.models._utils import get_channels
from spatialdata.models._utils import get_channel_names
from spatialdata.models.models import ATTRS_KEY
from spatialdata.transformations._utils import (
_get_transformations,
Expand Down Expand Up @@ -151,7 +151,7 @@ def _get_group_for_writing_transformations() -> zarr.Group:
# convert channel names to channel metadata in omero
if raster_type == "image":
metadata["metadata"] = {"omero": {"channels": []}}
channels = get_channels(raster_data)
channels = get_channel_names(raster_data)
for c in channels:
metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload]

Expand Down
35 changes: 35 additions & 0 deletions src/spatialdata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
from anndata import AnnData
from dask import array as da
from dask.array import Array as DaskArray
from xarray import DataArray, Dataset, DataTree

from spatialdata._types import ArrayLike
Expand Down Expand Up @@ -311,3 +312,37 @@ def _error_message_add_element() -> None:
"write_labels(), write_points(), write_shapes() and write_table(). We are going to make these calls more "
"ergonomic in a follow up PR."
)


def _check_match_length_channels_c_dim(
data: DaskArray | DataArray | DataTree, c_coords: str | list[str], dims: tuple[str]
) -> list[str]:
"""
Check whether channel names `c_coords` are of equal length to the `c` dimension of the data.

Parameters
----------
data
The image array
c_coords
The channel names
dims
The axes names in the order that is the same as the `ImageModel` from which it is derived.

Returns
-------
c_coords
The channel names as list
"""
c_index = dims.index("c")
c_length = (
LucaMarconato marked this conversation as resolved.
Show resolved Hide resolved
data.shape[c_index] if isinstance(data, DataArray | DaskArray) else data["scale0"]["image"].shape[c_index]
)
if isinstance(c_coords, str):
c_coords = [c_coords]
if c_coords is not None and len(c_coords) != c_length:
raise ValueError(
f"The number of channel names `{len(c_coords)}` does not match the length of dimension 'c'"
f" with length {c_length}."
)
return c_coords
4 changes: 4 additions & 0 deletions src/spatialdata/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
Z,
force_2d,
get_axes_names,
get_channel_names,
get_channels,
get_spatial_axes,
points_dask_dataframe_to_geopandas,
points_geopandas_to_dask_dataframe,
set_channel_names,
validate_axes,
validate_axis_name,
)
Expand Down Expand Up @@ -49,6 +51,8 @@
"check_target_region_column_symmetry",
"get_table_keys",
"get_channels",
"get_channel_names",
"set_channel_names",
"force_2d",
"RasterSchema",
]
68 changes: 65 additions & 3 deletions src/spatialdata/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from xarray import DataArray, DataTree

from spatialdata._logging import logger
from spatialdata._utils import _check_match_length_channels_c_dim
from spatialdata.transformations.transformations import BaseTransformation

SpatialElement: TypeAlias = DataArray | DataTree | GeoDataFrame | DaskDataFrame
Expand Down Expand Up @@ -268,7 +269,7 @@ def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame, suppress_z_warning: bo


@singledispatch
def get_channels(data: Any) -> list[Any]:
def get_channel_names(data: Any) -> list[Any]:
"""Get channels from data for an image element (both single and multiscale).

Parameters
Expand All @@ -287,12 +288,40 @@ def get_channels(data: Any) -> list[Any]:
raise ValueError(f"Cannot get channels from {type(data)}")


@get_channels.register
def get_channels(data: Any) -> list[Any]:
"""Get channels from data for an image element (both single and multiscale).

[Deprecation] This function will be deprecated in version 0.3.0. Please use
`get_channel_names`.

Parameters
----------
data
data to get channels from

Returns
-------
List of channels

Notes
-----
For multiscale images, the channels are validated to be consistent across scales.
"""
warnings.warn(
"The function 'get_channels' is deprecated and will be removed in version 0.3.0. "
"Please use 'get_channel_names' instead.",
DeprecationWarning,
stacklevel=2, # Adjust the stack level to point to the caller
)
return get_channel_names(data)


@get_channel_names.register
def _(data: DataArray) -> list[Any]:
return data.coords["c"].values.tolist() # type: ignore[no-any-return]


@get_channels.register
@get_channel_names.register
def _(data: DataTree) -> list[Any]:
name = list({list(data[i].data_vars.keys())[0] for i in data})[0]
channels = {tuple(data[i][name].coords["c"].values) for i in data}
Expand Down Expand Up @@ -374,3 +403,36 @@ def convert_region_column_to_categorical(table: AnnData) -> AnnData:
)
table.obs[region_key] = pd.Categorical(table.obs[region_key])
return table


def set_channel_names(element: DataArray | DataTree, channel_names: str | list[str]) -> DataArray | DataTree:
"""Set the channel names for a image `SpatialElement` in the `SpatialData` object.

Parameters
----------
element
The image `SpatialElement` or parsed `ImageModel`.
channel_names
The channel names to be assigned to the c dimension of the image `SpatialElement`.

Returns
-------
element
The image `SpatialElement` or parsed `ImageModel` with the channel names set to the `c` dimension.
"""
from spatialdata.models import Image2DModel, Image3DModel, get_model

channel_names = channel_names if isinstance(channel_names, list) else [channel_names]
model = get_model(element)

# get_model cannot be used due to circular import so get_axes_names is used instead
if model in [Image2DModel, Image3DModel]:
channel_names = _check_match_length_channels_c_dim(element, channel_names, model.dims.dims) # type: ignore[union-attr]
if isinstance(element, DataArray):
element = element.assign_coords(c=channel_names)
else:
element = element.msi.assign_coords({"c": channel_names})
else:
raise TypeError("Element model does not support setting channel names, no `c` dimension found.")

return element
Loading
Loading