Skip to content

Commit

Permalink
fix docstrings and support points
Browse files Browse the repository at this point in the history
  • Loading branch information
quentinblampey committed May 24, 2024
1 parent 2f42b52 commit bdc15e0
Showing 1 changed file with 41 additions and 30 deletions.
71 changes: 41 additions & 30 deletions src/spatialdata/_core/operations/rasterize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from functools import singledispatch
from typing import Any
from warnings import warn

import dask_image.ndinterp
Expand Down Expand Up @@ -155,6 +156,13 @@ def rasterize(
min_coordinate: list[Number] | ArrayLike,
max_coordinate: list[Number] | ArrayLike,
target_coordinate_system: str,
value_key: str | None = None,
values_sdata: SpatialData | None = None,
agg_func: str | ds.reductions.Reduction | None = None,
instance_key_as_default_value_key: bool = False,
return_single_channel: bool = True,
table_name: str | None = None,
return_as_labels: bool = False,
target_unit_to_pixels: float | None = None,
target_width: float | None = None,
target_height: float | None = None,
Expand All @@ -176,6 +184,29 @@ def rasterize(
target_coordinate_system
The coordinate system in which we define the bounding box. This will also be the coordinate system of the
produced rasterized image.
value_key
Name of the column containing the values to aggregate; can refer both to numerical or
categorical values.
The key can be:
- the name of a column(s) in the dataframe (Dask `DataFrame` for points or `GeoDataFrame` for shapes);
- the name of obs column(s) in the associated `AnnData` table (for shapes and labels);
- the name of a var(s), referring to the column(s) of the X matrix in the table (for shapes and labels).
If nothing is passed here, it defaults to the equivalent of a column of ones.
values_sdata
SpatialData object containing the values to aggregate if `value_key` refers to values from a table
agg_func
A reduction function from datashader (its name, or a Callable)
instance_key_as_default_value_key
If `True`, the geometry indices are used as a `value_key`
return_single_channel
If `False`, each category will be count in a separate channel
table_name
The table optionally containing the value_key and the name of the table in the returned `SpatialData` object.
return_as_labels
If `True`, returns labels of shape `(y, x)` instead of an image of shape `(c, y, x)`
target_unit_to_pixels
The number of pixels per unit that the target image should have. It is mandatory to specify precisely one of
the following options: target_unit_to_pixels, target_width, target_height, target_depth.
Expand Down Expand Up @@ -378,6 +409,7 @@ def _(
target_width: float | None = None,
target_height: float | None = None,
target_depth: float | None = None,
**kwargs: Any,
) -> SpatialImage:
min_coordinate = _parse_list_into_array(min_coordinate)
max_coordinate = _parse_list_into_array(max_coordinate)
Expand Down Expand Up @@ -476,34 +508,9 @@ def _(


@rasterize.register(DaskDataFrame)
def _(
data: DaskDataFrame,
axes: tuple[str, ...],
min_coordinate: list[Number] | ArrayLike,
max_coordinate: list[Number] | ArrayLike,
target_coordinate_system: str,
target_unit_to_pixels: float | None = None,
target_width: float | None = None,
target_height: float | None = None,
target_depth: float | None = None,
) -> SpatialImage:
min_coordinate = _parse_list_into_array(min_coordinate)
max_coordinate = _parse_list_into_array(max_coordinate)
target_width, target_height, target_depth = _compute_target_dimensions(
spatial_axes=axes,
min_coordinate=min_coordinate,
max_coordinate=max_coordinate,
target_unit_to_pixels=target_unit_to_pixels,
target_width=target_width,
target_height=target_height,
target_depth=target_depth,
)
raise NotImplementedError()


@rasterize.register(GeoDataFrame)
def _(
data: GeoDataFrame,
data: DaskDataFrame | GeoDataFrame,
axes: tuple[str, ...],
min_coordinate: list[Number] | ArrayLike,
max_coordinate: list[Number] | ArrayLike,
Expand Down Expand Up @@ -547,7 +554,7 @@ def _(
), f"Column name {VALUES_COLUMN} is reserved for internal use. Please rename your column."

if value_key is not None:
data[VALUES_COLUMN] = get_values(value_key, element=data, sdata=values_sdata, table_name=table_name)
data[VALUES_COLUMN] = get_values(value_key, element=data, sdata=values_sdata, table_name=table_name).iloc[:, 0]
elif instance_key_as_default_value_key:
value_key = VALUES_COLUMN
data[VALUES_COLUMN] = data.index.astype("category")
Expand All @@ -566,9 +573,13 @@ def _(
agg_func = getattr(ds, agg_func)(column=value_key)

cnv = ds.Canvas(plot_height=plot_height, plot_width=plot_width, x_range=x_range, y_range=y_range)
agg = cnv.polygons(data, "geometry", agg=agg_func)

if VALUES_COLUMN in data:
if isinstance(data, GeoDataFrame):
agg = cnv.polygons(data, "geometry", agg=agg_func)
else:
agg = cnv.points(data, x="x", y="y", agg=agg_func)

if VALUES_COLUMN in data and isinstance(data, GeoDataFrame):
data.drop(columns=[VALUES_COLUMN], inplace=True)

scale = Scale([(y_range[1] - y_range[0]) / plot_height, (x_range[1] - x_range[0]) / plot_width], axes=("y", "x"))
Expand All @@ -593,7 +604,7 @@ def _(


def _default_agg_func(
data: GeoDataFrame, value_key: str | None, return_single_channel: bool
data: DaskDataFrame | GeoDataFrame, value_key: str | None, return_single_channel: bool
) -> ds.reductions.Reduction:
if value_key is None:
return ds.count()
Expand Down

0 comments on commit bdc15e0

Please sign in to comment.