Skip to content

Commit

Permalink
correct refactor and test
Browse files Browse the repository at this point in the history
  • Loading branch information
melonora committed Nov 21, 2024
1 parent 4140091 commit 05445ad
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/spatialdata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _error_message_add_element() -> None:


def _check_match_length_channels_c_dim(
data: DaskArray | DataArray | DataTree, c_coords: str | list[str], cls_dims: tuple[str]
data: DaskArray | DataArray | DataTree, c_coords: str | list[str], dims: tuple[str, ...]
) -> list[str]:
"""
Check whether channel names `c_coords` are of equal length to the `c` dimension of the data.
Expand All @@ -326,15 +326,15 @@ def _check_match_length_channels_c_dim(
The image array
c_coords
The channel names
cls_dims
The dimensions of the particular `ImageModel`
dims
The axes names in the order that is the same as the `ImageModel` from which it is derived.
Returns
-------
c_coords
The channel names as list
"""
c_index = cls_dims.index("c")
c_index = dims.index("c")
c_length = (
data.shape[c_index] if isinstance(data, DataArray | DaskArray) else data["scale0"]["image"].shape[c_index]
)
Expand Down
31 changes: 31 additions & 0 deletions src/spatialdata/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from xarray import DataArray, DataTree

from spatialdata._logging import logger
from spatialdata._utils import _check_match_length_channels_c_dim
from spatialdata.transformations.transformations import BaseTransformation

SpatialElement: TypeAlias = DataArray | DataTree | GeoDataFrame | DaskDataFrame
Expand Down Expand Up @@ -374,3 +375,33 @@ def convert_region_column_to_categorical(table: AnnData) -> AnnData:
)
table.obs[region_key] = pd.Categorical(table.obs[region_key])
return table


def set_channel_names(element: DataArray | DataTree, channel_names: str | list[str]) -> DataArray | DataTree:
"""Set the channel names for a image `SpatialElement` in the `SpatialData` object.
Parameters
----------
element
The image `SpatialElement` or parsed `ImageModel`.
channel_names
The channel names to be assigned to the c dimension of the image `SpatialElement`.
Returns
-------
element
The image `SpatialElement` or parsed `ImageModel` with the channel names set to the `c` dimension.
"""
channel_names = channel_names if isinstance(channel_names, list) else [channel_names]

# get_model cannot be used due to circular import so get_axes_names is used instead
if "c" in (dims := get_axes_names(element)):
channel_names = _check_match_length_channels_c_dim(element, channel_names, dims) # type: ignore[union-attr]
if isinstance(element, DataArray):
element = element.assign_coords(c=channel_names)
else:
element = element.msi.assign_coords({"c": channel_names})
else:
raise TypeError("Element model does not support setting channel names, no `c` dimension found.")

return element
13 changes: 8 additions & 5 deletions tests/io/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from spatialdata import SpatialData, read_zarr
from spatialdata._io._utils import _is_element_self_contained
from spatialdata._logging import logger
from spatialdata.models import get_channels
from spatialdata.transformations import Scale, get_transformation, set_transformation


Expand Down Expand Up @@ -113,12 +114,14 @@ def test_save_transformations_incremental(element_name, full_sdata, caplog):
# test io for channel names
@pytest.mark.parametrize("write", ["overwrite", "write", "no"])
def test_save_channel_names_incremental(images: SpatialData, write: str) -> None:
old_channels2d = get_channels(images["image2d"])
old_channels3d = get_channels(images["image3d_numpy"])

with tempfile.TemporaryDirectory() as tmp_dir:
f0 = os.path.join(tmp_dir, "sdata.zarr")
images.write(f0)

over_write = write == "overwrite"
old_channels = images["image2d"].coords["c"].data.tolist()

new_channels = ["first", "second", "third"]
images.set_channel_names("image2d", new_channels, write=over_write)
Expand All @@ -136,10 +139,10 @@ def test_save_channel_names_incremental(images: SpatialData, write: str) -> None
assert images["image3d_numpy"].coords["c"].data.tolist() == new_channels
assert images["image3d_multiscale_numpy"]["scale0"]["image"].coords["c"].data.tolist() == new_channels
else:
assert images["image2d"].coords["c"].data.tolist() == old_channels
assert images["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == old_channels
assert images["image3d_numpy"].coords["c"].data.tolist() == old_channels
assert images["image3d_multiscale_numpy"]["scale0"]["image"].coords["c"].data.tolist() == old_channels
assert images["image2d"].coords["c"].data.tolist() == old_channels2d
assert images["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == old_channels2d
assert images["image3d_numpy"].coords["c"].data.tolist() == old_channels3d
assert images["image3d_multiscale_numpy"]["scale0"]["image"].coords["c"].data.tolist() == old_channels3d


# test io for consolidated metadata
Expand Down

0 comments on commit 05445ad

Please sign in to comment.