From bdc15e00b44caacb6b33595e462c2e6ffd59a3dd Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Fri, 24 May 2024 17:05:58 +0200 Subject: [PATCH] fix docstrings and support points --- src/spatialdata/_core/operations/rasterize.py | 71 +++++++++++-------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index 5ae47406..0c4c1f1e 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -1,6 +1,7 @@ from __future__ import annotations from functools import singledispatch +from typing import Any from warnings import warn import dask_image.ndinterp @@ -155,6 +156,13 @@ def rasterize( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, + value_key: str | None = None, + values_sdata: SpatialData | None = None, + agg_func: str | ds.reductions.Reduction | None = None, + instance_key_as_default_value_key: bool = False, + return_single_channel: bool = True, + table_name: str | None = None, + return_as_labels: bool = False, target_unit_to_pixels: float | None = None, target_width: float | None = None, target_height: float | None = None, @@ -176,6 +184,29 @@ def rasterize( target_coordinate_system The coordinate system in which we define the bounding box. This will also be the coordinate system of the produced rasterized image. + value_key + Name of the column containing the values to aggregate; can refer both to numerical or + categorical values. + + The key can be: + + - the name of a column(s) in the dataframe (Dask `DataFrame` for points or `GeoDataFrame` for shapes); + - the name of obs column(s) in the associated `AnnData` table (for shapes and labels); + - the name of a var(s), referring to the column(s) of the X matrix in the table (for shapes and labels). + + If nothing is passed here, it defaults to the equivalent of a column of ones. + values_sdata + SpatialData object containing the values to aggregate if `value_key` refers to values from a table + agg_func + A reduction function from datashader (its name, or a Callable) + instance_key_as_default_value_key + If `True`, the geometry indices are used as a `value_key` + return_single_channel + If `False`, each category will be count in a separate channel + table_name + The table optionally containing the value_key and the name of the table in the returned `SpatialData` object. + return_as_labels + If `True`, returns labels of shape `(y, x)` instead of an image of shape `(c, y, x)` target_unit_to_pixels The number of pixels per unit that the target image should have. It is mandatory to specify precisely one of the following options: target_unit_to_pixels, target_width, target_height, target_depth. @@ -378,6 +409,7 @@ def _( target_width: float | None = None, target_height: float | None = None, target_depth: float | None = None, + **kwargs: Any, ) -> SpatialImage: min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) @@ -476,34 +508,9 @@ def _( @rasterize.register(DaskDataFrame) -def _( - data: DaskDataFrame, - axes: tuple[str, ...], - min_coordinate: list[Number] | ArrayLike, - max_coordinate: list[Number] | ArrayLike, - target_coordinate_system: str, - target_unit_to_pixels: float | None = None, - target_width: float | None = None, - target_height: float | None = None, - target_depth: float | None = None, -) -> SpatialImage: - min_coordinate = _parse_list_into_array(min_coordinate) - max_coordinate = _parse_list_into_array(max_coordinate) - target_width, target_height, target_depth = _compute_target_dimensions( - spatial_axes=axes, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - target_unit_to_pixels=target_unit_to_pixels, - target_width=target_width, - target_height=target_height, - target_depth=target_depth, - ) - raise NotImplementedError() - - @rasterize.register(GeoDataFrame) def _( - data: GeoDataFrame, + data: DaskDataFrame | GeoDataFrame, axes: tuple[str, ...], min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, @@ -547,7 +554,7 @@ def _( ), f"Column name {VALUES_COLUMN} is reserved for internal use. Please rename your column." if value_key is not None: - data[VALUES_COLUMN] = get_values(value_key, element=data, sdata=values_sdata, table_name=table_name) + data[VALUES_COLUMN] = get_values(value_key, element=data, sdata=values_sdata, table_name=table_name).iloc[:, 0] elif instance_key_as_default_value_key: value_key = VALUES_COLUMN data[VALUES_COLUMN] = data.index.astype("category") @@ -566,9 +573,13 @@ def _( agg_func = getattr(ds, agg_func)(column=value_key) cnv = ds.Canvas(plot_height=plot_height, plot_width=plot_width, x_range=x_range, y_range=y_range) - agg = cnv.polygons(data, "geometry", agg=agg_func) - if VALUES_COLUMN in data: + if isinstance(data, GeoDataFrame): + agg = cnv.polygons(data, "geometry", agg=agg_func) + else: + agg = cnv.points(data, x="x", y="y", agg=agg_func) + + if VALUES_COLUMN in data and isinstance(data, GeoDataFrame): data.drop(columns=[VALUES_COLUMN], inplace=True) scale = Scale([(y_range[1] - y_range[0]) / plot_height, (x_range[1] - x_range[0]) / plot_width], axes=("y", "x")) @@ -593,7 +604,7 @@ def _( def _default_agg_func( - data: GeoDataFrame, value_key: str | None, return_single_channel: bool + data: DaskDataFrame | GeoDataFrame, value_key: str | None, return_single_channel: bool ) -> ds.reductions.Reduction: if value_key is None: return ds.count()