diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 4ae696cd..48c0220d 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -315,7 +315,7 @@ def _error_message_add_element() -> None: def _check_match_length_channels_c_dim( - data: DaskArray | DataArray | DataTree, c_coords: str | list[str], cls_dims: tuple[str] + data: DaskArray | DataArray | DataTree, c_coords: str | list[str], dims: tuple[str, ...] ) -> list[str]: """ Check whether channel names `c_coords` are of equal length to the `c` dimension of the data. @@ -326,15 +326,15 @@ def _check_match_length_channels_c_dim( The image array c_coords The channel names - cls_dims - The dimensions of the particular `ImageModel` + dims + The axes names in the order that is the same as the `ImageModel` from which it is derived. Returns ------- c_coords The channel names as list """ - c_index = cls_dims.index("c") + c_index = dims.index("c") c_length = ( data.shape[c_index] if isinstance(data, DataArray | DaskArray) else data["scale0"]["image"].shape[c_index] ) diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index f5df85e3..fb5a937d 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -13,6 +13,7 @@ from xarray import DataArray, DataTree from spatialdata._logging import logger +from spatialdata._utils import _check_match_length_channels_c_dim from spatialdata.transformations.transformations import BaseTransformation SpatialElement: TypeAlias = DataArray | DataTree | GeoDataFrame | DaskDataFrame @@ -374,3 +375,33 @@ def convert_region_column_to_categorical(table: AnnData) -> AnnData: ) table.obs[region_key] = pd.Categorical(table.obs[region_key]) return table + + +def set_channel_names(element: DataArray | DataTree, channel_names: str | list[str]) -> DataArray | DataTree: + """Set the channel names for a image `SpatialElement` in the `SpatialData` object. + + Parameters + ---------- + element + The image `SpatialElement` or parsed `ImageModel`. + channel_names + The channel names to be assigned to the c dimension of the image `SpatialElement`. + + Returns + ------- + element + The image `SpatialElement` or parsed `ImageModel` with the channel names set to the `c` dimension. + """ + channel_names = channel_names if isinstance(channel_names, list) else [channel_names] + + # get_model cannot be used due to circular import so get_axes_names is used instead + if "c" in (dims := get_axes_names(element)): + channel_names = _check_match_length_channels_c_dim(element, channel_names, dims) # type: ignore[union-attr] + if isinstance(element, DataArray): + element = element.assign_coords(c=channel_names) + else: + element = element.msi.assign_coords({"c": channel_names}) + else: + raise TypeError("Element model does not support setting channel names, no `c` dimension found.") + + return element diff --git a/tests/io/test_metadata.py b/tests/io/test_metadata.py index 8f1ffdf5..c0c02a3d 100644 --- a/tests/io/test_metadata.py +++ b/tests/io/test_metadata.py @@ -7,6 +7,7 @@ from spatialdata import SpatialData, read_zarr from spatialdata._io._utils import _is_element_self_contained from spatialdata._logging import logger +from spatialdata.models import get_channels from spatialdata.transformations import Scale, get_transformation, set_transformation @@ -113,12 +114,14 @@ def test_save_transformations_incremental(element_name, full_sdata, caplog): # test io for channel names @pytest.mark.parametrize("write", ["overwrite", "write", "no"]) def test_save_channel_names_incremental(images: SpatialData, write: str) -> None: + old_channels2d = get_channels(images["image2d"]) + old_channels3d = get_channels(images["image3d_numpy"]) + with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "sdata.zarr") images.write(f0) over_write = write == "overwrite" - old_channels = images["image2d"].coords["c"].data.tolist() new_channels = ["first", "second", "third"] images.set_channel_names("image2d", new_channels, write=over_write) @@ -136,10 +139,10 @@ def test_save_channel_names_incremental(images: SpatialData, write: str) -> None assert images["image3d_numpy"].coords["c"].data.tolist() == new_channels assert images["image3d_multiscale_numpy"]["scale0"]["image"].coords["c"].data.tolist() == new_channels else: - assert images["image2d"].coords["c"].data.tolist() == old_channels - assert images["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == old_channels - assert images["image3d_numpy"].coords["c"].data.tolist() == old_channels - assert images["image3d_multiscale_numpy"]["scale0"]["image"].coords["c"].data.tolist() == old_channels + assert images["image2d"].coords["c"].data.tolist() == old_channels2d + assert images["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == old_channels2d + assert images["image3d_numpy"].coords["c"].data.tolist() == old_channels3d + assert images["image3d_multiscale_numpy"]["scale0"]["image"].coords["c"].data.tolist() == old_channels3d # test io for consolidated metadata