Skip to content

Commit

Permalink
fix swapped x and y for MultiscaleImage in visium_associated_xenium_io
Browse files Browse the repository at this point in the history
  • Loading branch information
Sonja Stockhaus committed Aug 26, 2023
1 parent 3c05d63 commit 8761190
Showing 1 changed file with 25 additions and 56 deletions.
81 changes: 25 additions & 56 deletions src/spatialdata/_core/data_extent.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,9 @@ def _get_extent_of_polygons_multipolygons(
return min_coordinates, max_coordinates, axes


def _get_extent_of_data_array(
e: DataArray, coordinate_system: str
) -> BoundingBoxDescription:
def _get_extent_of_data_array(e: DataArray, coordinate_system: str) -> BoundingBoxDescription:
# lightweight conversion to SpatialImage just to fix the type of the single-dispatch
_check_element_has_coordinate_system(
element=SpatialImage(e), coordinate_system=coordinate_system
)
_check_element_has_coordinate_system(element=SpatialImage(e), coordinate_system=coordinate_system)
# also here
data_axes = get_axes_names(SpatialImage(e))
min_coordinates = []
Expand Down Expand Up @@ -115,7 +111,7 @@ def get_extent(
has_labels: bool = True,
has_points: bool = True,
has_shapes: bool = True,
elements: Iterable[Any] = None,
elements: Iterable[Any] | None = None,
) -> BoundingBoxDescription:
"""
Get the extent (bounding box) of a SpatialData object or a SpatialElement.
Expand Down Expand Up @@ -145,7 +141,7 @@ def _(
has_labels: bool = True,
has_points: bool = True,
has_shapes: bool = True,
elements: Iterable[Any] = None,
elements: Iterable[Any] | None = None,
) -> BoundingBoxDescription:
"""
Get the extent (bounding box) of a SpatialData object: the extent of the union of the extents of all its elements.
Expand All @@ -163,30 +159,22 @@ def _(
new_max_coordinates_dict = defaultdict(list)
mask = [has_images, has_labels, has_points, has_shapes]
include_spatial_elements = ["images", "labels", "points", "shapes"]
include_spatial_elements = [
i for (i, v) in zip(include_spatial_elements, mask) if v
]
include_spatial_elements = [i for (i, v) in zip(include_spatial_elements, mask) if v]

if elements is None: # to shut up ruff
elements = []
if not isinstance(elements, list):
raise ValueError(
f"Invalid type of `elements`: {type(elements)}, expected `list`."
)
raise ValueError(f"Invalid type of `elements`: {type(elements)}, expected `list`.")

for element in e._gen_elements():
plot_element = (len(elements) == 0) or (element[1] in elements)
plot_element = plot_element and (
element[0] in include_spatial_elements
)
plot_element = plot_element and (element[0] in include_spatial_elements)
if plot_element:
transformations = get_transformation(element[2], get_all=True)
assert isinstance(transformations, dict)
coordinate_systems = list(transformations.keys())
if coordinate_system in coordinate_systems:
min_coordinates, max_coordinates, axes = get_extent(
element[2], coordinate_system=coordinate_system
)
min_coordinates, max_coordinates, axes = get_extent(element[2], coordinate_system=coordinate_system)
for i, ax in enumerate(axes):
new_min_coordinates_dict[ax].append(min_coordinates[i])
new_max_coordinates_dict[ax].append(max_coordinates[i])
Expand All @@ -196,12 +184,8 @@ def _(
f"The SpatialData object does not contain any element in the coordinate system {coordinate_system!r}, "
f"please pass a different coordinate system wiht the argument 'coordinate_system'."
)
new_min_coordinates = np.array(
[min(new_min_coordinates_dict[ax]) for ax in axes]
)
new_max_coordinates = np.array(
[max(new_max_coordinates_dict[ax]) for ax in axes]
)
new_min_coordinates = np.array([min(new_min_coordinates_dict[ax]) for ax in axes])
new_max_coordinates = np.array([max(new_max_coordinates_dict[ax]) for ax in axes])
return new_min_coordinates, new_max_coordinates, axes


Expand All @@ -213,7 +197,7 @@ def _(
has_labels: bool = True,
has_points: bool = True,
has_shapes: bool = True,
elements: Iterable[Any] = None,
elements: Iterable[Any] | None = None,
) -> BoundingBoxDescription:
"""
Compute the extent (bounding box) of a set of shapes.
Expand All @@ -222,20 +206,14 @@ def _(
-------
The bounding box description.
"""
_check_element_has_coordinate_system(
element=e, coordinate_system=coordinate_system
)
_check_element_has_coordinate_system(element=e, coordinate_system=coordinate_system)
# remove potentially empty geometries
e_temp = e[e["geometry"].apply(lambda geom: not geom.is_empty)]
if isinstance(e_temp.geometry.iloc[0], Point):
assert (
"radius" in e_temp.columns
), "Shapes must have a 'radius' column."
assert "radius" in e_temp.columns, "Shapes must have a 'radius' column."
min_coordinates, max_coordinates, axes = _get_extent_of_circles(e_temp)
else:
assert isinstance(
e_temp.geometry.iloc[0], (Polygon, MultiPolygon)
), "Shapes must be polygons or multipolygons."
assert isinstance(e_temp.geometry.iloc[0], (Polygon, MultiPolygon)), "Shapes must be polygons or multipolygons."
(
min_coordinates,
max_coordinates,
Expand All @@ -259,11 +237,9 @@ def _(
has_labels: bool = True,
has_points: bool = True,
has_shapes: bool = True,
elements: Iterable[Any] = None,
elements: Iterable[Any] | None = None,
) -> BoundingBoxDescription:
_check_element_has_coordinate_system(
element=e, coordinate_system=coordinate_system
)
_check_element_has_coordinate_system(element=e, coordinate_system=coordinate_system)
axes = get_axes_names(e)
min_coordinates = np.array([e[ax].min().compute() for ax in axes])
max_coordinates = np.array([e[ax].max().compute() for ax in axes])
Expand All @@ -284,7 +260,7 @@ def _(
has_labels: bool = True,
has_points: bool = True,
has_shapes: bool = True,
elements: Iterable[Any] = None,
elements: Iterable[Any] | None = None,
) -> BoundingBoxDescription:
return _get_extent_of_data_array(e, coordinate_system=coordinate_system)

Expand All @@ -297,20 +273,14 @@ def _(
has_labels: bool = True,
has_points: bool = True,
has_shapes: bool = True,
elements: Iterable[Any] = None,
elements: Iterable[Any] | None = None,
) -> BoundingBoxDescription:
_check_element_has_coordinate_system(
element=e, coordinate_system=coordinate_system
)
_check_element_has_coordinate_system(element=e, coordinate_system=coordinate_system)
xdata = next(iter(e["scale0"].values()))
return _get_extent_of_data_array(
xdata, coordinate_system=coordinate_system
)
return _get_extent_of_data_array(xdata, coordinate_system=coordinate_system)


def _check_element_has_coordinate_system(
element: SpatialElement, coordinate_system: str
) -> None:
def _check_element_has_coordinate_system(element: SpatialElement, coordinate_system: str) -> None:
transformations = get_transformation(element, get_all=True)
assert isinstance(transformations, dict)
coordinate_systems = list(transformations.keys())
Expand Down Expand Up @@ -348,9 +318,7 @@ def _compute_extent_in_coordinate_system(
-------
The bounding box description in the specified coordinate system.
"""
transformation = get_transformation(
element, to_coordinate_system=coordinate_system
)
transformation = get_transformation(element, to_coordinate_system=coordinate_system)
assert isinstance(transformation, BaseTransformation)
from spatialdata._core.query._utils import get_bounding_box_corners

Expand All @@ -362,6 +330,7 @@ def _compute_extent_in_coordinate_system(
df = pd.DataFrame(corners.data, columns=corners.axis.data.tolist())
points = PointsModel.parse(df, coordinates={k: k for k in axes})
transformed_corners = transform(points, transformation).compute()
min_coordinates = transformed_corners.min().to_numpy()
max_coordinates = transformed_corners.max().to_numpy()
# Make sure min and max values are in the same order as axes
min_coordinates = transformed_corners.min()[list(axes)].to_numpy()
max_coordinates = transformed_corners.max()[list(axes)].to_numpy()
return min_coordinates, max_coordinates, axes

0 comments on commit 8761190

Please sign in to comment.