diff --git a/src/spatialdata/transformations/_utils.py b/src/spatialdata/transformations/_utils.py index 716961bc..7e7f4f8f 100644 --- a/src/spatialdata/transformations/_utils.py +++ b/src/spatialdata/transformations/_utils.py @@ -14,6 +14,7 @@ from spatialdata._types import ArrayLike if TYPE_CHECKING: + from spatialdata._core.spatialdata import SpatialData from spatialdata.models import SpatialElement from spatialdata.models._utils import MappingToCoordinateSystem_t from spatialdata.transformations.transformations import Affine, BaseTransformation, Scale @@ -254,3 +255,30 @@ def scale_radii(radii: ArrayLike, affine: Affine, axes: tuple[str, ...]) -> Arra new_radii = radii * scale_factor assert isinstance(new_radii, np.ndarray) return new_radii + + +def convert_transformations_to_affine(sdata: SpatialData, coordinate_system: str) -> None: + """ + Convert all transformations to the given coordinate system to affine transformations. + + Parameters + ---------- + coordinate_system + The coordinate system to convert to. + + Notes + ----- + The new transformations are modified only in-memory. If you want to save the changes to disk please call + `SpatialData.write_transformations()`. + """ + from spatialdata.transformations.operations import get_transformation, set_transformation + from spatialdata.transformations.transformations import Affine, _get_affine_for_element + + for _, _, element in sdata.gen_spatial_elements(): + transformations = get_transformation(element, get_all=True) + assert isinstance(transformations, dict) + if coordinate_system in transformations: + t = transformations[coordinate_system] + if not isinstance(t, Affine): + affine = _get_affine_for_element(element, t) + set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system) diff --git a/tests/transformations/test_transformations_utils.py b/tests/transformations/test_transformations_utils.py new file mode 100644 index 00000000..2c69d96e --- /dev/null +++ b/tests/transformations/test_transformations_utils.py @@ -0,0 +1,15 @@ +from spatialdata.transformations._utils import convert_transformations_to_affine +from spatialdata.transformations.operations import get_transformation, set_transformation +from spatialdata.transformations.transformations import Affine, Scale, Sequence, Translation + + +def test_convert_transformations_to_affine(full_sdata): + translation = Translation([1, 2, 3], axes=("x", "y", "z")) + scale = Scale([1, 2, 3], axes=("x", "y", "z")) + sequence = Sequence([translation, scale]) + for _, _, element in full_sdata.gen_spatial_elements(): + set_transformation(element, transformation=sequence, to_coordinate_system="test") + convert_transformations_to_affine(full_sdata, "test") + for _, _, element in full_sdata.gen_spatial_elements(): + t = get_transformation(element, "test") + assert isinstance(t, Affine)