Skip to content

Commit

Permalink
can take the min/max in channels aggregation (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
quentinblampey committed Nov 4, 2024
1 parent b7b4b5f commit faeb203
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 15 deletions.
9 changes: 9 additions & 0 deletions docs/api/aggregation.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# Aggregation

!!! tips "Recommendation"
We recommend using the `sopa.aggregate` function below, which is a wrapper for all types of aggregation. Internally, it uses `aggregate_channels`, `count_transcripts`, and/or `aggregate_bins`, which are also documented below if needed.

::: sopa.aggregate

::: sopa.aggregation.aggregate_channels

::: sopa.aggregation.count_transcripts

::: sopa.aggregation.aggregate_bins
2 changes: 1 addition & 1 deletion sopa/aggregation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .bins import aggregate_bins
from .channels import average_channels
from .channels import average_channels, aggregate_channels
from .transcripts import count_transcripts
from .aggregation import aggregate, Aggregator
from .overlay import overlay_segmentation
6 changes: 4 additions & 2 deletions sopa/aggregation/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
get_spatial_element,
get_spatial_image,
)
from . import aggregate_bins, average_channels, count_transcripts
from . import aggregate_bins
from . import aggregate_channels as _aggregate_channels
from . import count_transcripts

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -157,7 +159,7 @@ def compute_table(
self.filter_cells(self.table.X.sum(axis=1) < min_transcripts)

if aggregate_channels:
mean_intensities = average_channels(
mean_intensities = _aggregate_channels(
self.sdata,
image_key=self.image_key,
shapes_key=self.shapes_key,
Expand Down
47 changes: 40 additions & 7 deletions sopa/aggregation/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dask
import geopandas as gpd
import numpy as np
import numpy.ma as ma
import shapely
from dask.diagnostics import ProgressBar
from shapely.geometry import Polygon, box
Expand All @@ -16,35 +17,47 @@

log = logging.getLogger(__name__)

AVAILABLE_MODES = ["average", "min", "max"]


def average_channels(
sdata: SpatialData, image_key: str = None, shapes_key: str = None, expand_radius_ratio: float = 0
) -> np.ndarray:
log.warning("average_channels is deprecated, use `aggregate_channels` instead")
return aggregate_channels(sdata, image_key, shapes_key, expand_radius_ratio, mode="average")


def aggregate_channels(
sdata: SpatialData,
image_key: str = None,
shapes_key: str = None,
expand_radius_ratio: float = 0,
mode: str = "average",
) -> np.ndarray:
"""Average channel intensities per cell.
"""Aggregate the channel intensities per cell (either `"average"`, or take the `"min"` / `"max"`).
Args:
sdata: A `SpatialData` object
image_key: Key of `sdata` containing the image. If only one `images` element, this does not have to be provided.
shapes_key: Key of `sdata` containing the cell boundaries. If only one `shapes` element, this does not have to be provided.
expand_radius_ratio: Cells polygons will be expanded by `expand_radius_ratio * mean_radius`. This help better aggregate boundary stainings.
mode: Aggregation mode. One of `"average"`, `"min"`, `"max"`. By default, average the intensity inside the cell mask.
Returns:
A numpy `ndarray` of shape `(n_cells, n_channels)`
"""
assert mode in AVAILABLE_MODES, f"Invalid {mode=}. Available modes are {AVAILABLE_MODES}"

image = get_spatial_image(sdata, image_key)

geo_df = get_boundaries(sdata, key=shapes_key)
geo_df = to_intrinsic(sdata, geo_df, image)
geo_df = expand_radius(geo_df, expand_radius_ratio)

log.info(f"Averaging channels intensity over {len(geo_df)} cells with expansion {expand_radius_ratio=}")
return _average_channels_aligned(image, geo_df)
return _aggregate_channels_aligned(image, geo_df, mode)


def _average_channels_aligned(image: DataArray, geo_df: gpd.GeoDataFrame | list[Polygon]) -> np.ndarray:
def _aggregate_channels_aligned(image: DataArray, geo_df: gpd.GeoDataFrame | list[Polygon], mode: str) -> np.ndarray:
"""Average channel intensities per cell. The image and cells have to be aligned, i.e. be on the same coordinate system.
Args:
Expand All @@ -54,11 +67,17 @@ def _average_channels_aligned(image: DataArray, geo_df: gpd.GeoDataFrame | list[
Returns:
A numpy `ndarray` of shape `(n_cells, n_channels)`
"""
log.info(f"Aggregating channels intensity over {len(geo_df)} cells with {mode=}")

cells = geo_df if isinstance(geo_df, list) else list(geo_df.geometry)
tree = shapely.STRtree(cells)

intensities = np.zeros((len(cells), len(image.coords["c"])))
n_channels = len(image.coords["c"])
areas = np.zeros(len(cells))
if mode == "min":
aggregation = np.full((len(cells), n_channels), fill_value=np.inf)
else:
aggregation = np.zeros((len(cells), n_channels))

chunk_sizes = image.data.chunks
offsets_y = np.cumsum(np.pad(chunk_sizes[1], (1, 0), "constant"))
Expand Down Expand Up @@ -86,9 +105,20 @@ def _average_chunk_inside_cells(chunk, iy, ix):

mask = rasterize(cell, sub_image.shape[1:], bounds)

intensities[index] += np.sum(sub_image * mask, axis=(1, 2))
areas[index] += np.sum(mask)

if mode == "min":
masked_image = ma.masked_array(sub_image, 1 - np.repeat(mask[None], n_channels, axis=0))
aggregation[index] = np.minimum(aggregation[index], masked_image.min(axis=(1, 2)))
elif mode in ["average", "max"]:
func = np.sum if mode == "average" else np.max
values = func(sub_image * mask, axis=(1, 2))

if mode == "average":
aggregation[index] += values
else:
aggregation[index] = np.maximum(aggregation[index], values)

with ProgressBar():
tasks = [
dask.delayed(_average_chunk_inside_cells)(chunk, iy, ix)
Expand All @@ -97,4 +127,7 @@ def _average_chunk_inside_cells(chunk, iy, ix):
]
dask.compute(tasks)

return intensities / areas[:, None].clip(1)
if mode == "average":
return aggregation / areas[:, None].clip(1)
else:
return aggregation
22 changes: 17 additions & 5 deletions tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import xarray as xr
from shapely.geometry import Polygon, box

from sopa.aggregation.channels import _average_channels_aligned
from sopa.aggregation.channels import _aggregate_channels_aligned
from sopa.aggregation.transcripts import _count_transcripts_aligned

dask.config.set({"dataframe.query-planning": False})
import dask.dataframe as dd # noqa


def test_average_channels_aligned():
def test_aggregate_channels_aligned():
image = np.random.randint(1, 10, size=(3, 8, 16))
arr = da.from_array(image, chunks=(1, 8, 8))
xarr = xr.DataArray(arr, dims=["c", "y", "x"])
Expand All @@ -24,11 +24,23 @@ def test_average_channels_aligned():
# One cell is on the first block, one is overlapping on both blocks, and one is on the last block
cells = [box(x, y, x + cell_size - 1, y + cell_size - 1) for x, y in cell_start]

means = _average_channels_aligned(xarr, cells)
mean_intensities = _aggregate_channels_aligned(xarr, cells, "average")
min_intensities = _aggregate_channels_aligned(xarr, cells, "min")
max_intensities = _aggregate_channels_aligned(xarr, cells, "max")

true_means = np.stack([image[:, y : y + cell_size, x : x + cell_size].mean(axis=(1, 2)) for x, y in cell_start])
true_mean_intensities = np.stack(
[image[:, y : y + cell_size, x : x + cell_size].mean(axis=(1, 2)) for x, y in cell_start]
)
true_min_intensities = np.stack(
[image[:, y : y + cell_size, x : x + cell_size].min(axis=(1, 2)) for x, y in cell_start]
)
true_max_intensities = np.stack(
[image[:, y : y + cell_size, x : x + cell_size].max(axis=(1, 2)) for x, y in cell_start]
)

assert (means == true_means).all()
assert (mean_intensities == true_mean_intensities).all()
assert (min_intensities == true_min_intensities).all()
assert (max_intensities == true_max_intensities).all()


def test_count_transcripts():
Expand Down

0 comments on commit faeb203

Please sign in to comment.