From 33d83205872599856a1e8c75ef70e2e4c983afc4 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 27 Nov 2023 10:24:57 +0100 Subject: [PATCH 01/27] initial tests multi_table design (#405) * initial tests multi_table design * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add mock new table * create test class and cleanup * additional cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * additional cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add pseudo methods --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/conftest.py | 7 ++ tests/io/test_multi_table.py | 197 +++++++++++++++++++++++++++++++++++ 2 files changed, 204 insertions(+) create mode 100644 tests/io/test_multi_table.py diff --git a/tests/conftest.py b/tests/conftest.py index 045f88bf..490cd929 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,8 @@ # isort: off import os +from typing import Any +from collections.abc import Sequence os.environ["USE_PYGEOS"] = "0" # isort:on @@ -288,6 +290,11 @@ def _get_table( return TableModel.parse(adata=adata, region=region, region_key=region_key, instance_key=instance_key) +def _get_new_table(spatial_element: None | str | Sequence[str], instance_id: None | Sequence[Any]) -> AnnData: + adata = AnnData(np.random.default_rng().random(10, 20000)) + return TableModel.parse(adata=adata, spatial_element=spatial_element, instance_id=instance_id) + + @pytest.fixture() def labels_blobs() -> ArrayLike: """Create a 2D labels.""" diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py new file mode 100644 index 00000000..f7d4671c --- /dev/null +++ b/tests/io/test_multi_table.py @@ -0,0 +1,197 @@ +from pathlib import Path + +import anndata as ad +import numpy as np +from anndata import AnnData +from spatialdata import SpatialData + +from tests.conftest import _get_new_table, _get_shapes + +# notes on paths: https://github.com/orgs/scverse/projects/17/views/1?pane=issue&itemId=44066734 +# notes for the people (to prettify) https://hackmd.io/wd7K4Eg1SlykKVN-nOP44w + +# shapes +test_shapes = _get_shapes() +instance_id = np.array([str(i) for i in range(5)]) +table = _get_new_table(spatial_element="test_shapes", instance_id=instance_id) +adata0 = _get_new_table() +adata1 = _get_new_table() + + +# shuffle the indices of the dataframe +np.random.default_rng().shuffle(test_shapes["poly"].index) + +# tables is a dict +SpatialData.tables + +# def get_table_keys(sdata: SpatialData) -> tuple[list[str], str, str]: +# d = sdata.table.uns[sd.models.TableModel.ATTRS_KEY] +# return d['region'], d['region_key'], d['instance_key'] +# +# @staticmethod +# def SpatialData.get_key_column(table: AnnData, key_column: str) -> ...: +# region, region_key, instance_key = sd.models.get_table_keys() +# if key_clumns == 'region_key': +# return table.obs[region_key] +# else: .... +# +# @staticmethod +# def SpatialData.get_region_key_column(table: AnnData | str): +# return get_key_column(...) + +# @staticmethod +# def SpatialData.get_instance_key_column(table: AnnData | str): +# return get_key_column(...) + +# we need also the two set_...() functions + + +def get_annotation_target_of_table(table: AnnData) -> pd.Series: + return SpatialData.get_region_key_column(table) + + +def set_annotation_target_of_table(table: AnnData, spatial_element: str | pd.Series) -> None: + SpatialData.set_instance_key_column(table, spatial_element) + + +class TestMultiTable: + def test_set_get_tables_from_spatialdata(self, sdata): # sdata is form conftest + sdata["my_new_table0"] = adata0 + sdata["my_new_table1"] = adata1 + + def test_old_accessor_deprecation(self, sdata): + # assume no table is present + # this prints a deprecation warning + sdata.table = adata0 # this gets placed in sdata['table'] + # this prints a deprecation warning + _ = sdata.table # this returns sdata['table'] + # this prints a deprecation waring + del sdata.table + + sdata["my_new_table0"] = adata0 + # will fail, because there is no sdata['table'], even if another table is present + _ = sdata.table + + def test_single_table(self, tmp_path: str): + # shared table + tmpdir = Path(tmp_path) / "tmp.zarr" + + test_sdata = SpatialData( + shapes={ + "test_shapes": test_shapes["poly"], + }, + tables={"shape_annotate": table}, + ) + test_sdata.write(tmpdir) + sdata = SpatialData.read(tmpdir) + assert sdata.get("segmentation") + assert isinstance(sdata["segmentation"], AnnData) + from anndata.tests.helpers import assert_equal + + assert assert_equal(test_sdata["segmentation"], sdata["segmentation"]) + + # note (to keep in the code): these tests here should silmulate the interactions from teh users; if the syntax + # here we are matching the table to the shapes and viceversa (= subset + reordeing) + # there is already a function to do one of these two join operations which is match_table_to_element() + # is too verbose/complex we need to adjust the internals to make it smoother + # # use case example 1 + # # sorting the shapes to match the order of the table + # alternatively, we can have a helper function (join, and simpler ones "match_table_to_element()" + # "match_element_to_table()", "match_annotations_order(...)", "mathc_reference_eleemnt_order??(...)") + # sdata["visium0"][SpatialData.get_instance_key_column(sdata.table['visium0'])] + # assert ... + # # use case example 2 + # # sorting the table to match the order of the shapes + # sdata.table.obs.set_index(keys=["__instance_id__"]) + # sdata.table.obs[sdata["visium0"]] + # assert ... + + def test_paired_elements_tables(self, tmp_path: str): + pass + + def test_elements_transfer_annotation(self, tmp_path: str): + test_sdata = SpatialData( + shapes={"test_shapes": test_shapes["poly"], "test_multipoly": test_shapes["multipoly"]}, + tables={"segmentation": table}, + ) + set_annotation_target_of_table(test_sdata["segmentation"], "test_multipoly") + assert get_annotation_target_of_table(test_sdata["segmentation"]) == "test_multipoly" + + def test_single_table_multiple_elements(self, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + + test_sdata = SpatialData( + shapes={ + "test_shapes": test_shapes["poly"], + "test_multipoly": test_shapes["multi_poly"], + }, + tables={"segmentation": table}, + ) + test_sdata.write(tmpdir) + # sdata = SpatialData.read(tmpdir) + + # # use case example 1 + # # sorting the shapes visium0 to match the order of the table + # sdata["visium0"][sdata.table.obs["__instance_id__"][sdata.table.obs["__spatial_element__"] == "visium0"]] + # assert ... + # # use case example 2 + # # subsetting and sorting the table to match the order of the shapes visium0 + # sub_table = sdata.table[sdata.table.obs["__spatial_element"] == "visium0"] + # sub_table.set_index(keys=["__instance_id__"]) + # sub_table.obs[sdata["visium0"]] + # assert ... + + def test_concatenate_tables(self): + table_two = _get_new_table(spatial_element="test_multipoly", instance_id=np.array([str(i) for i in range(2)])) + concatenated_table = ad.concat([table, table_two]) + test_sdata = SpatialData( + shapes={ + "test_shapes": test_shapes["poly"], + "test_multipoly": test_shapes["multi_poly"], + }, + tables={"segmentation": concatenated_table}, + ) + # use case tests as above (we test only visium0) + + def test_multiple_table_without_element(self): + table = _get_new_table() + table_two = _get_new_table() + + test_sdata = SpatialData( + tables={"table": table, "table_two": table_two}, + ) + + def test_multiple_tables_same_element(self, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + table_two = _get_new_table(spatial_element="test_shapes", instance_id=instance_id) + + test_sdata = SpatialData( + shapes={ + "test_shapes": test_shapes["poly"], + }, + tables={"segmentation": table, "segmentation_two": table_two}, + ) + test_sdata.write(tmpdir) + + +# +# # these use cases could be the preferred one for the users; we need to choose one/two preferred ones (either this, either helper function, ...) +# # use cases +# # use case example 1 +# # sorting the shapes to match the order of the table +# sdata["visium0"][sdata.table.obs["__instance_id__"]] +# assert ... +# # use case example 2 +# # sorting the table to match the order of the shapes +# sdata.table.obs.set_index(keys=["__instance_id__"]) +# sdata.table.obs[sdata["visium0"]] +# assert ... +# +# def test_partial_match(): +# # the function spatialdata._core.query.relational_query.match_table_to_element(no s) needs to be modified (will be +# # simpler), we need also a function match_element_to_table. Maybe we can have just one function doing both the +# things, +# # called match_table_and_elements test that tables and elements do not need to have the same indices +# pass +# # the test would check that we cna call SpatiaLData() on such combinations of mismatching elements and that the +# # match_table_to_element-like functions return the correct subset of the data From 75d66f12477558a87c816e5feee7c9a3d3157d56 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 17 Jan 2024 19:11:49 +0100 Subject: [PATCH 02/27] Multi table (#410) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate (#394) * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/psf/black: 23.10.1 → 23.11.0](https://github.com/psf/black/compare/23.10.1...23.11.0) - [github.com/pre-commit/mirrors-prettier: v3.0.3 → v3.1.0](https://github.com/pre-commit/mirrors-prettier/compare/v3.0.3...v3.1.0) - [github.com/pre-commit/mirrors-mypy: v1.6.1 → v1.7.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.6.1...v1.7.0) - [github.com/astral-sh/ruff-pre-commit: v0.1.3 → v0.1.6](https://github.com/astral-sh/ruff-pre-commit/compare/v0.1.3...v0.1.6) * ficx pre-precommit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: giovp * initial tests multi_table design * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add mock new table * create test class and cleanup * additional cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * additional cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add pseudo methods * Change table type in init * make tables plural and add to validation in __init__ * revert to old public accessor * Validate each table in dictionary * iterate dict values * add comment * adjust table getter * Add tables getter * Fix missing parenthesis * change to warnings.warn DeprecationWarning * allow for backward compatibility in init * [pre-commit.ci] pre-commit autoupdate (#408) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix dict subscriptable * fix string representation of sdata * add deprecation decorator for future * Allow for tables not annotating elements * switch to using tables with deprecation * fix string representation * write tables element group * adjust io to multi_table * Alter io to give None as default value for spatialdata attrs keys * add tables setter * raise keyerror table getter * remove commented tables setter * raise keyerror in table deleter * add deprecation warning * fix tests * add DeprecationWarning * comment test * change setter into method * circumvent mappingproxy set issue * adjust set get test * add get table keys * add column getters * add change set target table * Give default table name * fix spatialdata without table * add int32 because of windows and add docstring * fix filtering by coordinate system * Change to Path to not be linux / mac specific * Change to Path to not be linux / mac specific * table should annotate existing element * return table with AnnData having 0 rows * Adjust for windows * adjust for accessing table elements * fix change annotation target * fix set annotation target * fix/add tests * fix init from elements * fix init from elements tests * add validation check * add table validation SpatialData.__init * fix ruff * only concatenate if annotating * change into warning because of filtering * fix last tests * adjust to tables * use tables parameter * fix some mypy * some mypy fixes * some more mypy * fix another mypy * circumvent typing error on py3.9 * mypy yet again * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix pre_commit * down to 12 mypy errors * down to 1mypy error * fixed mypy errors * fix set_table_annotation * added docstring * refactor data loader (#299) Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Marconato * add documentation * add documentation * minor adjustment docstring * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added / adjusted docstrings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix mypy after merge * refactor function name This is to avoid confusion. Many not easily resolved errors are created if we let this function generate table values. This makes it clear that only spatial element values are generated and not tables. This in opposite to gen_elements which does return tables as well. * [pre-commit.ci] pre-commit autoupdate (#411) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * small fixes * added gen_elements docstrings * tiny comments * fix ruff pre-commit * removed types from docstring * refactor of set_table_annotation_target * add quotes * fix (?) * refactor error messages * fix incremental update (#329) Co-authored-by: Wouter-Michiel Vierdag Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add concatenate argument * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add util functions to init * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add util functions to init * add tables class * add table class * add deprecation back * rename function in tests * rename function in tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix precommit * update precommit and remove add_table, store_table and general fixes * adjust tables init to incremental update pr * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix deletion of deprecated table * revert filter change * add public generators * adjust to public generators * Find element uses public generator * add validation in sdata for tables * add deprecation version number * fix mypy errors * Fix backing when deleting table * Fix mypy * cleanup * chance target_element_name to region * refactor test * adjust concatenate regarding not concatenating tables * add utility function * concatenate if in multiple sdata objects * minor docstring refactor * fix import * concatenate with tables * cleanup * fix test * [pre-commit.ci] pre-commit autoupdate (#415) updates: - [github.com/psf/black: 23.11.0 → 23.12.1](https://github.com/psf/black/compare/23.11.0...23.12.1) - [github.com/pre-commit/mirrors-prettier: v4.0.0-alpha.4 → v4.0.0-alpha.8](https://github.com/pre-commit/mirrors-prettier/compare/v4.0.0-alpha.4...v4.0.0-alpha.8) - [github.com/pre-commit/mirrors-mypy: v1.7.1 → v1.8.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.7.1...v1.8.0) - [github.com/astral-sh/ruff-pre-commit: v0.1.7 → v0.1.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.1.7...v0.1.9) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fixed typings; made pytest raises explicit * minor fixes human readable strings * fixed tests * fix * Add changes to changelog * make private * Remove commented code * add cache to ignore * updated changelog with giovp old pr * refactor into private function * Fix docstring * Fix import * specify key reuse in docstring * add orphan_table argument * Change docstring * remove todo * add example * change concatenate logic * updated changelog * Allow force-overwriting existing files (non-backing) (#344) * Add test for writing unbacked data over existing files * Protect overwriting existing file only if it is backing file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Simplify assertion, remove try/except * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed pre-commit * added get_dask_backing_files(); improved sdata.write with overwrite=True * fix docs * changed version in changelog * fix exception string --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Marconato Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> * Added error message for removed add_elements functions (#420) * added error message for removed add_elements functions * moved _error_message_add_element() to _utils * added validate and set region key * fix docstring * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added public functions Spatialdata * fix tests * add docstrings * Added subset API; fix behavior with zero-len table (#426) * added subset API, returning None instead of empty table for APIs with filter_table=True * fix 3.9 * [pre-commit.ci] pre-commit autoupdate (#424) updates: - [github.com/astral-sh/ruff-pre-commit: v0.1.9 → v0.1.11](https://github.com/astral-sh/ruff-pre-commit/compare/v0.1.9...v0.1.11) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * remove docstring typehint * Warn user over overwrite in docstring * Fix query of 2D/3D data with 2D/3D bounding box (#409) * wip * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * wip * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix 2d/3d bb for raster data * support for 2d/3d bb for 2d/3d points * better tests * applied suggestions from giovp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [pre-commit.ci] pre-commit autoupdate (#430) updates: - [github.com/astral-sh/ruff-pre-commit: v0.1.11 → v0.1.13](https://github.com/astral-sh/ruff-pre-commit/compare/v0.1.11...v0.1.13) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * updated changelog * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor subset * remove todo * Made _locate_spatial_element public, renamed to locate_element() (#427) * made _locate_spatial_element public, renamed to locate_element() * returning path instead of tuple in locate_element() * updated changelog * locate_elements() now returns a list * fix test * Update test_and_deploy.yaml (#434) Triggering the tests for pull requests to any branch. * change docstring * fix query test * add todo * refactor filter_by_coordinate_system * test filter with keep table * adjust docstring * adjust docstring --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: giovp Co-authored-by: Giovanni Palla <25887487+giovp@users.noreply.github.com> Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Co-authored-by: Luca Marconato Co-authored-by: aeisenbarth <54448967+aeisenbarth@users.noreply.github.com> --- .github/workflows/test_and_deploy.yaml | 4 +- .gitignore | 3 + .pre-commit-config.yaml | 8 +- .readthedocs.yaml | 2 +- CHANGELOG.md | 50 +- docs/_templates/autosummary/class.rst | 8 - docs/api.md | 16 +- pyproject.toml | 1 - src/spatialdata/__init__.py | 3 +- src/spatialdata/_core/_elements.py | 116 ++ src/spatialdata/_core/_utils.py | 22 + src/spatialdata/_core/concatenate.py | 51 +- src/spatialdata/_core/data_extent.py | 64 +- src/spatialdata/_core/operations/aggregate.py | 9 +- src/spatialdata/_core/operations/rasterize.py | 4 +- src/spatialdata/_core/query/_utils.py | 41 + .../_core/query/relational_query.py | 16 +- src/spatialdata/_core/query/spatial_query.py | 159 +- src/spatialdata/_core/spatialdata.py | 1423 +++++++++-------- src/spatialdata/_io/__init__.py | 2 + src/spatialdata/_io/_utils.py | 129 +- src/spatialdata/_io/io_raster.py | 8 +- src/spatialdata/_io/io_table.py | 12 +- src/spatialdata/_io/io_zarr.py | 59 +- src/spatialdata/_types.py | 8 +- src/spatialdata/_utils.py | 88 +- src/spatialdata/dataloader/__init__.py | 8 +- src/spatialdata/dataloader/datasets.py | 531 ++++-- src/spatialdata/datasets.py | 6 +- src/spatialdata/models/__init__.py | 4 + src/spatialdata/models/_utils.py | 1 - src/spatialdata/models/models.py | 229 ++- src/spatialdata/transformations/operations.py | 12 +- tests/conftest.py | 23 +- tests/core/operations/test_aggregations.py | 3 +- .../operations/test_spatialdata_operations.py | 174 +- tests/core/operations/test_transform.py | 5 +- tests/core/query/test_spatial_query.py | 228 ++- tests/dataloader/test_datasets.py | 151 +- tests/dataloader/test_transforms.py | 0 tests/io/test_multi_table.py | 303 ++-- tests/io/test_readwrite.py | 104 +- tests/io/test_utils.py | 47 +- tests/models/test_models.py | 5 +- 44 files changed, 2675 insertions(+), 1465 deletions(-) create mode 100644 src/spatialdata/_core/_elements.py create mode 100644 src/spatialdata/_core/_utils.py delete mode 100644 tests/dataloader/test_transforms.py diff --git a/.github/workflows/test_and_deploy.yaml b/.github/workflows/test_and_deploy.yaml index 40882d26..d11ca125 100644 --- a/.github/workflows/test_and_deploy.yaml +++ b/.github/workflows/test_and_deploy.yaml @@ -6,7 +6,7 @@ on: tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: - branches: [main] + branches: "*" jobs: test: @@ -63,7 +63,7 @@ jobs: PLATFORM: ${{ matrix.os }} DISPLAY: :42 run: | - pytest -v --cov --color=yes --cov-report=xml + pytest --cov --color=yes --cov-report=xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3.1.1 with: diff --git a/.gitignore b/.gitignore index c0a79ef9..666248e4 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,6 @@ spatialdata-sandbox # version file _version.py + +# other +node_modules/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 89e11456..9af87581 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,11 +9,11 @@ ci: skip: [] repos: - repo: https://github.com/psf/black - rev: 23.10.1 + rev: 23.12.1 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.3 + rev: v4.0.0-alpha.8 hooks: - id: prettier - repo: https://github.com/asottile/blacken-docs @@ -21,13 +21,13 @@ repos: hooks: - id: blacken-docs - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 + rev: v1.8.0 hooks: - id: mypy additional_dependencies: [numpy, types-requests] exclude: tests/|docs/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.3 + rev: v0.1.13 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 690bf115..b59dfb7b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,7 +6,7 @@ build: python: "3.10" sphinx: configuration: docs/conf.py - fail_on_warning: false + fail_on_warning: true python: install: - method: pip diff --git a/CHANGELOG.md b/CHANGELOG.md index f2fe3358..d913b652 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,10 +8,58 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html -## [0.0.15] - tbd +## [0.1.0] - tbd ### Added +#### Major + +- Implemented support in SpatialData for storing multiple tables. These tables + can annotate a SpatialElement but not necessarily so. +- Increased in-memory vs on-disk control: changes performed in-memory (e.g. adding a new image) are not automatically performed on-disk. + +#### Minor + +- Added public helper function get_table_keys in spatialdata.models to retrieve annotation information of a given table. +- Added public helper function check_target_region_column_symmetry in spatialdata.models to check whether annotation + metadata in table.uns['spatialdata_attrs'] corresponds with respective columns in table.obs. +- Added function validate_table_in_spatialdata in SpatialData to validate the annotation target of a table being + present in the SpatialData object. +- Added function get_annotated_regions in SpatialData to get the regions annotated by a given table. +- Added function get_region_key_column in SpatialData to get the region_key column in table.obs. +- Added function get_instance_key_column in SpatialData to get the instance_key column in table.obs. +- Added function set_table_annotates_spatialelement in SpatialData to either set or change the annotation metadata of + a table in a given SpatialData object. +- Added tables property in SpatialData. +- Added tables setter in SpatialData. +- Added gen_spatial_elements generator in SpatialData to generate the SpatialElements in a given SpatialData object. +- Added gen_elements generator in SpatialData to generate elements of a SpatialData object including tables. + +### Changed + +#### Minor + +- Changed the string representation of SpatialData to reflect the changes in regard to multiple tables. + +## [0.0.x] - tbd + +### Minor + +- improved usability and robustness of sdata.write() when overwrite=True @aeisenbarth + +### Added + +- added SpatialData.subset() API +- added SpatialData.locate_element() API + +### Fixed + +- generalized queries to any combination of 2D/3D data and 2D/3D query region #409 + +#### Minor + +- refactored data loader for deep learning + ## [0.0.14] - 2023-10-11 ### Added diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst index d4668a41..e4665dfc 100644 --- a/docs/_templates/autosummary/class.rst +++ b/docs/_templates/autosummary/class.rst @@ -12,11 +12,8 @@ Attributes table ~~~~~~~~~~~~~~~~~~ .. autosummary:: - {% for item in attributes %} - ~{{ fullname }}.{{ item }} - {%- endfor %} {% endif %} {% endblock %} @@ -27,13 +24,10 @@ Methods table ~~~~~~~~~~~~~ .. autosummary:: - {% for item in methods %} - {%- if item != '__init__' %} ~{{ fullname }}.{{ item }} {%- endif -%} - {%- endfor %} {% endif %} {% endblock %} @@ -46,7 +40,6 @@ Attributes {% for item in attributes %} .. autoattribute:: {{ [objname, item] | join(".") }} - {%- endfor %} {% endif %} @@ -61,7 +54,6 @@ Methods {%- if item != '__init__' %} .. automethod:: {{ [objname, item] | join(".") }} - {%- endif -%} {%- endfor %} diff --git a/docs/api.md b/docs/api.md index 696c92e0..93509ffd 100644 --- a/docs/api.md +++ b/docs/api.md @@ -29,12 +29,12 @@ Operations on `SpatialData` objects. get_extent match_table_to_element concatenate - rasterize transform + rasterize aggregate ``` -### Utilities +### Operations Utilities ```{eval-rst} .. autosummary:: @@ -49,6 +49,7 @@ The elements (building-blocks) that consitute `SpatialData`. ```{eval-rst} .. currentmodule:: spatialdata.models + .. autosummary:: :toctree: generated @@ -61,9 +62,11 @@ The elements (building-blocks) that consitute `SpatialData`. TableModel ``` -### Utilities +### Models Utilities ```{eval-rst} +.. currentmodule:: spatialdata.models + .. autosummary:: :toctree: generated @@ -94,9 +97,11 @@ The transformations that can be defined between elements and coordinate systems Sequence ``` -### Utilities +### Transformations Utilities ```{eval-rst} +.. currentmodule:: spatialdata.transformations + .. autosummary:: :toctree: generated @@ -119,7 +124,7 @@ The transformations that can be defined between elements and coordinate systems ImageTilesDataset ``` -## Input/output +## Input/Output ```{eval-rst} .. currentmodule:: spatialdata @@ -129,4 +134,5 @@ The transformations that can be defined between elements and coordinate systems read_zarr save_transformations + get_dask_backing_files ``` diff --git a/pyproject.toml b/pyproject.toml index acd3e191..0904668a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ dev = [ docs = [ "sphinx>=4.5", "sphinx-book-theme>=1.0.0", - "sphinx_rtd_theme", "myst-nb", "sphinxcontrib-bibtex>=1.0.0", "sphinx-autodoc-typehints", diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 0541c491..e09f42c0 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -28,6 +28,7 @@ "read_zarr", "unpad_raster", "save_transformations", + "get_dask_backing_files", ] from spatialdata import dataloader, models, transformations @@ -40,6 +41,6 @@ from spatialdata._core.query.relational_query import get_values, match_table_to_element from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import save_transformations +from spatialdata._io._utils import get_dask_backing_files, save_transformations from spatialdata._io.io_zarr import read_zarr from spatialdata._utils import unpad_raster diff --git a/src/spatialdata/_core/_elements.py b/src/spatialdata/_core/_elements.py new file mode 100644 index 00000000..023128ae --- /dev/null +++ b/src/spatialdata/_core/_elements.py @@ -0,0 +1,116 @@ +"""SpatialData elements.""" +from __future__ import annotations + +from collections import UserDict +from collections.abc import Iterable +from typing import Any +from warnings import warn + +from anndata import AnnData +from dask.dataframe.core import DataFrame as DaskDataFrame +from datatree import DataTree +from geopandas import GeoDataFrame + +from spatialdata._types import Raster_T +from spatialdata._utils import multiscale_spatial_image_from_data_tree +from spatialdata.models import ( + Image2DModel, + Image3DModel, + Labels2DModel, + Labels3DModel, + PointsModel, + ShapesModel, + TableModel, + get_axes_names, + get_model, +) + + +class Elements(UserDict[str, Any]): + def __init__(self, shared_keys: set[str | None]) -> None: + self._shared_keys = shared_keys + super().__init__() + + @staticmethod + def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | None]) -> None: + if key in element_keys: + warn(f"Key `{key}` already exists. Overwriting it.", UserWarning, stacklevel=2) + else: + if key in shared_keys: + raise KeyError(f"Key `{key}` already exists.") + + def __setitem__(self, key: str, value: Any) -> None: + self._shared_keys.add(key) + super().__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + self._shared_keys.remove(key) + super().__delitem__(key) + + +class Images(Elements): + def __setitem__(self, key: str, value: Raster_T) -> None: + self._check_key(key, self.keys(), self._shared_keys) + if isinstance(value, (DataTree)): + value = multiscale_spatial_image_from_data_tree(value) + schema = get_model(value) + if schema not in (Image2DModel, Image3DModel): + raise TypeError(f"Unknown element type with schema: {schema!r}.") + ndim = len(get_axes_names(value)) + if ndim == 3: + Image2DModel().validate(value) + super().__setitem__(key, value) + elif ndim == 4: + Image3DModel().validate(value) + super().__setitem__(key, value) + else: + NotImplementedError("TODO: implement for ndim > 4.") + + +class Labels(Elements): + def __setitem__(self, key: str, value: Raster_T) -> None: + self._check_key(key, self.keys(), self._shared_keys) + if isinstance(value, (DataTree)): + value = multiscale_spatial_image_from_data_tree(value) + schema = get_model(value) + if schema not in (Labels2DModel, Labels3DModel): + raise TypeError(f"Unknown element type with schema: {schema!r}.") + ndim = len(get_axes_names(value)) + if ndim == 2: + Labels2DModel().validate(value) + super().__setitem__(key, value) + elif ndim == 3: + Labels3DModel().validate(value) + super().__setitem__(key, value) + else: + NotImplementedError("TODO: implement for ndim > 3.") + + +class Shapes(Elements): + def __setitem__(self, key: str, value: GeoDataFrame) -> None: + self._check_key(key, self.keys(), self._shared_keys) + schema = get_model(value) + if schema != ShapesModel: + raise TypeError(f"Unknown element type with schema: {schema!r}.") + ShapesModel().validate(value) + super().__setitem__(key, value) + + +class Points(Elements): + def __setitem__(self, key: str, value: DaskDataFrame) -> None: + self._check_key(key, self.keys(), self._shared_keys) + schema = get_model(value) + if schema != PointsModel: + raise TypeError(f"Unknown element type with schema: {schema!r}.") + PointsModel().validate(value) + super().__setitem__(key, value) + + +class Tables(Elements): + def __setitem__(self, key: str, value: AnnData) -> None: + self._check_key(key, self.keys(), self._shared_keys) + schema = get_model(value) + if schema != TableModel: + raise TypeError(f"Unknown element type with schema: {schema!r}.") + TableModel().validate(value) + super().__setitem__(key, value) diff --git a/src/spatialdata/_core/_utils.py b/src/spatialdata/_core/_utils.py new file mode 100644 index 00000000..1c22c802 --- /dev/null +++ b/src/spatialdata/_core/_utils.py @@ -0,0 +1,22 @@ +from spatialdata._core.spatialdata import SpatialData + + +def _find_common_table_keys(sdatas: list[SpatialData]) -> set[str]: + """ + Find table keys present in more than one SpatialData object. + + Parameters + ---------- + sdatas + A list of SpatialData objects. + + Returns + ------- + A set of common keys that are present in the tables of more than one SpatialData object. + """ + common_keys = set(sdatas[0].tables.keys()) + + for sdata in sdatas[1:]: + common_keys.intersection_update(sdata.tables.keys()) + + return common_keys diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 77f82c53..8312d660 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -1,15 +1,16 @@ from __future__ import annotations +from collections import defaultdict from copy import copy # Should probably go up at the top from itertools import chain -from typing import TYPE_CHECKING, Any +from typing import Any +from warnings import warn import numpy as np from anndata import AnnData -if TYPE_CHECKING: - from spatialdata._core.spatialdata import SpatialData - +from spatialdata._core._utils import _find_common_table_keys +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import TableModel __all__ = [ @@ -25,6 +26,8 @@ def _concatenate_tables( ) -> AnnData: import anndata as ad + if not all(TableModel.ATTRS_KEY in table.uns for table in tables): + raise ValueError("Not all tables are annotating a spatial element") region_keys = [table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] for table in tables] instance_keys = [table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] for table in tables] regions = [table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] for table in tables] @@ -73,6 +76,7 @@ def concatenate( sdatas: list[SpatialData], region_key: str | None = None, instance_key: str | None = None, + concatenate_tables: bool = False, **kwargs: Any, ) -> SpatialData: """ @@ -87,6 +91,8 @@ def concatenate( If all region_keys are the same, the `region_key` is used. instance_key The key to use for the instance column in the concatenated object. + concatenate_tables + Whether to merge the tables in case of having the same element name. kwargs See :func:`anndata.concat` for more details. @@ -94,8 +100,6 @@ def concatenate( ------- The concatenated :class:`spatialdata.SpatialData` object. """ - from spatialdata import SpatialData - merged_images = {**{k: v for sdata in sdatas for k, v in sdata.images.items()}} if len(merged_images) != np.sum([len(sdata.images) for sdata in sdatas]): raise KeyError("Images must have unique names across the SpatialData objects to concatenate") @@ -112,16 +116,43 @@ def concatenate( assert isinstance(sdatas, list), "sdatas must be a list" assert len(sdatas) > 0, "sdatas must be a non-empty list" - merged_table = _concatenate_tables( - [sdata.table for sdata in sdatas if sdata.table is not None], region_key, instance_key, **kwargs - ) + if not concatenate_tables: + key_counts: dict[str, int] = defaultdict(int) + for sdata in sdatas: + for k in sdata.tables: + key_counts[k] += 1 + + if any(value > 1 for value in key_counts.values()): + warn( + "Duplicate table names found. Tables will be added with integer suffix. Set concatenate_tables to True" + "if concatenation is wished for instead.", + UserWarning, + stacklevel=2, + ) + merged_tables = {} + count_dict: dict[str, int] = defaultdict(int) + + for sdata in sdatas: + for k, v in sdata.tables.items(): + new_key = f"{k}_{count_dict[k]}" if key_counts[k] > 1 else k + count_dict[k] += 1 + merged_tables[new_key] = v + else: + common_keys = _find_common_table_keys(sdatas) + merged_tables = {} + for sdata in sdatas: + for k, v in sdata.tables.items(): + if k in common_keys and merged_tables.get(k) is not None: + merged_tables[k] = _concatenate_tables([merged_tables[k], v], region_key, instance_key, **kwargs) + else: + merged_tables[k] = v return SpatialData( images=merged_images, labels=merged_labels, points=merged_points, shapes=merged_shapes, - table=merged_table, + tables=merged_tables, ) diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index 251a9e7b..3947fe5f 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -115,9 +115,9 @@ def get_extent( has_labels: bool = True, has_points: bool = True, has_shapes: bool = True, - # python 3.9 tests fail if we don't use Union here, see - # https://github.com/scverse/spatialdata/pull/318#issuecomment-1755714287 - elements: Union[list[str], None] = None, # noqa: UP007 + elements: Union[ # noqa: UP007 # https://github.com/scverse/spatialdata/pull/318#issuecomment-1755714287 + list[str], None + ] = None, ) -> BoundingBoxDescription: """ Get the extent (bounding box) of a SpatialData object or a SpatialElement. @@ -129,43 +129,50 @@ def get_extent( Returns ------- + The bounding box description. + min_coordinate The minimum coordinate of the bounding box. max_coordinate The maximum coordinate of the bounding box. axes - The names of the dimensions of the bounding box + The names of the dimensions of the bounding box. exact - If True, the extent is computed exactly. If False, an approximation faster to compute is given. The - approximation is guaranteed to contain all the data, see notes for details. + Whether the extent is computed exactly or not. + + - If `True`, the extent is computed exactly. + - If `False`, an approximation faster to compute is given. + + The approximation is guaranteed to contain all the data, see notes for details. has_images - If True, images are included in the computation of the extent. + If `True`, images are included in the computation of the extent. has_labels - If True, labels are included in the computation of the extent. + If `True`, labels are included in the computation of the extent. has_points - If True, points are included in the computation of the extent. + If `True`, points are included in the computation of the extent. has_shapes - If True, shapes are included in the computation of the extent. + If `True`, shapes are included in the computation of the extent. elements - If not None, only the elements with the given names are included in the computation of the extent. + If not `None`, only the elements with the given names are included in the computation of the extent. Notes ----- - The extent of a SpatialData object is the extent of the union of the extents of all its elements. The extent of a - SpatialElement is the extent of the element in the coordinate system specified by the argument `coordinate_system`. + The extent of a `SpatialData` object is the extent of the union of the extents of all its elements. + The extent of a `SpatialElement` is the extent of the element in the coordinate system + specified by the argument `coordinate_system`. - If `exact` is False, first the extent of the SpatialElement before any transformation is computed. Then, the extent - is transformed to the target coordinate system. This is faster than computing the extent after the transformation, - since the transformation is applied to extent of the untransformed data, as opposed to transforming the data and - then computing the extent. + If `exact` is `False`, first the extent of the `SpatialElement` before any transformation is computed. + Then, the extent is transformed to the target coordinate system. This is faster than computing the extent + after the transformation, since the transformation is applied to extent of the untransformed data, + as opposed to transforming the data and then computing the extent. - The exact and approximate extent are the same if the transformation doesn't contain any rotation or shear, or in the - case in which the transformation is affine but all the corners of the extent of the untransformed data + The exact and approximate extent are the same if the transformation does not contain any rotation or shear, or in + the case in which the transformation is affine but all the corners of the extent of the untransformed data (bounding box corners) are part of the dataset itself. Note that this is always the case for raster data. - An extreme case is a dataset composed of the two points (0, 0) and (1, 1), rotated anticlockwise by 45 degrees. The - exact extent is the bounding box [minx, miny, maxx, maxy] = [0, 0, 0, 1.414], while the approximate extent is the - box [minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414]. + An extreme case is a dataset composed of the two points `(0, 0)` and `(1, 1)`, rotated anticlockwise by 45 degrees. + The exact extent is the bounding box `[minx, miny, maxx, maxy] = [0, 0, 0, 1.414]`, while the approximate extent is + the box `[minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414]`. """ raise ValueError("The object type is not supported.") @@ -184,7 +191,9 @@ def _( elements: Union[list[str], None] = None, # noqa: UP007 ) -> BoundingBoxDescription: """ - Get the extent (bounding box) of a SpatialData object: the extent of the union of the extents of all its elements. + Get the extent (bounding box) of a SpatialData object. + + The resulting extent is the union of the extents of all its elements. Parameters ---------- @@ -259,7 +268,14 @@ def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription: @get_extent.register def _(e: GeoDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription: """ - Compute the extent (bounding box) of a set of shapes. + Get the extent (bounding box) of a SpatialData object. + + The resulting extent is the union of the extents of all its elements. + + Parameters + ---------- + e + The SpatialData object. Returns ------- diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 9881dc7b..29a78b3f 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any +from typing import Any import anndata as ad import dask as da @@ -20,6 +20,7 @@ from spatialdata._core.operations.transform import transform from spatialdata._core.query._utils import circles_to_polygons from spatialdata._core.query.relational_query import get_values +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.models import ( Image2DModel, @@ -32,9 +33,6 @@ ) from spatialdata.transformations import BaseTransformation, Identity, get_transformation -if TYPE_CHECKING: - from spatialdata import SpatialData - __all__ = ["aggregate"] @@ -236,7 +234,6 @@ def _create_sdata_from_table_and_shapes( instance_key: str, deepcopy: bool, ) -> SpatialData: - from spatialdata import SpatialData from spatialdata._utils import _deepcopy_geodataframe table.obs[instance_key] = table.obs_names.copy() @@ -250,7 +247,7 @@ def _create_sdata_from_table_and_shapes( if deepcopy: shapes = _deepcopy_geodataframe(shapes) - return SpatialData.from_elements_dict({shapes_name: shapes, "": table}) + return SpatialData.from_elements_dict({shapes_name: shapes, "table": table}) def _aggregate_image_by_labels( diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index d1a30c46..a850542f 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -207,8 +207,6 @@ def _( target_height: Optional[float] = None, target_depth: Optional[float] = None, ) -> SpatialData: - from spatialdata import SpatialData - min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) @@ -232,7 +230,7 @@ def _( ) new_name = f"{name}_rasterized_{element_type}" new_images[new_name] = rasterized - return SpatialData(images=new_images, table=sdata.table) + return SpatialData(images=new_images, tables=sdata.tables) # get xdata diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index 25e8caa9..15fbe5c9 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -1,8 +1,13 @@ from __future__ import annotations +from typing import Any + import geopandas as gpd +from anndata import AnnData from xarray import DataArray +from spatialdata._core._elements import Tables +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata._utils import Number, _parse_list_into_array @@ -78,3 +83,39 @@ def get_bounding_box_corners( ], coords={"corner": range(8), "axis": list(axes)}, ) + + +def _get_filtered_or_unfiltered_tables( + filter_table: bool, elements: dict[str, Any], sdata: SpatialData +) -> dict[str, AnnData] | Tables: + """ + Get the tables in a SpatialData object. + + The tables of the SpatialData object can either be filtered to only include the tables that annotate an element in + elements or all tables are returned. + + Parameters + ---------- + filter_table + Specifies whether to filter the tables to only include tables that annotate elements in the retrieved + SpatialData object of the query. + elements + A dictionary containing the elements to use for filtering the tables. + sdata + The SpatialData object that contains the tables to filter. + + Returns + ------- + A dictionary containing the filtered or unfiltered tables based on the value of the 'filter_table' parameter. + + """ + if filter_table: + from spatialdata._core.query.relational_query import _filter_table_by_elements + + return { + name: filtered_table + for name, table in sdata.tables.items() + if (filtered_table := _filter_table_by_elements(table, elements)) and len(filtered_table) != 0 + } + + return sdata.tables diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 14d6e88c..9beb4f16 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import Any import dask.array as da import numpy as np @@ -11,6 +11,7 @@ from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _inplace_fix_subset_categorical_obs from spatialdata.models import ( Labels2DModel, @@ -22,11 +23,8 @@ get_model, ) -if TYPE_CHECKING: - from spatialdata import SpatialData - -def _filter_table_by_coordinate_system(table: AnnData | None, coordinate_system: str | list[str]) -> AnnData | None: +def _filter_table_by_element_names(table: AnnData | None, element_names: str | list[str]) -> AnnData | None: """ Filter an AnnData table to keep only the rows that are in the coordinate system. @@ -34,19 +32,19 @@ def _filter_table_by_coordinate_system(table: AnnData | None, coordinate_system: ---------- table The table to filter; if None, returns None - coordinate_system - The coordinate system to keep + element_names + The element_names to keep in the tables obs.region column Returns ------- The filtered table, or None if the input table was None """ - if table is None: + if table is None or not table.uns.get(TableModel.ATTRS_KEY): return None table_mapping_metadata = table.uns[TableModel.ATTRS_KEY] region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY] table.obs = pd.DataFrame(table.obs) - table = table[table.obs[region_key].isin(coordinate_system)].copy() + table = table[table.obs[region_key].isin(element_names)].copy() table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = table.obs[region_key].unique().tolist() return table diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 2cbd02f3..ef14c465 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -17,7 +17,7 @@ from tqdm import tqdm from xarray import DataArray -from spatialdata._core.query._utils import get_bounding_box_corners +from spatialdata._core.query._utils import _get_filtered_or_unfiltered_tables, get_bounding_box_corners from spatialdata._core.spatialdata import SpatialData from spatialdata._logging import logger from spatialdata._types import ArrayLike @@ -61,7 +61,9 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( target_coordinate_system The coordinate system the bounding box is defined in. - Returns ------- All the corners of the bounding box in the intrinsic coordinate system of the element. The shape + Returns + ------- + All the corners of the bounding box in the intrinsic coordinate system of the element. The shape is (2, 4) when axes has 2 spatial dimensions, and (2, 8) when axes has 3 spatial dimensions. The axes of the intrinsic coordinate system. @@ -73,6 +75,12 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( # get the transformation from the element's intrinsic coordinate system # to the query coordinate space transform_to_query_space = get_transformation(element, to_coordinate_system=target_coordinate_system) + m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_tranformation( + element, target_coordinate_system + ) + axes, min_coordinate, max_coordinate = _adjust_bounding_box_to_real_axes( + axes, min_coordinate, max_coordinate, output_axes_without_c + ) # get the coordinates of the bounding box corners bounding_box_corners = get_bounding_box_corners( @@ -155,7 +163,7 @@ def _bounding_box_mask_points( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, ) -> da.Array: - """Compute a mask that is true for the points inside of an axis-aligned bounding box.. + """Compute a mask that is true for the points inside an axis-aligned bounding box. Parameters ---------- @@ -164,23 +172,26 @@ def _bounding_box_mask_points( axes The axes that min_coordinate and max_coordinate refer to. min_coordinate - The upper left hand corner of the bounding box (i.e., minimum coordinates - along all dimensions). + The upper left hand corner of the bounding box (i.e., minimum coordinates along all dimensions). max_coordinate - The lower right hand corner of the bounding box (i.e., the maximum coordinates - along all dimensions + The lower right hand corner of the bounding box (i.e., the maximum coordinates along all dimensions). Returns ------- - The mask for the points inside of the bounding box. + The mask for the points inside the bounding box. """ + element_axes = get_axes_names(points) min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) in_bounding_box_masks = [] for axis_index, axis_name in enumerate(axes): + if axis_name not in element_axes: + continue min_value = min_coordinate[axis_index] in_bounding_box_masks.append(points[axis_name].gt(min_value).to_dask_array(lengths=True)) for axis_index, axis_name in enumerate(axes): + if axis_name not in element_axes: + continue max_value = max_coordinate[axis_index] in_bounding_box_masks.append(points[axis_name].lt(max_value).to_dask_array(lengths=True)) in_bounding_box_masks = da.stack(in_bounding_box_masks, axis=-1) @@ -248,9 +259,6 @@ def _( target_coordinate_system: str, filter_table: bool = True, ) -> SpatialData: - from spatialdata import SpatialData - from spatialdata._core.query.relational_query import _filter_table_by_elements - min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) new_elements = {} @@ -266,8 +274,80 @@ def _( ) new_elements[element_type] = queried_elements - table = _filter_table_by_elements(sdata.table, new_elements) if filter_table else sdata.table - return SpatialData(**new_elements, table=table) + tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata) + + return SpatialData(**new_elements, tables=tables) + + +def _get_axes_of_tranformation( + element: SpatialElement, target_coordinate_system: str +) -> tuple[ArrayLike, tuple[str, ...], tuple[str, ...]]: + """ + Get the transformation matrix and the transformation's axes (ignoring `c`). + + The transformation is the one from the element's intrinsic coordinate system to the query coordinate space. + Note that the axes which specify the query shape are not necessarily the same as the axes that are output of the + transformation + + Parameters + ---------- + element + SpatialData element to be transformed. + target_coordinate_system + The target coordinate system for the transformation. + + Returns + ------- + m_without_c + The transformation from the element's intrinsic coordinate system to the query coordinate space, without the + "c" axis. + input_axes_without_c + The axes of the element's intrinsic coordinate system, without the "c" axis. + output_axes_without_c + The axes of the query coordinate system, without the "c" axis. + + """ + from spatialdata.transformations import get_transformation + + transform_to_query_space = get_transformation(element, to_coordinate_system=target_coordinate_system) + assert isinstance(transform_to_query_space, BaseTransformation) + m = _get_affine_for_element(element, transform_to_query_space) + input_axes_without_c = tuple([ax for ax in m.input_axes if ax != "c"]) + output_axes_without_c = tuple([ax for ax in m.output_axes if ax != "c"]) + m_without_c = m.to_affine_matrix(input_axes=input_axes_without_c, output_axes=output_axes_without_c) + return m_without_c, input_axes_without_c, output_axes_without_c + + +def _adjust_bounding_box_to_real_axes( + axes: tuple[str, ...], + min_coordinate: ArrayLike, + max_coordinate: ArrayLike, + output_axes_without_c: tuple[str, ...], +) -> tuple[tuple[str, ...], ArrayLike, ArrayLike]: + """ + Adjust the bounding box to the real axes of the transformation. + + The bounding box is defined by the user and it's axes may not coincide with the axes of the transformation. + """ + if set(axes) != set(output_axes_without_c): + axes_only_in_bb = set(axes) - set(output_axes_without_c) + axes_only_in_output = set(output_axes_without_c) - set(axes) + + # let's remove from the bounding box whose axes that are not in the output axes (e.g. querying 2D points with a + # 3D bounding box) + indices_to_remove_from_bb = [axes.index(ax) for ax in axes_only_in_bb] + axes = tuple([ax for ax in axes if ax not in axes_only_in_bb]) + min_coordinate = np.delete(min_coordinate, indices_to_remove_from_bb) + max_coordinate = np.delete(max_coordinate, indices_to_remove_from_bb) + + # if there are axes in the output axes that are not in the bounding box, we need to add them to the bounding box + # with a range that includes everything (e.g. querying 3D points with a 2D bounding box) + for ax in axes_only_in_output: + axes = axes + (ax,) + M = np.finfo(np.float32).max - 1 + min_coordinate = np.append(min_coordinate, -M) + max_coordinate = np.append(max_coordinate, M) + return axes, min_coordinate, max_coordinate @bounding_box_query.register(SpatialImage) @@ -283,7 +363,6 @@ def _( Notes ----- - _____ See https://github.com/scverse/spatialdata/pull/151 for a detailed overview of the logic of this code, and for the cases the comments refer to. """ @@ -300,15 +379,10 @@ def _( max_coordinate=max_coordinate, ) - # get the transformation from the element's intrinsic coordinate system to the query coordinate space - transform_to_query_space = get_transformation(image, to_coordinate_system=target_coordinate_system) - assert isinstance(transform_to_query_space, BaseTransformation) - m = _get_affine_for_element(image, transform_to_query_space) - input_axes_without_c = tuple([ax for ax in m.input_axes if ax != "c"]) - output_axes_without_c = tuple([ax for ax in m.output_axes if ax != "c"]) - m_without_c = m.to_affine_matrix(input_axes=input_axes_without_c, output_axes=output_axes_without_c) + m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_tranformation( + image, target_coordinate_system + ) m_without_c_linear = m_without_c[:-1, :-1] - transform_dimension = np.linalg.matrix_rank(m_without_c_linear) transform_coordinate_length = len(output_axes_without_c) data_dim = len(input_axes_without_c) @@ -336,24 +410,13 @@ def _( error_message = ( f"This case is not supported (data with dimension" f"{data_dim} but transformation with rank {transform_dimension}." - f"Please open a GitHub issue if you want to discuss a case." + f"Please open a GitHub issue if you want to discuss a use case." ) raise ValueError(error_message) - if set(axes) != set(output_axes_without_c): - if set(axes).issubset(output_axes_without_c): - logger.warning( - f"The element has axes {output_axes_without_c}, but the query has axes {axes}. Excluding the element " - f"from the query result. In the future we can add support for this case. If you are interested, " - f"please open a GitHub issue." - ) - return None - error_messeage = ( - f"Invalid case. The bounding box axes are {axes}," - f"the spatial axes in {target_coordinate_system} are" - f"{output_axes_without_c}" - ) - raise ValueError(error_messeage) + axes, min_coordinate, max_coordinate = _adjust_bounding_box_to_real_axes( + axes, min_coordinate, max_coordinate, output_axes_without_c + ) spatial_transform = Affine(m_without_c, input_axes=input_axes_without_c, output_axes=output_axes_without_c) spatial_transform_bb_axes = Affine( @@ -370,7 +433,7 @@ def _( ) else: assert case == 2 - # TODO: we need to intersect the plane in the extrinsic coordiante system with the 3D bounding box. The + # TODO: we need to intersect the plane in the extrinsic coordinate system with the 3D bounding box. The # vertices of this polygons needs to be transformed to the intrinsic coordinate system raise NotImplementedError( "Case 2 (the transformation is embedding 2D data in the 3D space, is not " @@ -570,7 +633,6 @@ def _polygon_query( labels: bool, ) -> SpatialData: from spatialdata._core.query._utils import circles_to_polygons - from spatialdata._core.query.relational_query import _filter_table_by_elements from spatialdata.models import ( PointsModel, ShapesModel, @@ -640,11 +702,10 @@ def _polygon_query( "issue and we will prioritize the implementation." ) - if filter_table and sdata.table is not None: - table = _filter_table_by_elements(sdata.table, {"shapes": new_shapes, "points": new_points}) - else: - table = sdata.table - return SpatialData(shapes=new_shapes, points=new_points, images=new_images, table=table) + elements = {"shapes": new_shapes, "points": new_points} + tables = _get_filtered_or_unfiltered_tables(filter_table, elements, sdata) + + return SpatialData(shapes=new_shapes, points=new_points, images=new_images, tables=tables) # this function is currently excluded from the API documentation. TODO: add it after the refactoring @@ -669,6 +730,9 @@ def polygon_query( The polygon (or list of polygons) to query by target_coordinate_system The coordinate system of the polygon + filter_table + Specifies whether to filter the tables to only include tables that annotate elements in the retrieved + SpatialData object of the query. shapes Whether to filter shapes points @@ -685,8 +749,6 @@ def polygon_query( making this function more general and ergonomic. """ - from spatialdata._core.query.relational_query import _filter_table_by_elements - # adjust coordinate transformation (this implementation can be made faster) sdata = sdata.transform_to_coordinate_system(target_coordinate_system) @@ -749,6 +811,7 @@ def polygon_query( vv = vv[~vv.index.duplicated(keep="first")] geodataframes[k] = vv - table = _filter_table_by_elements(sdata.table, {"shapes": geodataframes}) if filter_table else sdata.table + elements = {"shapes": geodataframes} + tables = _get_filtered_or_unfiltered_tables(filter_table, elements, sdata) - return SpatialData(shapes=geodataframes, table=table) + return SpatialData(shapes=geodataframes, tables=tables) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d0506c48..3cdf91d2 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2,11 +2,13 @@ import hashlib import os +import warnings from collections.abc import Generator +from itertools import chain from pathlib import Path -from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Literal +import pandas as pd import zarr from anndata import AnnData from dask.dataframe import read_parquet @@ -18,17 +20,10 @@ from ome_zarr.types import JSONDict from spatial_image import SpatialImage -from spatialdata._io import ( - write_image, - write_labels, - write_points, - write_shapes, - write_table, -) -from spatialdata._io._utils import get_backing_files +from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables from spatialdata._logging import logger -from spatialdata._types import ArrayLike -from spatialdata._utils import _natural_keys +from spatialdata._types import ArrayLike, Raster_T +from spatialdata._utils import _error_message_add_element, deprecation_alias from spatialdata.models import ( Image2DModel, Image3DModel, @@ -36,11 +31,12 @@ Labels3DModel, PointsModel, ShapesModel, - SpatialElement, TableModel, - get_axes_names, + check_target_region_column_symmetry, get_model, + get_table_keys, ) +from spatialdata.models._utils import SpatialElement, get_axes_names if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest @@ -54,9 +50,6 @@ Point_s = PointsModel() Table_s = TableModel() -# create a shorthand for raster image types -Raster_T = Union[SpatialImage, MultiscaleSpatialImage] - class SpatialData: """ @@ -116,53 +109,91 @@ class SpatialData: """ - _images: dict[str, Raster_T] = MappingProxyType({}) # type: ignore[assignment] - _labels: dict[str, Raster_T] = MappingProxyType({}) # type: ignore[assignment] - _points: dict[str, DaskDataFrame] = MappingProxyType({}) # type: ignore[assignment] - _shapes: dict[str, GeoDataFrame] = MappingProxyType({}) # type: ignore[assignment] - _table: AnnData | None = None - path: str | None = None - + @deprecation_alias(table="tables") def __init__( self, - images: dict[str, Raster_T] = MappingProxyType({}), # type: ignore[assignment] - labels: dict[str, Raster_T] = MappingProxyType({}), # type: ignore[assignment] - points: dict[str, DaskDataFrame] = MappingProxyType({}), # type: ignore[assignment] - shapes: dict[str, GeoDataFrame] = MappingProxyType({}), # type: ignore[assignment] - table: AnnData | None = None, + images: dict[str, Raster_T] | None = None, + labels: dict[str, Raster_T] | None = None, + points: dict[str, DaskDataFrame] | None = None, + shapes: dict[str, GeoDataFrame] | None = None, + tables: dict[str, AnnData] | Tables | None = None, ) -> None: - self.path = None + self._path: Path | None = None + + self._shared_keys: set[str | None] = set() + self._images: Images = Images(shared_keys=self._shared_keys) + self._labels: Labels = Labels(shared_keys=self._shared_keys) + self._points: Points = Points(shared_keys=self._shared_keys) + self._shapes: Shapes = Shapes(shared_keys=self._shared_keys) + self._tables: Tables = Tables(shared_keys=self._shared_keys) + + # Workaround to allow for backward compatibility + if isinstance(tables, AnnData): + tables = {"table": tables} self._validate_unique_element_names( - list(images.keys()) + list(labels.keys()) + list(points.keys()) + list(shapes.keys()) + list(chain.from_iterable([e.keys() for e in [images, labels, points, shapes] if e is not None])) ) if images is not None: - self._images: dict[str, SpatialImage | MultiscaleSpatialImage] = {} for k, v in images.items(): - self._add_image_in_memory(name=k, image=v) + self.images[k] = v if labels is not None: - self._labels: dict[str, SpatialImage | MultiscaleSpatialImage] = {} for k, v in labels.items(): - self._add_labels_in_memory(name=k, labels=v) + self.labels[k] = v if shapes is not None: - self._shapes: dict[str, GeoDataFrame] = {} for k, v in shapes.items(): - self._add_shapes_in_memory(name=k, shapes=v) + self.shapes[k] = v if points is not None: - self._points: dict[str, DaskDataFrame] = {} for k, v in points.items(): - self._add_points_in_memory(name=k, points=v) + self.points[k] = v - if table is not None: - Table_s.validate(table) - self._table = table + if tables is not None: + for k, v in tables.items(): + self.validate_table_in_spatialdata(v) + self.tables[k] = v self._query = QueryManager(self) + def validate_table_in_spatialdata(self, data: AnnData) -> None: + """ + Validate the presence of the annotation target of a SpatialData table in the SpatialData object. + + This method validates a table in the SpatialData object to ensure that if annotation metadata is present, the + annotation target (SpatialElement) is present in the SpatialData object. Otherwise, a warning is raised. + + Parameters + ---------- + data + The table potentially annotating a SpatialElement + + Raises + ------ + UserWarning + If the table is annotating elements not present in the SpatialData object. + """ + TableModel().validate(data) + element_names = [ + element_name for element_type, element_name, _ in self._gen_elements() if element_type != "tables" + ] + if TableModel.ATTRS_KEY in data.uns: + attrs = data.uns[TableModel.ATTRS_KEY] + regions = ( + attrs[TableModel.REGION_KEY] + if isinstance(attrs[TableModel.REGION_KEY], list) + else [attrs[TableModel.REGION_KEY]] + ) + # TODO: check throwing error + if not all(element_name in element_names for element_name in regions): + warnings.warn( + "The table is annotating an/some element(s) not present in the SpatialData object", + UserWarning, + stacklevel=2, + ) + @staticmethod def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> SpatialData: """ @@ -183,7 +214,7 @@ def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> Sp "labels": {}, "points": {}, "shapes": {}, - "table": None, + "tables": {}, } for k, e in elements_dict.items(): schema = get_model(e) @@ -200,13 +231,200 @@ def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> Sp assert isinstance(d["shapes"], dict) d["shapes"][k] = e elif schema == TableModel: - if d["table"] is not None: - raise ValueError("Only one table can be present in the dataset.") - d["table"] = e + assert isinstance(d["tables"], dict) + d["tables"][k] = e else: raise ValueError(f"Unknown schema {schema}") return SpatialData(**d) # type: ignore[arg-type] + @staticmethod + def get_annotated_regions(table: AnnData) -> str | list[str]: + """ + Get the regions annotated by a table. + + Parameters + ---------- + table + The AnnData table for which to retrieve annotated regions. + + Returns + ------- + The annotated regions. + """ + regions, _, _ = get_table_keys(table) + return regions + + @staticmethod + def get_region_key_column(table: AnnData) -> pd.Series: + """Get the column of table.obs containing per row the region annotated by that row. + + Parameters + ---------- + table + The AnnData table. + + Returns + ------- + The region key column. + + Raises + ------ + KeyError + If the region key column is not found in table.obs. + """ + _, region_key, _ = get_table_keys(table) + if table.obs.get(region_key): + return table.obs[region_key] + raise KeyError(f"{region_key} is set as region key column. However the column is not found in table.obs.") + + @staticmethod + def get_instance_key_column(table: AnnData) -> pd.Series: + """ + Return the instance key column in table.obs containing for each row the instance id of that row. + + Parameters + ---------- + table + The AnnData table. + + Returns + ------- + The instance key column. + + Raises + ------ + KeyError + If the instance key column is not found in table.obs. + + """ + _, _, instance_key = get_table_keys(table) + if table.obs.get(instance_key): + return table.obs[instance_key] + raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.") + + @staticmethod + def _set_table_annotation_target( + table: AnnData, + region: str | pd.Series, + region_key: str, + instance_key: str, + ) -> None: + """ + Set the SpatialElement annotation target of an AnnData table. + + This method sets the target annotation element of a table based on the specified parameters. It creates the + `attrs` dictionary for `table.uns` and only after validation that the regions are present in the region_key + column of table.obs updates the annotation metadata of the table. + + Parameters + ---------- + table + The AnnData object containing the data table. + region + The name of the target element for the table annotation. + region_key + The key for the region annotation column in `table.obs`. + instance_key + The key for the instance annotation column in `table.obs`. + + Raises + ------ + ValueError + If `region_key` is not present in the `table.obs` columns. + ValueError + If `instance_key` is not present in the `table.obs` columns. + """ + TableModel()._validate_set_region_key(table, region_key) + TableModel()._validate_set_instance_key(table, instance_key) + attrs = { + TableModel.REGION_KEY: region, + TableModel.REGION_KEY_KEY: region_key, + TableModel.INSTANCE_KEY: instance_key, + } + check_target_region_column_symmetry(table, region_key, region) + table.uns[TableModel.ATTRS_KEY] = attrs + + @staticmethod + def _change_table_annotation_target( + table: AnnData, + region: str | pd.Series, + region_key: None | str = None, + instance_key: None | str = None, + ) -> None: + """Change the annotation target of a table currently having annotation metadata already. + + Parameters + ---------- + table + The table already annotating a SpatialElement. + region + The name of the target SpatialElement for which the table annotation will be changed. + region_key + The name of the region key column in the table. If not provided, it will be extracted from the table's uns + attribute. If present here but also given as argument, the value in the table's uns attribute will be + overwritten. + instance_key + The name of the instance key column in the table. If not provided, it will be extracted from the table's uns + attribute. If present here but also given as argument, the value in the table's uns attribute will be + overwritten. + + Raises + ------ + ValueError + If no region_key is provided, and it is not present in both table.uns['spatialdata_attrs'] and table.obs. + ValueError + If provided region_key is not present in table.obs. + """ + attrs = table.uns[TableModel.ATTRS_KEY] + table_region_key = region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) + + TableModel()._validate_set_region_key(table, region_key) + TableModel()._validate_set_instance_key(table, instance_key) + check_target_region_column_symmetry(table, table_region_key, region) + attrs[TableModel.REGION_KEY] = region + + def set_table_annotates_spatialelement( + self, + table_name: str, + region: str | pd.Series, + region_key: None | str = None, + instance_key: None | str = None, + ) -> None: + """ + Set the SpatialElement annotation target of a given AnnData table. + + Parameters + ---------- + table_name + The name of the table to set the annotation target for. + region + The name of the target element for the annotation. This can either be a string or a pandas Series object. + region_key + The region key for the annotation. If not specified, defaults to None which means the currently set region + key is reused. + instance_key + The instance key for the annotation. If not specified, defaults to None which means the currently set + instance key is reused. + + Raises + ------ + ValueError + If the annotation SpatialElement target is not present in the SpatialData object. + TypeError + If no current annotation metadata is found and both region_key and instance_key are not specified. + """ + table = self.tables[table_name] + element_names = {element[1] for element in self._gen_elements()} + if region not in element_names: + raise ValueError(f"Annotation target '{region}' not present as SpatialElement in " f"SpatialData object.") + + if table.uns.get(TableModel.ATTRS_KEY): + self._change_table_annotation_target(table, region, region_key, instance_key) + elif isinstance(region_key, str) and isinstance(instance_key, str): + self._set_table_annotation_target(table, region, region_key, instance_key) + else: + raise TypeError("No current annotation metadata found. Please specify both region_key and instance_key.") + @property def query(self) -> QueryManager: return self._query @@ -231,8 +449,8 @@ def aggregate( Notes ----- - This function calls :func:`spatialdata.aggregate` with the convenience that values and by can be string - without having to specify the values_sdata and by_sdata, which in that case will be replaced by `self`. + This function calls :func:`spatialdata.aggregate` with the convenience that `values` and `by` can be string + without having to specify the `values_sdata` and `by_sdata`, which in that case will be replaced by `self`. Please see :func:`spatialdata.aggregate` for the complete docstring. @@ -263,114 +481,18 @@ def aggregate( def _validate_unique_element_names(element_names: list[str]) -> None: if len(element_names) != len(set(element_names)): duplicates = {x for x in element_names if element_names.count(x) > 1} - raise ValueError( + raise KeyError( f"Element names must be unique. The following element names are used multiple times: {duplicates}" ) - def _add_image_in_memory( - self, name: str, image: SpatialImage | MultiscaleSpatialImage, overwrite: bool = False - ) -> None: - """Add an image element to the SpatialData object. - - Parameters - ---------- - name - name of the image - image - the image element to be added - overwrite - whether to overwrite the image if the name already exists. - """ - self._validate_unique_element_names( - list(self.labels.keys()) + list(self.points.keys()) + list(self.shapes.keys()) + [name] - ) - if name in self._images and not overwrite: - raise KeyError(f"Image {name} already exists in the dataset.") - ndim = len(get_axes_names(image)) - if ndim == 3: - Image2D_s.validate(image) - self._images[name] = image - elif ndim == 4: - Image3D_s.validate(image) - self._images[name] = image - else: - raise ValueError("Only czyx and cyx images supported") - - def _add_labels_in_memory( - self, name: str, labels: SpatialImage | MultiscaleSpatialImage, overwrite: bool = False - ) -> None: - """ - Add a labels element to the SpatialData object. - - Parameters - ---------- - name - name of the labels - labels - the labels element to be added - overwrite - whether to overwrite the labels if the name already exists. - """ - self._validate_unique_element_names( - list(self.images.keys()) + list(self.points.keys()) + list(self.shapes.keys()) + [name] - ) - if name in self._labels and not overwrite: - raise KeyError(f"Labels {name} already exists in the dataset.") - ndim = len(get_axes_names(labels)) - if ndim == 2: - Label2D_s.validate(labels) - self._labels[name] = labels - elif ndim == 3: - Label3D_s.validate(labels) - self._labels[name] = labels - else: - raise ValueError(f"Only yx and zyx labels supported, got {ndim} dimensions") - - def _add_shapes_in_memory(self, name: str, shapes: GeoDataFrame, overwrite: bool = False) -> None: - """ - Add a shapes element to the SpatialData object. - - Parameters - ---------- - name - name of the shapes - shapes - the shapes element to be added - overwrite - whether to overwrite the shapes if the name already exists. - """ - self._validate_unique_element_names( - list(self.images.keys()) + list(self.points.keys()) + list(self.labels.keys()) + [name] - ) - if name in self._shapes and not overwrite: - raise KeyError(f"Shapes {name} already exists in the dataset.") - Shape_s.validate(shapes) - self._shapes[name] = shapes - - def _add_points_in_memory(self, name: str, points: DaskDataFrame, overwrite: bool = False) -> None: - """ - Add a points element to the SpatialData object. - - Parameters - ---------- - name - name of the points element - points - the points to be added - overwrite - whether to overwrite the points if the name already exists. - """ - self._validate_unique_element_names( - list(self.images.keys()) + list(self.labels.keys()) + list(self.shapes.keys()) + [name] - ) - if name in self._points and not overwrite: - raise KeyError(f"Points {name} already exists in the dataset.") - Point_s.validate(points) - self._points[name] = points - def is_backed(self) -> bool: """Check if the data is backed by a Zarr storage or if it is in-memory.""" - return self.path is not None + return self._path is not None + + @property + def path(self) -> Path | None: + """Path to the Zarr storage.""" + return self._path # TODO: from a commennt from Giovanni: consolite somewhere in # a future PR (luca: also _init_add_element could be cleaned) @@ -387,7 +509,7 @@ def _get_group_for_element(self, name: str, element_type: str) -> zarr.Group: Returns ------- - either the existing Zarr sub-group or a new one + either the existing Zarr sub-group or a new one. """ store = parse_url(self.path, mode="r+").store root = zarr.group(store=store) @@ -396,13 +518,6 @@ def _get_group_for_element(self, name: str, element_type: str) -> zarr.Group: return element_type_group.require_group(name) def _init_add_element(self, name: str, element_type: str, overwrite: bool) -> zarr.Group: - if self.path is None: - # in the future we can relax this, but this ensures that we don't have objects that are partially backed - # and partially in memory - raise RuntimeError( - "The data is not backed by a Zarr storage. In order to add new elements after " - "initializing a SpatialData object you need to call SpatialData.write() first" - ) store = parse_url(self.path, mode="r+").store root = zarr.group(store=store) assert element_type in ["images", "labels", "points", "shapes"] @@ -430,24 +545,19 @@ def _init_add_element(self, name: str, element_type: str, overwrite: bool) -> za return elem_group return root - def _locate_spatial_element(self, element: SpatialElement) -> tuple[str, str]: + def locate_element(self, element: SpatialElement) -> list[str] | None: """ - Find the SpatialElement within the SpatialData object. + Locate a SpatialElement within the SpatialData object and, if found, returns its Zarr path relative to the root. Parameters ---------- element The queried SpatialElement - Returns ------- - name and type of the element - - Raises - ------ - ValueError - the element is not found or found multiple times in the SpatialData object + A list of Zarr paths of the element relative to the root (multiple copies of the same element are allowed), or + None if the element is not found. """ found: list[SpatialElement] = [] found_element_type: list[str] = [] @@ -459,39 +569,8 @@ def _locate_spatial_element(self, element: SpatialElement) -> tuple[str, str]: found_element_type.append(element_type) found_element_name.append(element_name) if len(found) == 0: - raise ValueError("Element not found in the SpatialData object.") - if len(found) > 1: - raise ValueError( - f"Element found multiple times in the SpatialData object." - f"Found {len(found)} elements with names: {found_element_name}," - f" and types: {found_element_type}" - ) - assert len(found_element_name) == 1 - assert len(found_element_type) == 1 - return found_element_name[0], found_element_type[0] - - def contains_element(self, element: SpatialElement, raise_exception: bool = False) -> bool: - """ - Check if the SpatialElement is contained in the SpatialData object. - - Parameters - ---------- - element - The SpatialElement to check - raise_exception - If True, raise an exception if the element is not found. If False, return False if the element is not found. - - Returns - ------- - True if the element is found; False otherwise (if raise_exception is False). - """ - try: - self._locate_spatial_element(element) - return True - except ValueError as e: - if raise_exception: - raise e - return False + return None + return [f"{found_element_type[i]}/{found_element_name[i]}" for i in range(len(found))] def _write_transformations_to_disk(self, element: SpatialElement) -> None: """ @@ -506,27 +585,37 @@ def _write_transformations_to_disk(self, element: SpatialElement) -> None: transformations = get_transformation(element, get_all=True) assert isinstance(transformations, dict) - found_element_name, found_element_type = self._locate_spatial_element(element) - + located = self.locate_element(element) + if located is None: + raise ValueError( + "Cannot save the transformation to the element as it has not been found in the SpatialData object" + ) if self.path is not None: - group = self._get_group_for_element(name=found_element_name, element_type=found_element_type) - axes = get_axes_names(element) - if isinstance(element, (SpatialImage, MultiscaleSpatialImage)): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_raster, - ) + for path in located: + found_element_type, found_element_name = path.split("/") + group = self._get_group_for_element(name=found_element_name, element_type=found_element_type) + axes = get_axes_names(element) + if isinstance(element, (SpatialImage, MultiscaleSpatialImage)): + from spatialdata._io._utils import ( + overwrite_coordinate_transformations_raster, + ) - overwrite_coordinate_transformations_raster(group=group, axes=axes, transformations=transformations) - elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_non_raster, - ) + overwrite_coordinate_transformations_raster(group=group, axes=axes, transformations=transformations) + elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)): + from spatialdata._io._utils import ( + overwrite_coordinate_transformations_non_raster, + ) - overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) - else: - raise ValueError("Unknown element type") + overwrite_coordinate_transformations_non_raster( + group=group, axes=axes, transformations=transformations + ) + else: + raise ValueError("Unknown element type") - def filter_by_coordinate_system(self, coordinate_system: str | list[str], filter_table: bool = True) -> SpatialData: + @deprecation_alias(filter_table="filter_tables") + def filter_by_coordinate_system( + self, coordinate_system: str | list[str], filter_tables: bool = True, include_orphan_tables: bool = False + ) -> SpatialData: """ Filter the SpatialData by one (or a list of) coordinate system. @@ -537,37 +626,104 @@ def filter_by_coordinate_system(self, coordinate_system: str | list[str], filter ---------- coordinate_system The coordinate system(s) to filter by. - filter_table - If True (default), the table will be filtered to only contain regions + filter_tables + If True (default), the tables will be filtered to only contain regions of an element belonging to the specified coordinate system(s). + include_orphan_tables + If True (not default), include tables that do not annotate SpatialElement(s). Only has an effect if + filter_tables is also set to True. Returns ------- The filtered SpatialData. """ - from spatialdata._core.query.relational_query import _filter_table_by_coordinate_system + # TODO: decide whether to add parameter to filter only specific table. + from spatialdata.transformations.operations import get_transformation elements: dict[str, dict[str, SpatialElement]] = {} - element_paths_in_coordinate_system = [] + element_names_in_coordinate_system = [] if isinstance(coordinate_system, str): coordinate_system = [coordinate_system] for element_type, element_name, element in self._gen_elements(): - transformations = get_transformation(element, get_all=True) - assert isinstance(transformations, dict) - for cs in coordinate_system: - if cs in transformations: - if element_type not in elements: - elements[element_type] = {} - elements[element_type][element_name] = element - element_paths_in_coordinate_system.append(element_name) - - if filter_table: - table = _filter_table_by_coordinate_system(self.table, element_paths_in_coordinate_system) + if element_type != "tables": + transformations = get_transformation(element, get_all=True) + assert isinstance(transformations, dict) + for cs in coordinate_system: + if cs in transformations: + if element_type not in elements: + elements[element_type] = {} + elements[element_type][element_name] = element + element_names_in_coordinate_system.append(element_name) + tables = self._filter_tables( + set(), filter_tables, "cs", include_orphan_tables, element_names=element_names_in_coordinate_system + ) + + return SpatialData(**elements, tables=tables) + + # TODO: move to relational query with refactor + def _filter_tables( + self, + names_tables_to_keep: set[str], + filter_tables: bool = True, + by: Literal["cs", "elements"] | None = None, + include_orphan_tables: bool = False, + element_names: str | list[str] | None = None, + elements_dict: dict[str, dict[str, Any]] | None = None, + ) -> Tables | dict[str, AnnData]: + """ + Filter tables by coordinate system or elements or return tables. + + Parameters + ---------- + names_tables_to_keep + The names of the tables to keep even when filter_tables is True. + filter_tables + If True (default), the tables will be filtered to only contain regions + of an element belonging to the specified coordinate system(s) or including only rows annotating specified + elements. + by + Filter mode. Valid values are "cs" or "elements". Default is None. + include_orphan_tables + Flag indicating whether to include orphan tables. Default is False. + element_names + Element names of elements present in specific coordinate system. + elements_dict + Dictionary of elements for filtering the tables. Default is None. + + Returns + ------- + The filtered tables if filter_tables was True, otherwise tables of the SpatialData object. + + """ + if filter_tables: + tables: dict[str, AnnData] | Tables = {} + for table_name, table in self._tables.items(): + if include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): + tables[table_name] = table + continue + if table_name in names_tables_to_keep: + tables[table_name] = table + continue + # each mode here requires paths or elements, using assert here to avoid mypy errors. + if by == "cs": + from spatialdata._core.query.relational_query import _filter_table_by_element_names + + assert element_names is not None + table = _filter_table_by_element_names(table, element_names) + if len(table) != 0: + tables[table_name] = table + elif by == "elements": + from spatialdata._core.query.relational_query import _filter_table_by_elements + + assert elements_dict is not None + table = _filter_table_by_elements(table, elements_dict=elements_dict) + if len(table) != 0: + tables[table_name] = table else: - table = self.table + tables = self.tables - return SpatialData(**elements, table=table) + return tables def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: """ @@ -599,7 +755,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: new_names.append(new_cs) # rename the coordinate systems - for element in self._gen_elements_values(): + for element in self._gen_spatial_element_values(): # get the transformations transformations = get_transformation(element, get_all=True) assert isinstance(transformations, dict) @@ -673,300 +829,15 @@ def transform_to_coordinate_system( ------- The transformed SpatialData. """ - sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_table=False) + sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_tables=False) elements: dict[str, dict[str, SpatialElement]] = {} for element_type, element_name, element in sdata._gen_elements(): - transformed = sdata.transform_element_to_coordinate_system(element, target_coordinate_system) - if element_type not in elements: - elements[element_type] = {} - elements[element_type][element_name] = transformed - return SpatialData(**elements, table=sdata.table) - - def add_image( - self, - name: str, - image: SpatialImage | MultiscaleSpatialImage, - storage_options: JSONDict | list[JSONDict] | None = None, - overwrite: bool = False, - ) -> None: - """ - Add an image to the SpatialData object. - - Parameters - ---------- - name - Key to the element inside the SpatialData object. - image - The image to add, the object needs to pass validation - (see :class:`~spatialdata.Image2DModel` and :class:`~spatialdata.Image3DModel`). - storage_options - Storage options for the Zarr storage. - See https://zarr.readthedocs.io/en/stable/api/storage.html for more details. - overwrite - If True, overwrite the element if it already exists. - - Notes - ----- - If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. - """ - if self.is_backed(): - files = get_backing_files(image) - assert self.path is not None - target_path = os.path.realpath(os.path.join(self.path, "images", name)) - if target_path in files: - raise ValueError( - "Cannot add the image to the SpatialData object because it would overwrite an element that it is" - "using for backing. See more here: https://github.com/scverse/spatialdata/pull/138" - ) - self._add_image_in_memory(name=name, image=image, overwrite=overwrite) - # old code to support overwriting the backing file - # with tempfile.TemporaryDirectory() as tmpdir: - # store = parse_url(Path(tmpdir) / "data.zarr", mode="w").store - # root = zarr.group(store=store) - # write_image( - # image=self.images[name], - # group=root, - # name=name, - # storage_options=storage_options, - # ) - # src_element_path = Path(store.path) / name - # assert isinstance(self.path, str) - # tgt_element_path = Path(self.path) / "images" / name - # if os.path.isdir(tgt_element_path) and overwrite: - # element_store = parse_url(tgt_element_path, mode="w").store - # _ = zarr.group(store=element_store, overwrite=True) - # element_store.close() - # pathlib.Path(tgt_element_path).mkdir(parents=True, exist_ok=True) - # for file in os.listdir(str(src_element_path)): - # src_file = src_element_path / file - # tgt_file = tgt_element_path / file - # os.rename(src_file, tgt_file) - # from spatialdata._io.read import _read_multiscale - # - # # reload the image from the Zarr storage so that now the element is lazy loaded, and most importantly, - # # from the correct storage - # image = _read_multiscale(str(tgt_element_path), raster_type="image") - # self._add_image_in_memory(name=name, image=image, overwrite=True) - elem_group = self._init_add_element(name=name, element_type="images", overwrite=overwrite) - write_image( - image=self.images[name], - group=elem_group, - name=name, - storage_options=storage_options, - ) - from spatialdata._io.io_raster import _read_multiscale - - # reload the image from the Zarr storage so that now the element is lazy loaded, and most importantly, - # from the correct storage - assert elem_group.path == "images" - path = Path(elem_group.store.path) / "images" / name - image = _read_multiscale(path, raster_type="image") - self._add_image_in_memory(name=name, image=image, overwrite=True) - else: - self._add_image_in_memory(name=name, image=image, overwrite=overwrite) - - def add_labels( - self, - name: str, - labels: SpatialImage | MultiscaleSpatialImage, - storage_options: JSONDict | list[JSONDict] | None = None, - overwrite: bool = False, - ) -> None: - """ - Add labels to the SpatialData object. - - Parameters - ---------- - name - Key to the element inside the SpatialData object. - labels - The labels (masks) to add, the object needs to pass validation - (see :class:`~spatialdata.Labels2DModel` and :class:`~spatialdata.Labels3DModel`). - storage_options - Storage options for the Zarr storage. - See https://zarr.readthedocs.io/en/stable/api/storage.html for more details. - overwrite - If True, overwrite the element if it already exists. - - Notes - ----- - If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. - """ - if self.is_backed(): - files = get_backing_files(labels) - assert self.path is not None - target_path = os.path.realpath(os.path.join(self.path, "labels", name)) - if target_path in files: - raise ValueError( - "Cannot add the image to the SpatialData object because it would overwrite an element that it is" - "using for backing. We are considering changing this behavior to allow the overwriting of " - "elements used for backing. If you would like to support this use case please leave a comment on " - "https://github.com/scverse/spatialdata/pull/138" - ) - self._add_labels_in_memory(name=name, labels=labels, overwrite=overwrite) - # old code to support overwriting the backing file - # with tempfile.TemporaryDirectory() as tmpdir: - # store = parse_url(Path(tmpdir) / "data.zarr", mode="w").store - # root = zarr.group(store=store) - # write_labels( - # labels=self.labels[name], - # group=root, - # name=name, - # storage_options=storage_options, - # ) - # src_element_path = Path(store.path) / "labels" / name - # assert isinstance(self.path, str) - # tgt_element_path = Path(self.path) / "labels" / name - # if os.path.isdir(tgt_element_path) and overwrite: - # element_store = parse_url(tgt_element_path, mode="w").store - # _ = zarr.group(store=element_store, overwrite=True) - # element_store.close() - # pathlib.Path(tgt_element_path).mkdir(parents=True, exist_ok=True) - # for file in os.listdir(str(src_element_path)): - # src_file = src_element_path / file - # tgt_file = tgt_element_path / file - # os.rename(src_file, tgt_file) - # from spatialdata._io.read import _read_multiscale - # - # # reload the labels from the Zarr storage so that now the element is lazy loaded, and most importantly, - # # from the correct storage - # labels = _read_multiscale(str(tgt_element_path), raster_type="labels") - # self._add_labels_in_memory(name=name, labels=labels, overwrite=True) - elem_group = self._init_add_element(name=name, element_type="labels", overwrite=overwrite) - write_labels( - labels=self.labels[name], - group=elem_group, - name=name, - storage_options=storage_options, - ) - # reload the labels from the Zarr storage so that now the element is lazy loaded, and most importantly, - # from the correct storage - from spatialdata._io.io_raster import _read_multiscale - - # just a check to make sure that things go as expected - assert elem_group.path == "" - path = Path(elem_group.store.path) / "labels" / name - labels = _read_multiscale(path, raster_type="labels") - self._add_labels_in_memory(name=name, labels=labels, overwrite=True) - else: - self._add_labels_in_memory(name=name, labels=labels, overwrite=overwrite) - - def add_points( - self, - name: str, - points: DaskDataFrame, - overwrite: bool = False, - ) -> None: - """ - Add points to the SpatialData object. - - Parameters - ---------- - name - Key to the element inside the SpatialData object. - points - The points to add, the object needs to pass validation (see :class:`spatialdata.PointsModel`). - storage_options - Storage options for the Zarr storage. - See https://zarr.readthedocs.io/en/stable/api/storage.html for more details. - overwrite - If True, overwrite the element if it already exists. - - Notes - ----- - If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. - """ - if self.is_backed(): - files = get_backing_files(points) - assert self.path is not None - target_path = os.path.realpath(os.path.join(self.path, "points", name, "points.parquet")) - if target_path in files: - raise ValueError( - "Cannot add the image to the SpatialData object because it would overwrite an element that it is" - "using for backing. We are considering changing this behavior to allow the overwriting of " - "elements used for backing. If you would like to support this use case please leave a comment on " - "https://github.com/scverse/spatialdata/pull/138" - ) - self._add_points_in_memory(name=name, points=points, overwrite=overwrite) - # old code to support overwriting the backing file - # with tempfile.TemporaryDirectory() as tmpdir: - # store = parse_url(Path(tmpdir) / "data.zarr", mode="w").store - # root = zarr.group(store=store) - # write_points( - # points=self.points[name], - # group=root, - # name=name, - # ) - # src_element_path = Path(store.path) / name - # assert isinstance(self.path, str) - # tgt_element_path = Path(self.path) / "points" / name - # if os.path.isdir(tgt_element_path) and overwrite: - # element_store = parse_url(tgt_element_path, mode="w").store - # _ = zarr.group(store=element_store, overwrite=True) - # element_store.close() - # pathlib.Path(tgt_element_path).mkdir(parents=True, exist_ok=True) - # for file in os.listdir(str(src_element_path)): - # src_file = src_element_path / file - # tgt_file = tgt_element_path / file - # os.rename(src_file, tgt_file) - # from spatialdata._io.read import _read_points - # - # # reload the points from the Zarr storage so that now the element is lazy loaded, and most importantly, - # # from the correct storage - # points = _read_points(str(tgt_element_path)) - # self._add_points_in_memory(name=name, points=points, overwrite=True) - elem_group = self._init_add_element(name=name, element_type="points", overwrite=overwrite) - write_points( - points=self.points[name], - group=elem_group, - name=name, - ) - # reload the points from the Zarr storage so that now the element is lazy loaded, and most importantly, - # from the correct storage - from spatialdata._io.io_points import _read_points - - assert elem_group.path == "points" - - path = Path(elem_group.store.path) / "points" / name - points = _read_points(path) - self._add_points_in_memory(name=name, points=points, overwrite=True) - else: - self._add_points_in_memory(name=name, points=points, overwrite=overwrite) - - def add_shapes( - self, - name: str, - shapes: GeoDataFrame, - overwrite: bool = False, - ) -> None: - """ - Add shapes to the SpatialData object. - - Parameters - ---------- - name - Key to the element inside the SpatialData object. - shapes - The shapes to add, the object needs to pass validation (see :class:`~spatialdata.ShapesModel`). - storage_options - Storage options for the Zarr storage. - See https://zarr.readthedocs.io/en/stable/api/storage.html for more details. - overwrite - If True, overwrite the element if it already exists. - - Notes - ----- - If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. - """ - self._add_shapes_in_memory(name=name, shapes=shapes, overwrite=overwrite) - if self.is_backed(): - elem_group = self._init_add_element(name=name, element_type="shapes", overwrite=overwrite) - write_shapes( - shapes=self.shapes[name], - group=elem_group, - name=name, - ) - # no reloading of the file storage since the AnnData is not lazy loaded + if element_type != "tables": + transformed = sdata.transform_element_to_coordinate_system(element, target_coordinate_system) + if element_type not in elements: + elements[element_type] = {} + elements[element_type][element_name] = transformed + return SpatialData(**elements, tables=sdata.tables) def write( self, @@ -975,17 +846,21 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, ) -> None: + from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table + from spatialdata._io._utils import get_dask_backing_files + """Write the SpatialData object to Zarr.""" if isinstance(file_path, str): file_path = Path(file_path) assert isinstance(file_path, Path) - if self.is_backed() and self.path != file_path: + if self.is_backed() and str(self.path) != str(file_path): logger.info(f"The Zarr file used for backing will now change from {self.path} to {file_path}") # old code to support overwriting the backing file # target_path = None # tmp_zarr_file = None + if os.path.exists(file_path): if parse_url(file_path, mode="r") is None: raise ValueError( @@ -993,14 +868,22 @@ def write( "a Zarr store. Overwriting non-Zarr stores is not supported to prevent accidental " "data loss." ) - if not overwrite and self.path != str(file_path): + if not overwrite: raise ValueError("The Zarr store already exists. Use `overwrite=True` to overwrite the store.") - raise ValueError( - "The file path specified is the same as the one used for backing. " - "Overwriting the backing file is not supported to prevent accidental data loss." - "We are discussing how to support this use case in the future, if you would like us to " - "support it please leave a comment on https://github.com/scverse/spatialdata/pull/138" - ) + if self.is_backed() and str(self.path) == str(file_path): + raise ValueError( + "The file path specified is the same as the one used for backing. " + "Overwriting the backing file is not supported to prevent accidental data loss." + "We are discussing how to support this use case in the future, if you would like us to " + "support it please leave a comment on https://github.com/scverse/spatialdata/pull/138" + ) + if any(Path(fp).resolve().is_relative_to(file_path.resolve()) for fp in get_dask_backing_files(self)): + raise ValueError( + "The file path specified is a parent directory of one or more files used for backing for one or " + "more elements in the SpatialData object. You can either load every element of the SpatialData " + "object in memory, or save the current spatialdata object to a different path." + ) + # old code to support overwriting the backing file # else: # target_path = tempfile.TemporaryDirectory() @@ -1023,14 +906,13 @@ def write( # self.path = str(file_path) # else: # self.path = str(tmp_zarr_file) - self.path = str(file_path) + self._path = Path(file_path) try: if len(self.images): root.create_group(name="images") # add_image_in_memory will delete and replace the same key in self.images, # so we need to make a copy of the keys. Same for the other elements keys = self.images.keys() - from spatialdata._io.io_raster import _read_multiscale for name in keys: elem_group = self._init_add_element(name=name, element_type="images", overwrite=overwrite) @@ -1041,17 +923,16 @@ def write( storage_options=storage_options, ) + # TODO(giovp): fix or remove # reload the image from the Zarr storage so that now the element is lazy loaded, # and most importantly, from the correct storage - element_path = Path(self.path) / "images" / name - image = _read_multiscale(element_path, raster_type="image") - self._add_image_in_memory(name=name, image=image, overwrite=True) + # element_path = Path(self.path) / "images" / name + # _read_multiscale(element_path, raster_type="image") if len(self.labels): root.create_group(name="labels") # keys = list(self.labels.keys()) keys = self.labels.keys() - from spatialdata._io.io_raster import _read_multiscale for name in keys: elem_group = self._init_add_element(name=name, element_type="labels", overwrite=overwrite) @@ -1062,17 +943,16 @@ def write( storage_options=storage_options, ) + # TODO(giovp): fix or remove # reload the labels from the Zarr storage so that now the element is lazy loaded, # and most importantly, from the correct storage - element_path = Path(self.path) / "labels" / name - labels = _read_multiscale(element_path, raster_type="labels") - self._add_labels_in_memory(name=name, labels=labels, overwrite=True) + # element_path = Path(self.path) / "labels" / name + # _read_multiscale(element_path, raster_type="labels") if len(self.points): root.create_group(name="points") # keys = list(self.points.keys()) keys = self.points.keys() - from spatialdata._io.io_points import _read_points for name in keys: elem_group = self._init_add_element(name=name, element_type="points", overwrite=overwrite) @@ -1081,12 +961,12 @@ def write( group=elem_group, name=name, ) - element_path = Path(self.path) / "points" / name + # TODO(giovp): fix or remove + # element_path = Path(self.path) / "points" / name - # reload the points from the Zarr storage so that the element is lazy loaded, - # and most importantly, from the correct storage - points = _read_points(element_path) - self._add_points_in_memory(name=name, points=points, overwrite=True) + # # reload the points from the Zarr storage so that the element is lazy loaded, + # # and most importantly, from the correct storage + # _read_points(element_path) if len(self.shapes): root.create_group(name="shapes") @@ -1099,14 +979,14 @@ def write( group=elem_group, name=name, ) - # no reloading of the file storage since the AnnData is not lazy loaded - if self.table is not None: - elem_group = root.create_group(name="table") - write_table(table=self.table, group=elem_group, name="table") + if len(self.tables): + elem_group = root.create_group(name="tables") + for key in self.tables: + write_table(table=self.tables[key], group=elem_group, name=key) except Exception as e: # noqa: B902 - self.path = None + self._path = None raise e if consolidate_metadata: @@ -1147,56 +1027,76 @@ def write( # else: # raise ValueError(f"Unknown element type {element_type}") # self.__getattribute__(element_type)[name] = element - assert isinstance(self.path, str) + assert isinstance(self.path, Path) @property - def table(self) -> AnnData: + def tables(self) -> Tables: """ - Return the table. + Return tables dictionary. Returns ------- - The table. + dict[str, AnnData] + Either the empty dictionary or a dictionary with as values the strings representing the table names and + as values the AnnData tables themselves. """ - return self._table + return self._tables - @table.setter - def table(self, table: AnnData) -> None: - """ - Set the table of a SpatialData object in a object that doesn't contain a table. + @tables.setter + def tables(self, shapes: dict[str, GeoDataFrame]) -> None: + """Set tables.""" + self._shared_keys = self._shared_keys - set(self._tables.keys()) + self._tables = Tables(shared_keys=self._shared_keys) + for k, v in shapes.items(): + self._tables[k] = v - Parameters - ---------- - table - The table to set. + @property + def table(self) -> None | AnnData: + """ + Return table with name table from tables if it exists. - Notes - ----- - If a table is already present, it needs to be removed first. - The table needs to pass validation (see :class:`~spatialdata.TableModel`). - If the SpatialData object is backed by a Zarr storage, the table will be written to the Zarr storage. + Returns + ------- + The table. """ + warnings.warn( + "Table accessor will be deprecated with SpatialData version 0.1, use sdata.tables instead.", + DeprecationWarning, + stacklevel=2, + ) + # Isinstance will still return table if anndata has 0 rows. + if isinstance(self.tables.get("table"), AnnData): + return self.tables["table"] + return None + + @table.setter + def table(self, table: AnnData) -> None: + warnings.warn( + "Table setter will be deprecated with SpatialData version 0.1, use tables instead.", + DeprecationWarning, + stacklevel=2, + ) TableModel().validate(table) - if self.table is not None: - raise ValueError("The table already exists. Use del sdata.table to remove it first.") - self._table = table - if self.is_backed(): - store = parse_url(self.path, mode="r+").store - root = zarr.group(store=store) - elem_group = root.require_group(name="table") - write_table(table=self.table, group=elem_group, name="table") + if self.tables.get("table") is not None: + raise ValueError("The table already exists. Use del sdata.tables['table'] to remove it first.") + self.tables["table"] = table @table.deleter def table(self) -> None: """Delete the table.""" - self._table = None - if self.is_backed(): - store = parse_url(self.path, mode="r+").store - root = zarr.group(store=store) - del root["table/table"] + warnings.warn( + "del sdata.table will be deprecated with SpatialData version 0.1, use del sdata.tables['table'] instead.", + DeprecationWarning, + stacklevel=2, + ) + if self.tables.get("table"): + del self.tables["table"] + else: + # More informative than the error in the zarr library. + raise KeyError("table with name 'table' not present in the SpatialData object.") @staticmethod - def read(file_path: str, selection: tuple[str] | None = None) -> SpatialData: + def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialData: """ Read a SpatialData object from a Zarr storage (on-disk or remote). @@ -1215,32 +1115,98 @@ def read(file_path: str, selection: tuple[str] | None = None) -> SpatialData: return read_zarr(file_path, selection=selection) + def add_image( + self, + name: str, + image: SpatialImage | MultiscaleSpatialImage, + storage_options: JSONDict | list[JSONDict] | None = None, + overwrite: bool = False, + ) -> None: + _error_message_add_element() + + def add_labels( + self, + name: str, + labels: SpatialImage | MultiscaleSpatialImage, + storage_options: JSONDict | list[JSONDict] | None = None, + overwrite: bool = False, + ) -> None: + _error_message_add_element() + + def add_points( + self, + name: str, + points: DaskDataFrame, + overwrite: bool = False, + ) -> None: + _error_message_add_element() + + def add_shapes( + self, + name: str, + shapes: GeoDataFrame, + overwrite: bool = False, + ) -> None: + _error_message_add_element() + @property - def images(self) -> dict[str, SpatialImage | MultiscaleSpatialImage]: + def images(self) -> Images: """Return images as a Dict of name to image data.""" return self._images + @images.setter + def images(self, images: dict[str, Raster_T]) -> None: + """Set images.""" + self._shared_keys = self._shared_keys - set(self._images.keys()) + self._images = Images(shared_keys=self._shared_keys) + for k, v in images.items(): + self._images[k] = v + @property - def labels(self) -> dict[str, SpatialImage | MultiscaleSpatialImage]: + def labels(self) -> Labels: """Return labels as a Dict of name to label data.""" return self._labels + @labels.setter + def labels(self, labels: dict[str, Raster_T]) -> None: + """Set labels.""" + self._shared_keys = self._shared_keys - set(self._labels.keys()) + self._labels = Labels(shared_keys=self._shared_keys) + for k, v in labels.items(): + self._labels[k] = v + @property - def points(self) -> dict[str, DaskDataFrame]: + def points(self) -> Points: """Return points as a Dict of name to point data.""" return self._points + @points.setter + def points(self, points: dict[str, DaskDataFrame]) -> None: + """Set points.""" + self._shared_keys = self._shared_keys - set(self._points.keys()) + self._points = Points(shared_keys=self._shared_keys) + for k, v in points.items(): + self._points[k] = v + @property - def shapes(self) -> dict[str, GeoDataFrame]: + def shapes(self) -> Shapes: """Return shapes as a Dict of name to shape data.""" return self._shapes + @shapes.setter + def shapes(self, shapes: dict[str, GeoDataFrame]) -> None: + """Set shapes.""" + self._shared_keys = self._shared_keys - set(self._shapes.keys()) + self._shapes = Shapes(shared_keys=self._shared_keys) + for k, v in shapes.items(): + self._shapes[k] = v + @property def coordinate_systems(self) -> list[str]: from spatialdata.transformations.operations import get_transformation all_cs = set() - gen = self._gen_elements_values() + gen = self._gen_spatial_element_values() for obj in gen: transformations = get_transformation(obj, get_all=True) assert isinstance(transformations, dict) @@ -1256,7 +1222,7 @@ def _non_empty_elements(self) -> list[str]: non_empty_elements The names of the elements that are not empty. """ - all_elements = ["images", "labels", "points", "shapes", "table"] + all_elements = ["images", "labels", "points", "shapes", "tables"] return [ element for element in all_elements @@ -1276,6 +1242,7 @@ def _gen_repr( ------- The string representation of the SpatialData object. """ + from spatialdata._utils import _natural_keys def rreplace(s: str, old: str, new: str, occurrence: int) -> str: li = s.rsplit(old, occurrence) @@ -1293,71 +1260,64 @@ def h(s: str) -> str: attribute = getattr(self, attr) descr += f"\n{h('level0')}{attr.capitalize()}" - if isinstance(attribute, AnnData): + + unsorted_elements = attribute.items() + sorted_elements = sorted(unsorted_elements, key=lambda x: _natural_keys(x[0])) + for k, v in sorted_elements: descr += f"{h('empty_line')}" - descr_class = attribute.__class__.__name__ - descr += f"{h('level1.0')}{attribute!r}: {descr_class} {attribute.shape}" - descr = rreplace(descr, h("level1.0"), " └── ", 1) - else: - unsorted_elements = attribute.items() - sorted_elements = sorted(unsorted_elements, key=lambda x: _natural_keys(x[0])) - for k, v in sorted_elements: - descr += f"{h('empty_line')}" - descr_class = v.__class__.__name__ - if attr == "shapes": - descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"shape: {v.shape} (2D shapes)" - elif attr == "points": - length: int | None = None - if len(v.dask.layers) == 1: - name, layer = v.dask.layers.items().__iter__().__next__() - if "read-parquet" in name: - t = layer.creation_info["args"] - assert isinstance(t, tuple) - assert len(t) == 1 - parquet_file = t[0] - table = read_parquet(parquet_file) - length = len(table) - else: - # length = len(v) - length = None + descr_class = v.__class__.__name__ + if attr == "shapes": + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"shape: {v.shape} (2D shapes)" + elif attr == "points": + length: int | None = None + if len(v.dask.layers) == 1: + name, layer = v.dask.layers.items().__iter__().__next__() + if "read-parquet" in name: + t = layer.creation_info["args"] + assert isinstance(t, tuple) + assert len(t) == 1 + parquet_file = t[0] + table = read_parquet(parquet_file) + length = len(table) else: + # length = len(v) length = None + else: + length = None - n = len(get_axes_names(v)) - dim_string = f"({n}D points)" + n = len(get_axes_names(v)) + dim_string = f"({n}D points)" - assert len(v.shape) == 2 - if length is not None: - shape_str = f"({length}, {v.shape[1]})" - else: - shape_str = ( - "(" - + ", ".join( - [str(dim) if not isinstance(dim, Delayed) else "" for dim in v.shape] - ) - + ")" - ) - descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"with shape: {shape_str} {dim_string}" + assert len(v.shape) == 2 + if length is not None: + shape_str = f"({length}, {v.shape[1]})" else: - if isinstance(v, SpatialImage): - descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{''.join(v.dims)}] {v.shape}" - elif isinstance(v, MultiscaleSpatialImage): - shapes = [] - dims: str | None = None - for pyramid_level in v: - dataset_names = list(v[pyramid_level].keys()) - assert len(dataset_names) == 1 - dataset_name = dataset_names[0] - vv = v[pyramid_level][dataset_name] - shape = vv.shape - if dims is None: - dims = "".join(vv.dims) - shapes.append(shape) - descr += ( - f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{dims}] " f"{', '.join(map(str, shapes))}" - ) - else: - raise TypeError(f"Unknown type {type(v)}") + shape_str = ( + "(" + + ", ".join([str(dim) if not isinstance(dim, Delayed) else "" for dim in v.shape]) + + ")" + ) + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"with shape: {shape_str} {dim_string}" + elif attr == "tables": + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} {v.shape}" + else: + if isinstance(v, SpatialImage): + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{''.join(v.dims)}] {v.shape}" + elif isinstance(v, MultiscaleSpatialImage): + shapes = [] + dims: str | None = None + for pyramid_level in v: + dataset_names = list(v[pyramid_level].keys()) + assert len(dataset_names) == 1 + dataset_name = dataset_names[0] + vv = v[pyramid_level][dataset_name] + shape = vv.shape + if dims is None: + dims = "".join(vv.dims) + shapes.append(shape) + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{dims}] " f"{', '.join(map(str, shapes))}" + else: + raise TypeError(f"Unknown type {type(v)}") if last_attr is True: descr = descr.replace(h("empty_line"), "\n ") else: @@ -1366,7 +1326,7 @@ def h(s: str) -> str: descr = rreplace(descr, h("level0"), "└── ", 1) descr = descr.replace(h("level0"), "├── ") - for attr in ["images", "labels", "points", "table", "shapes"]: + for attr in ["images", "labels", "points", "tables", "shapes"]: descr = rreplace(descr, h(attr + "level1.1"), " └── ", 1) descr = descr.replace(h(attr + "level1.1"), " ├── ") @@ -1380,13 +1340,14 @@ def h(s: str) -> str: gen = self._gen_elements() elements_in_cs: dict[str, list[str]] = {} for k, name, obj in gen: - transformations = get_transformation(obj, get_all=True) - assert isinstance(transformations, dict) - target_css = transformations.keys() - if cs in target_css: - if k not in elements_in_cs: - elements_in_cs[k] = [] - elements_in_cs[k].append(name) + if not isinstance(obj, AnnData): + transformations = get_transformation(obj, get_all=True) + assert isinstance(transformations, dict) + target_css = transformations.keys() + if cs in target_css: + if k not in elements_in_cs: + elements_in_cs[k] = [] + elements_in_cs[k].append(name) for element_names in elements_in_cs.values(): element_names.sort(key=_natural_keys) if len(elements_in_cs) > 0: @@ -1402,26 +1363,97 @@ def h(s: str) -> str: descr += "\n" return descr - def _gen_elements_values(self) -> Generator[SpatialElement, None, None]: + def _gen_spatial_element_values(self) -> Generator[SpatialElement, None, None]: + """ + Generate spatial element objects contained in the SpatialData instance. + + Returns + ------- + Generator[SpatialElement, None, None] + A generator that yields spatial element objects contained in the SpatialData instance. + + """ for element_type in ["images", "labels", "points", "shapes"]: d = getattr(SpatialData, element_type).fget(self) yield from d.values() - def _gen_elements(self) -> Generator[tuple[str, str, SpatialElement], None, None]: - for element_type in ["images", "labels", "points", "shapes"]: + def _gen_elements( + self, include_table: bool = False + ) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: + """ + Generate elements contained in the SpatialData instance. + + Parameters + ---------- + include_table + Whether to also generate table elements. + + Returns + ------- + A generator object that returns a tuple containing the type of the element, its name, and the element + itself. + """ + element_types = ["images", "labels", "points", "shapes"] + if include_table: + element_types.append("tables") + for element_type in element_types: d = getattr(SpatialData, element_type).fget(self) for k, v in d.items(): yield element_type, k, v - def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement]: - for element_type, element_name_, element in self._gen_elements(): + def gen_spatial_elements(self) -> Generator[tuple[str, str, SpatialElement], None, None]: + """ + Generate spatial elements within the SpatialData object. + + This method generates spatial elements (images, labels, points and shapes). + + Returns + ------- + A generator that yields tuples containing the name, description, and SpatialElement objects themselves. + """ + return self._gen_elements() + + def gen_elements(self) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: + """ + Generate elements within the SpatialData object. + + This method generates elements in the SpatialData object (images, labels, points, shapes and tables) + + Returns + ------- + A generator that yields tuples containing the name, description, and element objects themselves. + """ + return self._gen_elements(include_table=True) + + def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | AnnData]: + """ + Retrieve element from the SpatialData instance matching element_name. + + Parameters + ---------- + element_name + The name of the element to find. + + Returns + ------- + A tuple containing the element type, element name, and the retrieved element itself. + + Raises + ------ + KeyError + If the element with the given name cannot be found. + """ + for element_type, element_name_, element in self.gen_elements(): if element_name_ == element_name: return element_type, element_name_, element else: raise KeyError(f"Could not find element with name {element_name!r}") @classmethod - def init_from_elements(cls, elements: dict[str, SpatialElement], table: AnnData | None = None) -> SpatialData: + @deprecation_alias(table="tables") + def init_from_elements( + cls, elements: dict[str, SpatialElement], tables: AnnData | dict[str, AnnData] | None = None + ) -> SpatialData: """ Create a SpatialData object from a dict of named elements and an optional table. @@ -1429,8 +1461,8 @@ def init_from_elements(cls, elements: dict[str, SpatialElement], table: AnnData ---------- elements A dict of named elements. - table - An optional table. + tables + An optional table or dictionary of tables Returns ------- @@ -1449,7 +1481,46 @@ def init_from_elements(cls, elements: dict[str, SpatialElement], table: AnnData assert model == ShapesModel element_type = "shapes" elements_dict.setdefault(element_type, {})[name] = element - return cls(**elements_dict, table=table) + return cls(**elements_dict, tables=tables) + + def subset( + self, element_names: list[str], filter_tables: bool = True, include_orphan_tables: bool = False + ) -> SpatialData: + """ + Subset the SpatialData object. + + Parameters + ---------- + element_names + The names of the element_names to subset. If the element_name is the name of a table, this table would be + completely included in the subset even if filter_table is True. + filter_table + If True (default), the table is filtered to only contain rows that are annotating regions + contained within the element_names. + include_orphan_tables + If True (not default), include tables that do not annotate SpatialElement(s). Only has an effect if + filter_tables is also set to True. + + Returns + ------- + The subsetted SpatialData object. + """ + elements_dict: dict[str, SpatialElement] = {} + names_tables_to_keep: set[str] = set() + for element_type, element_name, element in self._gen_elements(include_table=True): + if element_name in element_names: + if element_type != "tables": + elements_dict.setdefault(element_type, {})[element_name] = element + else: + names_tables_to_keep.add(element_name) + tables = self._filter_tables( + names_tables_to_keep, + filter_tables, + "elements", + include_orphan_tables, + elements_dict=elements_dict, + ) + return SpatialData(**elements_dict, tables=tables) def __getitem__(self, item: str) -> SpatialElement: """ @@ -1480,17 +1551,17 @@ def __setitem__(self, key: str, value: SpatialElement | AnnData) -> None: """ schema = get_model(value) if schema in (Image2DModel, Image3DModel): - self.add_image(key, value) + self.images[key] = value elif schema in (Labels2DModel, Labels3DModel): - self.add_labels(key, value) + self.labels[key] = value elif schema == PointsModel: - self.add_points(key, value) + self.points[key] = value elif schema == ShapesModel: - self.add_shapes(key, value) + self.shapes[key] = value elif schema == TableModel: - raise TypeError("Use the table property to set the table (e.g. sdata.table = value)") + self.tables[key] = value else: - raise TypeError(f"Unknown element type with schema{schema!r}") + raise TypeError(f"Unknown element type with schema: {schema!r}.") class QueryManager: diff --git a/src/spatialdata/_io/__init__.py b/src/spatialdata/_io/__init__.py index fd72da5c..d9fc3cd6 100644 --- a/src/spatialdata/_io/__init__.py +++ b/src/spatialdata/_io/__init__.py @@ -1,3 +1,4 @@ +from spatialdata._io._utils import get_dask_backing_files from spatialdata._io.format import SpatialDataFormatV01 from spatialdata._io.io_points import write_points from spatialdata._io.io_raster import write_image, write_labels @@ -11,4 +12,5 @@ "write_shapes", "write_table", "SpatialDataFormatV01", + "get_dask_backing_files", ] diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index bfa12721..f5caa59d 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -8,17 +8,24 @@ from collections.abc import Generator, Mapping from contextlib import contextmanager from functools import singledispatch -from typing import TYPE_CHECKING, Any +from typing import Any +import numpy as np import zarr +from anndata import AnnData +from anndata import read_zarr as read_anndata_zarr +from anndata.experimental import read_elem +from dask.array.core import Array as DaskArray from dask.dataframe.core import DataFrame as DaskDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from ome_zarr.format import Format from ome_zarr.writer import _get_valid_axes from spatial_image import SpatialImage -from xarray import DataArray +from spatialdata._core.spatialdata import SpatialData +from spatialdata._logging import logger from spatialdata._utils import iterate_pyramid_levels +from spatialdata.models import TableModel from spatialdata.models._utils import ( MappingToCoordinateSystem_t, ValidAxis_t, @@ -30,9 +37,6 @@ _get_current_output_axes, ) -if TYPE_CHECKING: - from spatialdata import SpatialData - # suppress logger debug from ome_zarr with context manager @contextmanager @@ -175,8 +179,8 @@ def _are_directories_identical( if _root_dir2 is None: _root_dir2 = dir2 if exclude_regexp is not None and ( - re.match(rf"{_root_dir1}/" + exclude_regexp, str(dir1)) - or re.match(rf"{_root_dir2}/" + exclude_regexp, str(dir2)) + re.match(rf"{re.escape(str(_root_dir1))}/" + exclude_regexp, str(dir1)) + or re.match(rf"{re.escape(str(_root_dir2))}/" + exclude_regexp, str(dir2)) ): return True @@ -196,8 +200,6 @@ def _are_directories_identical( def _compare_sdata_on_disk(a: SpatialData, b: SpatialData) -> bool: - from spatialdata import SpatialData - if not isinstance(a, SpatialData) or not isinstance(b, SpatialData): return False # TODO: if the sdata object is backed on disk, don't create a new zarr file @@ -207,37 +209,59 @@ def _compare_sdata_on_disk(a: SpatialData, b: SpatialData) -> bool: return _are_directories_identical(os.path.join(tmpdir, "a.zarr"), os.path.join(tmpdir, "b.zarr")) -def _get_backing_files_raster(raster: DataArray) -> list[str]: - files = [] - for k, v in raster.data.dask.layers.items(): - if k.startswith("original-from-zarr-"): - mapping = v.mapping[k] - path = mapping.store.path - files.append(os.path.realpath(path)) - return files +@singledispatch +def get_dask_backing_files(element: SpatialData | SpatialImage | MultiscaleSpatialImage | DaskDataFrame) -> list[str]: + """ + Get the backing files that appear in the Dask computational graph of an element/any element of a SpatialData object. + Parameters + ---------- + element + The element to get the backing files from. -@singledispatch -def get_backing_files(element: SpatialImage | MultiscaleSpatialImage | DaskDataFrame) -> list[str]: + Returns + ------- + List of backing files. + + Notes + ----- + It is possible for lazy objects to be constructed from multiple files. + """ raise TypeError(f"Unsupported type: {type(element)}") -@get_backing_files.register(SpatialImage) +@get_dask_backing_files.register(SpatialData) +def _(element: SpatialData) -> list[str]: + files: set[str] = set() + for e in element._gen_spatial_element_values(): + if isinstance(e, (SpatialImage, MultiscaleSpatialImage, DaskDataFrame)): + files = files.union(get_dask_backing_files(e)) + return list(files) + + +@get_dask_backing_files.register(SpatialImage) def _(element: SpatialImage) -> list[str]: - return _get_backing_files_raster(element) + return _get_backing_files(element.data) -@get_backing_files.register(MultiscaleSpatialImage) +@get_dask_backing_files.register(MultiscaleSpatialImage) def _(element: MultiscaleSpatialImage) -> list[str]: xdata0 = next(iter(iterate_pyramid_levels(element))) - return _get_backing_files_raster(xdata0) + return _get_backing_files(xdata0.data) -@get_backing_files.register(DaskDataFrame) +@get_dask_backing_files.register(DaskDataFrame) def _(element: DaskDataFrame) -> list[str]: + return _get_backing_files(element) + + +def _get_backing_files(element: DaskArray | DaskDataFrame) -> list[str]: files = [] - layers = element.dask.layers - for k, v in layers.items(): + for k, v in element.dask.layers.items(): + if k.startswith("original-from-zarr-"): + mapping = v.mapping[k] + path = mapping.store.path + files.append(os.path.realpath(path)) if k.startswith("read-parquet-"): t = v.creation_info["args"] assert isinstance(t, tuple) @@ -286,6 +310,57 @@ def save_transformations(sdata: SpatialData) -> None: """ from spatialdata.transformations import get_transformation, set_transformation - for element in sdata._gen_elements_values(): + for element in sdata._gen_spatial_element_values(): transformations = get_transformation(element, get_all=True) set_transformation(element, transformations, set_all=True, write_to_sdata=sdata) + + +def read_table_and_validate( + zarr_store_path: str, group: zarr.Group, subgroup: zarr.Group, tables: dict[str, AnnData] +) -> dict[str, AnnData]: + """ + Read in tables in the tables Zarr.group of a SpatialData Zarr store. + + Parameters + ---------- + zarr_store_path + The path to the Zarr store. + group + The parent group containing the subgroup. + subgroup + The subgroup containing the tables. + tables + A dictionary of tables. + + Returns + ------- + The modified dictionary with the tables. + """ + count = 0 + for table_name in subgroup: + f_elem = subgroup[table_name] + f_elem_store = os.path.join(zarr_store_path, f_elem.path) + if isinstance(group.store, zarr.storage.ConsolidatedMetadataStore): + tables[table_name] = read_elem(f_elem) + # we can replace read_elem with read_anndata_zarr after this PR gets into a release (>= 0.6.5) + # https://github.com/scverse/anndata/pull/1057#pullrequestreview-1530623183 + # table = read_anndata_zarr(f_elem) + else: + tables[table_name] = read_anndata_zarr(f_elem_store) + if TableModel.ATTRS_KEY in tables[table_name].uns: + # fill out eventual missing attributes that has been omitted because their value was None + attrs = tables[table_name].uns[TableModel.ATTRS_KEY] + if "region" not in attrs: + attrs["region"] = None + if "region_key" not in attrs: + attrs["region_key"] = None + if "instance_key" not in attrs: + attrs["instance_key"] = None + # fix type for region + if "region" in attrs and isinstance(attrs["region"], np.ndarray): + attrs["region"] = attrs["region"].tolist() + + count += 1 + + logger.debug(f"Found {count} elements in {subgroup}") + return tables diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 7fafb676..57a6069c 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from typing import Any, Literal, Optional, Union @@ -67,7 +66,8 @@ def _read_multiscale( # and for instance in the xenium example encoded_ngff_transformations = multiscales[0]["coordinateTransformations"] transformations = _get_transformations_from_ngff_dict(encoded_ngff_transformations) - name = os.path.basename(node.metadata["name"]) + # TODO: what to do with name? For now remove? + # name = os.path.basename(node.metadata["name"]) # if image, read channels metadata channels: Optional[list[Any]] = None if raster_type == "image" and channels_metadata is not None: @@ -79,7 +79,7 @@ def _read_multiscale( data = node.load(Multiscales).array(resolution=d, version=fmt.version) multiscale_image[f"scale{i}"] = DataArray( data, - name=name, + name="image", dims=axes, coords={"c": channels} if channels is not None else {}, ) @@ -89,7 +89,7 @@ def _read_multiscale( data = node.load(Multiscales).array(resolution=datasets[0], version=fmt.version) si = SpatialImage( data, - name=name, + name="image", dims=axes, coords={"c": channels} if channels is not None else {}, ) diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 72ae5f4c..ead604af 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -4,6 +4,7 @@ from ome_zarr.format import Format from spatialdata._io.format import CurrentTablesFormat +from spatialdata.models import TableModel def write_table( @@ -13,10 +14,13 @@ def write_table( group_type: str = "ngff:regions_table", fmt: Format = CurrentTablesFormat(), ) -> None: - region = table.uns["spatialdata_attrs"]["region"] - region_key = table.uns["spatialdata_attrs"].get("region_key", None) - instance_key = table.uns["spatialdata_attrs"].get("instance_key", None) - fmt.validate_table(table, region_key, instance_key) + if TableModel.ATTRS_KEY in table.uns: + region = table.uns["spatialdata_attrs"]["region"] + region_key = table.uns["spatialdata_attrs"].get("region_key", None) + instance_key = table.uns["spatialdata_attrs"].get("instance_key", None) + fmt.validate_table(table, region_key, instance_key) + else: + region, region_key, instance_key = (None, None, None) write_adata(group, name, table) # creates group[name] tables_group = group[name] tables_group.attrs["spatialdata-encoding-type"] = group_type diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 7b4f286c..f5e378a2 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,21 +1,18 @@ import logging import os +import warnings from pathlib import Path from typing import Optional, Union -import numpy as np import zarr from anndata import AnnData -from anndata import read_zarr as read_anndata_zarr -from anndata.experimental import read_elem -from spatialdata import SpatialData -from spatialdata._io._utils import ome_zarr_logger +from spatialdata._core.spatialdata import SpatialData +from spatialdata._io._utils import ome_zarr_logger, read_table_and_validate from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes from spatialdata._logging import logger -from spatialdata.models import TableModel def _open_zarr_store(store: Union[str, Path, zarr.Group]) -> tuple[zarr.Group, str]: @@ -61,10 +58,11 @@ def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str images = {} labels = {} points = {} - table: Optional[AnnData] = None + tables: dict[str, AnnData] = {} shapes = {} - selector = {"images", "labels", "points", "shapes", "table"} if not selection else set(selection or []) + # TODO: remove table once deprecated. + selector = {"images", "labels", "points", "shapes", "tables", "table"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") # read multiscale images @@ -123,36 +121,21 @@ def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str shapes[subgroup_name] = _read_shapes(f_elem_store) count += 1 logger.debug(f"Found {count} elements in {group}") + if "tables" in selector and "tables" in f: + group = f["tables"] + tables = read_table_and_validate(f_store_path, f, group, tables) if "table" in selector and "table" in f: - group = f["table"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) - if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore): - table = read_elem(f_elem) - # we can replace read_elem with read_anndata_zarr after this PR gets into a release (>= 0.6.5) - # https://github.com/scverse/anndata/pull/1057#pullrequestreview-1530623183 - # table = read_anndata_zarr(f_elem) - else: - table = read_anndata_zarr(f_elem_store) - if TableModel.ATTRS_KEY in table.uns: - # fill out eventual missing attributes that has been omitted because their value was None - attrs = table.uns[TableModel.ATTRS_KEY] - if "region" not in attrs: - attrs["region"] = None - if "region_key" not in attrs: - attrs["region_key"] = None - if "instance_key" not in attrs: - attrs["instance_key"] = None - # fix type for region - if "region" in attrs and isinstance(attrs["region"], np.ndarray): - attrs["region"] = attrs["region"].tolist() - count += 1 + warnings.warn( + f"Table group found in zarr store at location {f_store_path}. Please update the zarr store" + f"to use tables instead.", + DeprecationWarning, + stacklevel=2, + ) + subgroup_name = "table" + group = f[subgroup_name] + tables = read_table_and_validate(f_store_path, f, group, tables) + logger.debug(f"Found {count} elements in {group}") sdata = SpatialData( @@ -160,7 +143,7 @@ def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str labels=labels, points=points, shapes=shapes, - table=table, + tables=tables, ) - sdata.path = str(store) + sdata._path = Path(store) return sdata diff --git a/src/spatialdata/_types.py b/src/spatialdata/_types.py index 6ae68e1a..30b235c1 100644 --- a/src/spatialdata/_types.py +++ b/src/spatialdata/_types.py @@ -1,8 +1,12 @@ from __future__ import annotations +from typing import Union + import numpy as np +from multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage -__all__ = ["ArrayLike", "DTypeLike"] +__all__ = ["ArrayLike", "DTypeLike", "Raster_T"] try: from numpy.typing import DTypeLike, NDArray @@ -11,3 +15,5 @@ except (ImportError, TypeError): ArrayLike = np.ndarray # type: ignore[misc] DTypeLike = np.dtype # type: ignore[misc] + +Raster_T = Union[SpatialImage, MultiscaleSpatialImage] diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 59eaec6c..205308e8 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -1,9 +1,11 @@ from __future__ import annotations +import functools import re +import warnings from collections.abc import Generator from copy import deepcopy -from typing import TYPE_CHECKING, Union +from typing import Any, Callable, TypeVar, Union import numpy as np import pandas as pd @@ -25,9 +27,7 @@ # I was using "from numbers import Number" but this led to mypy errors, so I switched to the following: Number = Union[int, float] - -if TYPE_CHECKING: - pass +RT = TypeVar("RT") def _parse_list_into_array(array: list[Number] | ArrayLike) -> ArrayLike: @@ -231,3 +231,83 @@ def _deepcopy_geodataframe(gdf: GeoDataFrame) -> GeoDataFrame: new_attrs = deepcopy(gdf.attrs) new_gdf.attrs = new_attrs return new_gdf + + +# TODO: change to paramspec as soon as we drop support for python 3.9, see https://stackoverflow.com/a/68290080 +def deprecation_alias(**aliases: str) -> Callable[[Callable[..., RT]], Callable[..., RT]]: + """ + Decorate a function to warn user of use of arguments set for deprecation. + + Parameters + ---------- + aliases + Deprecation argument aliases to be mapped to the new arguments. + + Returns + ------- + A decorator that can be used to mark an argument for deprecation and substituting it with the new argument. + + Raises + ------ + TypeError + If the provided aliases are not of string type. + + Example + ------- + Assuming we have an argument 'table' set for deprecation and we want to warn the user and substitute with 'tables': + + ```python + @deprecation_alias(table="tables") + def my_function(tables: AnnData | dict[str, AnnData]): + pass + ``` + """ + + def deprecation_decorator(f: Callable[..., RT]) -> Callable[..., RT]: + @functools.wraps(f) + def wrapper(*args: Any, **kwargs: Any) -> RT: + class_name = f.__qualname__ + rename_kwargs(f.__name__, kwargs, aliases, class_name) + return f(*args, **kwargs) + + return wrapper + + return deprecation_decorator + + +def rename_kwargs(func_name: str, kwargs: dict[str, Any], aliases: dict[str, str], class_name: None | str) -> None: + """Rename function arguments set for deprecation and gives warning in case of usage of these arguments.""" + for alias, new in aliases.items(): + if alias in kwargs: + class_name = class_name + "." if class_name else "" + if new in kwargs: + raise TypeError( + f"{class_name}{func_name} received both {alias} and {new} as arguments!" + f" {alias} is being deprecated in SpatialData version 0.1, only use {new} instead." + ) + warnings.warn( + message=( + f"`{alias}` is being deprecated as an argument to `{class_name}{func_name}` in SpatialData " + f"version 0.1, switch to `{new}` instead." + ), + category=DeprecationWarning, + stacklevel=3, + ) + kwargs[new] = kwargs.pop(alias) + + +def _error_message_add_element() -> None: + raise RuntimeError( + "The functions add_image(), add_labels(), add_points() and add_shapes() have been removed in favor of " + "dict-like access to the elements. Please use the following syntax to add an element:\n" + "\n" + '\tsdata.images["image_name"] = image\n' + '\tsdata.labels["labels_name"] = labels\n' + "\t...\n" + "\n" + "The new syntax does not automatically updates the disk storage, so you need to call sdata.write() when " + "the in-memory object is ready to be saved.\n" + "To save only a new specific element to an existing Zarr storage please use the functions write_image(), " + "write_labels(), write_points(), write_shapes() and write_table(). We are going to make these calls more " + "ergonomic in a follow up PR." + ) diff --git a/src/spatialdata/dataloader/__init__.py b/src/spatialdata/dataloader/__init__.py index f9262f85..819ab58e 100644 --- a/src/spatialdata/dataloader/__init__.py +++ b/src/spatialdata/dataloader/__init__.py @@ -1,6 +1,4 @@ -import contextlib - -with contextlib.suppress(ImportError): +try: from spatialdata.dataloader.datasets import ImageTilesDataset - -__all__ = ["ImageTilesDataset"] +except ImportError: + ImageTilesDataset = None # type: ignore[assignment, misc] diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index e31750f2..66fc5b4c 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -1,15 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Mapping +from functools import partial +from itertools import chain +from types import MappingProxyType +from typing import Any, Callable import numpy as np +import pandas as pd +from anndata import AnnData from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage -from shapely import MultiPolygon, Point, Polygon -from spatial_image import SpatialImage +from scipy.sparse import issparse from torch.utils.data import Dataset -from spatialdata._core.operations.rasterize import rasterize +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( Image2DModel, @@ -17,175 +22,389 @@ Labels2DModel, Labels3DModel, ShapesModel, + TableModel, get_axes_names, get_model, ) -from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations import get_transformation from spatialdata.transformations.transformations import BaseTransformation -if TYPE_CHECKING: - from spatialdata import SpatialData +__all__ = ["ImageTilesDataset"] class ImageTilesDataset(Dataset): + """ + Dataloader for SpatialData. + + :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. + + By default, the dataset returns spatialdata object, but when `return_image` and `return_annot` + are set, the dataset returns a tuple containing: + + - the tile image, centered in the target coordinate system of the region. + - a vector or scalar value from the table. + + Parameters + ---------- + sdata + The SpatialData object. + regions_to_images + A mapping between region and images. The regions are used to compute the tile centers, while the images are + used to get the pixel values. + regions_to_coordinate_systems + A mapping between regions and coordinate systems. The coordinate systems are used to transform both + regions coordinates for tiles as well as images. + tile_scale + The scale of the tiles. This is used only if the `regions` are `shapes`. + It is a scaling factor applied to either the radius (spots) or length (polygons) of the `shapes` + according to the geometry type of the `shapes` element: + + - if `shapes` are circles (spots), the radius is scaled by `tile_scale`. + - if `shapes` are polygons, the length of the polygon is scaled by `tile_scale`. + + If `tile_dim_in_units` is passed, `tile_scale` is ignored. + tile_dim_in_units + The dimension of the requested tile in the units of the target coordinate system. + This specifies the extent of the tile. This is not related the size in pixel of each returned tile. + rasterize + If True, the images are rasterized using :func:`spatialdata.rasterize`. + If False, they are queried using :func:`spatialdata.bounding_box_query`. + return_annotations + If not None, a value from the table is returned together with the image tile. + Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` + can be returned. If None, it will return a `SpatialData` object with only the tuple + containing the image and the table value. + transform + A callable that takes as input the tuple (image, table_value) and returns a new tuple (when + `return_annotations` is not None); a callable that takes as input the `SpatialData` object and + returns a tuple when `return_annotations` is `None`. + This parameter can be used to apply data transformations (for instance a normalization operation) to the + image and the table value. + rasterize_kwargs + Keyword arguments passed to :func:`spatialdata.rasterize` if `rasterize` is True. + This argument can be used for instance to choose the pixel dimension of the image tile. + + Returns + ------- + :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData`. + """ + + INSTANCE_KEY = "instance_id" + CS_KEY = "cs" + REGION_KEY = "region" + IMAGE_KEY = "image" + def __init__( self, sdata: SpatialData, regions_to_images: dict[str, str], - tile_dim_in_units: float, - tile_dim_in_pixels: int, - target_coordinate_system: str = "global", - # unused at the moment, see - transform: Callable[[SpatialData], Any] | None = None, + regions_to_coordinate_systems: dict[str, str], + tile_scale: float = 1.0, + tile_dim_in_units: float | None = None, + rasterize: bool = False, + return_annotations: str | list[str] | None = None, + transform: Callable[[Any], Any] | None = None, + rasterize_kwargs: Mapping[str, Any] = MappingProxyType({}), ): - """ - Torch Dataset that returns image tiles around regions from a SpatialData object. - - Parameters - ---------- - sdata - The SpatialData object containing the regions and images from which to extract the tiles from. - regions_to_images - A dictionary mapping the regions element key we want to extract the tiles around to the images element key - we want to get the image data from. - tile_dim_in_units - The dimension of the requested tile in the units of the target coordinate system. This specifies the extent - of the image each tile is querying. This is not related he size in pixel of each returned tile. - tile_dim_in_pixels - The dimension of the requested tile in pixels. This specifies the size of the output tiles that we will get, - independently of which extent of the image the tile is covering. - target_coordinate_system - The coordinate system in which the tile_dim_in_units is specified. - """ - # TODO: we can extend this code to support: - # - automatic dermination of the tile_dim_in_pixels to match the image resolution (prevent down/upscaling) - # - use the bounding box query instead of the raster function if the user wants - self.sdata = sdata - self.regions_to_images = regions_to_images - self.tile_dim_in_units = tile_dim_in_units - self.tile_dim_in_pixels = tile_dim_in_pixels + from spatialdata import bounding_box_query + from spatialdata._core.operations.rasterize import rasterize as rasterize_fn + + self._validate(sdata, regions_to_images, regions_to_coordinate_systems) + self._preprocess(tile_scale, tile_dim_in_units) + + self._crop_image: Callable[..., Any] = ( + partial( + rasterize_fn, + **dict(rasterize_kwargs), + ) + if rasterize + else bounding_box_query # type: ignore[assignment] + ) + self._return = self._get_return(return_annotations) self.transform = transform - self.target_coordinate_system = target_coordinate_system - - self.n_spots_dict = self._compute_n_spots_dict() - self.n_spots = sum(self.n_spots_dict.values()) - - def _validate_regions_to_images(self) -> None: - for region_key, image_key in self.regions_to_images.items(): - regions_element = self.sdata[region_key] - images_element = self.sdata[image_key] - # we could allow also for points - if get_model(regions_element) not in [ShapesModel, Labels2DModel, Labels3DModel]: - raise ValueError("regions_element must be a shapes element or a labels element") - if get_model(images_element) not in [Image2DModel, Image3DModel]: - raise ValueError("images_element must be an image element") - - def _compute_n_spots_dict(self) -> dict[str, int]: - n_spots_dict = {} - for region_key in self.regions_to_images: - element = self.sdata[region_key] - # we could allow also points - if isinstance(element, GeoDataFrame): - n_spots_dict[region_key] = len(element) - elif isinstance(element, (SpatialImage, MultiscaleSpatialImage)): - raise NotImplementedError("labels not supported yet") - else: - raise ValueError("element must be a geodataframe or a spatial image") - return n_spots_dict - - def _get_region_info_for_index(self, index: int) -> tuple[str, int]: - # TODO: this implmenetation can be improved - i = 0 - for region_key, n_spots in self.n_spots_dict.items(): - if index < i + n_spots: - return region_key, index - i - i += n_spots - raise ValueError(f"index {index} is out of range") - def __len__(self) -> int: - return self.n_spots + def _validate( + self, + sdata: SpatialData, + regions_to_images: dict[str, str], + regions_to_coordinate_systems: dict[str, str], + ) -> None: + """Validate input parameters.""" + self._region_key = sdata.tables["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] + self._instance_key = sdata.tables["table"].uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] + available_regions = sdata.tables["table"].obs[self._region_key].cat.categories + cs_region_image = [] # list of tuples (coordinate_system, region, image) - def __getitem__(self, idx: int) -> Any | SpatialData: - from spatialdata import SpatialData - - if idx >= self.n_spots: - raise IndexError() - regions_name, region_index = self._get_region_info_for_index(idx) - regions = self.sdata[regions_name] - # TODO: here we just need to compute the centroids, - # we probably want to move this functionality to a different file - if isinstance(regions, GeoDataFrame): - dims = get_axes_names(regions) - region = regions.iloc[region_index] - shape = regions.geometry.iloc[0] - if isinstance(shape, Polygon): - xy = region.geometry.centroid.coords.xy - centroid = np.array([[xy[0][0], xy[1][0]]]) - elif isinstance(shape, MultiPolygon): - raise NotImplementedError("MultiPolygon not supported yet") - elif isinstance(shape, Point): - xy = region.geometry.coords.xy - centroid = np.array([[xy[0][0], xy[1][0]]]) - else: - raise RuntimeError(f"Unsupported type: {type(shape)}") - - t = get_transformation(regions, self.target_coordinate_system) + # check unique matching between regions and images and coordinate systems + assert len(set(regions_to_images.values())) == len( + regions_to_images.keys() + ), "One region cannot be paired to multiple images." + assert len(set(regions_to_coordinate_systems.values())) == len( + regions_to_coordinate_systems.keys() + ), "One region cannot be paired to multiple coordinate systems." + + for region_key, image_key in regions_to_images.items(): + # get elements + region_elem = sdata[region_key] + image_elem = sdata[image_key] + + # check that the elements are supported + if get_model(region_elem) in [Labels2DModel, Labels3DModel]: + raise NotImplementedError("labels elements are not implemented yet.") + if get_model(region_elem) not in [ShapesModel]: + raise ValueError("`regions_element` must be a shapes element.") + if get_model(image_elem) not in [Image2DModel, Image3DModel]: + raise ValueError("`images_element` must be an image element.") + if isinstance(image_elem, MultiscaleSpatialImage): + raise NotImplementedError("Multiscale images are not implemented yet.") + + if region_key not in available_regions: + raise ValueError(f"region {region_key} not found in the spatialdata object.") + + # check that the coordinate systems are valid for the elements + try: + cs = regions_to_coordinate_systems[region_key] + region_trans = get_transformation(region_elem, cs) + image_trans = get_transformation(image_elem, cs) + if isinstance(region_trans, BaseTransformation) and isinstance(image_trans, BaseTransformation): + cs_region_image.append((cs, region_key, image_key)) + except KeyError as e: + raise KeyError(f"region {region_key} not found in `regions_to_coordinate_systems`") from e + + self.regions = list(regions_to_coordinate_systems.keys()) # all regions for the dataloader + self.sdata = sdata + self.dataset_table = self.sdata.tables["table"][ + self.sdata.tables["table"].obs[self._region_key].isin(self.regions) + ] # filtered table for the data loader + self._cs_region_image = tuple(cs_region_image) # tuple of tuples (coordinate_system, region_key, image_key) + + def _preprocess( + self, + tile_scale: float = 1.0, + tile_dim_in_units: float | None = None, + ) -> None: + """Preprocess the dataset.""" + index_df = [] + tile_coords_df = [] + dims_l = [] + shapes_l = [] + table = self.sdata.tables["table"] + for cs, region, image in self._cs_region_image: + # get dims and transformations for the region element + dims = get_axes_names(self.sdata[region]) + dims_l.append(dims) + t = get_transformation(self.sdata[region], cs) assert isinstance(t, BaseTransformation) - aff = t.to_affine_matrix(input_axes=dims, output_axes=dims) - transformed_centroid = np.squeeze(_affine_matrix_multiplication(aff, centroid), 0) - elif isinstance(regions, (SpatialImage, MultiscaleSpatialImage)): - raise NotImplementedError("labels not supported yet") - else: - raise ValueError("element must be shapes or labels") - min_coordinate = np.array(transformed_centroid) - self.tile_dim_in_units / 2 - max_coordinate = np.array(transformed_centroid) + self.tile_dim_in_units / 2 - - raster = self.sdata[self.regions_to_images[regions_name]] - tile = rasterize( - raster, - axes=dims, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - target_coordinate_system=self.target_coordinate_system, - target_width=self.tile_dim_in_pixels, + + # get instances from region + inst = table.obs[table.obs[self._region_key] == region][self._instance_key].values + + # subset the regions by instances + subset_region = self.sdata[region].iloc[inst] + # get coordinates of centroids and extent for tiles + tile_coords = _get_tile_coords(subset_region, t, dims, tile_scale, tile_dim_in_units) + tile_coords_df.append(tile_coords) + + # get shapes + shapes_l.append(self.sdata[region]) + + # get index dictionary, with `instance_id`, `cs`, `region`, and `image` + df = pd.DataFrame({self.INSTANCE_KEY: inst}) + df[self.CS_KEY] = cs + df[self.REGION_KEY] = region + df[self.IMAGE_KEY] = image + index_df.append(df) + + # concatenate and assign to self + self.dataset_index = pd.concat(index_df).reset_index(drop=True) + self.tiles_coords = pd.concat(tile_coords_df).reset_index(drop=True) + # get table filtered by regions + self.filtered_table = table.obs[table.obs[self._region_key].isin(self.regions)] + + assert len(self.tiles_coords) == len(self.dataset_index) + dims_ = set(chain(*dims_l)) + assert np.all([i in self.tiles_coords for i in dims_]) + self.dims = list(dims_) + + def _get_return( + self, + return_annot: str | list[str] | None, + ) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]: + """Get function to return values from the table of the dataset.""" + if return_annot is not None: + # table is always returned as array shape (1, len(return_annot)) + # where return_table can be a single column or a list of columns + return_annot = [return_annot] if isinstance(return_annot, str) else return_annot + # return tuple of (tile, table) + if np.all([i in self.dataset_table.obs for i in return_annot]): + return lambda x, tile: (tile, self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1)) + if np.all([i in self.dataset_table.var_names for i in return_annot]): + if issparse(self.dataset_table.X): + return lambda x, tile: (tile, self.dataset_table[x, return_annot].X.A) + return lambda x, tile: (tile, self.dataset_table[x, return_annot].X) + raise ValueError( + f"`return_annot` must be a column name in the table or a variable name in the table. " + f"Got {return_annot}." + ) + # return spatialdata consisting of the image tile and the associated table + return lambda x, tile: SpatialData( + images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile}, + table=self.dataset_table[x], ) - tile_regions = regions.iloc[region_index : region_index + 1] - # TODO: as explained in the TODO in the __init__(), we want to let the - # user also use the bounding box query instaed of the rasterization - # the return function of this function would change, so we need to - # decide if instead having an extra Tile dataset class - # from spatialdata._core._spatial_query import BoundingBoxRequest - # request = BoundingBoxRequest( - # target_coordinate_system=self.target_coordinate_system, - # axes=dims, - # min_coordinate=min_coordinate, - # max_coordinate=max_coordinate, - # ) - # sdata_item = self.sdata.query.bounding_box(**request.to_dict()) - table = self.sdata.table - filter_table = False - if table is not None: - region = table.uns["spatialdata_attrs"]["region"] - region_key = table.uns["spatialdata_attrs"]["region_key"] - instance_key = table.uns["spatialdata_attrs"]["instance_key"] - if isinstance(region, str): - if regions_name == region: - filter_table = True - elif isinstance(region, list): - if regions_name in region: - filter_table = True - else: - raise ValueError("region must be a string or a list of strings") - # TODO: maybe slow, we should check if there is a better way to do this - if filter_table: - instance = self.sdata[regions_name].iloc[region_index].name - row = table[(table.obs[region_key] == regions_name) & (table.obs[instance_key] == instance)].copy() - tile_table = row - else: - tile_table = None - tile_sdata = SpatialData( - images={self.regions_to_images[regions_name]: tile}, shapes={regions_name: tile_regions}, table=tile_table + + def __len__(self) -> int: + return len(self.dataset_index) + + def __getitem__(self, idx: int) -> Any | SpatialData: + """Get item from the dataset.""" + # get the row from the index + row = self.dataset_index.iloc[idx] + # get the tile coordinates + t_coords = self.tiles_coords.iloc[idx] + + image = self.sdata[row["image"]] + tile = self._crop_image( + image, + axes=self.dims, + min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, + max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, + target_coordinate_system=row["cs"], ) + if self.transform is not None: - return self.transform(tile_sdata) - return tile_sdata + out = self._return(idx, tile) + return self.transform(out) + return self._return(idx, tile) + + @property + def regions(self) -> list[str]: + """List of regions in the dataset.""" + return self._regions + + @regions.setter + def regions(self, regions: list[str]) -> None: # D102 + self._regions = regions + + @property + def sdata(self) -> SpatialData: + """The original SpatialData object.""" + return self._sdata + + @sdata.setter + def sdata(self, sdata: SpatialData) -> None: # D102 + self._sdata = sdata + + @property + def coordinate_systems(self) -> list[str]: + """List of coordinate systems in the dataset.""" + return self._coordinate_systems + + @coordinate_systems.setter + def coordinate_systems(self, coordinate_systems: list[str]) -> None: # D102 + self._coordinate_systems = coordinate_systems + + @property + def tiles_coords(self) -> pd.DataFrame: + """DataFrame with the index of tiles. + + It contains axis coordinates of the centroids, and extent of the tiles. + For example, for a 2D image, it contains the following columns: + + - `x`: the x coordinate of the centroid. + - `y`: the y coordinate of the centroid. + - `extent`: the extent of the tile. + - `minx`: the minimum x coordinate of the tile. + - `miny`: the minimum y coordinate of the tile. + - `maxx`: the maximum x coordinate of the tile. + - `maxy`: the maximum y coordinate of the tile. + """ + return self._tiles_coords + + @tiles_coords.setter + def tiles_coords(self, tiles: pd.DataFrame) -> None: + self._tiles_coords = tiles + + @property + def dataset_index(self) -> pd.DataFrame: + """DataFrame with the metadata of the tiles. + + It contains the following columns: + + - `instance`: the name of the instance in the region. + - `cs`: the coordinate system of the region-image pair. + - `region`: the name of the region. + - `image`: the name of the image. + """ + return self._dataset_index + + @dataset_index.setter + def dataset_index(self, dataset_index: pd.DataFrame) -> None: + self._dataset_index = dataset_index + + @property + def dataset_table(self) -> AnnData: + """AnnData table filtered by the `region` and `cs` present in the dataset.""" + return self._dataset_table + + @dataset_table.setter + def dataset_table(self, dataset_table: AnnData) -> None: + self._dataset_table = dataset_table + + @property + def dims(self) -> list[str]: + """Dimensions of the dataset.""" + return self._dims + + @dims.setter + def dims(self, dims: list[str]) -> None: + self._dims = dims + + +def _get_tile_coords( + elem: GeoDataFrame, + transformation: BaseTransformation, + dims: tuple[str, ...], + tile_scale: float | None = None, + tile_dim_in_units: float | None = None, +) -> pd.DataFrame: + """Get the (transformed) centroid of the region and the extent.""" + # get centroids and transform them + centroids = elem.centroid.get_coordinates().values + aff = transformation.to_affine_matrix(input_axes=dims, output_axes=dims) + centroids = _affine_matrix_multiplication(aff, centroids) + + # get extent, first by checking shape defaults, then by using the `tile_dim_in_units` + if tile_dim_in_units is None: + if elem.iloc[0, 0].geom_type == "Point": + extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale + elif elem.iloc[0, 0].geom_type in ["Polygon", "MultiPolygon"]: + extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale + else: + raise ValueError("Only point and polygon shapes are supported.") + if tile_dim_in_units is not None: + if isinstance(tile_dim_in_units, (float, int)): + extent = np.repeat(tile_dim_in_units, len(centroids)) + else: + raise TypeError( + f"`tile_dim_in_units` must be a `float`, `int`, `list`, `tuple` or `np.ndarray`, " + f"not {type(tile_dim_in_units)}." + ) + if len(extent) != len(centroids): + raise ValueError( + f"the number of elements in the region ({len(extent)}) does not match" + f" the number of instances ({len(centroids)})." + ) + + # transform extent + aff = transformation.to_affine_matrix(input_axes=tuple(dims[0]), output_axes=tuple(dims[0])) + extent = _affine_matrix_multiplication(aff, np.array(extent)[:, np.newaxis]) + + # get min and max coordinates + min_coordinates = np.array(centroids) - extent / 2 + max_coordinates = np.array(centroids) + extent / 2 + + # return a dataframe with columns e.g. ["x", "y", "extent", "minx", "miny", "maxx", "maxy"] + return pd.DataFrame( + np.hstack([centroids, extent, min_coordinates, max_coordinates]), + columns=list(dims) + ["extent"] + ["min" + dim for dim in dims] + ["max" + dim for dim in dims], + ) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 0d811ad3..3b207e7b 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -13,8 +13,8 @@ from skimage.segmentation import slic from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.operations.aggregate import aggregate +from spatialdata._core.spatialdata import SpatialData from spatialdata._logging import logger from spatialdata._types import ArrayLike from spatialdata.models import ( @@ -153,7 +153,7 @@ def blobs( circles = self._circles_blobs(self.transformations, self.length, self.n_shapes) polygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes) multipolygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes, multipolygons=True) - adata = aggregate(values=image, by=labels).table + adata = aggregate(values=image, by=labels).tables["table"] adata.obs["region"] = pd.Categorical(["blobs_labels"] * len(adata)) adata.obs["instance_id"] = adata.obs_names.astype(int) del adata.uns[TableModel.ATTRS_KEY] @@ -164,7 +164,7 @@ def blobs( labels={"blobs_labels": labels, "blobs_multiscale_labels": multiscale_labels}, points={"blobs_points": points}, shapes={"blobs_circles": circles, "blobs_polygons": polygons, "blobs_multipolygons": multipolygons}, - table=table, + tables=table, ) def _image_blobs( diff --git a/src/spatialdata/models/__init__.py b/src/spatialdata/models/__init__.py index 9a6cf64b..df370e4a 100644 --- a/src/spatialdata/models/__init__.py +++ b/src/spatialdata/models/__init__.py @@ -21,7 +21,9 @@ PointsModel, ShapesModel, TableModel, + check_target_region_column_symmetry, get_model, + get_table_keys, ) __all__ = [ @@ -44,4 +46,6 @@ "get_axes_names", "points_geopandas_to_dask_dataframe", "points_dask_dataframe_to_geopandas", + "check_target_region_column_symmetry", + "get_table_keys", ] diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index cf139d5e..d78dc8b9 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -16,7 +16,6 @@ SpatialElement = Union[SpatialImage, MultiscaleSpatialImage, GeoDataFrame, DaskDataFrame] TRANSFORM_KEY = "transform" DEFAULT_COORDINATE_SYSTEM = "global" -# ValidAxis_t = Literal["c", "x", "y", "z"] ValidAxis_t = str MappingToCoordinateSystem_t = dict[str, BaseTransformation] C = "c" diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index f7155a3b..f36d91e9 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -18,7 +18,7 @@ from multiscale_spatial_image import to_multiscale from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from multiscale_spatial_image.to_multiscale.to_multiscale import Methods -from pandas.api.types import is_categorical_dtype +from pandas import CategoricalDtype from shapely._geometry import GeometryType from shapely.geometry import MultiPolygon, Point, Polygon from shapely.geometry.collection import GeometryCollection @@ -465,12 +465,13 @@ def validate(cls, data: DaskDataFrame) -> None: """ for ax in [X, Y, Z]: if ax in data.columns: - assert data[ax].dtype in [np.float32, np.float64, np.int64] + # TODO: check why this can return int32 on windows. + assert data[ax].dtype in [np.int32, np.float32, np.float64, np.int64] if cls.TRANSFORM_KEY not in data.attrs: raise ValueError(f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`.") if cls.ATTRS_KEY in data.attrs and "feature_key" in data.attrs[cls.ATTRS_KEY]: feature_key = data.attrs[cls.ATTRS_KEY][cls.FEATURE_KEY] - if not is_categorical_dtype(data[feature_key]): + if not isinstance(data[feature_key], CategoricalDtype): logger.info(f"Feature key `{feature_key}`could be of type `pd.Categorical`. Consider casting it.") @singledispatchmethod @@ -624,7 +625,7 @@ def _add_metadata_and_validate( # Here we are explicitly importing the categories # but it is a convenient way to ensure that the categories are known. # It also just changes the state of the series, so it is not a big deal. - if is_categorical_dtype(data[c]) and not data[c].cat.known: + if isinstance(data[c], CategoricalDtype) and not data[c].cat.known: try: data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories) except ValueError: @@ -642,24 +643,121 @@ class TableModel: REGION_KEY_KEY = "region_key" INSTANCE_KEY = "instance_key" - def validate( - self, - data: AnnData, - ) -> AnnData: + def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) -> None: """ - Validate the data. + Validate the region key in table.uns or set a new region key as the region key column. Parameters ---------- data - The data to validate. + The AnnData table. + region_key + The region key to be validated and set in table.uns. + + + Raises + ------ + ValueError + If no region_key is found in table.uns and no region_key is provided as an argument. + ValueError + If the specified region_key in table.uns is not present as a column in table.obs. + ValueError + If the specified region key column is not present in table.obs. + """ + attrs = data.uns.get(self.ATTRS_KEY) + if attrs is None: + data.uns[self.ATTRS_KEY] = attrs = {} + table_region_key = attrs.get(self.REGION_KEY_KEY) + if not region_key: + if not table_region_key: + raise ValueError( + "No region_key in table.uns and no region_key provided as argument. Please specify 'region_key'." + ) + if data.obs.get(attrs[TableModel.REGION_KEY_KEY]) is None: + raise ValueError( + f"Specified region_key in table.uns '{table_region_key}' is not " + f"present as column in table.obs. Please specify region_key." + ) + else: + if region_key not in data.obs: + raise ValueError(f"'{region_key}' column not present in table.obs") + attrs[self.REGION_KEY_KEY] = region_key + + def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = None) -> None: + """ + Validate the instance_key in table.uns or set a new instance_key as the instance_key column. + + If no instance_key is provided as argument, the presence of instance_key in table.uns is checked and validated. + If instance_key is provided, presence in table.obs will be validated and if present it will be set as the new + instance_key in table.uns. + + Parameters + ---------- + data + The AnnData table. + + instance_key + The instance_key to be validated and set in table.uns. + + Raises + ------ + ValueError + If no instance_key is provided as argument and no instance_key is found in the `uns` attribute of table. + ValueError + If no instance_key is provided and the instance_key in table.uns does not match any column in table.obs. + ValueError + If provided instance_key is not present as table.obs column. + """ + attrs = data.uns.get(self.ATTRS_KEY) + if attrs is None: + data.uns[self.ATTRS_KEY] = {} + + if not instance_key: + if not attrs.get(TableModel.INSTANCE_KEY): + raise ValueError( + "No instance_key in table.uns and no instance_key provided as argument. Please " + "specify instance_key." + ) + if data.obs.get(attrs[self.INSTANCE_KEY]) is None: + raise ValueError( + f"Specified instance_key in table.uns '{attrs.get(self.INSTANCE_KEY)}' is not present" + f" as column in table.obs. Please specify instance_key." + ) + if instance_key: + if instance_key in data.obs: + attrs[self.INSTANCE_KEY] = instance_key + else: + raise ValueError(f"Instance key column '{instance_key}' not found in table.obs.") + + def _validate_table_annotation_metadata(self, data: AnnData) -> None: + """ + Validate annotation metadata. + + Parameters + ---------- + data + The AnnData object containing the table annotation data. + + Raises + ------ + ValueError + If any of the required metadata keys are not found in the `adata.uns` dictionary or the `adata.obs` + dataframe. + + - If "region" is not found in `adata.uns['ATTRS_KEY']`. + - If "region_key" is not found in `adata.uns['ATTRS_KEY']`. + - If "instance_key" is not found in `adata.uns['ATTRS_KEY']`. + - If `attr[self.REGION_KEY_KEY]` is not found in `adata.obs`, with attr = adata.uns['ATTRS_KEY'] + - If `attr[self.INSTANCE_KEY]` is not found in `adata.obs`. + - If the regions in `adata.uns['ATTRS_KEY']['self.REGION_KEY']` and the unique values of + `attr[self.REGION_KEY_KEY]` do not match. + + Notes + ----- + This does not check whether the annotation target of the table is present in a given SpatialData object. Rather + it is an internal validation of the annotation metadata of the table. - Returns - ------- - The validated data. """ - if self.ATTRS_KEY not in data.uns: - raise ValueError(f"`{self.ATTRS_KEY}` not found in `adata.uns`.") attr = data.uns[self.ATTRS_KEY] if "region" not in attr: @@ -678,6 +776,27 @@ def validate( if len(set(expected_regions).symmetric_difference(set(found_regions))) > 0: raise ValueError(f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match.") + def validate( + self, + data: AnnData, + ) -> AnnData: + """ + Validate the data. + + Parameters + ---------- + data + The data to validate. + + Returns + ------- + The validated data. + """ + if self.ATTRS_KEY not in data.uns: + return data + + self._validate_table_annotation_metadata(data) + return data @classmethod @@ -704,15 +823,17 @@ def parse( Returns ------- - :class:`anndata.AnnData`. + The parsed data. """ # either all live in adata.uns or all be passed in as argument n_args = sum([region is not None, region_key is not None, instance_key is not None]) + if n_args == 0: + return adata if n_args > 0: if cls.ATTRS_KEY in adata.uns: raise ValueError( - f"Either pass `{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and `{cls.INSTANCE_KEY}`" - f"as arguments or have them in `adata.uns[{cls.ATTRS_KEY!r}]`." + f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as" + f"as argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set." ) elif cls.ATTRS_KEY in adata.uns: attr = adata.uns[cls.ATTRS_KEY] @@ -729,7 +850,7 @@ def parse( region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") - if not is_categorical_dtype(adata.obs[region_key]): + if not isinstance(adata.obs[region_key], CategoricalDtype): warnings.warn( f"Converting `{cls.REGION_KEY_KEY}: {region_key}` to categorical dtype.", UserWarning, stacklevel=2 ) @@ -792,3 +913,73 @@ def _validate_and_return( if isinstance(e, AnnData): return _validate_and_return(TableModel, e) raise TypeError(f"Unsupported type {type(e)}") + + +def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]: + """ + Get the table keys giving information about what spatial element is annotated. + + The first element returned gives information regarding which spatial elements are annotated by the table, the second + element gives information which column in table.obs contains the information which spatial element is annotated + by each row in the table and the instance key indicates the column in obs giving information of the id of each row. + + Parameters + ---------- + table: + AnnData table for which to retrieve the spatialdata_attrs keys. + + Returns + ------- + The keys in table.uns['spatialdata_attrs'] + """ + if table.uns.get(TableModel.ATTRS_KEY): + attrs = table.uns[TableModel.ATTRS_KEY] + return attrs[TableModel.REGION_KEY], attrs[TableModel.REGION_KEY_KEY], attrs[TableModel.INSTANCE_KEY] + + raise ValueError( + "No spatialdata_attrs key found in table.uns, therefore, no table keys found. Please parse the table." + ) + + +def check_target_region_column_symmetry(table: AnnData, region_key: str, target: str | pd.Series) -> None: + """ + Check region and region_key column symmetry. + + This checks whether the specified targets are also present in the region key column in obs and raises an error + if this is not the case. + + Parameters + ---------- + table + Table annotating specific SpatialElements + region_key + The column in obs containing for each row which SpatialElement is annotated by that row. + target + Name of target(s) SpatialElement(s) + + Raises + ------ + ValueError + If there is a mismatch between specified target regions and regions in the region key column of table.obs. + + Example + ------- + Assuming we have a table with region column in obs given by `region_key` called 'region' for which we want to check + whether it contains the specified annotation targets in the `target` variable as `pd.Series['region1', 'region2']`: + + ```python + check_target_region_column_symmetry(table, region_key=region_key, target=target) + ``` + + This returns None if both specified targets are present in the region_key obs column. In this case the annotation + targets can be safely set. If not then a ValueError is raised stating the elements that are not shared between + the region_key column in obs and the specified targets. + """ + found_regions = set(table.obs[region_key].unique().tolist()) + target_element_set = [target] if isinstance(target, str) else target + symmetric_difference = found_regions.symmetric_difference(target_element_set) + if symmetric_difference: + raise ValueError( + f"Mismatch(es) found between regions in region column in obs and target element: " + f"{', '.join(diff for diff in symmetric_difference)}" + ) diff --git a/src/spatialdata/transformations/operations.py b/src/spatialdata/transformations/operations.py index 165551c4..2e4a9e1f 100644 --- a/src/spatialdata/transformations/operations.py +++ b/src/spatialdata/transformations/operations.py @@ -15,7 +15,7 @@ ) if TYPE_CHECKING: - from spatialdata import SpatialData + from spatialdata._core.spatialdata import SpatialData from spatialdata.models import SpatialElement from spatialdata.transformations import Affine, BaseTransformation @@ -68,8 +68,8 @@ def set_transformation( assert to_coordinate_system is None _set_transformations(element, transformation) else: - if not write_to_sdata.contains_element(element, raise_exception=True): - raise RuntimeError("contains_element() failed without raising an exception.") + if write_to_sdata.locate_element(element) is None: + raise RuntimeError("The element is not found in the SpatialData object.") if not write_to_sdata.is_backed(): raise ValueError( "The SpatialData object is not backed. You can either set a transformation to an element " @@ -164,8 +164,8 @@ def remove_transformation( assert to_coordinate_system is None _set_transformations(element, {}) else: - if not write_to_sdata.contains_element(element, raise_exception=True): - raise RuntimeError("contains_element() failed without raising an exception.") + if write_to_sdata.locate_element(element) is None: + raise RuntimeError("The element is not found in the SpatialData object.") if not write_to_sdata.is_backed(): raise ValueError( "The SpatialData object is not backed. You can either remove a transformation from an " @@ -178,7 +178,7 @@ def remove_transformation( def _build_transformations_graph(sdata: SpatialData) -> nx.Graph: g = nx.DiGraph() - gen = sdata._gen_elements_values() + gen = sdata._gen_spatial_element_values() for cs in sdata.coordinate_systems: g.add_node(cs) for e in gen: diff --git a/tests/conftest.py b/tests/conftest.py index 490cd929..3fcfe005 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ from numpy.random import default_rng from shapely.geometry import MultiPolygon, Point, Polygon from spatial_image import SpatialImage -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import ( Image2DModel, Image3DModel, @@ -66,12 +66,12 @@ def points() -> SpatialData: @pytest.fixture() def table_single_annotation() -> SpatialData: - return SpatialData(table=_get_table(region="sample1")) + return SpatialData(tables=_get_table(region="labels2d")) @pytest.fixture() def table_multiple_annotations() -> SpatialData: - return SpatialData(table=_get_table(region=["sample1", "sample2"])) + return SpatialData(table=_get_table(region=["labels2d", "poly"])) @pytest.fixture() @@ -93,7 +93,7 @@ def full_sdata() -> SpatialData: labels=_get_labels(), shapes=_get_shapes(), points=_get_points(), - table=_get_table(region="sample1"), + tables=_get_table(region="labels2d"), ) @@ -128,7 +128,7 @@ def sdata(request) -> SpatialData: labels=_get_labels(), shapes=_get_shapes(), points=_get_points(), - table=_get_table("sample1"), + tables=_get_table("labels2d"), ) if request.param == "empty": return SpatialData() @@ -141,7 +141,10 @@ def _get_images() -> dict[str, SpatialImage | MultiscaleSpatialImage]: dims_3d = ("z", "y", "x", "c") out["image2d"] = Image2DModel.parse(RNG.normal(size=(3, 64, 64)), dims=dims_2d, c_coords=["r", "g", "b"]) out["image2d_multiscale"] = Image2DModel.parse( - RNG.normal(size=(3, 64, 64)), scale_factors=[2, 2], dims=dims_2d, c_coords=["r", "g", "b"] + RNG.normal(size=(3, 64, 64)), + scale_factors=[2, 2], + dims=dims_2d, + c_coords=["r", "g", "b"], ) out["image2d_xarray"] = Image2DModel.parse(DataArray(RNG.normal(size=(3, 64, 64)), dims=dims_2d), dims=None) out["image2d_multiscale_xarray"] = Image2DModel.parse( @@ -277,11 +280,13 @@ def _get_points() -> dict[str, DaskDataFrame]: def _get_table( - region: str | list[str] = "sample1", - region_key: str = "region", - instance_key: str = "instance_id", + region: None | str | list[str] = "sample1", + region_key: None | str = "region", + instance_key: None | str = "instance_id", ) -> AnnData: adata = AnnData(RNG.normal(size=(100, 10)), obs=pd.DataFrame(RNG.normal(size=(100, 3)), columns=["a", "b", "c"])) + if not all(var for var in (region, region_key, instance_key)): + return TableModel.parse(adata=adata) adata.obs[instance_key] = np.arange(adata.n_obs) if isinstance(region, str): adata.obs[region_key] = region diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 51ffa98e..36609464 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -9,8 +9,9 @@ from anndata.tests.helpers import assert_equal from geopandas import GeoDataFrame from numpy.random import default_rng -from spatialdata import SpatialData, aggregate +from spatialdata import aggregate from spatialdata._core.query._utils import circles_to_polygons +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _deepcopy_geodataframe from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel from spatialdata.transformations import Affine, Identity, set_transformation diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index c551461c..3b2b3e6a 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -1,5 +1,4 @@ -import tempfile -from pathlib import Path +from __future__ import annotations import numpy as np import pytest @@ -9,8 +8,8 @@ from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.concatenate import _concatenate_tables, concatenate +from spatialdata._core.spatialdata import SpatialData from spatialdata.datasets import blobs from spatialdata.models import ( Image2DModel, @@ -25,34 +24,88 @@ from tests.conftest import _get_table -def test_element_names_unique(): +def test_element_names_unique() -> None: shapes = ShapesModel.parse(np.array([[0, 0]]), geometry=0, radius=1) points = PointsModel.parse(np.array([[0, 0]])) labels = Labels2DModel.parse(np.array([[0, 0], [0, 0]]), dims=["y", "x"]) image = Image2DModel.parse(np.array([[[0, 0], [0, 0]]]), dims=["c", "y", "x"]) - with pytest.raises(ValueError): + with pytest.raises(KeyError): SpatialData(images={"image": image}, points={"image": points}) - with pytest.raises(ValueError): + with pytest.raises(KeyError): SpatialData(images={"image": image}, shapes={"image": shapes}) - with pytest.raises(ValueError): + with pytest.raises(KeyError): SpatialData(images={"image": image}, labels={"image": labels}) sdata = SpatialData( images={"image": image}, points={"points": points}, shapes={"shapes": shapes}, labels={"labels": labels} ) - with pytest.raises(ValueError): - sdata.add_image(name="points", image=image) - with pytest.raises(ValueError): - sdata.add_points(name="image", points=points) - with pytest.raises(ValueError): - sdata.add_shapes(name="image", shapes=shapes) - with pytest.raises(ValueError): - sdata.add_labels(name="image", labels=labels) + # add elements with the same name + # of element of same type + with pytest.warns(UserWarning): + sdata.images["image"] = image + with pytest.warns(UserWarning): + sdata.points["points"] = points + with pytest.warns(UserWarning): + sdata.shapes["shapes"] = shapes + with pytest.warns(UserWarning): + sdata.labels["labels"] = labels + + # add elements with the same name + # of element of different type + with pytest.raises(KeyError): + sdata.images["points"] = image + with pytest.raises(KeyError): + sdata.images["shapes"] = image + with pytest.raises(KeyError): + sdata.labels["points"] = labels + with pytest.raises(KeyError): + sdata.points["shapes"] = points + with pytest.raises(KeyError): + sdata.shapes["labels"] = shapes + + assert sdata["image"].shape == image.shape + assert sdata["labels"].shape == labels.shape + assert len(sdata["points"]) == len(points) + assert sdata["shapes"].shape == shapes.shape + + # add elements with the same name, test only couples of elements + with pytest.raises(KeyError): + sdata["labels"] = image + with pytest.warns(UserWarning): + sdata["points"] = points + # this should not raise warnings because it's a different (new) name + sdata["image2"] = image -def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: SpatialData): + # test replacing complete attribute + sdata = SpatialData( + images={"image": image}, points={"points": points}, shapes={"shapes": shapes}, labels={"labels": labels} + ) + # test for images + sdata.images = {"image2": image} + assert set(sdata.images.keys()) == {"image2"} + assert "image2" in sdata._shared_keys + assert "image" not in sdata._shared_keys + # test for labels + sdata.labels = {"labels2": labels} + assert set(sdata.labels.keys()) == {"labels2"} + assert "labels2" in sdata._shared_keys + assert "labels" not in sdata._shared_keys + # test for points + sdata.points = {"points2": points} + assert set(sdata.points.keys()) == {"points2"} + assert "points2" in sdata._shared_keys + assert "points" not in sdata._shared_keys + # test for points + sdata.shapes = {"shapes2": shapes} + assert set(sdata.shapes.keys()) == {"shapes2"} + assert "shapes2" in sdata._shared_keys + assert "shapes" not in sdata._shared_keys + + +def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: for element_type, element_name, element in sdata0._gen_elements(): elements = sdata1.__getattribute__(element_type) assert element_name in elements @@ -72,11 +125,11 @@ def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: S raise TypeError(f"Unsupported type {type(element)}") -def _assert_tables_seem_identical(table0: AnnData, table1: AnnData): - assert table0.shape == table1.shape +def _assert_tables_seem_identical(table0: AnnData | None, table1: AnnData | None) -> None: + assert table0 is None and table1 is None or table0.shape == table1.shape -def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: SpatialData): +def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: # this is not a full comparison, but it's fine anyway assert len(list(sdata0._gen_elements())) == len(list(sdata1._gen_elements())) assert set(sdata0.coordinate_systems) == set(sdata1.coordinate_systems) @@ -85,7 +138,7 @@ def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: Spat _assert_tables_seem_identical(sdata0.table, sdata1.table) -def test_filter_by_coordinate_system(full_sdata): +def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: sdata = full_sdata.filter_by_coordinate_system(coordinate_system="global", filter_table=False) _assert_spatialdata_objects_seem_identical(sdata, full_sdata) @@ -95,16 +148,16 @@ def test_filter_by_coordinate_system(full_sdata): set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") sdata_my_space = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) - assert len(list(sdata_my_space._gen_elements())) == 2 + assert len(list(sdata_my_space.gen_elements())) == 3 _assert_tables_seem_identical(sdata_my_space.table, full_sdata.table) sdata_my_space1 = full_sdata.filter_by_coordinate_system( coordinate_system=["my_space0", "my_space1", "my_space2"], filter_table=False ) - assert len(list(sdata_my_space1._gen_elements())) == 3 + assert len(list(sdata_my_space1.gen_elements())) == 4 -def test_filter_by_coordinate_system_also_table(full_sdata): +def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None: from spatialdata.models import TableModel rng = np.random.default_rng(seed=0) @@ -128,7 +181,7 @@ def test_filter_by_coordinate_system_also_table(full_sdata): assert len(filtered_sdata2.table) == len(full_sdata.table) -def test_rename_coordinate_systems(full_sdata): +def test_rename_coordinate_systems(full_sdata: SpatialData) -> None: # all the elements point to global, add new coordinate systems set_transformation( element=full_sdata.shapes["circles"], transformation=Identity(), to_coordinate_system="my_space0" @@ -181,7 +234,7 @@ def test_rename_coordinate_systems(full_sdata): assert elements_in_global_before == elements_in_global_after -def test_concatenate_tables(): +def test_concatenate_tables() -> None: """ The concatenation uses AnnData.concatenate(), here we test the concatenation result on region, region_key, instance_key @@ -226,7 +279,7 @@ def test_concatenate_tables(): ) -def test_concatenate_sdatas(full_sdata): +def test_concatenate_sdatas(full_sdata: SpatialData) -> None: with pytest.raises(KeyError): concatenate([full_sdata, SpatialData(images={"image2d": full_sdata.images["image2d"]})]) with pytest.raises(KeyError): @@ -241,7 +294,7 @@ def test_concatenate_sdatas(full_sdata): set_transformation(full_sdata.shapes["circles"], Identity(), "my_space0") set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") filtered = full_sdata.filter_by_coordinate_system(coordinate_system=["my_space0", "my_space1"], filter_table=False) - assert len(list(filtered._gen_elements())) == 2 + assert len(list(filtered.gen_elements())) == 3 filtered0 = filtered.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) filtered1 = filtered.filter_by_coordinate_system(coordinate_system="my_space1", filter_table=False) # this is needed cause we can't handle regions with same name. @@ -252,23 +305,22 @@ def test_concatenate_sdatas(full_sdata): filtered1.table = table_new filtered1.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_region filtered1.table.obs[filtered1.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]] = new_region - concatenated = concatenate([filtered0, filtered1]) - assert len(list(concatenated._gen_elements())) == 2 + concatenated = concatenate([filtered0, filtered1], concatenate_tables=True) + assert len(list(concatenated.gen_elements())) == 3 -def test_locate_spatial_element(full_sdata): - assert full_sdata._locate_spatial_element(full_sdata.images["image2d"]) == ("image2d", "images") +def test_locate_spatial_element(full_sdata: SpatialData) -> None: + assert full_sdata.locate_element(full_sdata.images["image2d"])[0] == "images/image2d" im = full_sdata.images["image2d"] del full_sdata.images["image2d"] - with pytest.raises(ValueError, match="Element not found in the SpatialData object."): - full_sdata._locate_spatial_element(im) + assert full_sdata.locate_element(im) is None full_sdata.images["image2d"] = im full_sdata.images["image2d_again"] = im - with pytest.raises(ValueError): - full_sdata._locate_spatial_element(im) + paths = full_sdata.locate_element(im) + assert len(paths) == 2 -def test_get_item(points): +def test_get_item(points: SpatialData) -> None: assert id(points["points_0"]) == id(points.points["points_0"]) # removed this test after this change: https://github.com/scverse/spatialdata/pull/145#discussion_r1133122720 @@ -282,20 +334,14 @@ def test_get_item(points): _ = points["not_present"] -def test_set_item(full_sdata): +def test_set_item(full_sdata: SpatialData) -> None: for name in ["image2d", "labels2d", "points_0", "circles", "poly"]: full_sdata[name + "_again"] = full_sdata[name] - with pytest.raises(KeyError): + with pytest.warns(UserWarning): full_sdata[name] = full_sdata[name] - with tempfile.TemporaryDirectory() as tmpdir: - full_sdata.write(Path(tmpdir) / "test.zarr") - for name in ["image2d", "labels2d", "points_0"]: - # trying to overwrite the file used for backing (only for images, labels and points) - with pytest.raises(ValueError): - full_sdata[name] = full_sdata[name] -def test_no_shared_transformations(): +def test_no_shared_transformations() -> None: """Test transformation dictionary copy for transformations not to be shared.""" sdata = blobs() element_name = "blobs_image" @@ -303,15 +349,41 @@ def test_no_shared_transformations(): set_transformation(sdata.images[element_name], Identity(), to_coordinate_system=test_space) gen = sdata._gen_elements() - for _, name, obj in gen: - if name != element_name: - assert test_space not in get_transformation(obj, get_all=True) - else: - assert test_space in get_transformation(obj, get_all=True) + for element_type, name, obj in gen: + if element_type != "tables": + if name != element_name: + assert test_space not in get_transformation(obj, get_all=True) + else: + assert test_space in get_transformation(obj, get_all=True) -def test_init_from_elements(full_sdata): +def test_init_from_elements(full_sdata: SpatialData) -> None: all_elements = {name: el for _, name, el in full_sdata._gen_elements()} sdata = SpatialData.init_from_elements(all_elements, table=full_sdata.table) for element_type in ["images", "labels", "points", "shapes"]: assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) + + +def test_subset(full_sdata: SpatialData) -> None: + element_names = ["image2d", "points_0", "circles", "poly"] + subset0 = full_sdata.subset(element_names) + unique_names = set() + for _, k, _ in subset0.gen_spatial_elements(): + unique_names.add(k) + assert "image3d_xarray" in full_sdata.images + assert unique_names == set(element_names) + assert subset0.table is None + + adata = AnnData( + shape=(10, 0), + obs={"region": ["circles"] * 5 + ["poly"] * 5, "instance_id": [0, 1, 2, 3, 4, "a", "b", "c", "d", "e"]}, + ) + del full_sdata.table + sdata_table = TableModel.parse(adata, region=["circles", "poly"], region_key="region", instance_key="instance_id") + full_sdata.table = sdata_table + full_sdata.tables["second_table"] = sdata_table + subset1 = full_sdata.subset(["poly", "second_table"]) + assert subset1.table is not None + assert len(subset1.table) == 5 + assert subset1.table.obs["region"].unique().tolist() == ["poly"] + assert len(subset1["second_table"]) == 10 diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index f28b345f..5c11083d 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -8,7 +8,8 @@ from geopandas.testing import geom_almost_equals from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage -from spatialdata import SpatialData, transform +from spatialdata import transform +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import unpad_raster from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names from spatialdata.transformations.operations import ( @@ -462,7 +463,7 @@ def test_transform_elements_and_entire_spatial_data_object(sdata: SpatialData): # TODO: we are just applying the transformation, # we are not checking it is correct. We could improve this test scale = Scale([2], axes=("x",)) - for element in sdata._gen_elements_values(): + for element in sdata._gen_spatial_element_values(): set_transformation(element, scale, "my_space") sdata.transform_element_to_coordinate_system(element, "my_space") sdata.transform_to_coordinate_system("my_space") diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 6db7e904..04ed6b11 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -2,17 +2,18 @@ import numpy as np import pytest +import xarray from anndata import AnnData from multiscale_spatial_image import MultiscaleSpatialImage from shapely import Polygon from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.query.spatial_query import ( BaseSpatialRequest, BoundingBoxRequest, bounding_box_query, polygon_query, ) +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import ( Image2DModel, Image3DModel, @@ -101,23 +102,54 @@ def test_bounding_box_request_wrong_coordinate_order(): ) -def test_bounding_box_points(): +@pytest.mark.parametrize("is_3d", [True, False]) +@pytest.mark.parametrize("is_bb_3d", [True, False]) +def test_bounding_box_points(is_3d: bool, is_bb_3d: bool): """test the points bounding box_query""" - points_element = _make_points(np.array([[10, 10], [20, 20], [20, 30]])) - original_x = np.array(points_element["x"]) - original_y = np.array(points_element["y"]) + data_x = np.array([10, 20, 20, 20]) + data_y = np.array([10, 20, 30, 30]) + data_z = np.array([100, 200, 200, 300]) + + data = np.stack((data_x, data_y), axis=1) + if is_3d: + data = np.hstack((data, data_z.reshape(-1, 1))) + points_element = _make_points(data) + + original_x = points_element["x"] + original_y = points_element["y"] + if is_3d: + original_z = points_element["z"] + + if is_bb_3d: + _min_coordinate = np.array([18, 25, 250]) + _max_coordinate = np.array([22, 35, 350]) + _axes = ("x", "y", "z") + else: + _min_coordinate = np.array([18, 25]) + _max_coordinate = np.array([22, 35]) + _axes = ("x", "y") points_result = bounding_box_query( points_element, - axes=("x", "y"), - min_coordinate=np.array([18, 25]), - max_coordinate=np.array([22, 35]), + axes=_axes, + min_coordinate=_min_coordinate, + max_coordinate=_max_coordinate, target_coordinate_system="global", ) # Check that the correct point was selected - np.testing.assert_allclose(points_result["x"].compute(), [20]) - np.testing.assert_allclose(points_result["y"].compute(), [30]) + if is_3d: + if is_bb_3d: + np.testing.assert_allclose(points_result["x"].compute(), [20]) + np.testing.assert_allclose(points_result["y"].compute(), [30]) + np.testing.assert_allclose(points_result["z"].compute(), [300]) + else: + np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result["z"].compute(), [200, 300]) + else: + np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) # result should be valid points element PointsModel.validate(points_result) @@ -125,6 +157,8 @@ def test_bounding_box_points(): # original element should be unchanged np.testing.assert_allclose(points_element["x"].compute(), original_x) np.testing.assert_allclose(points_element["y"].compute(), original_y) + if is_3d: + np.testing.assert_allclose(points_element["z"].compute(), original_z) def test_bounding_box_points_no_points(): @@ -142,57 +176,74 @@ def test_bounding_box_points_no_points(): assert request is None +# @pytest.mark.parametrize("n_channels", [1, 2, 3]) @pytest.mark.parametrize("n_channels", [1, 2, 3]) -def test_bounding_box_image_2d(n_channels): - """Apply a bounding box to a 2D image""" - image = np.zeros((n_channels, 10, 10)) - # y: [5, 9], x: [0, 4] has value 1 - image[:, 5::, 0:5] = 1 - image_element = Image2DModel.parse(image) - image_element_multiscale = Image2DModel.parse(image, scale_factors=[2, 2]) +@pytest.mark.parametrize("is_labels", [True, False]) +@pytest.mark.parametrize("is_3d", [True, False]) +@pytest.mark.parametrize("is_bb_3d", [True, False]) +def test_bounding_box_raster(n_channels: int, is_labels: bool, is_3d: bool, is_bb_3d: bool): + """Apply a bounding box to a raster element.""" + if is_labels and n_channels > 1: + # labels cannot have multiple channels, let's ignore this combination of parameters + return + + shape = (10, 10) + if is_3d: + shape = (10,) + shape + shape = (n_channels,) + shape if not is_labels else (1,) + shape + + image = np.zeros(shape) + axes = ["y", "x"] + if is_3d: + image[:, 2:7, 5::, 0:5] = 1 + axes = ["z"] + axes + else: + image[:, 5::, 0:5] = 1 + + if is_labels: + image = np.squeeze(image, axis=0) + else: + axes = ["c"] + axes + + ximage = xarray.DataArray(image, dims=axes) + model = ( + Labels3DModel + if is_labels and is_3d + else Labels2DModel + if is_labels + else Image3DModel + if is_3d + else Image2DModel + ) - for image in [image_element, image_element_multiscale]: - # bounding box: y: [5, 10[, x: [0, 5[ - image_result = bounding_box_query( - image, - axes=("y", "x"), - min_coordinate=np.array([5, 0]), - max_coordinate=np.array([10, 5]), - target_coordinate_system="global", - ) - expected_image = np.ones((n_channels, 5, 5)) # c dimension is preserved - if isinstance(image, SpatialImage): - assert isinstance(image, SpatialImage) - np.testing.assert_allclose(image_result, expected_image) - elif isinstance(image, MultiscaleSpatialImage): - assert isinstance(image_result, MultiscaleSpatialImage) - v = image_result["scale0"].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - np.testing.assert_allclose(xdata, expected_image) - else: - raise ValueError("Unexpected type") + image_element = model.parse(image) + image_element_multiscale = model.parse(image, scale_factors=[2, 2]) + images = [image_element, image_element_multiscale] -@pytest.mark.parametrize("n_channels", [1, 2, 3]) -def test_bounding_box_image_3d(n_channels): - """Apply a bounding box to a 3D image""" - image = np.zeros((n_channels, 10, 10, 10)) - # z: [5, 9], y: [0, 4], x: [2, 6] has value 1 - image[:, 5::, 0:5, 2:7] = 1 - image_element = Image3DModel.parse(image) - image_element_multiscale = Image3DModel.parse(image, scale_factors=[2, 2]) + for image in images: + if is_bb_3d: + _min_coordinate = np.array([2, 5, 0]) + _max_coordinate = np.array([7, 10, 5]) + _axes = ("z", "y", "x") + else: + _min_coordinate = np.array([5, 0]) + _max_coordinate = np.array([10, 5]) + _axes = ("y", "x") - for image in [image_element, image_element_multiscale]: - # bounding box: z: [5, 10[, y: [0, 5[, x: [2, 7[ image_result = bounding_box_query( image, - axes=("z", "y", "x"), - min_coordinate=np.array([5, 0, 2]), - max_coordinate=np.array([10, 5, 7]), + axes=_axes, + min_coordinate=_min_coordinate, + max_coordinate=_max_coordinate, target_coordinate_system="global", ) - expected_image = np.ones((n_channels, 5, 5, 5)) # c dimension is preserved + + slices = {"y": slice(5, 10), "x": slice(0, 5)} + if is_bb_3d and is_3d: + slices["z"] = slice(2, 7) + expected_image = ximage.sel(**slices) + if isinstance(image, SpatialImage): assert isinstance(image, SpatialImage) np.testing.assert_allclose(image_result, expected_image) @@ -206,69 +257,6 @@ def test_bounding_box_image_3d(n_channels): raise ValueError("Unexpected type") -def test_bounding_box_labels_2d(): - """Apply a bounding box to a 2D label image""" - # in this test let's try some affine transformations, we could do that also for the other tests - image = np.zeros((10, 10)) - # y: [5, 9], x: [0, 4] has value 1 - image[5::, 0:5] = 1 - labels_element = Labels2DModel.parse(image) - labels_element_multiscale = Labels2DModel.parse(image, scale_factors=[2, 2]) - - for labels in [labels_element, labels_element_multiscale]: - # bounding box: y: [5, 10[, x: [0, 5[ - labels_result = bounding_box_query( - labels, - axes=("y", "x"), - min_coordinate=np.array([5, 0]), - max_coordinate=np.array([10, 5]), - target_coordinate_system="global", - ) - expected_image = np.ones((5, 5)) - if isinstance(labels, SpatialImage): - assert isinstance(labels, SpatialImage) - np.testing.assert_allclose(labels_result, expected_image) - elif isinstance(labels, MultiscaleSpatialImage): - assert isinstance(labels_result, MultiscaleSpatialImage) - v = labels_result["scale0"].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - np.testing.assert_allclose(xdata, expected_image) - else: - raise ValueError("Unexpected type") - - -def test_bounding_box_labels_3d(): - """Apply a bounding box to a 3D label image""" - image = np.zeros((10, 10, 10), dtype=int) - # z: [5, 9], y: [0, 4], x: [2, 6] has value 1 - image[5::, 0:5, 2:7] = 1 - labels_element = Labels3DModel.parse(image) - labels_element_multiscale = Labels3DModel.parse(image, scale_factors=[2, 2]) - - for labels in [labels_element, labels_element_multiscale]: - # bounding box: z: [5, 10[, y: [0, 5[, x: [2, 7[ - labels_result = bounding_box_query( - labels, - axes=("z", "y", "x"), - min_coordinate=np.array([5, 0, 2]), - max_coordinate=np.array([10, 5, 7]), - target_coordinate_system="global", - ) - expected_image = np.ones((5, 5, 5)) - if isinstance(labels, SpatialImage): - assert isinstance(labels, SpatialImage) - np.testing.assert_allclose(labels_result, expected_image) - elif isinstance(labels, MultiscaleSpatialImage): - assert isinstance(labels_result, MultiscaleSpatialImage) - v = labels_result["scale0"].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - np.testing.assert_allclose(xdata, expected_image) - else: - raise ValueError("Unexpected type") - - # TODO: more tests can be added for spatial queries after the cases 2, 3, 4 are implemented # (see https://github.com/scverse/spatialdata/pull/151, also for details on more tests) @@ -323,7 +311,7 @@ def test_bounding_box_spatial_data(full_sdata): _assert_spatialdata_objects_seem_identical(result, result2) - for element in result._gen_elements_values(): + for element in result._gen_spatial_element_values(): d = get_transformation(element, get_all=True) new_d = {k.replace("global", "cropped"): v for k, v in d.items()} set_transformation(element, new_d, set_all=True) @@ -338,7 +326,7 @@ def test_bounding_box_filter_table(): table.obs["region"] = ["circles0", "circles0", "circles1"] table.obs["instance"] = [0, 1, 0] table = TableModel.parse(table, region=["circles0", "circles1"], region_key="region", instance_key="instance") - sdata = SpatialData(shapes={"circles0": circles0, "circles1": circles1}, table=table) + sdata = SpatialData(shapes={"circles0": circles0, "circles1": circles1}, tables=table) queried0 = sdata.query.bounding_box( axes=("y", "x"), min_coordinate=np.array([15, 15]), @@ -364,7 +352,7 @@ def test_polygon_query_points(sdata_query_aggregation): queried = polygon_query(sdata, polygons=polygon, target_coordinate_system="global", shapes=False, points=True) points = queried["points"].compute() assert len(points) == 6 - assert len(queried.table) == 0 + assert queried.table is None # TODO: the case of querying points with multiple polygons is not currently implemented @@ -373,7 +361,7 @@ def test_polygon_query_shapes(sdata_query_aggregation): sdata = sdata_query_aggregation values_sdata = SpatialData( shapes={"values_polygons": sdata["values_polygons"], "values_circles": sdata["values_circles"]}, - table=sdata.table, + tables=sdata.table, ) polygon = sdata["by_polygons"].geometry.iloc[0] circle = sdata["by_circles"].geometry.iloc[0] @@ -427,7 +415,7 @@ def test_polygon_query_spatial_data(sdata_query_aggregation): "values_circles": sdata["values_circles"], }, points={"points": sdata["points"]}, - table=sdata.table, + tables=sdata.table, ) polygon = sdata["by_polygons"].geometry.iloc[0] queried = polygon_query(values_sdata, polygons=polygon, target_coordinate_system="global", shapes=True, points=True) diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 59126996..dac01e80 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -4,59 +4,112 @@ import pandas as pd import pytest from anndata import AnnData -from spatialdata.dataloader.datasets import ImageTilesDataset +from spatialdata._core.spatialdata import SpatialData +from spatialdata.dataloader import ImageTilesDataset from spatialdata.models import TableModel -@pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"]) -@pytest.mark.parametrize( - "regions_element", - ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], -) -def test_tiles_dataset(sdata_blobs, image_element, regions_element): - if regions_element in ["blobs_labels", "blobs_multipolygons", "blobs_multiscale_labels"]: - cm = pytest.raises(NotImplementedError) - else: - cm = contextlib.nullcontext() - with cm: +class TestImageTilesDataset: + @pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"]) + @pytest.mark.parametrize( + "regions_element", + ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], + ) + def test_validation(self, sdata_blobs, image_element, regions_element): + if regions_element in ["blobs_labels", "blobs_multiscale_labels"] or image_element == "blobs_multiscale_image": + cm = pytest.raises(NotImplementedError) + elif regions_element in ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]: + cm = pytest.raises(ValueError) + else: + cm = contextlib.nullcontext() + with cm: + _ = ImageTilesDataset( + sdata=sdata_blobs, + regions_to_images={regions_element: image_element}, + regions_to_coordinate_systems={regions_element: "global"}, + ) + + @pytest.mark.parametrize("regions_element", ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]) + @pytest.mark.parametrize("raster", [True, False]) + def test_default(self, sdata_blobs, regions_element, raster): + raster_kwargs = {"target_unit_to_pixels": 2} if raster else {} + + sdata = self._annotate_shapes(sdata_blobs, regions_element) ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={regions_element: image_element}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", + sdata=sdata, + rasterize=raster, + regions_to_images={regions_element: "blobs_image"}, + regions_to_coordinate_systems={regions_element: "global"}, + rasterize_kwargs=raster_kwargs, ) - tile = ds[0].images.values().__iter__().__next__() - assert tile.shape == (3, 32, 32) + sdata_tile = ds[0] + tile = sdata_tile.images.values().__iter__().__next__() -def test_tiles_table(sdata_blobs): - new_table = AnnData( - X=np.random.default_rng().random((3, 10)), - obs=pd.DataFrame({"region": "blobs_circles", "instance_id": np.array([0, 1, 2])}), - ) - new_table = TableModel.parse(new_table, region="blobs_circles", region_key="region", instance_key="instance_id") - del sdata_blobs.table - sdata_blobs.table = new_table - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={"blobs_circles": "blobs_image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - assert len(ds) == 3 - assert len(ds[0].table) == 1 - assert np.all(ds[0].table.X == new_table[0].X) - - -def test_tiles_multiple_elements(sdata_blobs): - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={"blobs_circles": "blobs_image", "blobs_polygons": "blobs_multiscale_image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - assert len(ds) == 6 - _ = ds[0] + if regions_element == "blobs_circles": + if raster: + assert tile.shape == (3, 50, 50) + else: + assert tile.shape == (3, 25, 25) + elif regions_element == "blobs_polygons": + if raster: + assert tile.shape == (3, 164, 164) + else: + assert tile.shape == (3, 82, 82) + elif regions_element == "blobs_multipolygons": + if raster: + assert tile.shape == (3, 329, 329) + else: + assert tile.shape == (3, 165, 164) + else: + raise ValueError(f"Unexpected regions_element: {regions_element}") + + # extent has units in pixel so should be the same as tile shape + if raster: + assert round(ds.tiles_coords.extent.unique()[0] * 2) == tile.shape[1] + else: + if regions_element != "blobs_multipolygons": + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + else: + assert int(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] + assert np.all(sdata_tile.table.obs.columns == ds.sdata.table.obs.columns) + assert list(sdata_tile.images.keys())[0] == "blobs_image" + + @pytest.mark.parametrize("regions_element", ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]) + @pytest.mark.parametrize("return_annot", ["region", ["region", "instance_id"]]) + def test_return_annot(self, sdata_blobs, regions_element, return_annot): + sdata = self._annotate_shapes(sdata_blobs, regions_element) + ds = ImageTilesDataset( + sdata=sdata, + regions_to_images={regions_element: "blobs_image"}, + regions_to_coordinate_systems={regions_element: "global"}, + return_annotations=return_annot, + ) + + tile, annot = ds[0] + if regions_element == "blobs_circles": + assert tile.shape == (3, 25, 25) + elif regions_element == "blobs_polygons": + assert tile.shape == (3, 82, 82) + elif regions_element == "blobs_multipolygons": + assert tile.shape == (3, 165, 164) + else: + raise ValueError(f"Unexpected regions_element: {regions_element}") + # extent has units in pixel so should be the same as tile shape + if regions_element != "blobs_multipolygons": + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + else: + assert round(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] + return_annot = [return_annot] if isinstance(return_annot, str) else return_annot + assert annot.shape[1] == len(return_annot) + + # TODO: consider adding this logic to blobs, to generate blobs with arbitrary table annotation + def _annotate_shapes(self, sdata: SpatialData, shape: str) -> SpatialData: + new_table = AnnData( + X=np.random.default_rng().random((len(sdata[shape]), 10)), + obs=pd.DataFrame({"region": shape, "instance_id": sdata[shape].index.values}), + ) + new_table = TableModel.parse(new_table, region=shape, region_key="region", instance_key="instance_id") + del sdata.table + sdata.table = new_table + return sdata diff --git a/tests/dataloader/test_transforms.py b/tests/dataloader/test_transforms.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index f7d4671c..56755b79 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -1,94 +1,172 @@ from pathlib import Path -import anndata as ad -import numpy as np +import pytest from anndata import AnnData -from spatialdata import SpatialData +from anndata.tests.helpers import assert_equal +from spatialdata import SpatialData, concatenate +from spatialdata.models import TableModel -from tests.conftest import _get_new_table, _get_shapes +from tests.conftest import _get_shapes, _get_table # notes on paths: https://github.com/orgs/scverse/projects/17/views/1?pane=issue&itemId=44066734 -# notes for the people (to prettify) https://hackmd.io/wd7K4Eg1SlykKVN-nOP44w - -# shapes test_shapes = _get_shapes() -instance_id = np.array([str(i) for i in range(5)]) -table = _get_new_table(spatial_element="test_shapes", instance_id=instance_id) -adata0 = _get_new_table() -adata1 = _get_new_table() - # shuffle the indices of the dataframe -np.random.default_rng().shuffle(test_shapes["poly"].index) +# np.random.default_rng().shuffle(test_shapes["poly"].index) -# tables is a dict -SpatialData.tables - -# def get_table_keys(sdata: SpatialData) -> tuple[list[str], str, str]: -# d = sdata.table.uns[sd.models.TableModel.ATTRS_KEY] -# return d['region'], d['region_key'], d['instance_key'] -# -# @staticmethod -# def SpatialData.get_key_column(table: AnnData, key_column: str) -> ...: -# region, region_key, instance_key = sd.models.get_table_keys() -# if key_clumns == 'region_key': -# return table.obs[region_key] -# else: .... -# -# @staticmethod -# def SpatialData.get_region_key_column(table: AnnData | str): -# return get_key_column(...) -# @staticmethod -# def SpatialData.get_instance_key_column(table: AnnData | str): -# return get_key_column(...) +class TestMultiTable: + def test_set_get_tables_from_spatialdata(self, full_sdata: SpatialData, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + adata0 = _get_table(region="polygon") + adata1 = _get_table(region="multipolygon") + full_sdata["adata0"] = adata0 + full_sdata["adata1"] = adata1 + + adata2 = adata0.copy() + del adata2.obs["region"] + # fails because either none either all three 'region', 'region_key', 'instance_key' are required + with pytest.raises(ValueError): + full_sdata["not_added_table"] = adata2 + + assert len(full_sdata.tables) == 3 + assert "adata0" in full_sdata.tables and "adata1" in full_sdata.tables + full_sdata.write(tmpdir) + + full_sdata = SpatialData.read(tmpdir) + assert_equal(adata0, full_sdata["adata0"]) + assert_equal(adata1, full_sdata["adata1"]) + assert "adata0" in full_sdata.tables and "adata1" in full_sdata.tables + + @pytest.mark.parametrize( + "region_key, instance_key, error_msg", + [ + ( + None, + None, + "Specified instance_key in table.uns 'instance_id' is not present as column in table.obs. " + "Please specify instance_key.", + ), + ( + "region", + None, + "Specified instance_key in table.uns 'instance_id' is not present as column in table.obs. " + "Please specify instance_key.", + ), + ("region", "instance_id", "Instance key column 'instance_id' not found in table.obs."), + (None, "instance_id", "Instance key column 'instance_id' not found in table.obs."), + ], + ) + def test_change_annotation_target(self, full_sdata, region_key, instance_key, error_msg): + n_obs = full_sdata["table"].n_obs + ## + with pytest.raises( + ValueError, match=r"Mismatch\(es\) found between regions in region column in obs and target element: " + ): + # ValueError: Mismatch(es) found between regions in region column in obs and target element: labels2d, poly + full_sdata.set_table_annotates_spatialelement("table", "poly") + ## + + del full_sdata["table"].obs["region"] + with pytest.raises( + ValueError, + match="Specified region_key in table.uns 'region' is not present as column in table.obs. " + "Please specify region_key.", + ): + full_sdata.set_table_annotates_spatialelement("table", "poly") + + del full_sdata["table"].obs["instance_id"] + full_sdata["table"].obs["region"] = ["poly"] * n_obs + with pytest.raises(ValueError, match=error_msg): + full_sdata.set_table_annotates_spatialelement( + "table", "poly", region_key=region_key, instance_key=instance_key + ) + + full_sdata["table"].obs["instance_id"] = range(n_obs) + full_sdata.set_table_annotates_spatialelement( + "table", "poly", instance_key="instance_id", region_key=region_key + ) -# we need also the two set_...() functions + with pytest.raises(ValueError, match="'not_existing' column not present in table.obs"): + full_sdata.set_table_annotates_spatialelement("table", "circles", region_key="not_existing") + + def test_set_table_nonexisting_target(self, full_sdata): + with pytest.raises( + ValueError, + match="Annotation target 'non_existing' not present as SpatialElement in " "SpatialData object.", + ): + full_sdata.set_table_annotates_spatialelement("table", "non_existing") + + def test_set_table_annotates_spatialelement(self, full_sdata): + del full_sdata["table"].uns[TableModel.ATTRS_KEY] + with pytest.raises( + TypeError, match="No current annotation metadata found. " "Please specify both region_key and instance_key." + ): + full_sdata.set_table_annotates_spatialelement("table", "labels2d", region_key="non_existent") + with pytest.raises(ValueError, match="Instance key column 'non_existent' not found in table.obs."): + full_sdata.set_table_annotates_spatialelement( + "table", "labels2d", region_key="region", instance_key="non_existent" + ) + with pytest.raises(ValueError, match="column not present"): + full_sdata.set_table_annotates_spatialelement( + "table", "labels2d", region_key="non_existing", instance_key="instance_id" + ) + full_sdata.set_table_annotates_spatialelement( + "table", "labels2d", region_key="region", instance_key="instance_id" + ) + def test_old_accessor_deprecation(self, full_sdata, tmp_path): + # To test self._backed + tmpdir = Path(tmp_path) / "tmp.zarr" + full_sdata.write(tmpdir) + adata0 = _get_table(region="polygon") -def get_annotation_target_of_table(table: AnnData) -> pd.Series: - return SpatialData.get_region_key_column(table) + with pytest.warns(DeprecationWarning): + _ = full_sdata.table + with pytest.raises(ValueError): + full_sdata.table = adata0 + with pytest.warns(DeprecationWarning): + del full_sdata.table + with pytest.raises(KeyError): + del full_sdata.table + with pytest.warns(DeprecationWarning): + full_sdata.table = adata0 # this gets placed in sdata['table'] + assert_equal(adata0, full_sdata.table) -def set_annotation_target_of_table(table: AnnData, spatial_element: str | pd.Series) -> None: - SpatialData.set_instance_key_column(table, spatial_element) + del full_sdata.table + full_sdata.tables["my_new_table0"] = adata0 + assert full_sdata.table is None -class TestMultiTable: - def test_set_get_tables_from_spatialdata(self, sdata): # sdata is form conftest - sdata["my_new_table0"] = adata0 - sdata["my_new_table1"] = adata1 - - def test_old_accessor_deprecation(self, sdata): - # assume no table is present - # this prints a deprecation warning - sdata.table = adata0 # this gets placed in sdata['table'] - # this prints a deprecation warning - _ = sdata.table # this returns sdata['table'] - # this prints a deprecation waring - del sdata.table - - sdata["my_new_table0"] = adata0 - # will fail, because there is no sdata['table'], even if another table is present - _ = sdata.table - - def test_single_table(self, tmp_path: str): - # shared table + @pytest.mark.parametrize("region", ["test_shapes", "non_existing"]) + def test_single_table(self, tmp_path: str, region: str): tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region=region) + + # Create shapes dictionary + shapes_dict = { + "test_shapes": test_shapes["poly"], + } + + if region == "non_existing": + with pytest.warns( + UserWarning, match=r"The table is annotating an/some element\(s\) not present in the SpatialData object" + ): + SpatialData( + shapes=shapes_dict, + tables={"shape_annotate": table}, + ) test_sdata = SpatialData( - shapes={ - "test_shapes": test_shapes["poly"], - }, + shapes=shapes_dict, tables={"shape_annotate": table}, ) + test_sdata.write(tmpdir) sdata = SpatialData.read(tmpdir) - assert sdata.get("segmentation") - assert isinstance(sdata["segmentation"], AnnData) - from anndata.tests.helpers import assert_equal - - assert assert_equal(test_sdata["segmentation"], sdata["segmentation"]) + assert isinstance(sdata["shape_annotate"], AnnData) + assert_equal(test_sdata["shape_annotate"], sdata["shape_annotate"]) # note (to keep in the code): these tests here should silmulate the interactions from teh users; if the syntax # here we are matching the table to the shapes and viceversa (= subset + reordeing) @@ -107,28 +185,41 @@ def test_single_table(self, tmp_path: str): # assert ... def test_paired_elements_tables(self, tmp_path: str): - pass - - def test_elements_transfer_annotation(self, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region="poly") + table2 = _get_table(region="multipoly") + table3 = _get_table(region="non_existing") + with pytest.warns( + UserWarning, match=r"The table is annotating an/some element\(s\) not present in the SpatialData object" + ): + SpatialData( + shapes={"poly": test_shapes["poly"], "multipoly": test_shapes["multipoly"]}, + table={"poly_annotate": table, "multipoly_annotate": table3}, + ) test_sdata = SpatialData( - shapes={"test_shapes": test_shapes["poly"], "test_multipoly": test_shapes["multipoly"]}, - tables={"segmentation": table}, + shapes={"poly": test_shapes["poly"], "multipoly": test_shapes["multipoly"]}, + table={"poly_annotate": table, "multipoly_annotate": table2}, ) - set_annotation_target_of_table(test_sdata["segmentation"], "test_multipoly") - assert get_annotation_target_of_table(test_sdata["segmentation"]) == "test_multipoly" + test_sdata.write(tmpdir) + test_sdata = SpatialData.read(tmpdir) + assert len(test_sdata.tables) == 2 def test_single_table_multiple_elements(self, tmp_path: str): tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region=["poly", "multipoly"]) + subset = table[table.obs.region == "multipoly"] + with pytest.raises(ValueError, match="Regions in"): + TableModel().validate(subset) test_sdata = SpatialData( shapes={ - "test_shapes": test_shapes["poly"], - "test_multipoly": test_shapes["multi_poly"], + "poly": test_shapes["poly"], + "multipoly": test_shapes["multipoly"], }, - tables={"segmentation": table}, + table=table, ) test_sdata.write(tmpdir) - # sdata = SpatialData.read(tmpdir) + SpatialData.read(tmpdir) # # use case example 1 # # sorting the shapes visium0 to match the order of the table @@ -141,41 +232,58 @@ def test_single_table_multiple_elements(self, tmp_path: str): # sub_table.obs[sdata["visium0"]] # assert ... - def test_concatenate_tables(self): - table_two = _get_new_table(spatial_element="test_multipoly", instance_id=np.array([str(i) for i in range(2)])) - concatenated_table = ad.concat([table, table_two]) - test_sdata = SpatialData( - shapes={ - "test_shapes": test_shapes["poly"], - "test_multipoly": test_shapes["multi_poly"], - }, - tables={"segmentation": concatenated_table}, - ) - # use case tests as above (we test only visium0) - - def test_multiple_table_without_element(self): - table = _get_new_table() - table_two = _get_new_table() + def test_multiple_table_without_element(self, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region=None, region_key=None, instance_key=None) + table_two = _get_table(region=None, region_key=None, instance_key=None) - test_sdata = SpatialData( + sdata = SpatialData( tables={"table": table, "table_two": table_two}, ) + sdata.write(tmpdir) + SpatialData.read(tmpdir) def test_multiple_tables_same_element(self, tmp_path: str): tmpdir = Path(tmp_path) / "tmp.zarr" - table_two = _get_new_table(spatial_element="test_shapes", instance_id=instance_id) + table = _get_table(region="test_shapes") + table2 = _get_table(region="test_shapes") test_sdata = SpatialData( shapes={ "test_shapes": test_shapes["poly"], }, - tables={"segmentation": table, "segmentation_two": table_two}, + tables={"table": table, "table2": table2}, ) test_sdata.write(tmpdir) + SpatialData.read(tmpdir) -# -# # these use cases could be the preferred one for the users; we need to choose one/two preferred ones (either this, either helper function, ...) +def test_concatenate_sdata_multitables(): + sdatas = [ + SpatialData( + shapes={f"poly_{i + 1}": test_shapes["poly"], f"multipoly_{i + 1}": test_shapes["multipoly"]}, + tables={"table": _get_table(region=f"poly_{i + 1}"), "table2": _get_table(region=f"multipoly_{i + 1}")}, + ) + for i in range(3) + ] + + with pytest.warns( + UserWarning, + match="Duplicate table names found.", + ): + concatenate(sdatas) + + merged_sdata = concatenate(sdatas, concatenate_tables=True) + assert merged_sdata.tables["table"].n_obs == 300 + assert merged_sdata.tables["table2"].n_obs == 300 + assert all(merged_sdata.tables["table"].obs.region.unique() == ["poly_1", "poly_2", "poly_3"]) + assert all(merged_sdata.tables["table2"].obs.region.unique() == ["multipoly_1", "multipoly_2", "multipoly_3"]) + + +# The following use cases needs to be put in the tutorial notebook, let's keep the comment here until we have the +# notebook ready. +# # these use cases could be the preferred one for the users; we need to choose one/two preferred ones (either this, +# either helper function, ...) # # use cases # # use case example 1 # # sorting the shapes to match the order of the table @@ -187,6 +295,7 @@ def test_multiple_tables_same_element(self, tmp_path: str): # sdata.table.obs[sdata["visium0"]] # assert ... # +# We can postpone the implemntation of this test when the functions "match_table_to_element" etc. are ready. # def test_partial_match(): # # the function spatialdata._core.query.relational_query.match_table_to_element(no s) needs to be modified (will be # # simpler), we need also a function match_element_to_table. Maybe we can have just one function doing both the diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 046baf3b..e629182d 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -12,7 +12,7 @@ from numpy.random import default_rng from shapely.geometry import Point from spatial_image import SpatialImage -from spatialdata import SpatialData +from spatialdata import SpatialData, read_zarr from spatialdata._io._utils import _are_directories_identical from spatialdata.models import TableModel from spatialdata.transformations.operations import ( @@ -110,7 +110,7 @@ def test_incremental_io( tmpdir = Path(tmp_path) / "tmp.zarr" sdata = full_sdata - sdata.add_image(name="sdata_not_saved_yet", image=_get_images().values().__iter__().__next__()) + sdata.images["sdata_not_saved_yet"] = _get_images().values().__iter__().__next__() sdata.write(tmpdir) for k, v in _get_images().items(): @@ -122,10 +122,10 @@ def test_incremental_io( assert len(names) == 1 name = names[0] v[scale] = v[scale].rename_vars({name: f"incremental_{k}"}) - sdata.add_image(name=f"incremental_{k}", image=v) - with pytest.raises(KeyError): - sdata.add_image(name=f"incremental_{k}", image=v) - sdata.add_image(name=f"incremental_{k}", image=v, overwrite=True) + sdata.images[f"incremental_{k}"] = v + with pytest.warns(UserWarning): + sdata.images[f"incremental_{k}"] = v + sdata[f"incremental_{k}"] = v for k, v in _get_labels().items(): if isinstance(v, SpatialImage): @@ -136,26 +136,26 @@ def test_incremental_io( assert len(names) == 1 name = names[0] v[scale] = v[scale].rename_vars({name: f"incremental_{k}"}) - sdata.add_labels(name=f"incremental_{k}", labels=v) - with pytest.raises(KeyError): - sdata.add_labels(name=f"incremental_{k}", labels=v) - sdata.add_labels(name=f"incremental_{k}", labels=v, overwrite=True) + sdata.labels[f"incremental_{k}"] = v + with pytest.warns(UserWarning): + sdata.labels[f"incremental_{k}"] = v + sdata[f"incremental_{k}"] = v for k, v in _get_shapes().items(): - sdata.add_shapes(name=f"incremental_{k}", shapes=v) - with pytest.raises(KeyError): - sdata.add_shapes(name=f"incremental_{k}", shapes=v) - sdata.add_shapes(name=f"incremental_{k}", shapes=v, overwrite=True) + sdata.shapes[f"incremental_{k}"] = v + with pytest.warns(UserWarning): + sdata.shapes[f"incremental_{k}"] = v + sdata[f"incremental_{k}"] = v break for k, v in _get_points().items(): - sdata.add_points(name=f"incremental_{k}", points=v) - with pytest.raises(KeyError): - sdata.add_points(name=f"incremental_{k}", points=v) - sdata.add_points(name=f"incremental_{k}", points=v, overwrite=True) + sdata.points[f"incremental_{k}"] = v + with pytest.warns(UserWarning): + sdata.points[f"incremental_{k}"] = v + sdata[f"incremental_{k}"] = v break - def test_incremental_io_table(self, table_single_annotation): + def test_incremental_io_table(self, table_single_annotation: SpatialData) -> None: s = table_single_annotation t = s.table[:10, :].copy() with pytest.raises(ValueError): @@ -182,8 +182,8 @@ def test_io_and_lazy_loading_points(self, points): f = os.path.join(td, "data.zarr") dask0 = points.points[elem_name] points.write(f) - dask1 = points.points[elem_name] assert all("read-parquet" not in key for key in dask0.dask.layers) + dask1 = read_zarr(f).points[elem_name] assert any("read-parquet" in key for key in dask1.dask.layers) def test_io_and_lazy_loading_raster(self, images, labels): @@ -198,6 +198,7 @@ def test_io_and_lazy_loading_raster(self, images, labels): sdata.write(f) dask1 = d[elem_name].data assert all("from-zarr" not in key for key in dask0.dask.layers) + dask1 = read_zarr(f)[elem_name].data assert any("from-zarr" in key for key in dask1.dask.layers) def test_replace_transformation_on_disk_raster(self, images, labels): @@ -238,12 +239,34 @@ def test_replace_transformation_on_disk_non_raster(self, shapes, points): t1 = get_transformation(SpatialData.read(f).__getattribute__(k)[elem_name]) assert type(t1) == Scale + def test_overwrite_files_without_backed_data(self, full_sdata): + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + old_data = SpatialData() + old_data.write(f) + # Since not backed, no risk of overwriting backing data. + # Should not raise "The file path specified is the same as the one used for backing." + full_sdata.write(f, overwrite=True) + + def test_not_overwrite_files_without_backed_data_but_with_dask_backed_data(self, full_sdata, points): + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + points.write(f) + points2 = SpatialData.read(f) + p = points2["points_0"] + full_sdata["points_0"] = p + with pytest.raises( + ValueError, + match="The file path specified is a parent directory of one or more files used for backing for one or ", + ): + full_sdata.write(f, overwrite=True) + def test_overwrite_files_with_backed_data(self, full_sdata): # addressing https://github.com/scverse/spatialdata/issues/137 with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") full_sdata.write(f) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="The file path specified is the same as the one used for backing."): full_sdata.write(f, overwrite=True) # support for overwriting backed sdata has been temporarily removed @@ -275,45 +298,6 @@ def test_overwrite_onto_non_zarr_file(self, full_sdata): with pytest.raises(ValueError): full_sdata.write(f1) - def test_incremental_io_with_backed_elements(self, full_sdata): - # addressing https://github.com/scverse/spatialdata/issues/137 - # we test also the non-backed case so that if we switch to the - # backed version in the future we already have the tests - - with tempfile.TemporaryDirectory() as tmpdir: - f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) - - e = full_sdata.images.values().__iter__().__next__() - full_sdata.add_image("new_images", e, overwrite=True) - # support for overwriting backed images has been temporarily removed - with pytest.raises(ValueError): - full_sdata.add_image("new_images", full_sdata.images["new_images"], overwrite=True) - - e = full_sdata.labels.values().__iter__().__next__() - full_sdata.add_labels("new_labels", e, overwrite=True) - # support for overwriting backed labels has been temporarily removed - with pytest.raises(ValueError): - full_sdata.add_labels("new_labels", full_sdata.labels["new_labels"], overwrite=True) - - e = full_sdata.points.values().__iter__().__next__() - full_sdata.add_points("new_points", e, overwrite=True) - # support for overwriting backed points has been temporarily removed - with pytest.raises(ValueError): - full_sdata.add_points("new_points", full_sdata.points["new_points"], overwrite=True) - - e = full_sdata.shapes.values().__iter__().__next__() - full_sdata.add_shapes("new_shapes", e, overwrite=True) - full_sdata.add_shapes("new_shapes", full_sdata.shapes["new_shapes"], overwrite=True) - - # commenting out as it is failing - # f2 = os.path.join(tmpdir, "data2.zarr") - # sdata2 = SpatialData(table=full_sdata.table.copy()) - # sdata2.write(f2) - # del full_sdata.table - # full_sdata.table = sdata2.table - # full_sdata.write(f2, overwrite=True) - def test_io_table(shapes): adata = AnnData(X=RNG.normal(size=(5, 10))) diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 3e2cb04f..d8f86c44 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -5,12 +5,13 @@ import numpy as np import pytest from spatialdata import read_zarr, save_transformations -from spatialdata._io._utils import get_backing_files +from spatialdata._io._utils import get_dask_backing_files from spatialdata._utils import multiscale_spatial_image_from_data_tree from spatialdata.transformations import Scale, get_transformation, set_transformation def test_backing_files_points(points): + """Test the ability to identify the backing files of a dask dataframe from examining its computational graph""" with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "points0.zarr") f1 = os.path.join(tmp_dir, "points1.zarr") @@ -21,7 +22,7 @@ def test_backing_files_points(points): p0 = points0.points["points_0"] p1 = points1.points["points_0"] p2 = dd.concat([p0, p1], axis=0) - files = get_backing_files(p2) + files = get_dask_backing_files(p2) expected_zarr_locations = [ os.path.realpath(os.path.join(f, "points/points_0/points.parquet")) for f in [f0, f1] ] @@ -29,6 +30,10 @@ def test_backing_files_points(points): def test_backing_files_images(images): + """ + Test the ability to identify the backing files of single scale and multiscale images from examining their + computational graph + """ with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "images0.zarr") f1 = os.path.join(tmp_dir, "images1.zarr") @@ -41,7 +46,7 @@ def test_backing_files_images(images): im0 = images0.images["image2d"] im1 = images1.images["image2d"] im2 = im0 + im1 - files = get_backing_files(im2) + files = get_dask_backing_files(im2) expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) @@ -49,13 +54,17 @@ def test_backing_files_images(images): im3 = images0.images["image2d_multiscale"] im4 = images1.images["image2d_multiscale"] im5 = multiscale_spatial_image_from_data_tree(im3 + im4) - files = get_backing_files(im5) + files = get_dask_backing_files(im5) expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d_multiscale")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) # TODO: this function here below is very similar to the above, unify the test with the above or delete this todo def test_backing_files_labels(labels): + """ + Test the ability to identify the backing files of single scale and multiscale labels from examining their + computational graph + """ with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "labels0.zarr") f1 = os.path.join(tmp_dir, "labels1.zarr") @@ -68,7 +77,7 @@ def test_backing_files_labels(labels): im0 = labels0.labels["labels2d"] im1 = labels1.labels["labels2d"] im2 = im0 + im1 - files = get_backing_files(im2) + files = get_dask_backing_files(im2) expected_zarr_locations = [os.path.realpath(os.path.join(f, "labels/labels2d")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) @@ -76,11 +85,37 @@ def test_backing_files_labels(labels): im3 = labels0.labels["labels2d_multiscale"] im4 = labels1.labels["labels2d_multiscale"] im5 = multiscale_spatial_image_from_data_tree(im3 + im4) - files = get_backing_files(im5) + files = get_dask_backing_files(im5) expected_zarr_locations = [os.path.realpath(os.path.join(f, "labels/labels2d_multiscale")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) +def test_backing_files_combining_points_and_images(points, images): + """ + Test the ability to identify the backing files of an object that depends both on dask dataframes and dask arrays + from examining its computational graph + """ + with tempfile.TemporaryDirectory() as tmp_dir: + f0 = os.path.join(tmp_dir, "points0.zarr") + f1 = os.path.join(tmp_dir, "images1.zarr") + points.write(f0) + images.write(f1) + points0 = read_zarr(f0) + images1 = read_zarr(f1) + + p0 = points0.points["points_0"] + im1 = images1.images["image2d"] + v = p0["x"].loc[0].values + v.compute_chunk_sizes() + im2 = v + im1 + files = get_dask_backing_files(im2) + expected_zarr_locations = [ + os.path.realpath(os.path.join(f0, "points/points_0/points.parquet")), + os.path.realpath(os.path.join(f1, "images/image2d")), + ] + assert set(files) == set(expected_zarr_locations) + + def test_save_transformations(labels): with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "labels0.zarr") diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 34035c7d..8e2cd333 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import pathlib import tempfile from copy import deepcopy from functools import partial @@ -22,7 +21,7 @@ from pandas.api.types import is_categorical_dtype from shapely.io import to_ragged_array from spatial_image import SpatialImage, to_spatial_image -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.models import ( Image2DModel, @@ -119,7 +118,7 @@ def _parse_transformation_from_multiple_places(self, model: Any, element: Any, * str, np.ndarray, dask.array.core.Array, - pathlib.PosixPath, + Path, pd.DataFrame, ) ): From a5f01b5b0a3e3ca0bedd3303b608efe04288afdf Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 29 Jan 2024 13:49:53 +0100 Subject: [PATCH 03/27] Enforce instance key to be dtype int (#444) * change genes in blobs points * force instance_id of int dtype * Change error * Check unique instance_key values per region --- src/spatialdata/datasets.py | 2 +- src/spatialdata/models/models.py | 18 ++++++++++++ tests/conftest.py | 2 +- .../operations/test_spatialdata_operations.py | 2 +- tests/models/test_models.py | 28 ++++++++++++++++++- 5 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 3b207e7b..58df2420 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -240,7 +240,7 @@ def _points_blobs( arr = rng.integers(padding, length - padding, size=(n_points, 2)).astype(np.int64) # randomly assign some values from v to the points points_assignment0 = rng.integers(0, 10, size=arr.shape[0]).astype(np.int64) - genes = rng.choice(["a", "b"], size=arr.shape[0]) + genes = rng.choice(["gene_a", "gene_b"], size=arr.shape[0]) annotation = pd.DataFrame( { "genes": genes, diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index f36d91e9..c99179a2 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -19,6 +19,7 @@ from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from multiscale_spatial_image.to_multiscale.to_multiscale import Methods from pandas import CategoricalDtype +from pandas.errors import IntCastingNaNError from shapely._geometry import GeometryType from shapely.geometry import MultiPolygon, Point, Polygon from shapely.geometry.collection import GeometryCollection @@ -857,6 +858,23 @@ def parse( adata.obs[region_key] = pd.Categorical(adata.obs[region_key]) if instance_key is None: raise ValueError("`instance_key` must be provided.") + if adata.obs[instance_key].dtype != int: + try: + warnings.warn( + f"Converting `{cls.INSTANCE_KEY}: {instance_key}` to integer dtype.", UserWarning, stacklevel=2 + ) + adata.obs[instance_key] = adata.obs[instance_key].astype(int) + except IntCastingNaNError as exc: + raise ValueError("Values within table.obs[] must be able to be coerced to int dtype.") from exc + + grouped = adata.obs.groupby(region_key) + grouped_size = grouped.size() + grouped_nunique = grouped.nunique() + not_unique = grouped_size[grouped_size != grouped_nunique[instance_key]].index.tolist() + if not_unique: + raise ValueError( + f"Instance key column for region(s) `{', '.join(not_unique)}` does not contain only unique integers" + ) attr = {"region": region, "region_key": region_key, "instance_key": instance_key} adata.uns[cls.ATTRS_KEY] = attr diff --git a/tests/conftest.py b/tests/conftest.py index 3fcfe005..9f4ab6e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -247,7 +247,7 @@ def _get_shapes() -> dict[str, GeoDataFrame]: points["radius"] = rng.normal(size=(len(points), 1)) out["poly"] = ShapesModel.parse(poly) - out["poly"].index = ["a", "b", "c", "d", "e"] + out["poly"].index = [0, 1, 2, 3, 4] out["multipoly"] = ShapesModel.parse(multipoly) out["circles"] = ShapesModel.parse(points) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 3b2b3e6a..a72f470a 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -376,7 +376,7 @@ def test_subset(full_sdata: SpatialData) -> None: adata = AnnData( shape=(10, 0), - obs={"region": ["circles"] * 5 + ["poly"] * 5, "instance_id": [0, 1, 2, 3, 4, "a", "b", "c", "d", "e"]}, + obs={"region": ["circles"] * 5 + ["poly"] * 5, "instance_id": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]}, ) del full_sdata.table sdata_table = TableModel.parse(adata, region=["circles", "poly"], region_key="region", instance_key="instance_id") diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 8e2cd333..e6556946 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import re import tempfile from copy import deepcopy from functools import partial @@ -305,7 +306,7 @@ def test_table_model( region: str | np.ndarray, ) -> None: region_key = "reg" - obs = pd.DataFrame(RNG.integers(0, 100, size=(10, 3)), columns=["A", "B", "C"]) + obs = pd.DataFrame(RNG.choice(np.arange(0, 100), size=(10, 3), replace=False), columns=["A", "B", "C"]) obs[region_key] = region adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) table = model.parse(adata, region=region, region_key=region_key, instance_key="A") @@ -319,6 +320,31 @@ def test_table_model( assert TableModel.REGION_KEY_KEY in table.uns[TableModel.ATTRS_KEY] assert table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == region + obs["A"] = obs["A"].astype(str) + adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) + with pytest.warns(UserWarning, match="Converting"): + model.parse(adata, region=region, region_key=region_key, instance_key="A") + + obs["A"] = pd.Series(len([chr(ord("a") + i) for i in range(10)])) + adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) + with pytest.raises(ValueError, match="Values within"): + model.parse(adata, region=region, region_key=region_key, instance_key="A") + + @pytest.mark.parametrize("model", [TableModel]) + @pytest.mark.parametrize("region", [["sample_1"] * 5 + ["sample_2"] * 5]) + def test_table_instance_key_values_not_unique(self, model: TableModel, region: str | np.ndarray): + region_key = "region" + obs = pd.DataFrame(RNG.integers(0, 100, size=(10, 3)), columns=["A", "B", "C"]) + obs[region_key] = region + obs["A"] = [1] * 5 + list(range(5)) + adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) + with pytest.raises(ValueError, match=re.escape("Instance key column for region(s) `sample_1`")): + model.parse(adata, region=region, region_key=region_key, instance_key="A") + + adata.obs["A"] = [1] * 10 + with pytest.raises(ValueError, match=re.escape("Instance key column for region(s) `sample_1, sample_2`")): + model.parse(adata, region=region, region_key=region_key, instance_key="A") + def test_get_schema(): images = _get_images() From 8813e70670ef60e3fe7dc6e3433a96a2345fd8ae Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 13 Feb 2024 14:25:36 +0100 Subject: [PATCH 04/27] Join elements table (#445) * change genes in blobs points * force instance_id of int dtype * Change error * Check unique instance_key values per region * silence warning * add left inner join * change to left join * add left_exclusive join * add to Enum * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add inner join * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add right join * add ugly right exclusive join * return none instead of empty df * add left join tests * test left_exclusive_join * test inner join * refactor get_table_keys * assert valid element in elements_dict * test warnings * add fail join tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * silence warning * add left inner join * change to left join * add left_exclusive join * add to Enum * add inner join * fix merge conflict * Add right join * add ugly right exclusive join * return none instead of empty df * add left join tests * test left_exclusive_join * test inner join * refactor get_table_keys * assert valid element in elements_dict * test warnings * add fail join tests * fix comments * lowercase + docstring * complete docstring * lowercase enum * explicit instance_id column in test * test instance_id and region column * change error * add get method * add contain * get rid of create_element_dict * add todo * adjust enums * add ability for matching rows * some cleanup * make any element none in right exclusive join * add check of table order conservation * fix tests * add match_element_to_table * change docstring * revert to old match_table_to_element * readd match_element_to_table * added tests for label joins * add points tests * fixed type and docstring * fixed comments and docstrings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change typehint * update changelog with joins * include join functions * docs for joins * added test for right match rows * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add validation of match rows * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix docstring * fix docs * fix typo in changelog * attempt fix docs * fix docs * fix docs --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Marconato Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> --- CHANGELOG.md | 6 + docs/api.md | 2 + docs/design_doc.md | 1 + src/spatialdata/__init__.py | 9 +- .../_core/query/relational_query.py | 442 +++++++++++++++++- src/spatialdata/_core/spatialdata.py | 27 ++ src/spatialdata/models/models.py | 2 +- tests/core/operations/test_aggregations.py | 1 + tests/core/query/test_relational_query.py | 218 ++++++++- 9 files changed, 697 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d913b652..8ab1e6c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,12 @@ and this project adheres to [Semantic Versioning][]. - Implemented support in SpatialData for storing multiple tables. These tables can annotate a SpatialElement but not necessarily so. +- Added SQL like joins that can be executed by calling one public function `join_sdata_spatialelement_table`. The following + joins are supported: `left`, `left_exclusive`, `right`, `right_exclusive` and `inner`. The function has an option to + match rows. For `left` only matching `left` is supported and for `right` join only `right` matching of rows is supported. + Not all joins are supported for `Labels` elements. +- Added function `match_element_to_table` which allows the user to perform a right join of `SpatialElement`(s) with a table + with rows matching the row order in the table. - Increased in-memory vs on-disk control: changes performed in-memory (e.g. adding a new image) are not automatically performed on-disk. #### Minor diff --git a/docs/api.md b/docs/api.md index 93509ffd..d6d8e5df 100644 --- a/docs/api.md +++ b/docs/api.md @@ -27,6 +27,8 @@ Operations on `SpatialData` objects. polygon_query get_values get_extent + join_sdata_spatialelement_table + match_element_to_table match_table_to_element concatenate transform diff --git a/docs/design_doc.md b/docs/design_doc.md index 5001d305..91b10748 100644 --- a/docs/design_doc.md +++ b/docs/design_doc.md @@ -564,6 +564,7 @@ with coordinate systems: with axes: c, y, x with elements: /images/point8, /labels/point8 """ + sdata0 = sdata.query.coordinate_system("point23", filter_rows=False) sdata1 = sdata.query.bounding_box((0, 20, 0, 300)) sdata1 = sdata.query.polygon("/polygons/annotations") diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index e09f42c0..9334dfc6 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -22,6 +22,8 @@ "bounding_box_query", "polygon_query", "get_values", + "join_sdata_spatialelement_table", + "match_element_to_table", "match_table_to_element", "SpatialData", "get_extent", @@ -38,7 +40,12 @@ from spatialdata._core.operations.rasterize import rasterize from spatialdata._core.operations.transform import transform from spatialdata._core.query._utils import circles_to_polygons, get_bounding_box_corners -from spatialdata._core.query.relational_query import get_values, match_table_to_element +from spatialdata._core.query.relational_query import ( + get_values, + join_sdata_spatialelement_table, + match_element_to_table, + match_table_to_element, +) from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData from spatialdata._io._utils import get_dask_backing_files, save_transformations diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 9beb4f16..9e537cc7 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1,7 +1,12 @@ from __future__ import annotations +import math +import warnings +from collections import defaultdict from dataclasses import dataclass -from typing import Any +from enum import Enum +from functools import partial +from typing import Any, Literal import dask.array as da import numpy as np @@ -21,6 +26,7 @@ SpatialElement, TableModel, get_model, + get_table_keys, ) @@ -49,6 +55,21 @@ def _filter_table_by_element_names(table: AnnData | None, element_names: str | l return table +def _get_unique_label_values_as_index(element: SpatialElement) -> pd.Index: + if isinstance(element, SpatialImage): + # get unique labels value (including 0 if present) + instances = da.unique(element.data).compute() + else: + assert isinstance(element, MultiscaleSpatialImage) + v = element["scale0"].values() + assert len(v) == 1 + xdata = next(iter(v)) + # can be slow + instances = da.unique(xdata.data).compute() + return pd.Index(np.sort(instances)) + + +# TODO: replace function use throughout repo by `join_sdata_spatialelement_table` def _filter_table_by_elements( table: AnnData | None, elements_dict: dict[str, dict[str, Any]], match_rows: bool = False ) -> AnnData | None: @@ -129,6 +150,389 @@ def _filter_table_by_elements( return table +def _get_joined_table_indices( + joined_indices: pd.Index | None, + element_indices: pd.RangeIndex, + table_instance_key_column: pd.Series, + match_rows: Literal["left", "no", "right"], +) -> pd.Index: + """ + Get indices of the table that are present in element_indices. + + Parameters + ---------- + joined_indices + Current indices that have been found to match indices of an element + element_indices + Element indices to match against table_instance_key_column. + table_instance_key_column + The column of a table containing the instance ids. + match_rows + Whether to match the indices of the element and table and if so how. If left, element_indices take priority and + if right table instance ids take priority. + + Returns + ------- + The indices that of the table that match the SpatialElement indices. + """ + mask = table_instance_key_column.isin(element_indices) + if joined_indices is None: + if match_rows == "left": + joined_indices = _match_rows(table_instance_key_column, mask, element_indices, match_rows) + else: + joined_indices = table_instance_key_column[mask].index + else: + if match_rows == "left": + add_indices = _match_rows(table_instance_key_column, mask, element_indices, match_rows) + joined_indices = joined_indices.append(add_indices) + # in place append does not work with pd.Index + else: + joined_indices = joined_indices.append(table_instance_key_column[mask].index) + return joined_indices + + +def _get_masked_element( + element_indices: pd.RangeIndex, + element: SpatialElement, + table_instance_key_column: pd.Series, + match_rows: Literal["left", "no", "right"], +) -> SpatialElement: + """ + Get element rows matching the instance ids in the table_instance_key_column. + + Parameters + ---------- + element_indices + The indices of an element. + element + The spatial element to be masked. + table_instance_key_column + The column of a table containing the instance ids + match_rows + Whether to match the indices of the element and table and if so how. If left, element_indices take priority and + if right table instance ids take priority. + + Returns + ------- + The masked spatial element based on the provided indices and match rows. + """ + mask = table_instance_key_column.isin(element_indices) + masked_table_instance_key_column = table_instance_key_column[mask] + mask_values = mask_values if len(mask_values := masked_table_instance_key_column.values) != 0 else None + if match_rows == "right": + mask_values = _match_rows(table_instance_key_column, mask, element_indices, match_rows) + + return element.loc[mask_values, :] + + +def _right_exclusive_join_spatialelement_table( + element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] +) -> tuple[dict[str, Any], AnnData | None]: + regions, region_column_name, instance_key = get_table_keys(table) + groups_df = table.obs.groupby(by=region_column_name) + mask = [] + for element_type, name_element in element_dict.items(): + for name, element in name_element.items(): + if name in regions: + group_df = groups_df.get_group(name) + table_instance_key_column = group_df[instance_key] + if element_type in ["points", "shapes"]: + element_indices = element.index + else: + element_indices = _get_unique_label_values_as_index(element) + + element_dict[element_type][name] = None + submask = ~table_instance_key_column.isin(element_indices) + mask.append(submask) + else: + warnings.warn( + f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2 + ) + element_dict[element_type][name] = None + continue + + if len(mask) != 0: + mask = pd.concat(mask) + exclusive_table = table[mask, :].copy() if mask.sum() != 0 else None # type: ignore[attr-defined] + else: + exclusive_table = None + + return element_dict, exclusive_table + + +def _right_join_spatialelement_table( + element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] +) -> tuple[dict[str, Any], AnnData]: + if match_rows == "left": + warnings.warn("Matching rows ``'left'`` is not supported for ``'right'`` join.", UserWarning, stacklevel=2) + regions, region_column_name, instance_key = get_table_keys(table) + groups_df = table.obs.groupby(by=region_column_name) + for element_type, name_element in element_dict.items(): + for name, element in name_element.items(): + if name in regions: + group_df = groups_df.get_group(name) + table_instance_key_column = group_df[instance_key] + if element_type in ["points", "shapes"]: + element_indices = element.index + else: + warnings.warn( + f"Element type `labels` not supported for left exclusive join. Skipping `{name}`", + UserWarning, + stacklevel=2, + ) + continue + + masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows) + element_dict[element_type][name] = masked_element + else: + warnings.warn( + f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2 + ) + continue + return element_dict, table + + +def _inner_join_spatialelement_table( + element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] +) -> tuple[dict[str, Any], AnnData]: + regions, region_column_name, instance_key = get_table_keys(table) + groups_df = table.obs.groupby(by=region_column_name) + joined_indices = None + for element_type, name_element in element_dict.items(): + for name, element in name_element.items(): + if name in regions: + group_df = groups_df.get_group(name) + table_instance_key_column = group_df[instance_key] # This is always a series + if element_type in ["points", "shapes"]: + element_indices = element.index + else: + warnings.warn( + f"Element type `labels` not supported for left exclusive join. Skipping `{name}`", + UserWarning, + stacklevel=2, + ) + continue + + masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows) + element_dict[element_type][name] = masked_element + + joined_indices = _get_joined_table_indices( + joined_indices, element_indices, table_instance_key_column, match_rows + ) + else: + warnings.warn( + f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2 + ) + element_dict[element_type][name] = None + continue + + joined_table = table[joined_indices, :].copy() if joined_indices is not None else None + return element_dict, joined_table + + +def _left_exclusive_join_spatialelement_table( + element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] +) -> tuple[dict[str, Any], AnnData | None]: + regions, region_column_name, instance_key = get_table_keys(table) + groups_df = table.obs.groupby(by=region_column_name) + for element_type, name_element in element_dict.items(): + for name, element in name_element.items(): + if name in regions: + group_df = groups_df.get_group(name) + table_instance_key_column = group_df[instance_key] + if element_type in ["points", "shapes"]: + mask = np.full(len(element), True, dtype=bool) + mask[table_instance_key_column.values] = False + masked_element = element.loc[mask, :] if mask.sum() != 0 else None + element_dict[element_type][name] = masked_element + else: + warnings.warn( + f"Element type `labels` not supported for left exclusive join. Skipping `{name}`", + UserWarning, + stacklevel=2, + ) + continue + else: + warnings.warn( + f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2 + ) + continue + + return element_dict, None + + +def _left_join_spatialelement_table( + element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] +) -> tuple[dict[str, Any], AnnData]: + if match_rows == "right": + warnings.warn("Matching rows ``'right'`` is not supported for ``'left'`` join.", UserWarning, stacklevel=2) + regions, region_column_name, instance_key = get_table_keys(table) + groups_df = table.obs.groupby(by=region_column_name) + joined_indices = None + for element_type, name_element in element_dict.items(): + for name, element in name_element.items(): + if name in regions: + group_df = groups_df.get_group(name) + table_instance_key_column = group_df[instance_key] # This is always a series + if element_type in ["points", "shapes"]: + element_indices = element.index + else: + element_indices = _get_unique_label_values_as_index(element) + + joined_indices = _get_joined_table_indices( + joined_indices, element_indices, table_instance_key_column, match_rows + ) + else: + warnings.warn( + f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2 + ) + continue + + joined_indices = joined_indices.dropna() if joined_indices is not None else None + joined_table = table[joined_indices, :].copy() if joined_indices is not None else None + + return element_dict, joined_table + + +def _match_rows( + table_instance_key_column: pd.Series, + mask: pd.Series, + element_indices: pd.RangeIndex, + match_rows: str, +) -> pd.Index: + instance_id_df = pd.DataFrame( + {"instance_id": table_instance_key_column[mask].values, "index_right": table_instance_key_column[mask].index} + ) + element_index_df = pd.DataFrame({"index_left": element_indices}) + index_col = "index_left" if match_rows == "right" else "index_right" + + merged_df = pd.merge( + element_index_df, instance_id_df, left_on="index_left", right_on="instance_id", how=match_rows + )[index_col] + + # With labels it can be that index 0 is NaN + if isinstance(merged_df.iloc[0], float) and math.isnan(merged_df.iloc[0]): + merged_df = merged_df.iloc[1:] + + return pd.Index(merged_df) + + +class JoinTypes(Enum): + """Available join types for matching elements to tables and vice versa.""" + + left = partial(_left_join_spatialelement_table) + left_exclusive = partial(_left_exclusive_join_spatialelement_table) + inner = partial(_inner_join_spatialelement_table) + right = partial(_right_join_spatialelement_table) + right_exclusive = partial(_right_exclusive_join_spatialelement_table) + + def __call__(self, *args: Any) -> tuple[dict[str, Any], AnnData]: + return self.value(*args) + + +class MatchTypes(Enum): + """Available match types for matching rows of elements and tables.""" + + left = "left" + right = "right" + no = "no" + + +def join_sdata_spatialelement_table( + sdata: SpatialData, + spatial_element_name: str | list[str], + table_name: str, + how: str = "left", + match_rows: Literal["no", "left", "right"] = "no", +) -> tuple[dict[str, Any], AnnData]: + """Join SpatialElement(s) and table together in SQL like manner. + + The function allows the user to perform SQL like joins of SpatialElements and a table. The elements are not + returned together in one dataframe like structure, but instead filtered elements are returned. To determine matches, + for the SpatialElement the index is used and for the table the region key column and instance key column. The + elements are not overwritten in the `SpatialData` object. + + The following joins are supported: ``'left'``, ``'left_exclusive'``, ``'inner'``, ``'right'`` and + ``'right_exclusive'``. In case of a ``'left'`` join the SpatialElements are returned in a dictionary as is + while the table is filtered to only include matching rows. In case of ``'left_exclusive'`` join None is returned + for table while the SpatialElements returned are filtered to only include indices not present in the table. The + cases for ``'right'`` joins are symmetric to the ``'left'`` joins. In case of an ``'inner'`` join of + SpatialElement(s) and a table, for each an element is returned only containing the rows that are present in + both the SpatialElement and table. + + For Points and Shapes elements every valid join for argument how is supported. For Labels elements only + the ``'left'`` and ``'right_exclusive'`` joins are supported. + + Parameters + ---------- + sdata + The SpatialData object containing the tables and spatial elements. + spatial_element_name + The name(s) of the spatial elements to be joined with the table. + table_name + The name of the table to join with the spatial elements. + how + The type of SQL like join to perform, default is ``'left'``. Options are ``'left'``, ``'left_exclusive'``, + ``'inner'``, ``'right'`` and ``'right_exclusive'``. + match_rows + Whether to match the indices of the element and table and if so how. If ``'left'``, element_indices take + priority and if ``'right'`` table instance ids take priority. + + Returns + ------- + A tuple containing the joined elements as a dictionary and the joined table as an AnnData object. + + Raises + ------ + AssertionError + If no table with the given table_name exists in the SpatialData object. + ValueError + If the provided join type is not supported. + """ + assert sdata.tables.get(table_name), f"No table with `{table_name}` exists in the SpatialData object." + table = sdata.tables[table_name] + if isinstance(spatial_element_name, str): + spatial_element_name = [spatial_element_name] + + elements_dict: dict[str, dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) + for name in spatial_element_name: + if name in sdata.tables: + warnings.warn( + f"Tables: `{', '.join(elements_dict['tables'].keys())}` given in spatial_element_names cannot be " + f"joined with a table using this function.", + UserWarning, + stacklevel=2, + ) + elif name in sdata.images: + warnings.warn( + f"Images: `{', '.join(elements_dict['images'].keys())}` cannot be joined with a table", + UserWarning, + stacklevel=2, + ) + else: + element_type, _, element = sdata._find_element(name) + elements_dict[element_type][name] = element + + assert any(key in elements_dict for key in ["labels", "shapes", "points"]), ( + "No valid element to join in spatial_element_name. Must provide at least one of either `labels`, `points` or " + "`shapes`." + ) + + if match_rows not in MatchTypes.__dict__["_member_names_"]: + raise TypeError( + f"`{match_rows}` is an invalid argument for `match_rows`. Can be either `no`, ``'left'`` or ``'right'``" + ) + if how in JoinTypes.__dict__["_member_names_"]: + elements_dict, table = JoinTypes[how](elements_dict, table, match_rows) + else: + raise TypeError(f"`{how}` is not a valid type of join.") + + elements_dict = { + name: element for outer_key, dict_val in elements_dict.items() for name, element in dict_val.items() + } + return elements_dict, table + + def match_table_to_element(sdata: SpatialData, element_name: str) -> AnnData: """ Filter the table and reorders the rows to match the instances (rows/labels) of the specified SpatialElement. @@ -138,12 +542,23 @@ def match_table_to_element(sdata: SpatialData, element_name: str) -> AnnData: sdata SpatialData object element_name - Name of the element to match the table to + The name of the spatial elements to be joined with the table. Returns ------- Table with the rows matching the instances of the element """ + # TODO: refactor this to make use of the new join_sdata_spatialelement_table function. + # if table_name is None: + # warnings.warn( + # "Assumption of table with name `table` being present is being deprecated in SpatialData v0.1. " + # "Please provide the name of the table as argument to table_name.", + # DeprecationWarning, + # stacklevel=2, + # ) + # table_name = "table" + # _, table = join_sdata_spatialelement_table(sdata, element_name, table_name, "left", match_rows="left") + # return table assert sdata.table is not None, "No table found in the SpatialData" element_type, _, element = sdata._find_element(element_name) assert element_type in ["labels", "shapes"], f"Element {element_name} ({element_type}) is not supported" @@ -151,6 +566,29 @@ def match_table_to_element(sdata: SpatialData, element_name: str) -> AnnData: return _filter_table_by_elements(sdata.table, elements_dict, match_rows=True) +def match_element_to_table( + sdata: SpatialData, element_name: str | list[str], table_name: str +) -> tuple[dict[str, Any], AnnData]: + """ + Filter the elements and make the indices match those in the table. + + Parameters + ---------- + sdata + SpatialData object + element_name + The name(s) of the spatial elements to be joined with the table. Not supported for Label elements. + table_name + The name of the table to join with the spatial elements. + + Returns + ------- + A tuple containing the joined elements as a dictionary and the joined table as an AnnData object. + """ + element_dict, table = join_sdata_spatialelement_table(sdata, element_name, table_name, "right", match_rows="right") + return element_dict, table + + @dataclass class _ValueOrigin: origin: str diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 3cdf91d2..80fee105 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1538,6 +1538,33 @@ def __getitem__(self, item: str) -> SpatialElement: _, _, element = self._find_element(item) return element + def __contains__(self, key: str) -> bool: + element_dict = { + element_name: element_value for _, element_name, element_value in self._gen_elements(include_table=True) + } + return key in element_dict + + def get(self, key: str, default_value: SpatialElement | AnnData | None = None) -> SpatialElement | AnnData: + """ + Get element from SpatialData object based on corresponding name. + + Parameters + ---------- + key + The key to lookup in the spatial elements. + default_value + The default value (a SpatialElement or a table) to return if the key is not found. Default is None. + + Returns + ------- + The SpatialData element associated with the given key, if found. Otherwise, the default value is returned. + """ + for _, element_name_, element in self.gen_elements(): + if element_name_ == key: + return element + else: + return default_value + def __setitem__(self, key: str, value: SpatialElement | AnnData) -> None: """ Add the element to the SpatialData object. diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index c99179a2..4d588e07 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -867,7 +867,7 @@ def parse( except IntCastingNaNError as exc: raise ValueError("Values within table.obs[] must be able to be coerced to int dtype.") from exc - grouped = adata.obs.groupby(region_key) + grouped = adata.obs.groupby(region_key, observed=True) grouped_size = grouped.size() grouped_nunique = grouped.nunique() not_unique = grouped_size[grouped_size != grouped_nunique[instance_key]].index.tolist() diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 36609464..01aaa2ae 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -124,6 +124,7 @@ def test_aggregate_points_by_shapes(sdata_query_aggregation, by_shapes: str, val ) +# TODO: refactor in smaller functions for easier understanding @pytest.mark.parametrize("by_shapes", ["by_circles", "by_polygons"]) @pytest.mark.parametrize("values_shapes", ["values_circles", "values_polygons"]) @pytest.mark.parametrize( diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index 6cb4daec..c39d0e11 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -3,17 +3,11 @@ import pytest from anndata import AnnData from spatialdata import get_values, match_table_to_element -from spatialdata._core.query.relational_query import _locate_value, _ValueOrigin +from spatialdata._core.query.relational_query import _locate_value, _ValueOrigin, join_sdata_spatialelement_table from spatialdata.models.models import TableModel def test_match_table_to_element(sdata_query_aggregation): - # table can't annotate points - with pytest.raises(AssertionError): - match_table_to_element(sdata=sdata_query_aggregation, element_name="points") - # table is not annotating "by_circles" - with pytest.raises(AssertionError, match="No row matches in the table annotates the element"): - match_table_to_element(sdata=sdata_query_aggregation, element_name="by_circles") matched_table = match_table_to_element(sdata=sdata_query_aggregation, element_name="values_circles") arr = np.array(list(reversed(sdata_query_aggregation["values_circles"].index))) sdata_query_aggregation["values_circles"].index = arr @@ -23,6 +17,149 @@ def test_match_table_to_element(sdata_query_aggregation): # TODO: add tests for labels +def test_left_inner_right_exclusive_join(sdata_query_aggregation): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, "values_polygons", "table", "right_exclusive" + ) + assert table is None + assert all(element_dict[key] is None for key in element_dict) + + sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"].drop([10, 11]) + with pytest.raises(AssertionError, match="No table with"): + join_sdata_spatialelement_table(sdata_query_aggregation, "values_polygons", "not_existing_table", "left") + + # Should we reindex before returning the table? + element_dict, table = join_sdata_spatialelement_table(sdata_query_aggregation, "values_polygons", "table", "left") + assert all(element_dict["values_polygons"].index == table.obs["instance_id"].values) + + # Check no matches in table for element not annotated by table + element_dict, table = join_sdata_spatialelement_table(sdata_query_aggregation, "by_polygons", "table", "left") + assert table is None + assert element_dict["by_polygons"] is sdata_query_aggregation["by_polygons"] + + # Check multiple elements, one of which not annotated by table + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["by_polygons", "values_polygons"], "table", "left" + ) + assert "by_polygons" in element_dict + + # check multiple elements joined to table. + sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"].drop([7, 8]) + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left" + ) + indices = pd.concat( + [element_dict["values_circles"].index.to_series(), element_dict["values_polygons"].index.to_series()] + ) + assert all(table.obs["instance_id"] == indices.values) + + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons", "by_polygons"], "table", "right_exclusive" + ) + assert all(element_dict[key] is None for key in element_dict) + assert all(table.obs.index == ["7", "8", "19", "20"]) + assert all(table.obs["instance_id"].values == [7, 8, 10, 11]) + assert all(table.obs["region"].values == ["values_circles", "values_circles", "values_polygons", "values_polygons"]) + + # the triggered warning is: UserWarning: The element `{name}` is not annotated by the table. Skipping + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons", "by_polygons"], "table", "inner" + ) + indices = pd.concat( + [element_dict["values_circles"].index.to_series(), element_dict["values_polygons"].index.to_series()] + ) + assert all(table.obs["instance_id"] == indices.values) + assert element_dict["by_polygons"] is None + + +def test_join_spatialelement_table_fail(full_sdata): + with pytest.warns(UserWarning, match="Images:"): + join_sdata_spatialelement_table(full_sdata, ["image2d", "labels2d"], "table", "left_exclusive") + with pytest.warns(UserWarning, match="Tables:"): + join_sdata_spatialelement_table(full_sdata, ["labels2d", "table"], "table", "left_exclusive") + with pytest.raises(TypeError, match="`not_join` is not a"): + join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "not_join") + + +def test_left_exclusive_and_right_join(sdata_query_aggregation): + # Test case in which all table rows match rows in elements + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left_exclusive" + ) + assert all(element_dict[key] is None for key in element_dict) + assert table is None + + # Dropped indices correspond to instance ids 7, 8 for 'values_circles' and 10, 11 for 'values_polygons' + sdata_query_aggregation["table"] = sdata_query_aggregation["table"][ + sdata_query_aggregation["table"].obs.index.drop(["7", "8", "19", "20"]) + ] + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_polygons", "by_polygons"], "table", "left_exclusive" + ) + assert table is None + assert not set(element_dict["values_polygons"].index).issubset(sdata_query_aggregation["table"].obs["instance_id"]) + + # test right join + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons", "by_polygons"], "table", "right" + ) + assert table is sdata_query_aggregation["table"] + assert not {7, 8}.issubset(element_dict["values_circles"].index) + assert not {10, 11}.issubset(element_dict["values_polygons"].index) + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left_exclusive" + ) + assert table is None + assert not np.array_equal( + sdata_query_aggregation["table"].obs.iloc[7:9]["instance_id"].values, + element_dict["values_circles"].index.values, + ) + assert not np.array_equal( + sdata_query_aggregation["table"].obs.iloc[19:21]["instance_id"].values, + element_dict["values_polygons"].index.values, + ) + + +def test_match_rows_join(sdata_query_aggregation): + reversed_instance_id = [3, 4, 5, 6, 7, 8, 1, 2, 0] + list(reversed(range(12))) + original_instance_id = sdata_query_aggregation.table.obs["instance_id"] + sdata_query_aggregation.table.obs["instance_id"] = reversed_instance_id + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left", match_rows="left" + ) + assert all(table.obs["instance_id"].values == original_instance_id.values) + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right", match_rows="right" + ) + indices = [*element_dict["values_circles"].index, *element_dict[("values_polygons")].index] + assert all(indices == table.obs["instance_id"]) + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "inner", match_rows="left" + ) + assert all(table.obs["instance_id"].values == original_instance_id.values) + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "inner", match_rows="right" + ) + indices = [*element_dict["values_circles"].index, *element_dict[("values_polygons")].index] + assert all(indices == table.obs["instance_id"]) + + # check whether table ordering is preserved if not matching + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left" + ) + assert all(table.obs["instance_id"] == reversed_instance_id) + + def test_locate_value(sdata_query_aggregation): def _check_location(locations: list[_ValueOrigin], origin: str, is_categorical: bool): assert len(locations) == 1 @@ -200,3 +337,70 @@ def test_filter_table_categorical_bug(shapes): adata_subset = adata[adata.obs["categorical"] == "a"].copy() shapes.table = adata_subset shapes.filter_by_coordinate_system("global") + + +def test_labels_table_joins(full_sdata): + element_dict, table = join_sdata_spatialelement_table( + full_sdata, + "labels2d", + "table", + "left", + ) + assert all(table.obs["instance_id"] == range(100)) + + full_sdata["table"].obs["instance_id"] = list(reversed(range(100))) + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "left", match_rows="left") + assert all(table.obs["instance_id"] == range(100)) + + with pytest.warns(UserWarning, match="Element type"): + join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "left_exclusive") + + with pytest.warns(UserWarning, match="Element type"): + join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "inner") + + with pytest.warns(UserWarning, match="Element type"): + join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "right") + + # all labels are present in table so should return None + element_dict, table = join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "right_exclusive") + assert element_dict["labels2d"] is None + assert table is None + + +def test_points_table_joins(full_sdata): + full_sdata["table"].uns["spatialdata_attrs"]["region"] = "points_0" + full_sdata["table"].obs["region"] = ["points_0"] * 100 + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "left") + + # points should have the same number of rows as before and table as well + assert len(element_dict["points_0"]) == 300 + assert all(table.obs["instance_id"] == range(100)) + + full_sdata["table"].obs["instance_id"] = list(reversed(range(100))) + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "left", match_rows="left") + assert len(element_dict["points_0"]) == 300 + assert all(table.obs["instance_id"] == range(100)) + + # We have 100 table instances so resulting length of points should be 200 as we started with 300 + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "left_exclusive") + assert len(element_dict["points_0"]) == 200 + assert table is None + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "inner") + assert len(element_dict["points_0"]) == 100 + assert all(table.obs["instance_id"] == list(reversed(range(100)))) + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "right") + assert len(element_dict["points_0"]) == 100 + assert all(table.obs["instance_id"] == list(reversed(range(100)))) + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "right", match_rows="right") + assert all(element_dict["points_0"].index.values.compute() == list(reversed(range(100)))) + assert all(table.obs["instance_id"] == list(reversed(range(100)))) + + element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "right_exclusive") + assert element_dict["points_0"] is None + assert table is None From 055e549a577dc6dfd3fc533a5cc79ec8e7071f4b Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 19 Feb 2024 19:31:33 +0100 Subject: [PATCH 05/27] fix tests --- pyproject.toml | 2 +- src/spatialdata/_core/operations/_utils.py | 6 +++--- tests/core/operations/test_spatialdata_operations.py | 7 ++++--- tests/core/operations/test_transform.py | 10 +++++----- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 416a315a..bab753b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,7 +182,7 @@ unfixable = ["B", "C4", "UP", "BLE", "T20", "RET"] [tool.ruff.lint.pydocstyle] convention = "numpy" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "tests/*" = ["D", "PT", "B024"] "*/__init__.py" = ["F401", "D104", "D107", "E402"] "docs/*" = ["D","B","E","A"] diff --git a/src/spatialdata/_core/operations/_utils.py b/src/spatialdata/_core/operations/_utils.py index 36903298..a5fbd7be 100644 --- a/src/spatialdata/_core/operations/_utils.py +++ b/src/spatialdata/_core/operations/_utils.py @@ -112,7 +112,7 @@ def transform_to_data_extent( **sdata_vector_transformed_inplace.points, } - for _, element_name, element in sdata_raster._gen_elements(): + for _, element_name, element in sdata_raster.gen_spatial_elements(): if isinstance(element, (MultiscaleSpatialImage, SpatialImage)): rasterized = rasterize( element, @@ -128,9 +128,9 @@ def transform_to_data_extent( sdata_to_return_elements[element_name] = rasterized else: sdata_to_return_elements[element_name] = element - if sdata.table is not None: - sdata_to_return_elements["table"] = sdata.table if not maintain_positioning: for el in sdata_to_return_elements.values(): set_transformation(el, transformation={coordinate_system: Identity()}, set_all=True) + for k, v in sdata.tables.items(): + sdata_to_return_elements[k] = v.copy() return SpatialData.from_elements_dict(sdata_to_return_elements) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 8fc4d760..c457f089 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -376,13 +376,14 @@ def test_init_from_elements(full_sdata: SpatialData) -> None: def test_subset(full_sdata: SpatialData) -> None: - element_names = ["image2d", "labels2d", "points_0", "circles", "poly"] + element_names = ["image2d", "points_0", "circles", "poly"] subset0 = full_sdata.subset(element_names) unique_names = set() for _, k, _ in subset0.gen_spatial_elements(): unique_names.add(k) assert "image3d_xarray" in full_sdata.images assert unique_names == set(element_names) + # no table since the labels are not present in the subset assert subset0.table is None adata = AnnData( @@ -415,7 +416,7 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning: scale = Scale([2.0], axes=("x",)) translation = Translation([-100.0, 200.0], axes=("x", "y")) sequence = Sequence([rotation, scale, translation]) - for el in full_sdata._gen_elements_values(): + for el in full_sdata._gen_spatial_element_values(): set_transformation(el, sequence, "global") elements = [ "image2d", @@ -431,7 +432,7 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning: sdata = transform_to_data_extent(full_sdata, "global", target_width=1000, maintain_positioning=maintain_positioning) matrices = [] - for el in sdata._gen_elements_values(): + for el in sdata._gen_spatial_element_values(): t = get_transformation(el, to_coordinate_system="global") assert isinstance(t, BaseTransformation) a = t.to_affine_matrix(input_axes=("x", "y", "z"), output_axes=("x", "y", "z")) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 953de15c..e53515fd 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -123,7 +123,7 @@ def _unpad_rasters(sdata: SpatialData) -> SpatialData: def _postpone_transformation( sdata: SpatialData, from_coordinate_system: str, to_coordinate_system: str, transformation: BaseTransformation ): - for element in sdata._gen_elements_values(): + for element in sdata._gen_spatial_element_values(): d = get_transformation(element, get_all=True) assert isinstance(d, dict) assert len(d) == 1 @@ -490,7 +490,7 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa scale = Scale([k], axes=("x",)) translation = Translation([k], axes=("x",)) sequence = Sequence([scale, translation]) - for element in full_sdata._gen_elements_values(): + for element in full_sdata._gen_spatial_element_values(): set_transformation(element, sequence, "my_space") transformed_element = full_sdata.transform_element_to_coordinate_system( element, "my_space", maintain_positioning=maintain_positioning @@ -524,7 +524,7 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( ): k = 10.0 scale = Scale([k], axes=("x",)) - for element in full_sdata._gen_elements_values(): + for element in full_sdata._gen_spatial_element_values(): set_transformation(element, scale, "my_space") # testing the scenario "element1 -> cs1 <- element2 -> cs2" and transforming element1 to cs2 @@ -535,13 +535,13 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( ) # otherwise we have multiple paths to go from my_space to multi_hop_space - for element in full_sdata._gen_elements_values(): + for element in full_sdata._gen_spatial_element_values(): d = get_transformation(element, get_all=True) assert isinstance(d, dict) if "global" in d: remove_transformation(element, "global") - for element in full_sdata._gen_elements_values(): + for element in full_sdata._gen_spatial_element_values(): transformed_element = full_sdata.transform_element_to_coordinate_system( element, "multi_hop_space", maintain_positioning=maintain_positioning ) From c6ae76aebe21a3f326712d57ba661c83b98c7b65 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 19 Feb 2024 19:32:41 +0100 Subject: [PATCH 06/27] fix docs --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62a93ccf..ae0717d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning][]. ## [0.0.x] - tbd +### Added + #### Major - Implemented support in SpatialData for storing multiple tables. These tables can annotate a SpatialElement but not From b8ea945998fc40fc72e4a8863df2deac5fd342e7 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 19 Feb 2024 19:52:57 +0100 Subject: [PATCH 07/27] add possibility for custom table name (#459) * add possibility for custom table name * change docstring * updated changelog * added table_name to SpatialData.aggregate() --------- Co-authored-by: Luca Marconato --- CHANGELOG.md | 2 + src/spatialdata/_core/operations/aggregate.py | 7 +- src/spatialdata/_core/spatialdata.py | 2 + tests/core/operations/test_aggregations.py | 111 ++++++++++-------- 4 files changed, 69 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae0717d1..3ccf1b99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,8 @@ and this project adheres to [Semantic Versioning][]. - Added function get_instance_key_column in SpatialData to get the instance_key column in table.obs. - Added function set_table_annotates_spatialelement in SpatialData to either set or change the annotation metadata of a table in a given SpatialData object. +- Added table_name parameter to the aggegate function to allow users to give a custom table name to table resulting + from aggregation. - Added tables property in SpatialData. - Added tables setter in SpatialData. - Added gen_spatial_elements generator in SpatialData to generate the SpatialElements in a given SpatialData object. diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index ce1b7ed7..8bdab94f 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -64,6 +64,7 @@ def aggregate( region_key: str = "region", instance_key: str = "instance_id", deepcopy: bool = True, + table_name: str = "table", **kwargs: Any, ) -> SpatialData: """ @@ -125,6 +126,8 @@ def aggregate( deepcopy Whether to deepcopy the shapes in the returned `SpatialData` object. If the shapes are large (e.g. large multiscale labels), you may consider disabling the deepcopy to use a lazy Dask representation. + table_name + The name of the table resulting from the aggregation. kwargs Additional keyword arguments to pass to :func:`xrspatial.zonal_stats`. @@ -218,6 +221,7 @@ def aggregate( shapes_name = by if isinstance(by, str) else "by" return _create_sdata_from_table_and_shapes( table=adata, + table_name=table_name, shapes_name=shapes_name, shapes=by_, region_key=region_key, @@ -228,6 +232,7 @@ def aggregate( def _create_sdata_from_table_and_shapes( table: ad.AnnData, + table_name: str, shapes: GeoDataFrame | SpatialImage | MultiscaleSpatialImage, shapes_name: str, region_key: str, @@ -247,7 +252,7 @@ def _create_sdata_from_table_and_shapes( if deepcopy: shapes = _deepcopy_geodataframe(shapes) - return SpatialData.from_elements_dict({shapes_name: shapes, "table": table}) + return SpatialData.from_elements_dict({shapes_name: shapes, table_name: table}) def _aggregate_image_by_labels( diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index cee7c2c5..b44c6c55 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -443,6 +443,7 @@ def aggregate( region_key: str = "region", instance_key: str = "instance_id", deepcopy: bool = True, + table_name: str = "table", **kwargs: Any, ) -> SpatialData: """ @@ -475,6 +476,7 @@ def aggregate( region_key=region_key, instance_key=instance_key, deepcopy=deepcopy, + table_name=table_name, **kwargs, ) diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 01aaa2ae..8d1a734a 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -44,10 +44,10 @@ def test_aggregate_points_by_shapes(sdata_query_aggregation, by_shapes: str, val shapes = sdata[by_shapes] # testing that we can call aggregate with the two equivalent syntaxes for the values argument - result_adata = aggregate(values=points, by=shapes, value_key=value_key, agg_func="sum").table + result_adata = aggregate(values=points, by=shapes, value_key=value_key, agg_func="sum").tables["table"] result_adata_bis = aggregate( values_sdata=sdata, values="points", by=shapes, value_key=value_key, agg_func="sum" - ).table + ).tables["table"] np.testing.assert_equal(result_adata.X.A, result_adata_bis.X.A) # check that the obs of aggregated values are correct @@ -75,12 +75,12 @@ def test_aggregate_points_by_shapes(sdata_query_aggregation, by_shapes: str, val # id_key can be implicit for points points.attrs[PointsModel.ATTRS_KEY][PointsModel.FEATURE_KEY] = value_key - result_adata_implicit = aggregate(values=points, by=shapes, agg_func="sum").table + result_adata_implicit = aggregate(values=points, by=shapes, agg_func="sum").tables["table"] assert_equal(result_adata, result_adata_implicit) # in the categorical case, check that sum and count behave the same if value_key == "categorical_in_ddf": - result_adata_count = aggregate(values=points, by=shapes, value_key=value_key, agg_func="count").table + result_adata_count = aggregate(values=points, by=shapes, value_key=value_key, agg_func="count").tables["table"] assert_equal(result_adata, result_adata_count) # querying multiple values at the same time @@ -91,7 +91,9 @@ def test_aggregate_points_by_shapes(sdata_query_aggregation, by_shapes: str, val aggregate(values=points, by=shapes, value_key=new_value_key, agg_func="sum") else: points["another_" + value_key] = points[value_key] + 10 - result_adata_multiple = aggregate(values=points, by=shapes, value_key=new_value_key, agg_func="sum").table + result_adata_multiple = aggregate(values=points, by=shapes, value_key=new_value_key, agg_func="sum").tables[ + "table" + ] assert result_adata_multiple.var_names.to_list() == new_value_key if by_shapes == "by_circles": row = ( @@ -144,12 +146,14 @@ def test_aggregate_shapes_by_shapes( by = _parse_shapes(sdata, by_shapes=by_shapes) values = _parse_shapes(sdata, values_shapes=values_shapes) - result_adata = aggregate(values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="sum").table + result_adata = aggregate( + values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="sum" + ).tables["table"] # testing that we can call aggregate with the two equivalent syntaxes for the values argument (only relevant when # the values to aggregate are not in the table, for which only one of the two syntaxes is possible) if value_key.endswith("_in_gdf"): - result_adata_bis = aggregate(values=values, by=by, value_key=value_key, agg_func="sum").table + result_adata_bis = aggregate(values=values, by=by, value_key=value_key, agg_func="sum").tables["table"] np.testing.assert_equal(result_adata.X.A, result_adata_bis.X.A) # check that the obs of the aggregated values are correct @@ -162,41 +166,41 @@ def test_aggregate_shapes_by_shapes( if value_key == "numerical_in_var": if values_shapes == "values_circles": if by_shapes == "by_circles": - s = sdata.table[np.array([0, 1, 2, 3]), "numerical_in_var"].X.sum() + s = sdata.tables["table"][np.array([0, 1, 2, 3]), "numerical_in_var"].X.sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s], [0]]))) else: - s0 = sdata.table[np.array([5, 6, 7, 8]), "numerical_in_var"].X.sum() + s0 = sdata.tables["table"][np.array([5, 6, 7, 8]), "numerical_in_var"].X.sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s0], [0], [0], [0], [0]]))) else: if by_shapes == "by_circles": - s = sdata.table[np.array([9, 10, 11, 12]), "numerical_in_var"].X.sum() + s = sdata.tables["table"][np.array([9, 10, 11, 12]), "numerical_in_var"].X.sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s], [0]]))) else: - s0 = sdata.table[np.array([14, 15, 16, 17]), "numerical_in_var"].X.sum() - s1 = sdata.table[np.array([20]), "numerical_in_var"].X.sum() - s2 = sdata.table[np.array([20]), "numerical_in_var"].X.sum() + s0 = sdata.tables["table"][np.array([14, 15, 16, 17]), "numerical_in_var"].X.sum() + s1 = sdata.tables["table"][np.array([20]), "numerical_in_var"].X.sum() + s2 = sdata.tables["table"][np.array([20]), "numerical_in_var"].X.sum() s3 = 0 - s4 = sdata.table[np.array([18, 19]), "numerical_in_var"].X.sum() + s4 = sdata.tables["table"][np.array([18, 19]), "numerical_in_var"].X.sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s0], [s1], [s2], [s3], [s4]]))) elif value_key == "numerical_in_obs": # these cases are basically identically to the one above if values_shapes == "values_circles": if by_shapes == "by_circles": - s = sdata.table[np.array([0, 1, 2, 3]), :].obs["numerical_in_obs"].sum() + s = sdata.tables["table"][np.array([0, 1, 2, 3]), :].obs["numerical_in_obs"].sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s], [0]]))) else: - s0 = sdata.table[np.array([5, 6, 7, 8]), :].obs["numerical_in_obs"].sum() + s0 = sdata.tables["table"][np.array([5, 6, 7, 8]), :].obs["numerical_in_obs"].sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s0], [0], [0], [0], [0]]))) else: if by_shapes == "by_circles": - s = sdata.table[np.array([9, 10, 11, 12]), :].obs["numerical_in_obs"].sum() + s = sdata.tables["table"][np.array([9, 10, 11, 12]), :].obs["numerical_in_obs"].sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s], [0]]))) else: - s0 = sdata.table[np.array([14, 15, 16, 17]), :].obs["numerical_in_obs"].sum() - s1 = sdata.table[np.array([20]), :].obs["numerical_in_obs"].sum() - s2 = sdata.table[np.array([20]), :].obs["numerical_in_obs"].sum() + s0 = sdata.tables["table"][np.array([14, 15, 16, 17]), :].obs["numerical_in_obs"].sum() + s1 = sdata.tables["table"][np.array([20]), :].obs["numerical_in_obs"].sum() + s2 = sdata.tables["table"][np.array([20]), :].obs["numerical_in_obs"].sum() s3 = 0 - s4 = sdata.table[np.array([18, 19]), :].obs["numerical_in_obs"].sum() + s4 = sdata.tables["table"][np.array([18, 19]), :].obs["numerical_in_obs"].sum() assert np.all(np.isclose(result_adata.X.A, np.array([[s0], [s1], [s2], [s3], [s4]]))) elif value_key == "numerical_in_gdf": if values_shapes == "values_circles": @@ -252,7 +256,7 @@ def test_aggregate_shapes_by_shapes( if value_key in ["categorical_in_obs", "categorical_in_gdf"]: result_adata_count = aggregate( values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="count" - ).table + ).tables["table"] assert_equal(result_adata, result_adata_count) # querying multiple values at the same time @@ -263,20 +267,20 @@ def test_aggregate_shapes_by_shapes( aggregate(values_sdata=sdata, values=values_shapes, by=by, value_key=new_value_key, agg_func="sum") else: if value_key == "numerical_in_obs": - sdata.table.obs["another_numerical_in_obs"] = 1.0 + sdata.tables["table"].obs["another_numerical_in_obs"] = 1.0 elif value_key == "numerical_in_gdf": values["another_numerical_in_gdf"] = 1.0 else: assert value_key == "numerical_in_var" - new_var = pd.concat((sdata.table.var, pd.DataFrame(index=["another_numerical_in_var"]))) - new_x = np.concatenate((sdata.table.X, np.ones_like(sdata.table.X[:, :1])), axis=1) - new_table = AnnData(X=new_x, obs=sdata.table.obs, var=new_var, uns=sdata.table.uns) - del sdata.table - sdata.table = new_table + new_var = pd.concat((sdata.tables["table"].var, pd.DataFrame(index=["another_numerical_in_var"]))) + new_x = np.concatenate((sdata.tables["table"].X, np.ones_like(sdata.tables["table"].X[:, :1])), axis=1) + new_table = AnnData(X=new_x, obs=sdata.tables["table"].obs, var=new_var, uns=sdata.tables["table"].uns) + del sdata.tables["table"] + sdata.tables["table"] = new_table result_adata = aggregate( values_sdata=sdata, values=values_shapes, by=by, value_key=new_value_key, agg_func="sum" - ).table + ).tables["table"] assert result_adata.var_names.to_list() == new_value_key # since we added only columns of 1., we just have 4 cases to check all the aggregations, and not 12 like before @@ -327,15 +331,16 @@ def test_aggregate_image_by_labels(labels_blobs, image_schema, labels_schema) -> image = image_schema.parse(image) labels = labels_schema.parse(labels_blobs) - out = aggregate(values=image, by=labels, agg_func="mean").table + out_sdata = aggregate(values=image, by=labels, agg_func="mean", table_name="aggregation") + out = out_sdata.tables["aggregation"] assert len(out) + 1 == len(np.unique(labels_blobs)) assert isinstance(out, AnnData) np.testing.assert_array_equal(out.var_names, [f"channel_{i}_mean" for i in image.coords["c"].values]) - out = aggregate(values=image, by=labels, agg_func=["mean", "sum", "count"]).table + out = aggregate(values=image, by=labels, agg_func=["mean", "sum", "count"]).tables["table"] assert len(out) + 1 == len(np.unique(labels_blobs)) - out = aggregate(values=image, by=labels, zone_ids=[1, 2, 3]).table + out = aggregate(values=image, by=labels, zone_ids=[1, 2, 3]).tables["table"] assert len(out) == 3 @@ -353,7 +358,7 @@ def test_aggregate_requiring_alignment(sdata_blobs: SpatialData, values, by) -> assert by.attrs["transform"] is not values.attrs["transform"] sdata = SpatialData.init_from_elements({"values": values, "by": by}) - out0 = aggregate(values=values, by=by, agg_func="sum").table + out0 = aggregate(values=values, by=by, agg_func="sum").tables["table"] theta = np.pi / 7 affine = Affine( @@ -375,12 +380,12 @@ def test_aggregate_requiring_alignment(sdata_blobs: SpatialData, values, by) -> # both values and by map to the "other" coordinate system, but they are not aligned set_transformation(by, Identity(), "other") - out1 = aggregate(values=values, by=by, target_coordinate_system="other", agg_func="sum").table + out1 = aggregate(values=values, by=by, target_coordinate_system="other", agg_func="sum").tables["table"] assert not np.allclose(out0.X.A, out1.X.A) # both values and by map to the "other" coordinate system, and they are aligned set_transformation(by, affine, "other") - out2 = aggregate(values=values, by=by, target_coordinate_system="other", agg_func="sum").table + out2 = aggregate(values=values, by=by, target_coordinate_system="other", agg_func="sum").tables["table"] assert np.allclose(out0.X.A, out2.X.A) # actually transforming the data still lead to a correct the result @@ -388,7 +393,9 @@ def test_aggregate_requiring_alignment(sdata_blobs: SpatialData, values, by) -> sdata2 = SpatialData.init_from_elements({"values": sdata["values"], "by": transformed_sdata["by"]}) # let's take values from the original sdata (non-transformed but aligned to 'other'); let's take by from the # transformed sdata - out3 = aggregate(values=sdata["values"], by=sdata2["by"], target_coordinate_system="other", agg_func="sum").table + out3 = aggregate(values=sdata["values"], by=sdata2["by"], target_coordinate_system="other", agg_func="sum").tables[ + "table" + ] assert np.allclose(out0.X.A, out3.X.A) @@ -407,7 +414,7 @@ def test_aggregate_considering_fractions_single_values( sdata = sdata_query_aggregation values = sdata[values_name] by = sdata[by_name] - result_adata = aggregate(values=values, by=by, value_key=value_key, agg_func="sum", fractions=True).table + result_adata = aggregate(values=values, by=by, value_key=value_key, agg_func="sum", fractions=True).tables["table"] # to manually compute the fractions of overlap that we use to test that aggregate() works values = circles_to_polygons(values) values["__index"] = values.index @@ -475,11 +482,11 @@ def test_aggregate_considering_fractions_multiple_values( sdata_query_aggregation: SpatialData, by_name, values_name, value_key ) -> None: sdata = sdata_query_aggregation - new_var = pd.concat((sdata.table.var, pd.DataFrame(index=["another_numerical_in_var"]))) - new_x = np.concatenate((sdata.table.X, np.ones_like(sdata.table.X[:, :1])), axis=1) - new_table = AnnData(X=new_x, obs=sdata.table.obs, var=new_var, uns=sdata.table.uns) - del sdata.table - sdata.table = new_table + new_var = pd.concat((sdata.tables["table"].var, pd.DataFrame(index=["another_numerical_in_var"]))) + new_x = np.concatenate((sdata.tables["table"].X, np.ones_like(sdata.tables["table"].X[:, :1])), axis=1) + new_table = AnnData(X=new_x, obs=sdata.tables["table"].obs, var=new_var, uns=sdata.tables["table"].uns) + del sdata.tables["table"] + sdata.tables["table"] = new_table out = aggregate( values_sdata=sdata, values="values_circles", @@ -487,9 +494,9 @@ def test_aggregate_considering_fractions_multiple_values( value_key=["numerical_in_var", "another_numerical_in_var"], agg_func="sum", fractions=True, - ).table + ).tables["table"] overlaps = np.array([0.655781239649211, 1.0000000000000002, 1.0000000000000004, 0.1349639285777728]) - row0 = np.sum(sdata.table.X[[0, 1, 2, 3], :] * overlaps.reshape(-1, 1), axis=0) + row0 = np.sum(sdata.tables["table"].X[[0, 1, 2, 3], :] * overlaps.reshape(-1, 1), axis=0) assert np.all(np.isclose(out.X.A, np.array([row0, [0, 0]]))) @@ -530,19 +537,19 @@ def test_aggregate_spatialdata(sdata_blobs: SpatialData) -> None: sdata2 = sdata_blobs.aggregate(values="blobs_points", by=sdata_blobs["blobs_polygons"], agg_func="sum") sdata3 = sdata_blobs.aggregate(values=sdata_blobs["blobs_points"], by=sdata_blobs["blobs_polygons"], agg_func="sum") - assert_equal(sdata0.table, sdata1.table) - assert_equal(sdata2.table, sdata3.table) + assert_equal(sdata0.tables["table"], sdata1.tables["table"]) + assert_equal(sdata2.tables["table"], sdata3.tables["table"]) # in sdata2 the name of the "by" region was not passed, so a default one is used - assert sdata2.table.obs["region"].value_counts()["by"] == 3 + assert sdata2.tables["table"].obs["region"].value_counts()["by"] == 3 # let's change it so we can make the objects comparable - sdata2.table.obs["region"] = "blobs_polygons" - sdata2.table.obs["region"] = sdata2.table.obs["region"].astype("category") - sdata2.table.uns[TableModel.ATTRS_KEY]["region"] = "blobs_polygons" - assert_equal(sdata0.table, sdata2.table) + sdata2.tables["table"].obs["region"] = "blobs_polygons" + sdata2.tables["table"].obs["region"] = sdata2.tables["table"].obs["region"].astype("category") + sdata2.tables["table"].uns[TableModel.ATTRS_KEY]["region"] = "blobs_polygons" + assert_equal(sdata0.tables["table"], sdata2.tables["table"]) assert len(sdata0.shapes["blobs_polygons"]) == 3 - assert sdata0.table.shape == (3, 2) + assert sdata0.tables["table"].shape == (3, 2) def test_aggregate_deepcopy(sdata_blobs: SpatialData) -> None: From e05ba1a7bdc249044872ca009baa79e44408063f Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 19 Feb 2024 20:01:45 +0100 Subject: [PATCH 08/27] Update locate values (#460) * add possibility for custom table name * change docstring * updated changelog * added table_name parameter * update changelog --------- Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> --- CHANGELOG.md | 3 ++- .../_core/query/relational_query.py | 20 +++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ccf1b99..83719405 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,8 +38,9 @@ and this project adheres to [Semantic Versioning][]. - Added function get_instance_key_column in SpatialData to get the instance_key column in table.obs. - Added function set_table_annotates_spatialelement in SpatialData to either set or change the annotation metadata of a table in a given SpatialData object. -- Added table_name parameter to the aggegate function to allow users to give a custom table name to table resulting +- Added table_name parameter to the aggregate function to allow users to give a custom table name to table resulting from aggregation. +- Added table_name parameter to the get_values function. - Added tables property in SpatialData. - Added tables setter in SpatialData. - Added gen_spatial_elements generator in SpatialData to generate the SpatialElements in a given SpatialData object. diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 9e537cc7..247bcfa3 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -533,7 +533,7 @@ def join_sdata_spatialelement_table( return elements_dict, table -def match_table_to_element(sdata: SpatialData, element_name: str) -> AnnData: +def match_table_to_element(sdata: SpatialData, element_name: str, table_name: str = "table") -> AnnData: """ Filter the table and reorders the rows to match the instances (rows/labels) of the specified SpatialElement. @@ -543,6 +543,8 @@ def match_table_to_element(sdata: SpatialData, element_name: str) -> AnnData: SpatialData object element_name The name of the spatial elements to be joined with the table. + table_name + The name of the table to match to the element. Returns ------- @@ -559,11 +561,11 @@ def match_table_to_element(sdata: SpatialData, element_name: str) -> AnnData: # table_name = "table" # _, table = join_sdata_spatialelement_table(sdata, element_name, table_name, "left", match_rows="left") # return table - assert sdata.table is not None, "No table found in the SpatialData" + assert sdata[table_name] is not None, "No table found in the SpatialData" element_type, _, element = sdata._find_element(element_name) assert element_type in ["labels", "shapes"], f"Element {element_name} ({element_type}) is not supported" elements_dict = {element_type: {element_name: element}} - return _filter_table_by_elements(sdata.table, elements_dict, match_rows=True) + return _filter_table_by_elements(sdata[table_name], elements_dict, match_rows=True) def match_element_to_table( @@ -611,6 +613,7 @@ def _locate_value( element: SpatialElement | None = None, sdata: SpatialData | None = None, element_name: str | None = None, + table_name: str = "table", ) -> list[_ValueOrigin]: el = _get_element(element=element, sdata=sdata, element_name=element_name) origins = [] @@ -625,7 +628,7 @@ def _locate_value( # adding from the obs columns or var if model in [ShapesModel, Labels2DModel, Labels3DModel] and sdata is not None: - table = sdata.table + table = sdata[table_name] if table is not None: # check if the table is annotating the element region = table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] @@ -646,6 +649,7 @@ def get_values( element: SpatialElement | None = None, sdata: SpatialData | None = None, element_name: str | None = None, + table_name: str = "table", ) -> pd.DataFrame: """ Get the values from the element, from any location: df columns, obs or var columns (table). @@ -660,6 +664,8 @@ def get_values( SpatialData object; either element or (sdata, element_name) must be provided element_name Name of the element; either element or (sdata, element_name) must be provided + table_name + Name of the table to get the values from. Returns ------- @@ -674,7 +680,9 @@ def get_values( value_keys = [value_key] if isinstance(value_key, str) else value_key locations = [] for vk in value_keys: - origins = _locate_value(value_key=vk, element=element, sdata=sdata, element_name=element_name) + origins = _locate_value( + value_key=vk, element=element, sdata=sdata, element_name=element_name, table_name=table_name + ) if len(origins) > 1: raise ValueError( f"{vk} has been found in multiple locations of (element, sdata, element_name) = " @@ -706,7 +714,7 @@ def get_values( return df if sdata is not None: assert element_name is not None - matched_table = match_table_to_element(sdata=sdata, element_name=element_name) + matched_table = match_table_to_element(sdata=sdata, element_name=element_name, table_name=table_name) region_key = matched_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] instance_key = matched_table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] obs = matched_table.obs From ea0989d7c17bc8bbfd1f44f999790a36d0a2c8ea Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 20 Feb 2024 15:43:45 +0100 Subject: [PATCH 09/27] Filter table annotate (#462) * add get_element_annotator * add docstring * add test --- .../_core/query/relational_query.py | 24 +++++++++++++++++++ tests/core/query/test_relational_query.py | 20 +++++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 247bcfa3..05858d42 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -30,6 +30,30 @@ ) +def _get_element_annotators(sdata: SpatialData, element_name: str) -> set[str]: + """ + Retrieve names of tables that annotate a SpatialElement in a SpatialData object. + + Parameters + ---------- + sdata + SpatialData object. + element_name + The name of the SpatialElement. + + Returns + ------- + The names of the tables annotating the SpatialElement. + """ + table_names = set() + for name, table in sdata.tables.items(): + if table.uns.get(TableModel.ATTRS_KEY): + regions, _, _ = get_table_keys(table) + if element_name in regions: + table_names.add(name) + return table_names + + def _filter_table_by_element_names(table: AnnData | None, element_names: str | list[str]) -> AnnData | None: """ Filter an AnnData table to keep only the rows that are in the coordinate system. diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index c39d0e11..d4af878a 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -3,7 +3,12 @@ import pytest from anndata import AnnData from spatialdata import get_values, match_table_to_element -from spatialdata._core.query.relational_query import _locate_value, _ValueOrigin, join_sdata_spatialelement_table +from spatialdata._core.query.relational_query import ( + _get_element_annotators, + _locate_value, + _ValueOrigin, + join_sdata_spatialelement_table, +) from spatialdata.models.models import TableModel @@ -404,3 +409,16 @@ def test_points_table_joins(full_sdata): element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "right_exclusive") assert element_dict["points_0"] is None assert table is None + + +def test_get_element_annotators(full_sdata): + names = _get_element_annotators(full_sdata, "points_0") + assert len(names) == 0 + + names = _get_element_annotators(full_sdata, "labels2d") + assert names == {"table"} + + another_table = full_sdata.tables["table"].copy() + full_sdata.tables["another_table"] = another_table + names = _get_element_annotators(full_sdata, "labels2d") + assert names == {"another_table", "table"} From f03ba37206c239e232ddfd32496add918d12c56e Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 22 Feb 2024 18:23:16 +0100 Subject: [PATCH 10/27] wip get_centroids --- docs/api.md | 1 + pyproject.toml | 2 +- src/spatialdata/__init__.py | 2 + src/spatialdata/_core/centroids.py | 89 ++++++++++++++++++ src/spatialdata/_core/operations/transform.py | 4 + tests/core/test_centroids.py | 90 +++++++++++++++++++ tests/core/test_data_extent.py | 2 +- 7 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 src/spatialdata/_core/centroids.py create mode 100644 tests/core/test_centroids.py diff --git a/docs/api.md b/docs/api.md index 6ca058a2..2388c9f6 100644 --- a/docs/api.md +++ b/docs/api.md @@ -27,6 +27,7 @@ Operations on `SpatialData` objects. polygon_query get_values get_extent + get_centroids match_table_to_element concatenate transform diff --git a/pyproject.toml b/pyproject.toml index 416a315a..bab753b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,7 +182,7 @@ unfixable = ["B", "C4", "UP", "BLE", "T20", "RET"] [tool.ruff.lint.pydocstyle] convention = "numpy" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "tests/*" = ["D", "PT", "B024"] "*/__init__.py" = ["F401", "D104", "D107", "E402"] "docs/*" = ["D","B","E","A"] diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 1781b3b0..ced697ca 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -25,6 +25,7 @@ "match_table_to_element", "SpatialData", "get_extent", + "get_centroids", "read_zarr", "unpad_raster", "save_transformations", @@ -33,6 +34,7 @@ ] from spatialdata import dataloader, models, transformations +from spatialdata._core.centroids import get_centroids from spatialdata._core.concatenate import concatenate from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.operations.aggregate import aggregate diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py new file mode 100644 index 00000000..eced3fcf --- /dev/null +++ b/src/spatialdata/_core/centroids.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +# Functions to compute the bounding box describing the extent of a SpatialElement or SpatialData object +from functools import singledispatch + +from dask.dataframe.core import DataFrame as DaskDataFrame +from geopandas import GeoDataFrame +from multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage + +from spatialdata._core.operations.transform import transform +from spatialdata.models import get_axes_names +from spatialdata.models._utils import SpatialElement +from spatialdata.models.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, get_model +from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations.transformations import BaseTransformation + +BoundingBoxDescription = dict[str, tuple[float, float]] + + +def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> None: + d = get_transformation(e, get_all=True) + assert isinstance(d, dict) + assert coordinate_system in d, ( + f"No transformation to coordinate system {coordinate_system} is available for the given element.\n" + f"Available coordinate systems: {list(d.keys())}" + ) + + +@singledispatch +def get_centroids( + e: SpatialElement, + coordinate_system: str = "global", +) -> DaskDataFrame: + """ + Get the centroids of the geometries contained in a SpatialElement, as a new Points element. + + Parameters + ---------- + e + The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported. + coordinate_system + The coordinate system in which the centroids are computed. + """ + raise ValueError(f"The object type {type(e)} is not supported.") + + +@get_centroids.register(SpatialImage) +@get_centroids.register(MultiscaleSpatialImage) +def _( + e: SpatialImage | MultiscaleSpatialImage, + coordinate_system: str = "global", +) -> DaskDataFrame: + model = get_model(e) + if model in [Image2DModel, Image3DModel]: + raise ValueError("Cannot compute centroids for images.") + assert model in [Labels2DModel, Labels3DModel] + + _validate_coordinate_system(e, coordinate_system) + + +# def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription: +# # remove potentially empty geometries +# e_temp = e[e["geometry"].apply(lambda geom: not geom.is_empty)] +# assert len(e_temp) > 0, "Cannot compute extent of an empty collection of geometries." +# +# # separate points from (multi-)polygons +# first_geometry = e_temp["geometry"].iloc[0] +# if isinstance(first_geometry, Point): +# return _get_extent_of_circles(e) +# assert isinstance(first_geometry, (Polygon, MultiPolygon)) +# return _get_extent_of_polygons_multipolygons(e) + + +@get_centroids.register(GeoDataFrame) +def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame: + _validate_coordinate_system(e, coordinate_system) + + +@get_centroids.register(DaskDataFrame) +def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame: + _validate_coordinate_system(e, coordinate_system) + axes = get_axes_names(e) + assert axes in [("x", "y"), ("x", "y", "z")] + coords = e[list(axes)].compute().values + t = get_transformation(e, coordinate_system) + assert isinstance(t, BaseTransformation) + centroids = PointsModel.parse(coords, transformations={coordinate_system: t}) + return transform(centroids, to_coordinate_system=coordinate_system) diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index 86704629..b0e356b8 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -331,6 +331,7 @@ def _( transformation, raster_translation=raster_translation, maintain_positioning=maintain_positioning, + to_coordinate_system=to_coordinate_system, ) transformed_data = compute_coordinates(transformed_data) schema().validate(transformed_data) @@ -404,6 +405,7 @@ def _( transformation, raster_translation=raster_translation, maintain_positioning=maintain_positioning, + to_coordinate_system=to_coordinate_system, ) transformed_data = compute_coordinates(transformed_data) schema().validate(transformed_data) @@ -447,6 +449,7 @@ def _( transformation, raster_translation=None, maintain_positioning=maintain_positioning, + to_coordinate_system=to_coordinate_system, ) PointsModel.validate(transformed) return transformed @@ -490,6 +493,7 @@ def _( transformation, raster_translation=None, maintain_positioning=maintain_positioning, + to_coordinate_system=to_coordinate_system, ) ShapesModel.validate(transformed_data) return transformed_data diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py new file mode 100644 index 00000000..e2784239 --- /dev/null +++ b/tests/core/test_centroids.py @@ -0,0 +1,90 @@ +import numpy as np +import pytest +from anndata import AnnData +from numpy.random import default_rng +from spatialdata._core.centroids import get_centroids +from spatialdata.models import TableModel, get_axes_names +from spatialdata.transformations import Identity, get_transformation, set_transformation + +from tests.core.operations.test_transform import _get_affine + +RNG = default_rng(42) + + +@pytest.mark.parametrize("coordinate_system", ["global", "aligned"]) +@pytest.mark.parametrize("is_3d", [False, True]) +def test_get_centroids_points(points, coordinate_system: str, is_3d: bool): + element = points["points_0"] + + affine = _get_affine() + # by default, the coordinate system is global and the points are 2D; let's modify the points as instructed by the + # test arguments + if coordinate_system == "aligned": + set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system) + if is_3d: + element["z"] = element["x"] + + axes = get_axes_names(element) + centroids = get_centroids(element, coordinate_system=coordinate_system) + + # the axes of the centroids should be the same as the axes of the element + assert centroids.columns.tolist() == list(axes) + + # the centroids should not contain extra columns + assert "genes" in element.columns and "genes" not in centroids.columns + + # the centroids transformation to the target coordinate system should be an Identity because the transformation has + # already been applied + assert get_transformation(centroids, to_coordinate_system=coordinate_system) == Identity() + + # let's check the values + if coordinate_system == "global": + assert np.array_equal(centroids.compute().values, element[list(axes)].compute().values) + else: + matrix = affine.to_affine_matrix(input_axes=axes, output_axes=axes) + centroids_untransformed = element[list(axes)].compute().values + n = len(axes) + centroids_transformed = np.dot(centroids_untransformed, matrix[:n, :n].T) + matrix[:n, n] + assert np.allclose(centroids.compute().values, centroids_transformed) + + +def test_get_centroids_circles(): + pass + + +def test_get_centroids_polygons(): + pass + + +def test_get_centroids_multipolygons(): + pass + + +def test_get_centroids_single_scale_labels(): + pass + + +def test_get_centroids_multiscale_labels(): + pass + + +def test_get_centroids_invalid_element(images): + # cannot compute centroids for images + with pytest.raises(ValueError, match="Cannot compute centroids for images."): + get_centroids(images["image2d"]) + + # cannot compute centroids for tables + N = 10 + adata = TableModel.parse( + AnnData(X=RNG.random((N, N)), obs={"region": ["dummy" for _ in range(N)], "instance_id": np.arange(N)}), + region="dummy", + region_key="region", + instance_key="instance_id", + ) + with pytest.raises(ValueError, match="The object type is not supported."): + get_centroids(adata) + + +def test_get_centroids_invalid_coordinate_system(points): + with pytest.raises(AssertionError, match="No transformation to coordinate system"): + get_centroids(points["points_0"], coordinate_system="invalid") diff --git a/tests/core/test_data_extent.py b/tests/core/test_data_extent.py index bf03a15a..d7304ddf 100644 --- a/tests/core/test_data_extent.py +++ b/tests/core/test_data_extent.py @@ -265,7 +265,7 @@ def test_get_extent_affine_circles(): gdf = ShapesModel.parse(gdf, transformations={"transformed": affine}) transformed_bounding_box = transform(gdf, to_coordinate_system="transformed") - transformed_bounding_box_extent = get_extent(transformed_bounding_box) + transformed_bounding_box_extent = get_extent(transformed_bounding_box, coordinate_system="transformed") assert transformed_axes == list(transformed_bounding_box_extent.keys()) for ax in transformed_axes: From 0c7293bcfefccfbae71773fbb43c5255b1f27ac0 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 22 Feb 2024 20:35:12 +0100 Subject: [PATCH 11/27] implemented get_centroids() --- CHANGELOG.md | 1 + src/spatialdata/_core/centroids.py | 83 +++++++++++++++++++++----- tests/core/test_centroids.py | 93 +++++++++++++++++++++++++----- 3 files changed, 150 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d4c8f314..0d63d4d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning][]. - added utils function: are_extents_equal() - added utils function: postpone_transformation() - added utils function: remove_transformations_to_coordinate_system() +- added utils function: get_centroids() ### Minor diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index eced3fcf..4bafd458 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -1,11 +1,15 @@ from __future__ import annotations -# Functions to compute the bounding box describing the extent of a SpatialElement or SpatialData object +from collections import defaultdict from functools import singledispatch +import dask.array as da +import pandas as pd +import xarray as xr from dask.dataframe.core import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage +from shapely import MultiPolygon, Point, Polygon from spatial_image import SpatialImage from spatialdata._core.operations.transform import transform @@ -41,10 +45,51 @@ def get_centroids( The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported. coordinate_system The coordinate system in which the centroids are computed. + + Notes + ----- + For :class:`~shapely.Multipolygon`s, the centroids are the average of the centroids of the polygons that constitute + each :class:`~shapely.Multipolygon`. """ raise ValueError(f"The object type {type(e)} is not supported.") +def _get_centroids_for_axis(xdata: xr.DataArray, axis: str) -> pd.DataFrame: + """ + Compute the component "axis" of the centroid of each label as a weighted average of the xarray coordinates. + + Parameters + ---------- + xdata + The xarray DataArray containing the labels. + axis + The axis for which the centroids are computed. + + Returns + ------- + pd.DataFrame + A DataFrame containing one column, named after "axis", with the centroids of the labels along that axis. + The index of the DataFrame is the collection of label values, sorted ascendingly. + """ + centroids: dict[int, float] = defaultdict(float) + for i in xdata[axis]: + portion = xdata.sel(**{axis: i}).data + u = da.unique(portion, return_counts=True) + labels_values = u[0].compute() + counts = u[1].compute() + for j in range(len(labels_values)): + label_value = labels_values[j] + count = counts[j] + centroids[label_value] += count * i.values.item() + + all_labels_values, all_labels_counts = da.unique(xdata.data, return_counts=True) + all_labels = dict(zip(all_labels_values.compute(), all_labels_counts.compute())) + for label_value in centroids: + centroids[label_value] /= all_labels[label_value] + centroids = dict(sorted(centroids.items(), key=lambda x: x[0])) + return pd.DataFrame({axis: centroids.values()}, index=list(centroids.keys())) + + @get_centroids.register(SpatialImage) @get_centroids.register(MultiscaleSpatialImage) def _( @@ -55,26 +100,35 @@ def _( if model in [Image2DModel, Image3DModel]: raise ValueError("Cannot compute centroids for images.") assert model in [Labels2DModel, Labels3DModel] - _validate_coordinate_system(e, coordinate_system) + if isinstance(e, MultiscaleSpatialImage): + assert len(e["scale0"]) == 1 + e = SpatialImage(next(iter(e["scale0"].values()))) -# def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription: -# # remove potentially empty geometries -# e_temp = e[e["geometry"].apply(lambda geom: not geom.is_empty)] -# assert len(e_temp) > 0, "Cannot compute extent of an empty collection of geometries." -# -# # separate points from (multi-)polygons -# first_geometry = e_temp["geometry"].iloc[0] -# if isinstance(first_geometry, Point): -# return _get_extent_of_circles(e) -# assert isinstance(first_geometry, (Polygon, MultiPolygon)) -# return _get_extent_of_polygons_multipolygons(e) + dfs = [] + for axis in get_axes_names(e): + dfs.append(_get_centroids_for_axis(e, axis)) + df = pd.concat(dfs, axis=1) + t = get_transformation(e, coordinate_system) + centroids = PointsModel.parse(df, transformations={coordinate_system: t}) + return transform(centroids, to_coordinate_system=coordinate_system) @get_centroids.register(GeoDataFrame) def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame: _validate_coordinate_system(e, coordinate_system) + t = get_transformation(e, coordinate_system) + assert isinstance(t, BaseTransformation) + # separate points from (multi-)polygons + first_geometry = e["geometry"].iloc[0] + if isinstance(first_geometry, Point): + xy = e.geometry.get_coordinates().values + else: + assert isinstance(first_geometry, (Polygon, MultiPolygon)) + xy = e.centroid.get_coordinates().values + points = PointsModel.parse(xy, transformations={coordinate_system: t}) + return transform(points, to_coordinate_system=coordinate_system) @get_centroids.register(DaskDataFrame) @@ -87,3 +141,6 @@ def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame: assert isinstance(t, BaseTransformation) centroids = PointsModel.parse(coords, transformations={coordinate_system: t}) return transform(centroids, to_coordinate_system=coordinate_system) + + +## diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py index e2784239..ede95b51 100644 --- a/tests/core/test_centroids.py +++ b/tests/core/test_centroids.py @@ -1,22 +1,24 @@ import numpy as np +import pandas as pd import pytest from anndata import AnnData from numpy.random import default_rng from spatialdata._core.centroids import get_centroids -from spatialdata.models import TableModel, get_axes_names +from spatialdata.models import Labels2DModel, Labels3DModel, TableModel, get_axes_names from spatialdata.transformations import Identity, get_transformation, set_transformation from tests.core.operations.test_transform import _get_affine RNG = default_rng(42) +affine = _get_affine() + @pytest.mark.parametrize("coordinate_system", ["global", "aligned"]) @pytest.mark.parametrize("is_3d", [False, True]) def test_get_centroids_points(points, coordinate_system: str, is_3d: bool): element = points["points_0"] - affine = _get_affine() # by default, the coordinate system is global and the points are 2D; let's modify the points as instructed by the # test arguments if coordinate_system == "aligned": @@ -48,24 +50,87 @@ def test_get_centroids_points(points, coordinate_system: str, is_3d: bool): assert np.allclose(centroids.compute().values, centroids_transformed) -def test_get_centroids_circles(): - pass - - -def test_get_centroids_polygons(): - pass +@pytest.mark.parametrize("coordinate_system", ["global", "aligned"]) +@pytest.mark.parametrize("shapes_name", ["circles", "poly", "multipoly"]) +def test_get_centroids_shapes(shapes, coordinate_system: str, shapes_name: str): + element = shapes[shapes_name] + if coordinate_system == "aligned": + set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system) + centroids = get_centroids(element, coordinate_system=coordinate_system) + if shapes_name == "circles": + xy = element.geometry.get_coordinates().values + else: + assert shapes_name in ["poly", "multipoly"] + xy = element.geometry.centroid.get_coordinates().values -def test_get_centroids_multipolygons(): - pass + if coordinate_system == "global": + assert np.array_equal(centroids.compute().values, xy) + else: + matrix = affine.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) + centroids_transformed = np.dot(xy, matrix[:2, :2].T) + matrix[:2, 2] + assert np.allclose(centroids.compute().values, centroids_transformed) -def test_get_centroids_single_scale_labels(): - pass +@pytest.mark.parametrize("coordinate_system", ["global", "aligned"]) +@pytest.mark.parametrize("is_multiscale", [False, True]) +@pytest.mark.parametrize("is_3d", [False, True]) +def test_get_centroids_labels(labels, coordinate_system: str, is_multiscale: bool, is_3d: bool): + scale_factors = [2] if is_multiscale else None + if is_3d: + model = Labels3DModel + array = np.array( + [ + [ + [0, 0, 1, 1], + [0, 0, 1, 1], + ], + [ + [2, 2, 1, 1], + [2, 2, 1, 1], + ], + ] + ) + expected_centroids = pd.DataFrame( + { + "x": [1, 3, 1], + "y": [1, 1.0, 1], + "z": [0.5, 1, 1.5], + }, + index=[0, 1, 2], + ) + else: + array = np.array( + [ + [1, 1, 1, 1], + [2, 2, 2, 2], + [2, 2, 2, 2], + [2, 2, 2, 2], + ] + ) + model = Labels2DModel + expected_centroids = pd.DataFrame( + { + "x": [2, 2], + "y": [0.5, 2.5], + }, + index=[1, 2], + ) + element = model.parse(array, scale_factors=scale_factors) + if coordinate_system == "aligned": + set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system) + centroids = get_centroids(element, coordinate_system=coordinate_system) -def test_get_centroids_multiscale_labels(): - pass + if coordinate_system == "global": + assert np.array_equal(centroids.compute().values, expected_centroids.values) + else: + axes = get_axes_names(element) + n = len(axes) + # the axes from the labels have 'x' last, but we want it first to manually transform the points, so we sort + matrix = affine.to_affine_matrix(input_axes=sorted(axes), output_axes=sorted(axes)) + centroids_transformed = np.dot(expected_centroids.values, matrix[:n, :n].T) + matrix[:n, n] + assert np.allclose(centroids.compute().values, centroids_transformed) def test_get_centroids_invalid_element(images): From b0f17104c2e36221fb3d8cbb3af5d434ef2034f0 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 22 Feb 2024 23:30:41 +0100 Subject: [PATCH 12/27] made _assert_spatialdata_objects_seem_identical() into a util --- src/spatialdata/_utils.py | 47 ++++++++++++++++++- .../operations/test_spatialdata_operations.py | 41 +--------------- tests/core/query/test_spatial_query.py | 4 +- 3 files changed, 49 insertions(+), 43 deletions(-) diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 205308e8..8900e4bf 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -5,12 +5,15 @@ import warnings from collections.abc import Generator from copy import deepcopy -from typing import Any, Callable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union import numpy as np import pandas as pd from anndata import AnnData +from anndata.tests.helpers import assert_equal from dask import array as da +from dask.dataframe import DataFrame as DaskDataFrame +from dask.delayed import Delayed from datatree import DataTree from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage @@ -25,6 +28,9 @@ set_transformation, ) +if TYPE_CHECKING: + from spatialdata import SpatialData + # I was using "from numbers import Number" but this led to mypy errors, so I switched to the following: Number = Union[int, float] RT = TypeVar("RT") @@ -311,3 +317,42 @@ def _error_message_add_element() -> None: "write_labels(), write_points(), write_shapes() and write_table(). We are going to make these calls more " "ergonomic in a follow up PR." ) + + +def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: + for element_type, element_name, element in sdata0._gen_elements(): + elements = sdata1.__getattribute__(element_type) + assert element_name in elements + element1 = elements[element_name] + if isinstance(element, (AnnData, SpatialImage, GeoDataFrame)): + assert element.shape == element1.shape + elif isinstance(element, DaskDataFrame): + for s0, s1 in zip(element.shape, element1.shape): + if isinstance(s0, Delayed): + s0 = s0.compute() + if isinstance(s1, Delayed): + s1 = s1.compute() + assert s0 == s1 + elif isinstance(element, MultiscaleSpatialImage): + assert len(element) == len(element1) + else: + raise TypeError(f"Unsupported type {type(element)}") + + +def _assert_tables_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: + tables0 = sdata0.tables + tables1 = sdata1.tables + assert set(tables0.keys()) == set(tables1.keys()) + for k in tables0: + t0 = tables0[k] + t1 = tables1[k] + assert_equal(t0, t1) + + +def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: + # this is not a full comparison, but it's fine anyway + assert len(list(sdata0._gen_elements())) == len(list(sdata1._gen_elements())) + assert set(sdata0.coordinate_systems) == set(sdata1.coordinate_systems) + _assert_elements_left_to_right_seem_identical(sdata0, sdata1) + _assert_elements_left_to_right_seem_identical(sdata1, sdata0) + _assert_tables_seem_identical(sdata0, sdata1) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index c457f089..22723f51 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -5,15 +5,11 @@ import numpy as np import pytest from anndata import AnnData -from dask.dataframe.core import DataFrame as DaskDataFrame -from dask.delayed import Delayed -from geopandas import GeoDataFrame -from multiscale_spatial_image import MultiscaleSpatialImage -from spatial_image import SpatialImage from spatialdata._core.concatenate import _concatenate_tables, concatenate from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData +from spatialdata._utils import _assert_spatialdata_objects_seem_identical, _assert_tables_seem_identical from spatialdata.datasets import blobs from spatialdata.models import ( Image2DModel, @@ -116,39 +112,6 @@ def test_element_names_unique() -> None: assert "shapes" not in sdata._shared_keys -def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: - for element_type, element_name, element in sdata0._gen_elements(): - elements = sdata1.__getattribute__(element_type) - assert element_name in elements - element1 = elements[element_name] - if isinstance(element, (AnnData, SpatialImage, GeoDataFrame)): - assert element.shape == element1.shape - elif isinstance(element, DaskDataFrame): - for s0, s1 in zip(element.shape, element1.shape): - if isinstance(s0, Delayed): - s0 = s0.compute() - if isinstance(s1, Delayed): - s1 = s1.compute() - assert s0 == s1 - elif isinstance(element, MultiscaleSpatialImage): - assert len(element) == len(element1) - else: - raise TypeError(f"Unsupported type {type(element)}") - - -def _assert_tables_seem_identical(table0: AnnData | None, table1: AnnData | None) -> None: - assert table0 is None and table1 is None or table0.shape == table1.shape - - -def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: - # this is not a full comparison, but it's fine anyway - assert len(list(sdata0._gen_elements())) == len(list(sdata1._gen_elements())) - assert set(sdata0.coordinate_systems) == set(sdata1.coordinate_systems) - _assert_elements_left_to_right_seem_identical(sdata0, sdata1) - _assert_elements_left_to_right_seem_identical(sdata1, sdata0) - _assert_tables_seem_identical(sdata0.table, sdata1.table) - - def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: sdata = full_sdata.filter_by_coordinate_system(coordinate_system="global", filter_table=False) _assert_spatialdata_objects_seem_identical(sdata, full_sdata) @@ -160,7 +123,7 @@ def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: sdata_my_space = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) assert len(list(sdata_my_space.gen_elements())) == 3 - _assert_tables_seem_identical(sdata_my_space.table, full_sdata.table) + _assert_tables_seem_identical(sdata_my_space, full_sdata) sdata_my_space1 = full_sdata.filter_by_coordinate_system( coordinate_system=["my_space0", "my_space1", "my_space2"], filter_table=False diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index bbd9ccee..6ca91d6a 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -18,6 +18,7 @@ polygon_query, ) from spatialdata._core.spatialdata import SpatialData +from spatialdata._utils import _assert_spatialdata_objects_seem_identical from spatialdata.models import ( Image2DModel, Image3DModel, @@ -30,9 +31,6 @@ from spatialdata.transformations import Identity, set_transformation from tests.conftest import _make_points, _make_squares -from tests.core.operations.test_spatialdata_operations import ( - _assert_spatialdata_objects_seem_identical, -) def test_bounding_box_request_immutable(): From a873a33c118de57694c21e698c40555ee80ee66f Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Sun, 25 Feb 2024 19:05:08 +0100 Subject: [PATCH 13/27] fix docs, attemp --- src/spatialdata/_utils.py | 2 +- src/spatialdata/transformations/operations.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 8900e4bf..b04d5f7f 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -29,7 +29,7 @@ ) if TYPE_CHECKING: - from spatialdata import SpatialData + from spatialdata._core.spatialdata import SpatialData # I was using "from numbers import Number" but this led to mypy errors, so I switched to the following: Number = Union[int, float] diff --git a/src/spatialdata/transformations/operations.py b/src/spatialdata/transformations/operations.py index f3e06adc..9206d3e3 100644 --- a/src/spatialdata/transformations/operations.py +++ b/src/spatialdata/transformations/operations.py @@ -16,8 +16,8 @@ if TYPE_CHECKING: from spatialdata._core.spatialdata import SpatialData - from spatialdata.models import SpatialElement - from spatialdata.transformations import Affine, BaseTransformation + from spatialdata.models._utils import SpatialElement + from spatialdata.transformations.transformations import Affine, BaseTransformation def set_transformation( @@ -329,7 +329,7 @@ def get_transformation_between_landmarks( example on how to call this function on two sets of numpy arrays describing x, y coordinates. >>> import numpy as np >>> from spatialdata.models import PointsModel - >>> from spatialdata.transform import get_transformation_between_landmarks + >>> from spatialdata.transformations import get_transformation_between_landmarks >>> points_moving = np.array([[0, 0], [1, 1], [2, 2]]) >>> points_reference = np.array([[0, 0], [10, 10], [20, 20]]) >>> moving_coords = PointsModel(points_moving) From 0685fb7a1578ab343e3151ee2c5f1eff3bf13288 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sun, 25 Feb 2024 23:35:13 +0100 Subject: [PATCH 14/27] allow table to be None in get_values and _locate_values (#466) * allow table to be None * fixes to aggregate * check annotation * adjust docstrings * remove table check * add sphinx_pytest --- pyproject.toml | 1 + src/spatialdata/_core/operations/aggregate.py | 13 +++- .../_core/query/relational_query.py | 8 +- tests/core/operations/test_aggregations.py | 20 +++-- tests/core/query/test_relational_query.py | 73 +++++++++++++++---- 5 files changed, 90 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bab753b4..8884abe2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ docs = [ # For notebooks "ipython>=8.6.0", "sphinx-copybutton", + "sphinx-pytest", ] test = [ "pytest", diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 8bdab94f..26993e62 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -64,7 +64,7 @@ def aggregate( region_key: str = "region", instance_key: str = "instance_id", deepcopy: bool = True, - table_name: str = "table", + table_name: str | None = None, **kwargs: Any, ) -> SpatialData: """ @@ -127,7 +127,7 @@ def aggregate( Whether to deepcopy the shapes in the returned `SpatialData` object. If the shapes are large (e.g. large multiscale labels), you may consider disabling the deepcopy to use a lazy Dask representation. table_name - The name of the table resulting from the aggregation. + The table optionally containing the value_key and the name of the table in the returned `SpatialData` object. kwargs Additional keyword arguments to pass to :func:`xrspatial.zonal_stats`. @@ -203,6 +203,7 @@ def aggregate( value_key=value_key, agg_func=agg_func, fractions=fractions, + table_name=table_name, ) # eventually remove the colum of ones if it was added @@ -217,6 +218,7 @@ def aggregate( if adata is None: raise NotImplementedError(f"Cannot aggregate {values_type} by {by_type}") + table_name = table_name if table_name is not None else "table" # create a SpatialData object with the aggregated table and the "by" shapes shapes_name = by if isinstance(by, str) else "by" return _create_sdata_from_table_and_shapes( @@ -322,6 +324,7 @@ def _aggregate_shapes( by: gpd.GeoDataFrame, values_sdata: SpatialData | None = None, values_element_name: str | None = None, + table_name: str | None = None, value_key: str | list[str] | None = None, agg_func: str | list[str] = "count", fractions: bool = False, @@ -348,13 +351,17 @@ def _aggregate_shapes( Column in value dataframe to perform aggregation on. agg_func Aggregation function to apply over grouped values. Passed to pandas.DataFrame.groupby.agg. + table_name + Name of the table optionally containing the value_key column. """ from spatialdata.models import points_dask_dataframe_to_geopandas assert value_key is not None assert (values_sdata is None) == (values_element_name is None) if values_sdata is not None: - actual_values = get_values(value_key=value_key, sdata=values_sdata, element_name=values_element_name) + actual_values = get_values( + value_key=value_key, sdata=values_sdata, element_name=values_element_name, table_name=table_name + ) else: actual_values = get_values(value_key=value_key, element=values) assert isinstance(actual_values, pd.DataFrame), f"Expected pd.DataFrame, got {type(actual_values)}" diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 05858d42..9ca9dd70 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -637,7 +637,7 @@ def _locate_value( element: SpatialElement | None = None, sdata: SpatialData | None = None, element_name: str | None = None, - table_name: str = "table", + table_name: str | None = None, ) -> list[_ValueOrigin]: el = _get_element(element=element, sdata=sdata, element_name=element_name) origins = [] @@ -652,7 +652,7 @@ def _locate_value( # adding from the obs columns or var if model in [ShapesModel, Labels2DModel, Labels3DModel] and sdata is not None: - table = sdata[table_name] + table = sdata.tables.get(table_name) if table_name is not None else None if table is not None: # check if the table is annotating the element region = table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] @@ -673,7 +673,7 @@ def get_values( element: SpatialElement | None = None, sdata: SpatialData | None = None, element_name: str | None = None, - table_name: str = "table", + table_name: str | None = None, ) -> pd.DataFrame: """ Get the values from the element, from any location: df columns, obs or var columns (table). @@ -736,7 +736,7 @@ def get_values( if isinstance(el, DaskDataFrame): df = df.compute() return df - if sdata is not None: + if sdata is not None and table_name is not None: assert element_name is not None matched_table = match_table_to_element(sdata=sdata, element_name=element_name, table_name=table_name) region_key = matched_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 8d1a734a..09283dca 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -46,7 +46,7 @@ def test_aggregate_points_by_shapes(sdata_query_aggregation, by_shapes: str, val # testing that we can call aggregate with the two equivalent syntaxes for the values argument result_adata = aggregate(values=points, by=shapes, value_key=value_key, agg_func="sum").tables["table"] result_adata_bis = aggregate( - values_sdata=sdata, values="points", by=shapes, value_key=value_key, agg_func="sum" + values_sdata=sdata, values="points", by=shapes, value_key=value_key, agg_func="sum", table_name="table" ).tables["table"] np.testing.assert_equal(result_adata.X.A, result_adata_bis.X.A) @@ -147,7 +147,7 @@ def test_aggregate_shapes_by_shapes( values = _parse_shapes(sdata, values_shapes=values_shapes) result_adata = aggregate( - values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="sum" + values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="sum", table_name="table" ).tables["table"] # testing that we can call aggregate with the two equivalent syntaxes for the values argument (only relevant when @@ -255,7 +255,7 @@ def test_aggregate_shapes_by_shapes( # in the categorical case, check that sum and count behave the same if value_key in ["categorical_in_obs", "categorical_in_gdf"]: result_adata_count = aggregate( - values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="count" + values_sdata=sdata, values=values_shapes, by=by, value_key=value_key, agg_func="count", table_name="table" ).tables["table"] assert_equal(result_adata, result_adata_count) @@ -264,7 +264,14 @@ def test_aggregate_shapes_by_shapes( if value_key in ["categorical_in_obs", "categorical_in_gdf"]: # can't aggregate multiple categorical values with pytest.raises(ValueError): - aggregate(values_sdata=sdata, values=values_shapes, by=by, value_key=new_value_key, agg_func="sum") + aggregate( + values_sdata=sdata, + values=values_shapes, + by=by, + value_key=new_value_key, + agg_func="sum", + table_name="table", + ) else: if value_key == "numerical_in_obs": sdata.tables["table"].obs["another_numerical_in_obs"] = 1.0 @@ -279,7 +286,7 @@ def test_aggregate_shapes_by_shapes( sdata.tables["table"] = new_table result_adata = aggregate( - values_sdata=sdata, values=values_shapes, by=by, value_key=new_value_key, agg_func="sum" + values_sdata=sdata, values=values_shapes, by=by, value_key=new_value_key, agg_func="sum", table_name="table" ).tables["table"] assert result_adata.var_names.to_list() == new_value_key @@ -311,6 +318,7 @@ def test_aggregate_shapes_by_shapes( by=by, value_key=value_key, agg_func="sum", + table_name="table", ) # test we can't aggregate from mixed categorical and numerical sources (let's just test one case) with pytest.raises(ValueError): @@ -320,6 +328,7 @@ def test_aggregate_shapes_by_shapes( by=by, value_key=["numerical_values_in_obs", "categorical_values_in_obs"], agg_func="sum", + table_name="table", ) @@ -494,6 +503,7 @@ def test_aggregate_considering_fractions_multiple_values( value_key=["numerical_in_var", "another_numerical_in_var"], agg_func="sum", fractions=True, + table_name="table", ).tables["table"] overlaps = np.array([0.655781239649211, 1.0000000000000002, 1.0000000000000004, 0.1349639285777728]) row0 = np.sum(sdata.tables["table"].X[[0, 1, 2, 3], :] * overlaps.reshape(-1, 1), axis=0) diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index d4af878a..8f68c448 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -173,45 +173,74 @@ def _check_location(locations: list[_ValueOrigin], origin: str, is_categorical: # var, numerical _check_location( - _locate_value(value_key="numerical_in_var", sdata=sdata_query_aggregation, element_name="values_circles"), + _locate_value( + value_key="numerical_in_var", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ), origin="var", is_categorical=False, ) # obs, categorical _check_location( - _locate_value(value_key="categorical_in_obs", sdata=sdata_query_aggregation, element_name="values_circles"), + _locate_value( + value_key="categorical_in_obs", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ), origin="obs", is_categorical=True, ) # obs, numerical _check_location( - _locate_value(value_key="numerical_in_obs", sdata=sdata_query_aggregation, element_name="values_circles"), + _locate_value( + value_key="numerical_in_obs", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ), origin="obs", is_categorical=False, ) # gdf, categorical # sdata + element_name _check_location( - _locate_value(value_key="categorical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles"), + _locate_value( + value_key="categorical_in_gdf", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ), origin="df", is_categorical=True, ) # element _check_location( - _locate_value(value_key="categorical_in_gdf", element=sdata_query_aggregation["values_circles"]), + _locate_value( + value_key="categorical_in_gdf", element=sdata_query_aggregation["values_circles"], table_name="table" + ), origin="df", is_categorical=True, ) # gdf, numerical # sdata + element_name _check_location( - _locate_value(value_key="numerical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles"), + _locate_value( + value_key="numerical_in_gdf", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ), origin="df", is_categorical=False, ) # element _check_location( - _locate_value(value_key="numerical_in_gdf", element=sdata_query_aggregation["values_circles"]), + _locate_value( + value_key="numerical_in_gdf", element=sdata_query_aggregation["values_circles"], table_name="table" + ), origin="df", is_categorical=False, ) @@ -245,7 +274,9 @@ def _check_location(locations: list[_ValueOrigin], origin: str, is_categorical: def test_get_values_df(sdata_query_aggregation): # test with a single value, in the dataframe; using sdata + element_name - v = get_values(value_key="numerical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles") + v = get_values( + value_key="numerical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" + ) assert v.shape == (9, 1) # test with multiple values, in the dataframe; using element @@ -256,7 +287,9 @@ def test_get_values_df(sdata_query_aggregation): assert v.shape == (9, 2) # test with a single value, in the obs - v = get_values(value_key="numerical_in_obs", sdata=sdata_query_aggregation, element_name="values_circles") + v = get_values( + value_key="numerical_in_obs", sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" + ) assert v.shape == (9, 1) # test with multiple values, in the obs @@ -265,11 +298,14 @@ def test_get_values_df(sdata_query_aggregation): value_key=["numerical_in_obs", "another_numerical_in_obs"], sdata=sdata_query_aggregation, element_name="values_circles", + table_name="table", ) assert v.shape == (9, 2) # test with a single value, in the var - v = get_values(value_key="numerical_in_var", sdata=sdata_query_aggregation, element_name="values_circles") + v = get_values( + value_key="numerical_in_var", sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" + ) assert v.shape == (9, 1) # test with multiple values, in the var @@ -287,6 +323,7 @@ def test_get_values_df(sdata_query_aggregation): value_key=["numerical_in_var", "another_numerical_in_var"], sdata=sdata_query_aggregation, element_name="values_circles", + table_name="table", ) assert v.shape == (9, 2) @@ -294,11 +331,18 @@ def test_get_values_df(sdata_query_aggregation): # value found in multiple locations sdata_query_aggregation.table.obs["another_numerical_in_gdf"] = np.zeros(21) with pytest.raises(ValueError): - get_values(value_key="another_numerical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles") + get_values( + value_key="another_numerical_in_gdf", + sdata=sdata_query_aggregation, + element_name="values_circles", + table_name="table", + ) # value not found with pytest.raises(ValueError): - get_values(value_key="not_present", sdata=sdata_query_aggregation, element_name="values_circles") + get_values( + value_key="not_present", sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" + ) # mixing categorical and numerical values with pytest.raises(ValueError): @@ -306,6 +350,7 @@ def test_get_values_df(sdata_query_aggregation): value_key=["numerical_in_gdf", "categorical_in_gdf"], sdata=sdata_query_aggregation, element_name="values_circles", + table_name="table", ) # multiple categorical values @@ -315,6 +360,7 @@ def test_get_values_df(sdata_query_aggregation): value_key=["categorical_in_gdf", "another_categorical_in_gdf"], sdata=sdata_query_aggregation, element_name="values_circles", + table_name="table", ) # mixing different origins @@ -323,6 +369,7 @@ def test_get_values_df(sdata_query_aggregation): value_key=["numerical_in_gdf", "numerical_in_obs"], sdata=sdata_query_aggregation, element_name="values_circles", + table_name="table", ) @@ -330,7 +377,7 @@ def test_get_values_labels_bug(sdata_blobs): # https://github.com/scverse/spatialdata-plot/issues/165 from spatialdata import get_values - get_values("channel_0_sum", sdata=sdata_blobs, element_name="blobs_labels") + get_values("channel_0_sum", sdata=sdata_blobs, element_name="blobs_labels", table_name="table") def test_filter_table_categorical_bug(shapes): From 5248d3e9b7a966720fc7745ed2c94f7c38b616ca Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:55:07 +0100 Subject: [PATCH 15/27] Added `validate_table_annotation_target()` (#468) * added validate_table_annotation_target() * removed test that is no longer relevant * fix docs * fixed way to test no warnings emitted * merged validate_table_annotation_target() into validate_table_in_spatialdata() * fix test --- .../_core/query/relational_query.py | 8 +-- src/spatialdata/_core/spatialdata.py | 58 ++++++++++++------- .../operations/test_spatialdata_operations.py | 48 ++++++++++++--- tests/io/test_multi_table.py | 12 ++-- 4 files changed, 86 insertions(+), 40 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 9ca9dd70..46713e64 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -288,7 +288,7 @@ def _right_join_spatialelement_table( element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] ) -> tuple[dict[str, Any], AnnData]: if match_rows == "left": - warnings.warn("Matching rows ``'left'`` is not supported for ``'right'`` join.", UserWarning, stacklevel=2) + warnings.warn("Matching rows 'left' is not supported for 'right' join.", UserWarning, stacklevel=2) regions, region_column_name, instance_key = get_table_keys(table) groups_df = table.obs.groupby(by=region_column_name) for element_type, name_element in element_dict.items(): @@ -300,7 +300,7 @@ def _right_join_spatialelement_table( element_indices = element.index else: warnings.warn( - f"Element type `labels` not supported for left exclusive join. Skipping `{name}`", + f"Element type `labels` not supported for 'right' join. Skipping `{name}`", UserWarning, stacklevel=2, ) @@ -331,7 +331,7 @@ def _inner_join_spatialelement_table( element_indices = element.index else: warnings.warn( - f"Element type `labels` not supported for left exclusive join. Skipping `{name}`", + f"Element type `labels` not supported for 'inner' join. Skipping `{name}`", UserWarning, stacklevel=2, ) @@ -389,7 +389,7 @@ def _left_join_spatialelement_table( element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] ) -> tuple[dict[str, Any], AnnData]: if match_rows == "right": - warnings.warn("Matching rows ``'right'`` is not supported for ``'left'`` join.", UserWarning, stacklevel=2) + warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2) regions, region_column_name, instance_key = get_table_keys(table) groups_df = table.obs.groupby(by=region_column_name) joined_indices = None diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index b44c6c55..9baa4ccc 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -159,41 +159,55 @@ def __init__( self._query = QueryManager(self) - def validate_table_in_spatialdata(self, data: AnnData) -> None: + def validate_table_in_spatialdata(self, table: AnnData) -> None: """ Validate the presence of the annotation target of a SpatialData table in the SpatialData object. This method validates a table in the SpatialData object to ensure that if annotation metadata is present, the - annotation target (SpatialElement) is present in the SpatialData object. Otherwise, a warning is raised. + annotation target (SpatialElement) is present in the SpatialData object, the dtypes of the instance key column + in the table and the annotation target do not match. Otherwise, a warning is raised. Parameters ---------- - data + table The table potentially annotating a SpatialElement Raises ------ UserWarning If the table is annotating elements not present in the SpatialData object. + UserWarning + The dtypes of the instance key column in the table and the annotation target do not match. """ - TableModel().validate(data) - element_names = [ - element_name for element_type, element_name, _ in self._gen_elements() if element_type != "tables" - ] - if TableModel.ATTRS_KEY in data.uns: - attrs = data.uns[TableModel.ATTRS_KEY] - regions = ( - attrs[TableModel.REGION_KEY] - if isinstance(attrs[TableModel.REGION_KEY], list) - else [attrs[TableModel.REGION_KEY]] - ) - # TODO: check throwing error - if not all(element_name in element_names for element_name in regions): - warnings.warn( - "The table is annotating an/some element(s) not present in the SpatialData object", - UserWarning, - stacklevel=2, - ) + TableModel().validate(table) + if TableModel.ATTRS_KEY in table.uns: + region, _, instance_key = get_table_keys(table) + region = region if isinstance(region, list) else [region] + for r in region: + element = self.get(r) + if element is None: + warnings.warn( + f"The table is annotating {r!r}, which is not present in the SpatialData object.", + UserWarning, + stacklevel=2, + ) + else: + if isinstance(element, (SpatialImage, MultiscaleSpatialImage)): + dtype = element.dtype + else: + dtype = element.index.dtype + if dtype != table.obs[instance_key].dtype: + warnings.warn( + ( + f"Table instance_key column ({instance_key}) has a dtype " + f"({table.obs[instance_key].dtype}) that does not match the dtype of the indices of " + f"the annotated element ({dtype}). Please note in the case of int16 vs int32 or " + "similar cases may be tolerated in downstream methods, but it is recommended to make " + "the dtypes match." + ), + UserWarning, + stacklevel=2, + ) @staticmethod def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> SpatialData: @@ -417,7 +431,7 @@ def set_table_annotates_spatialelement( table = self.tables[table_name] element_names = {element[1] for element in self._gen_elements()} if region not in element_names: - raise ValueError(f"Annotation target '{region}' not present as SpatialElement in " f"SpatialData object.") + raise ValueError(f"Annotation target '{region}' not present as SpatialElement in SpatialData object.") if table.uns.get(TableModel.ATTRS_KEY): self._change_table_annotation_target(table, region, region_key, instance_key) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 22723f51..43b648ef 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +import warnings import numpy as np import pytest @@ -11,13 +12,7 @@ from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _assert_spatialdata_objects_seem_identical, _assert_tables_seem_identical from spatialdata.datasets import blobs -from spatialdata.models import ( - Image2DModel, - Labels2DModel, - PointsModel, - ShapesModel, - TableModel, -) +from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel, get_table_keys from spatialdata.transformations.operations import get_transformation, set_transformation from spatialdata.transformations.transformations import ( Affine, @@ -417,3 +412,42 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning: assert are_extents_equal( data_extent_before, data_extent_after, atol=3 ), f"data_extent_before: {data_extent_before}, data_extent_after: {data_extent_after} for element {element}" + + +def test_validate_table_in_spatialdata(full_sdata): + table = full_sdata["table"] + region, region_key, _ = get_table_keys(table) + assert region == "labels2d" + + # no warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") + full_sdata.validate_table_in_spatialdata(table) + + # dtype mismatch + full_sdata.labels["labels2d"] = Labels2DModel.parse(full_sdata.labels["labels2d"].astype("int16")) + with pytest.warns(UserWarning, match="that does not match the dtype of the indices of the annotated element"): + full_sdata.validate_table_in_spatialdata(table) + + # region not found + del full_sdata.labels["labels2d"] + with pytest.warns(UserWarning, match="in the SpatialData object"): + full_sdata.validate_table_in_spatialdata(table) + + table.obs[region_key] = "points_0" + full_sdata.set_table_annotates_spatialelement("table", region="points_0") + + # no warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") + full_sdata.validate_table_in_spatialdata(table) + + # dtype mismatch + full_sdata.points["points_0"].index = full_sdata.points["points_0"].index.astype("int16") + with pytest.warns(UserWarning, match="that does not match the dtype of the indices of the annotated element"): + full_sdata.validate_table_in_spatialdata(table) + + # region not found + del full_sdata.points["points_0"] + with pytest.warns(UserWarning, match="in the SpatialData object"): + full_sdata.validate_table_in_spatialdata(table) diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index 56755b79..d1bd21e0 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -93,7 +93,7 @@ def test_change_annotation_target(self, full_sdata, region_key, instance_key, er def test_set_table_nonexisting_target(self, full_sdata): with pytest.raises( ValueError, - match="Annotation target 'non_existing' not present as SpatialElement in " "SpatialData object.", + match="Annotation target 'non_existing' not present as SpatialElement in SpatialData object.", ): full_sdata.set_table_annotates_spatialelement("table", "non_existing") @@ -150,9 +150,8 @@ def test_single_table(self, tmp_path: str, region: str): } if region == "non_existing": - with pytest.warns( - UserWarning, match=r"The table is annotating an/some element\(s\) not present in the SpatialData object" - ): + # annotation target not present in the SpatialData object + with pytest.warns(UserWarning, match=r", which is not present in the SpatialData object"): SpatialData( shapes=shapes_dict, tables={"shape_annotate": table}, @@ -189,9 +188,8 @@ def test_paired_elements_tables(self, tmp_path: str): table = _get_table(region="poly") table2 = _get_table(region="multipoly") table3 = _get_table(region="non_existing") - with pytest.warns( - UserWarning, match=r"The table is annotating an/some element\(s\) not present in the SpatialData object" - ): + # annotation target not present in the SpatialData object + with pytest.warns(UserWarning, match=r", which is not present in the SpatialData object"): SpatialData( shapes={"poly": test_shapes["poly"], "multipoly": test_shapes["multipoly"]}, table={"poly_annotate": table, "multipoly_annotate": table3}, From 5cc1347c2e0bf2fa16b48d8d25c97112b154f6e9 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 6 Mar 2024 02:49:10 +0100 Subject: [PATCH 16/27] silence warning --- src/spatialdata/_io/format.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index f82b74f3..75ebab8d 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -3,7 +3,7 @@ from anndata import AnnData from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from ome_zarr.format import CurrentFormat -from pandas.api.types import is_categorical_dtype +from pandas.api.types import CategoricalDtype from shapely import GeometryType from spatial_image import SpatialImage @@ -166,7 +166,7 @@ def validate_table( ) -> None: if not isinstance(table, AnnData): raise TypeError(f"`table` must be `anndata.AnnData`, was {type(table)}.") - if region_key is not None and not is_categorical_dtype(table.obs[region_key]): + if region_key is not None and not isinstance(table.obs[region_key].dtype, CategoricalDtype): raise ValueError( f"`table.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`." ) From a09ea4937e494ea58504f44c6652fbea1e489e83 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 12 Mar 2024 16:46:28 +0100 Subject: [PATCH 17/27] fix gettin dtype from multiscale --- src/spatialdata/_core/spatialdata.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 9baa4ccc..06be0586 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -192,10 +192,10 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None: stacklevel=2, ) else: - if isinstance(element, (SpatialImage, MultiscaleSpatialImage)): + if isinstance(element, SpatialImage): dtype = element.dtype - else: - dtype = element.index.dtype + elif isinstance(element, MultiscaleSpatialImage): + dtype = element.scale0.ds.dtypes["image"] if dtype != table.obs[instance_key].dtype: warnings.warn( ( From b3571ef70360e244c7af54791cc38005fdb2d006 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 12 Mar 2024 16:50:10 +0100 Subject: [PATCH 18/27] add else dtype back --- src/spatialdata/_core/spatialdata.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 06be0586..ed0f6a27 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -196,6 +196,8 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None: dtype = element.dtype elif isinstance(element, MultiscaleSpatialImage): dtype = element.scale0.ds.dtypes["image"] + else: + dtype = element.index.dtype if dtype != table.obs[instance_key].dtype: warnings.warn( ( From 5392ea1122a0a486af27bd99ef0e08b780728c0e Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 13 Mar 2024 12:20:45 +0100 Subject: [PATCH 19/27] silence scipy.misc.face deprecation --- src/spatialdata/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 70dc0e39..546a0927 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -86,7 +86,7 @@ def raccoon( self, ) -> SpatialData: """Raccoon dataset.""" - im_data = scipy.misc.face() + im_data = scipy.datasets.face() im = Image2DModel.parse(im_data, dims=["y", "x", "c"]) labels_data = slic(im_data, n_segments=100, compactness=10, sigma=1) labels = Labels2DModel.parse(labels_data, dims=["y", "x"]) From 3386b3df8187ae31458e69e288f93397ab7a9595 Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Wed, 13 Mar 2024 22:01:02 +0100 Subject: [PATCH 20/27] Operation `to_circles()` (#473) * `get_centroids()` (#465) * wip get_centroids * implemented get_centroids() * code suggestions from kevin * wip vectorize.py * added vectorize; still no tests * refactored testing functions; wip tests to_circles() * considering removing the target_coordinate_system parameter in to_circles() * adjusted tests, docs, changelog * added changelog * fix tests, remove inject_docs * fix sphinx * attempt fix sphinx --- CHANGELOG.md | 3 + docs/api.md | 15 ++ docs/conf.py | 9 +- src/spatialdata/__init__.py | 2 + src/spatialdata/_core/operations/vectorize.py | 156 ++++++++++++++++++ src/spatialdata/_utils.py | 47 +----- src/spatialdata/testing.py | 156 ++++++++++++++++++ .../operations/test_spatialdata_operations.py | 6 +- tests/core/operations/test_transform.py | 3 +- tests/core/operations/test_vectorize.py | 68 ++++++++ tests/core/query/test_spatial_query.py | 10 +- 11 files changed, 415 insertions(+), 60 deletions(-) create mode 100644 src/spatialdata/_core/operations/vectorize.py create mode 100644 src/spatialdata/testing.py create mode 100644 tests/core/operations/test_vectorize.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c3cd242..768fc6df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,9 @@ and this project adheres to [Semantic Versioning][]. - added utils function: postpone_transformation() - added utils function: remove_transformations_to_coordinate_system() - added utils function: get_centroids() +- added operation: to_circles() +- added testing utilities: assert_spatial_data_objects_are_identical(), assert_elements_are_identical(), + assert_elements_dict_are_identical() ### Changed diff --git a/docs/api.md b/docs/api.md index cc24bf3c..b0bb586d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -27,6 +27,7 @@ Operations on `SpatialData` objects. polygon_query get_values get_extent + get_centroids join_sdata_spatialelement_table match_element_to_table get_centroids @@ -34,6 +35,7 @@ Operations on `SpatialData` objects. concatenate transform rasterize + to_circles aggregate ``` @@ -141,3 +143,16 @@ The transformations that can be defined between elements and coordinate systems save_transformations get_dask_backing_files ``` + +## Testing utilities + +```{eval-rst} +.. currentmodule:: spatialdata.testing + +.. autosummary:: + :toctree: generated + + assert_spatial_data_objects_are_identical + assert_elements_are_identical + assert_elements_dict_are_identical +``` diff --git a/docs/conf.py b/docs/conf.py index b394f0b0..8262f9d4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -133,10 +133,11 @@ html_title = project_name html_logo = "_static/img/spatialdata_horizontal.png" -# html_theme_options = { -# "repository_url": repository_url, -# "use_repository_button": True, -# } +html_theme_options = { + "navigation_with_keys": True, + # "repository_url": repository_url, + # "use_repository_button": True, +} pygments_style = "default" diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index d4521699..a3afe8e3 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -17,6 +17,7 @@ "dataloader", "concatenate", "rasterize", + "to_circles", "transform", "aggregate", "bounding_box_query", @@ -42,6 +43,7 @@ from spatialdata._core.operations.aggregate import aggregate from spatialdata._core.operations.rasterize import rasterize from spatialdata._core.operations.transform import transform +from spatialdata._core.operations.vectorize import to_circles from spatialdata._core.query._utils import circles_to_polygons, get_bounding_box_corners from spatialdata._core.query.relational_query import ( get_values, diff --git a/src/spatialdata/_core/operations/vectorize.py b/src/spatialdata/_core/operations/vectorize.py new file mode 100644 index 00000000..7f496b1c --- /dev/null +++ b/src/spatialdata/_core/operations/vectorize.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from functools import singledispatch + +import numpy as np +import pandas as pd +from geopandas import GeoDataFrame +from multiscale_spatial_image import MultiscaleSpatialImage +from shapely import MultiPolygon, Point, Polygon +from spatial_image import SpatialImage + +from spatialdata._core.centroids import get_centroids +from spatialdata._core.operations.aggregate import aggregate +from spatialdata.models import ( + Image2DModel, + Image3DModel, + Labels3DModel, + ShapesModel, + SpatialElement, + get_axes_names, + get_model, +) +from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations.transformations import Identity + +INTRINSIC_COORDINATE_SYSTEM = "__intrinsic" + + +@singledispatch +def to_circles( + data: SpatialElement, +) -> GeoDataFrame: + """ + Convert a set of geometries (2D/3D labels, 2D shapes) to approximated circles/spheres. + + Parameters + ---------- + data + The SpatialElement representing the geometries to approximate as circles/spheres. + + Returns + ------- + The approximated circles/spheres. + + Notes + ----- + The approximation is done by computing the centroids and the area/volume of the geometries. The geometries are then + replaced by circles/spheres with the same centroids and area/volume. + """ + raise RuntimeError(f"Unsupported type: {type(data)}") + + +@to_circles.register(SpatialImage) +@to_circles.register(MultiscaleSpatialImage) +def _( + element: SpatialImage | MultiscaleSpatialImage, +) -> GeoDataFrame: + model = get_model(element) + if model in (Image2DModel, Image3DModel): + raise RuntimeError("Cannot apply to_circles() to images.") + if model == Labels3DModel: + raise RuntimeError("to_circles() is not yet implemented for 3D labels.") + + # reduce to the single scale case + if isinstance(element, MultiscaleSpatialImage): + element_single_scale = SpatialImage(element["scale0"].values().__iter__().__next__()) + else: + element_single_scale = element + shape = element_single_scale.shape + + # find the area of labels, estimate the radius from it; find the centroids + axes = get_axes_names(element) + model = Image3DModel if "z" in axes else Image2DModel + ones = model.parse(np.ones((1,) + shape), dims=("c",) + axes) + aggregated = aggregate(values=ones, by=element_single_scale, agg_func="sum")["table"] + areas = aggregated.X.todense().A1.reshape(-1) + aobs = aggregated.obs + aobs["areas"] = areas + aobs["radius"] = np.sqrt(areas / np.pi) + + # get the centroids; remove the background if present (the background is not considered during aggregation) + centroids = _get_centroids(element) + if 0 in centroids.index: + centroids = centroids.drop(index=0) + # instance_id is the key used by the aggregation APIs + aobs.index = aobs["instance_id"] + aobs.index.name = None + assert len(aobs) == len(centroids) + obs = pd.merge(aobs, centroids, left_index=True, right_index=True, how="inner") + assert len(obs) == len(centroids) + return _make_circles(element, obs) + + +@to_circles.register(GeoDataFrame) +def _( + element: GeoDataFrame, +) -> GeoDataFrame: + if isinstance(element.geometry.iloc[0], (Polygon, MultiPolygon)): + radius = np.sqrt(element.geometry.area / np.pi) + centroids = _get_centroids(element) + obs = pd.DataFrame({"radius": radius}) + obs = pd.merge(obs, centroids, left_index=True, right_index=True, how="inner") + return _make_circles(element, obs) + assert isinstance(element.geometry.iloc[0], Point), ( + f"Unsupported geometry type: " f"{type(element.geometry.iloc[0])}" + ) + return element + + +def _get_centroids(element: SpatialElement) -> pd.DataFrame: + d = get_transformation(element, get_all=True) + assert isinstance(d, dict) + if INTRINSIC_COORDINATE_SYSTEM in d: + raise RuntimeError(f"The name {INTRINSIC_COORDINATE_SYSTEM} is reserved.") + d[INTRINSIC_COORDINATE_SYSTEM] = Identity() + centroids = get_centroids(element, coordinate_system=INTRINSIC_COORDINATE_SYSTEM).compute() + del d[INTRINSIC_COORDINATE_SYSTEM] + return centroids + + +def _make_circles(element: SpatialImage | MultiscaleSpatialImage | GeoDataFrame, obs: pd.DataFrame) -> GeoDataFrame: + spatial_axes = sorted(get_axes_names(element)) + centroids = obs[spatial_axes].values + transformations = get_transformation(element, get_all=True) + assert isinstance(transformations, dict) + return ShapesModel.parse( + centroids, + geometry=0, + index=obs.index, + radius=obs["radius"].values, + transformations=transformations.copy(), + ) + + +# TODO: depending of the implementation, add a parameter to control the degree of approximation of the constructed +# polygons/multipolygons +@singledispatch +def to_polygons( + data: SpatialElement, + target_coordinate_system: str, +) -> GeoDataFrame: + """ + Convert a set of geometries (2D labels, 2D shapes) to approximated 2D polygons/multypolygons. + + Parameters + ---------- + data + The SpatialElement representing the geometries to approximate as 2D polygons/multipolygons. + target_coordinate_system + The coordinate system to which the geometries to consider should be transformed. + + Returns + ------- + The approximated 2D polygons/multipolygons in the specified coordinate system. + """ + raise RuntimeError("Unsupported type: {type(data)}") diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index b04d5f7f..205308e8 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -5,15 +5,12 @@ import warnings from collections.abc import Generator from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union +from typing import Any, Callable, TypeVar, Union import numpy as np import pandas as pd from anndata import AnnData -from anndata.tests.helpers import assert_equal from dask import array as da -from dask.dataframe import DataFrame as DaskDataFrame -from dask.delayed import Delayed from datatree import DataTree from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage @@ -28,9 +25,6 @@ set_transformation, ) -if TYPE_CHECKING: - from spatialdata._core.spatialdata import SpatialData - # I was using "from numbers import Number" but this led to mypy errors, so I switched to the following: Number = Union[int, float] RT = TypeVar("RT") @@ -317,42 +311,3 @@ def _error_message_add_element() -> None: "write_labels(), write_points(), write_shapes() and write_table(). We are going to make these calls more " "ergonomic in a follow up PR." ) - - -def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: - for element_type, element_name, element in sdata0._gen_elements(): - elements = sdata1.__getattribute__(element_type) - assert element_name in elements - element1 = elements[element_name] - if isinstance(element, (AnnData, SpatialImage, GeoDataFrame)): - assert element.shape == element1.shape - elif isinstance(element, DaskDataFrame): - for s0, s1 in zip(element.shape, element1.shape): - if isinstance(s0, Delayed): - s0 = s0.compute() - if isinstance(s1, Delayed): - s1 = s1.compute() - assert s0 == s1 - elif isinstance(element, MultiscaleSpatialImage): - assert len(element) == len(element1) - else: - raise TypeError(f"Unsupported type {type(element)}") - - -def _assert_tables_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: - tables0 = sdata0.tables - tables1 = sdata1.tables - assert set(tables0.keys()) == set(tables1.keys()) - for k in tables0: - t0 = tables0[k] - t1 = tables1[k] - assert_equal(t0, t1) - - -def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: - # this is not a full comparison, but it's fine anyway - assert len(list(sdata0._gen_elements())) == len(list(sdata1._gen_elements())) - assert set(sdata0.coordinate_systems) == set(sdata1.coordinate_systems) - _assert_elements_left_to_right_seem_identical(sdata0, sdata1) - _assert_elements_left_to_right_seem_identical(sdata1, sdata0) - _assert_tables_seem_identical(sdata0, sdata1) diff --git a/src/spatialdata/testing.py b/src/spatialdata/testing.py new file mode 100644 index 00000000..16f155bd --- /dev/null +++ b/src/spatialdata/testing.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from anndata import AnnData +from anndata.tests.helpers import assert_equal as assert_anndata_equal +from dask.dataframe import DataFrame as DaskDataFrame +from dask.dataframe.tests.test_dataframe import assert_eq as assert_dask_dataframe_equal +from datatree.testing import assert_equal as assert_datatree_equal +from geopandas import GeoDataFrame +from geopandas.testing import assert_geodataframe_equal +from multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage +from xarray.testing import assert_equal as assert_xarray_equal + +from spatialdata import SpatialData +from spatialdata._core._elements import Elements +from spatialdata.models._utils import SpatialElement +from spatialdata.transformations.operations import get_transformation + + +def assert_elements_dict_are_identical( + elements0: Elements, elements1: Elements, check_transformations: bool = True +) -> None: + """ + Compare two dictionaries of elements and assert that they are identical (except for the order of the keys). + + The dictionaries of elements can be obtained from a SpatialData object using the `.shapes`, `.labels`, `.points`, + `.images` and `.tables` properties. + + Parameters + ---------- + elements0 + The first dictionary of elements. + elements1 + The second dictionary of elements. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the two dictionaries of elements are not identical. + + Notes + ----- + With the current implementation, the transformations Translate([1.0, 2.0], + axes=('x', 'y')) and Translate([2.0, 1.0], axes=('y', 'x')) are considered different. + A quick way to avoid an error in this case is to use the check_transformations=False parameter. + """ + assert set(elements0.keys()) == set(elements1.keys()) + for k in elements0: + element0 = elements0[k] + element1 = elements1[k] + assert_elements_are_identical(element0, element1, check_transformations=check_transformations) + + +def assert_elements_are_identical( + element0: SpatialElement | AnnData, element1: SpatialElement | AnnData, check_transformations: bool = True +) -> None: + """ + Compare two elements (two SpatialElements or two tables) and assert that they are identical. + + Parameters + ---------- + element0 + The first element. + element1 + The second element. + check_transformations + Whether to check if the transformations are identical, for each element. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the two elements are not identical. + + Notes + ----- + With the current implementation, the transformations Translate([1.0, 2.0], + axes=('x', 'y')) and Translate([2.0, 1.0], axes=('y', 'x')) are considered different. + A quick way to avoid an error in this case is to use the check_transformations=False parameter. + """ + assert type(element0) == type(element1) + + # compare transformations (only for SpatialElements) + if not isinstance(element0, AnnData): + transformations0 = get_transformation(element0, get_all=True) + transformations1 = get_transformation(element1, get_all=True) + assert isinstance(transformations0, dict) + assert isinstance(transformations1, dict) + if check_transformations: + assert transformations0.keys() == transformations1.keys() + for key in transformations0: + assert ( + transformations0[key] == transformations1[key] + ), f"transformations0[{key}] != transformations1[{key}]" + + # compare the elements + if isinstance(element0, AnnData): + assert_anndata_equal(element0, element1) + elif isinstance(element0, SpatialImage): + assert_xarray_equal(element0, element1) + elif isinstance(element0, MultiscaleSpatialImage): + assert_datatree_equal(element0, element1) + elif isinstance(element0, GeoDataFrame): + assert_geodataframe_equal(element0, element1, check_less_precise=True) + else: + assert isinstance(element0, DaskDataFrame) + assert_dask_dataframe_equal(element0, element1) + + +def assert_spatial_data_objects_are_identical( + sdata0: SpatialData, sdata1: SpatialData, check_transformations: bool = True +) -> None: + """ + Compare two SpatialData objects and assert that they are identical. + + Parameters + ---------- + sdata0 + The first SpatialData object. + sdata1 + The second SpatialData object. + check_transformations + Whether to check if the transformations are identical, for each element. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the two SpatialData objects are not identical. + + Notes + ----- + With the current implementation, the transformations Translate([1.0, 2.0], + axes=('x', 'y')) and Translate([2.0, 1.0], axes=('y', 'x')) are considered different. + A quick way to avoid an error in this case is to use the check_transformations=False parameter. + """ + # this is not a full comparison, but it's fine anyway + element_names0 = [element_name for _, element_name, _ in sdata0.gen_elements()] + element_names1 = [element_name for _, element_name, _ in sdata1.gen_elements()] + assert len(set(element_names0)) == len(element_names0) + assert len(set(element_names1)) == len(element_names1) + assert set(sdata0.coordinate_systems) == set(sdata1.coordinate_systems) + for element_name in element_names0: + element0 = sdata0[element_name] + element1 = sdata1[element_name] + assert_elements_are_identical(element0, element1, check_transformations=check_transformations) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 43b648ef..c7fcf3db 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -10,9 +10,9 @@ from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData -from spatialdata._utils import _assert_spatialdata_objects_seem_identical, _assert_tables_seem_identical from spatialdata.datasets import blobs from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel, get_table_keys +from spatialdata.testing import assert_elements_dict_are_identical, assert_spatial_data_objects_are_identical from spatialdata.transformations.operations import get_transformation, set_transformation from spatialdata.transformations.transformations import ( Affine, @@ -109,7 +109,7 @@ def test_element_names_unique() -> None: def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: sdata = full_sdata.filter_by_coordinate_system(coordinate_system="global", filter_table=False) - _assert_spatialdata_objects_seem_identical(sdata, full_sdata) + assert_spatial_data_objects_are_identical(sdata, full_sdata) scale = Scale([2.0], axes=("x",)) set_transformation(full_sdata.images["image2d"], scale, "my_space0") @@ -118,7 +118,7 @@ def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: sdata_my_space = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) assert len(list(sdata_my_space.gen_elements())) == 3 - _assert_tables_seem_identical(sdata_my_space, full_sdata) + assert_elements_dict_are_identical(sdata_my_space.tables, full_sdata.tables) sdata_my_space1 = full_sdata.filter_by_coordinate_system( coordinate_system=["my_space0", "my_space1", "my_space2"], filter_table=False diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index e53515fd..5ad61dec 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -94,8 +94,7 @@ def test_physical_units(self, tmp_path: str, shapes: SpatialData) -> None: assert new_sdata.coordinate_systems["test"]._axes[0].unit == "micrometers" -def _get_affine(small_translation: bool = True) -> Affine: - theta = math.pi / 18 +def _get_affine(small_translation: bool = True, theta: float = math.pi / 18) -> Affine: k = 10.0 if small_translation else 1.0 return Affine( [ diff --git a/tests/core/operations/test_vectorize.py b/tests/core/operations/test_vectorize.py new file mode 100644 index 00000000..a0a7306e --- /dev/null +++ b/tests/core/operations/test_vectorize.py @@ -0,0 +1,68 @@ +import numpy as np +import pytest +from geopandas import GeoDataFrame +from shapely import Point +from spatialdata._core.operations.vectorize import to_circles +from spatialdata.datasets import blobs +from spatialdata.models.models import ShapesModel +from spatialdata.testing import assert_elements_are_identical + +# each of the tests operates on different elements, hence we can initialize the data once without conflicts +sdata = blobs() + + +@pytest.mark.parametrize("is_multiscale", [False, True]) +def test_labels_2d_to_circles(is_multiscale: bool) -> None: + key = "blobs" + ("_multiscale" if is_multiscale else "") + "_labels" + element = sdata[key] + new_circles = to_circles(element) + + assert np.isclose(new_circles.loc[1].geometry.x, 330.59258152354386) + assert np.isclose(new_circles.loc[1].geometry.y, 78.85026897788404) + assert np.isclose(new_circles.loc[1].radius, 69.229993) + assert 7 not in new_circles.index + + +@pytest.mark.skip(reason="Not implemented") +# @pytest.mark.parametrize("background", [0, 1]) +# @pytest.mark.parametrize("is_multiscale", [False, True]) +def test_labels_3d_to_circles() -> None: + pass + + +def test_circles_to_circles() -> None: + element = sdata["blobs_circles"] + new_circles = to_circles(element) + assert_elements_are_identical(element, new_circles) + + +def test_polygons_to_circles() -> None: + element = sdata["blobs_polygons"].iloc[:2] + new_circles = to_circles(element) + + data = { + "geometry": [Point(315.8120722406787, 220.18894606643332), Point(270.1386975678398, 417.8747936281634)], + "radius": [16.608781, 17.541365], + } + expected = ShapesModel.parse(GeoDataFrame(data, geometry="geometry")) + + assert_elements_are_identical(new_circles, expected) + + +def test_multipolygons_to_circles() -> None: + element = sdata["blobs_multipolygons"] + new_circles = to_circles(element) + + data = { + "geometry": [Point(340.37951022629096, 250.76310705786318), Point(337.1680699150594, 316.39984581697314)], + "radius": [23.488363, 19.059285], + } + expected = ShapesModel.parse(GeoDataFrame(data, geometry="geometry")) + assert_elements_are_identical(new_circles, expected) + + +def test_points_images_to_circles() -> None: + with pytest.raises(RuntimeError, match=r"Cannot apply to_circles\(\) to images."): + to_circles(sdata["blobs_image"]) + with pytest.raises(RuntimeError, match="Unsupported type"): + to_circles(sdata["blobs_points"]) diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 6ca91d6a..a043cab0 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -18,7 +18,6 @@ polygon_query, ) from spatialdata._core.spatialdata import SpatialData -from spatialdata._utils import _assert_spatialdata_objects_seem_identical from spatialdata.models import ( Image2DModel, Image3DModel, @@ -28,6 +27,7 @@ ShapesModel, TableModel, ) +from spatialdata.testing import assert_spatial_data_objects_are_identical from spatialdata.transformations import Identity, set_transformation from tests.conftest import _make_points, _make_squares @@ -356,15 +356,15 @@ def test_query_spatial_data(full_sdata): result1 = full_sdata.query(request, filter_table=True) result2 = full_sdata.query.bounding_box(**request.to_dict(), filter_table=True) - _assert_spatialdata_objects_seem_identical(result0, result1) - _assert_spatialdata_objects_seem_identical(result0, result2) + assert_spatial_data_objects_are_identical(result0, result1) + assert_spatial_data_objects_are_identical(result0, result2) polygon = Polygon([(1, 2), (60, 2), (60, 40), (1, 40)]) result3 = polygon_query(full_sdata, polygon=polygon, target_coordinate_system="global", filter_table=True) result4 = full_sdata.query.polygon(polygon=polygon, target_coordinate_system="global", filter_table=True) - _assert_spatialdata_objects_seem_identical(result0, result3) - _assert_spatialdata_objects_seem_identical(result0, result4) + assert_spatial_data_objects_are_identical(result0, result3, check_transformations=False) + assert_spatial_data_objects_are_identical(result0, result4, check_transformations=False) @pytest.mark.parametrize("with_polygon_query", [True, False]) From 0c1e339dc8756dd94effc8685d89809becbcf873 Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Thu, 14 Mar 2024 00:06:47 +0100 Subject: [PATCH 21/27] deedcopy() utils function (#480) * deedcopy() utils function * fixed missings seeds for default_rng() * wip fix * wip fix * fix bug due to data being computed in-place and then failing validation * add pooch requirement * Update src/spatialdata/_core/_deepcopy.py * Update src/spatialdata/_core/_deepcopy.py * Update src/spatialdata/_core/_deepcopy.py * Update src/spatialdata/_core/_deepcopy.py --------- Co-authored-by: Wouter-Michiel Vierdag --- CHANGELOG.md | 1 + docs/api.md | 1 + pyproject.toml | 3 +- src/spatialdata/__init__.py | 2 + src/spatialdata/_core/_deepcopy.py | 100 ++++++++++++++++++ src/spatialdata/_core/operations/aggregate.py | 4 +- src/spatialdata/_core/spatialdata.py | 3 +- src/spatialdata/_utils.py | 25 +---- tests/conftest.py | 6 +- tests/core/operations/test_aggregations.py | 6 +- .../operations/test_spatialdata_operations.py | 2 +- tests/core/test_data_extent.py | 8 +- tests/core/test_deepcopy.py | 21 ++++ tests/dataloader/test_datasets.py | 2 +- tests/io/test_multi_table.py | 2 +- tests/io/test_readwrite.py | 2 +- 16 files changed, 147 insertions(+), 41 deletions(-) create mode 100644 src/spatialdata/_core/_deepcopy.py create mode 100644 tests/core/test_deepcopy.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 768fc6df..d7260e43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ and this project adheres to [Semantic Versioning][]. - added utils function: postpone_transformation() - added utils function: remove_transformations_to_coordinate_system() - added utils function: get_centroids() +- added utils function: deepcopy() - added operation: to_circles() - added testing utilities: assert_spatial_data_objects_are_identical(), assert_elements_are_identical(), assert_elements_dict_are_identical() diff --git a/docs/api.md b/docs/api.md index b0bb586d..c3e39478 100644 --- a/docs/api.md +++ b/docs/api.md @@ -47,6 +47,7 @@ Operations on `SpatialData` objects. unpad_raster are_extents_equal + deepcopy ``` ## Models diff --git a/pyproject.toml b/pyproject.toml index df9c93f8..8f020ada 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,8 @@ dependencies = [ "xarray-spatial>=0.3.5", "tqdm", "fsspec<=2023.6", - "dask<=2024.2.1" + "dask<=2024.2.1", + "pooch", ] [project.optional-dependencies] diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index a3afe8e3..d899a6ec 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -34,9 +34,11 @@ "save_transformations", "get_dask_backing_files", "are_extents_equal", + "deepcopy", ] from spatialdata import dataloader, models, transformations +from spatialdata._core._deepcopy import deepcopy from spatialdata._core.centroids import get_centroids from spatialdata._core.concatenate import concatenate from spatialdata._core.data_extent import are_extents_equal, get_extent diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py new file mode 100644 index 00000000..ff3c9570 --- /dev/null +++ b/src/spatialdata/_core/_deepcopy.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from copy import deepcopy as _deepcopy +from functools import singledispatch + +from anndata import AnnData +from dask.array.core import Array as DaskArray +from dask.array.core import from_array +from dask.dataframe.core import DataFrame as DaskDataFrame +from geopandas import GeoDataFrame +from multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage + +from spatialdata._core.spatialdata import SpatialData +from spatialdata._utils import multiscale_spatial_image_from_data_tree +from spatialdata.models._utils import SpatialElement +from spatialdata.models.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, get_model + + +@singledispatch +def deepcopy(element: SpatialData | SpatialElement | AnnData) -> SpatialData | SpatialElement | AnnData: + """ + Deepcopy a SpatialData or SpatialElement object. + + Deepcopy will load the data in memory. Using this function for large Dask-backed objects is discouraged. In that + case, please save the SpatialData object to a different disk location and read it back again. + + Parameters + ---------- + element + The SpatialData or SpatialElement object to deepcopy + + Returns + ------- + A deepcopy of the SpatialData or SpatialElement object + """ + raise RuntimeError(f"Wrong type for deepcopy: {type(element)}") + + +# In the implementations below, when the data is loaded from Dask, we first use compute() and then we deepcopy the data. +# This leads to double copying the data, but since we expect the data to be small, this is acceptable. +@deepcopy.register(SpatialData) +def _(sdata: SpatialData) -> SpatialData: + elements_dict = {} + for _, element_name, element in sdata.gen_elements(): + elements_dict[element_name] = deepcopy(element) + return SpatialData.from_elements_dict(elements_dict) + + +@deepcopy.register(SpatialImage) +def _(element: SpatialImage) -> SpatialImage: + model = get_model(element) + if isinstance(element.data, DaskArray): + element = element.compute() + if model in [Image2DModel, Image3DModel]: + return model.parse(element.copy(deep=True), c_coords=element["c"]) # type: ignore[call-arg] + assert model in [Labels2DModel, Labels3DModel] + return model.parse(element.copy(deep=True)) + + +@deepcopy.register(MultiscaleSpatialImage) +def _(element: MultiscaleSpatialImage) -> MultiscaleSpatialImage: + # the complexity here is due to the fact that the parsers don't accept MultiscaleSpatialImage types and that we need + # to convert the DataTree to a MultiscaleSpatialImage. This will be simplified once we support + # multiscale_spatial_image 1.0.0 + model = get_model(element) + for key in element: + ds = element[key].ds + assert len(ds) == 1 + variable = ds.__iter__().__next__() + if isinstance(element[key][variable].data, DaskArray): + element[key][variable] = element[key][variable].compute() + msi = multiscale_spatial_image_from_data_tree(element.copy(deep=True)) + for key in msi: + ds = msi[key].ds + variable = ds.__iter__().__next__() + msi[key][variable].data = from_array(msi[key][variable].data) + element[key][variable].data = from_array(element[key][variable].data) + assert model in [Image2DModel, Image3DModel, Labels2DModel, Labels3DModel] + model().validate(msi) + return msi + + +@deepcopy.register(GeoDataFrame) +def _(gdf: GeoDataFrame) -> GeoDataFrame: + new_gdf = _deepcopy(gdf) + # temporary fix for https://github.com/scverse/spatialdata/issues/286. + new_attrs = _deepcopy(gdf.attrs) + new_gdf.attrs = new_attrs + return new_gdf + + +@deepcopy.register(DaskDataFrame) +def _(df: DaskDataFrame) -> DaskDataFrame: + return PointsModel.parse(df.compute().copy(deep=True)) + + +@deepcopy.register(AnnData) +def _(adata: AnnData) -> AnnData: + return adata.copy() diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 26993e62..46f32b50 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -241,7 +241,7 @@ def _create_sdata_from_table_and_shapes( instance_key: str, deepcopy: bool, ) -> SpatialData: - from spatialdata._utils import _deepcopy_geodataframe + from spatialdata._core._deepcopy import deepcopy as _deepcopy table.obs[instance_key] = table.obs_names.copy() table.obs[region_key] = shapes_name @@ -252,7 +252,7 @@ def _create_sdata_from_table_and_shapes( table.obs[instance_key] = table.obs[instance_key].astype(int) if deepcopy: - shapes = _deepcopy_geodataframe(shapes) + shapes = _deepcopy(shapes) return SpatialData.from_elements_dict({shapes_name: shapes, table_name: table}) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index ed0f6a27..b4644c4c 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1477,7 +1477,8 @@ def gen_spatial_elements(self) -> Generator[tuple[str, str, SpatialElement], Non Returns ------- - A generator that yields tuples containing the name, description, and SpatialElement objects themselves. + A generator that yields tuples containing the element_type (string), name, and SpatialElement objects + themselves. """ return self._gen_elements() diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 205308e8..0e59d63d 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -4,7 +4,6 @@ import re import warnings from collections.abc import Generator -from copy import deepcopy from typing import Any, Callable, TypeVar, Union import numpy as np @@ -12,7 +11,6 @@ from anndata import AnnData from dask import array as da from datatree import DataTree -from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage from xarray import DataArray @@ -162,7 +160,10 @@ def multiscale_spatial_image_from_data_tree(data_tree: DataTree) -> MultiscaleSp assert len(v) == 1 xdata = v.__iter__().__next__() d[k] = xdata + # this stopped working, we should add support for multiscale_spatial_image 1.0.0 so that the problem is solved return MultiscaleSpatialImage.from_dict(d) + # data_tree.__class__ = MultiscaleSpatialImage + # return cast(MultiscaleSpatialImage, data_tree) # TODO: this functions is similar to _iter_multiscale(), the latter is more powerful but not exposed to the user. @@ -213,26 +214,6 @@ def _inplace_fix_subset_categorical_obs(subset_adata: AnnData, original_adata: A subset_adata.obs = obs -def _deepcopy_geodataframe(gdf: GeoDataFrame) -> GeoDataFrame: - """ - temporary fix for https://github.com/scverse/spatialdata/issues/286. - - Parameters - ---------- - gdf - The GeoDataFrame to deepcopy - - Returns - ------- - A deepcopy of the GeoDataFrame - """ - # - new_gdf = deepcopy(gdf) - new_attrs = deepcopy(gdf.attrs) - new_gdf.attrs = new_attrs - return new_gdf - - # TODO: change to paramspec as soon as we drop support for python 3.9, see https://stackoverflow.com/a/68290080 def deprecation_alias(**aliases: str) -> Callable[[Callable[..., RT]], Callable[..., RT]]: """ diff --git a/tests/conftest.py b/tests/conftest.py index 53913150..66f65e0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,7 +35,7 @@ from spatialdata.datasets import BlobsDataset import geopandas as gpd import dask.dataframe as dd -from spatialdata._utils import _deepcopy_geodataframe +from spatialdata._core._deepcopy import deepcopy as _deepcopy RNG = default_rng(seed=0) @@ -295,7 +295,7 @@ def _get_table( def _get_new_table(spatial_element: None | str | Sequence[str], instance_id: None | Sequence[Any]) -> AnnData: - adata = AnnData(np.random.default_rng().random(10, 20000)) + adata = AnnData(np.random.default_rng(seed=0).random(10, 20000)) return TableModel.parse(adata=adata, spatial_element=spatial_element, instance_id=instance_id) @@ -313,7 +313,7 @@ def sdata_blobs() -> SpatialData: sdata = deepcopy(blobs(256, 300, 3)) for k, v in sdata.shapes.items(): - sdata.shapes[k] = _deepcopy_geodataframe(v) + sdata.shapes[k] = _deepcopy(v) from spatialdata._utils import multiscale_spatial_image_from_data_tree sdata.images["blobs_multiscale_image"] = multiscale_spatial_image_from_data_tree( diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 09283dca..fcbf89b9 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Optional import geopandas @@ -10,9 +9,9 @@ from geopandas import GeoDataFrame from numpy.random import default_rng from spatialdata import aggregate +from spatialdata._core._deepcopy import deepcopy as _deepcopy from spatialdata._core.query._utils import circles_to_polygons from spatialdata._core.spatialdata import SpatialData -from spatialdata._utils import _deepcopy_geodataframe from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel from spatialdata.transformations import Affine, Identity, set_transformation @@ -362,8 +361,7 @@ def test_aggregate_requiring_alignment(sdata_blobs: SpatialData, values, by) -> by = sdata_blobs[by] if id(values) == id(by): # warning: this will give problems when aggregation labels by labels (not supported yet), because of this: https://github.com/scverse/spatialdata/issues/269 - by = deepcopy(by) - by = _deepcopy_geodataframe(by) + by = _deepcopy(by) assert by.attrs["transform"] is not values.attrs["transform"] sdata = SpatialData.init_from_elements({"values": values, "by": by}) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index c7fcf3db..0b39ddce 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -410,7 +410,7 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning: data_extent_after = get_extent(after, coordinate_system="global") # huge tolerance because of the bug with pixel perfectness assert are_extents_equal( - data_extent_before, data_extent_after, atol=3 + data_extent_before, data_extent_after, atol=4 ), f"data_extent_before: {data_extent_before}, data_extent_after: {data_extent_after} for element {element}" diff --git a/tests/core/test_data_extent.py b/tests/core/test_data_extent.py index d7304ddf..94a1216f 100644 --- a/tests/core/test_data_extent.py +++ b/tests/core/test_data_extent.py @@ -7,7 +7,7 @@ from numpy.random import default_rng from shapely.geometry import MultiPolygon, Point, Polygon from spatialdata import SpatialData, get_extent, transform -from spatialdata._utils import _deepcopy_geodataframe +from spatialdata._core._deepcopy import deepcopy as _deepcopy from spatialdata.datasets import blobs from spatialdata.models import Image2DModel, PointsModel, ShapesModel from spatialdata.transformations import Affine, Translation, remove_transformation, set_transformation @@ -237,7 +237,7 @@ def test_get_extent_affine_circles(): affine = _get_affine(small_translation=True) # let's do a deepcopy of the circles since we don't want to modify the original data - circles = _deepcopy_geodataframe(sdata["blobs_circles"]) + circles = _deepcopy(sdata["blobs_circles"]) set_transformation(element=circles, transformation=affine, to_coordinate_system="transformed") @@ -304,8 +304,8 @@ def test_get_extent_affine_sdata(): # let's make a copy since we don't want to modify the original data sdata2 = SpatialData( shapes={ - "circles": _deepcopy_geodataframe(sdata["blobs_circles"]), - "polygons": _deepcopy_geodataframe(sdata["blobs_polygons"]), + "circles": _deepcopy(sdata["blobs_circles"]), + "polygons": _deepcopy(sdata["blobs_polygons"]), } ) translation0 = Translation([10], axes=("x",)) diff --git a/tests/core/test_deepcopy.py b/tests/core/test_deepcopy.py new file mode 100644 index 00000000..57a9adb7 --- /dev/null +++ b/tests/core/test_deepcopy.py @@ -0,0 +1,21 @@ +from spatialdata._core._deepcopy import deepcopy as _deepcopy +from spatialdata.testing import assert_spatial_data_objects_are_identical + + +def test_deepcopy(full_sdata): + to_delete = [] + for element_type, element_name in to_delete: + del getattr(full_sdata, element_type)[element_name] + + copied = _deepcopy(full_sdata) + # we first compute() the data in-place, then deepcopy and then we make the data lazy again; if the last step is + # missing, calling _deepcopy() again on the original data would fail. Here we check for that. + copied_again = _deepcopy(full_sdata) + + assert_spatial_data_objects_are_identical(full_sdata, copied) + assert_spatial_data_objects_are_identical(full_sdata, copied_again) + + for _, element_name, _ in full_sdata.gen_elements(): + assert full_sdata[element_name] is not copied[element_name] + assert full_sdata[element_name] is not copied_again[element_name] + assert copied[element_name] is not copied_again[element_name] diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index dac01e80..9c8e7c60 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -106,7 +106,7 @@ def test_return_annot(self, sdata_blobs, regions_element, return_annot): # TODO: consider adding this logic to blobs, to generate blobs with arbitrary table annotation def _annotate_shapes(self, sdata: SpatialData, shape: str) -> SpatialData: new_table = AnnData( - X=np.random.default_rng().random((len(sdata[shape]), 10)), + X=np.random.default_rng(0).random((len(sdata[shape]), 10)), obs=pd.DataFrame({"region": shape, "instance_id": sdata[shape].index.values}), ) new_table = TableModel.parse(new_table, region=shape, region_key="region", instance_key="instance_id") diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index d1bd21e0..d2382941 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -12,7 +12,7 @@ test_shapes = _get_shapes() # shuffle the indices of the dataframe -# np.random.default_rng().shuffle(test_shapes["poly"].index) +# np.random.default_rng(0).shuffle(test_shapes["poly"].index) class TestMultiTable: diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index e629182d..fcbb87ae 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -23,7 +23,7 @@ from tests.conftest import _get_images, _get_labels, _get_points, _get_shapes -RNG = default_rng() +RNG = default_rng(0) class TestReadWrite: From 9c20ee617c51eeb8f5027ebea1a7fd20ecbc3118 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 14 Mar 2024 13:59:48 +0100 Subject: [PATCH 22/27] fix bug deepcopy() of points wrong columns order --- src/spatialdata/models/models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index f84bf4d5..527c13ea 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -609,6 +609,12 @@ def _( logger.info(f"Column `{Z}` in `data` will be ignored since the data is 2D.") for c in set(data.columns) - {feature_key, instance_key, *coordinates.values(), X, Y, Z}: table[c] = data[c] + + # reorder the columns to respect the original order + old_columns = list(data.columns) + col_order = [col for col in old_columns if col in data.columns] + table = table[col_order] + return cls._add_metadata_and_validate( table, feature_key=feature_key, instance_key=instance_key, transformations=transformations ) From b6897a88ed12cc7661e916f971f5a1450da5acb2 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 14 Mar 2024 14:36:14 +0100 Subject: [PATCH 23/27] workaround wrong order points columns after deepcopy --- src/spatialdata/_core/_deepcopy.py | 5 +++++ src/spatialdata/models/models.py | 16 +++++++++++---- tests/core/test_deepcopy.py | 31 +++++++++++++++++++++++++++--- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py index ff3c9570..e3634df4 100644 --- a/src/spatialdata/_core/_deepcopy.py +++ b/src/spatialdata/_core/_deepcopy.py @@ -33,6 +33,11 @@ def deepcopy(element: SpatialData | SpatialElement | AnnData) -> SpatialData | S Returns ------- A deepcopy of the SpatialData or SpatialElement object + + Notes + ----- + The order of the columns for a deepcopied points element may be differ from the original one, please see more here: + https://github.com/scverse/spatialdata/issues/486 """ raise RuntimeError(f"Wrong type for deepcopy: {type(element)}") diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 527c13ea..b84bb4cd 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -519,6 +519,11 @@ def parse(cls, data: Any, **kwargs: Any) -> DaskDataFrame: Returns ------- :class:`dask.dataframe.core.DataFrame` + + Notes + ----- + The order of the columns of the dataframe returned by the parser is not guaranteed to be the same as the order + of the columns in the dataframe passed as an argument. """ raise NotImplementedError() @@ -610,10 +615,13 @@ def _( for c in set(data.columns) - {feature_key, instance_key, *coordinates.values(), X, Y, Z}: table[c] = data[c] - # reorder the columns to respect the original order - old_columns = list(data.columns) - col_order = [col for col in old_columns if col in data.columns] - table = table[col_order] + # when `coordinates` is None, and no columns have been added or removed, preserves the original order + # here I tried to fix https://github.com/scverse/spatialdata/issues/486, didn't work + # old_columns = list(data.columns) + # new_columns = list(table.columns) + # if new_columns == set(old_columns) and new_columns != old_columns: + # col_order = [col for col in old_columns if col in new_columns] + # table = table[col_order] return cls._add_metadata_and_validate( table, feature_key=feature_key, instance_key=instance_key, transformations=transformations diff --git a/tests/core/test_deepcopy.py b/tests/core/test_deepcopy.py index 57a9adb7..7c3bcae5 100644 --- a/tests/core/test_deepcopy.py +++ b/tests/core/test_deepcopy.py @@ -1,3 +1,4 @@ +from pandas.testing import assert_frame_equal from spatialdata._core._deepcopy import deepcopy as _deepcopy from spatialdata.testing import assert_spatial_data_objects_are_identical @@ -12,10 +13,34 @@ def test_deepcopy(full_sdata): # missing, calling _deepcopy() again on the original data would fail. Here we check for that. copied_again = _deepcopy(full_sdata) - assert_spatial_data_objects_are_identical(full_sdata, copied) - assert_spatial_data_objects_are_identical(full_sdata, copied_again) - + # workaround for https://github.com/scverse/spatialdata/issues/486 for _, element_name, _ in full_sdata.gen_elements(): assert full_sdata[element_name] is not copied[element_name] assert full_sdata[element_name] is not copied_again[element_name] assert copied[element_name] is not copied_again[element_name] + + p0_0 = full_sdata["points_0"].compute() + columns = list(p0_0.columns) + p0_1 = full_sdata["points_0_1"].compute()[columns] + + p1_0 = copied["points_0"].compute()[columns] + p1_1 = copied["points_0_1"].compute()[columns] + + p2_0 = copied_again["points_0"].compute()[columns] + p2_1 = copied_again["points_0_1"].compute()[columns] + + assert_frame_equal(p0_0, p1_0) + assert_frame_equal(p0_1, p1_1) + assert_frame_equal(p0_0, p2_0) + assert_frame_equal(p0_1, p2_1) + + del full_sdata.points["points_0"] + del full_sdata.points["points_0_1"] + del copied.points["points_0"] + del copied.points["points_0_1"] + del copied_again.points["points_0"] + del copied_again.points["points_0_1"] + # end workaround + + assert_spatial_data_objects_are_identical(full_sdata, copied) + assert_spatial_data_objects_are_identical(full_sdata, copied_again) From 09e339e765c36ff147b9fe62f748bf0bb66bd3aa Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Thu, 14 Mar 2024 16:01:22 +0100 Subject: [PATCH 24/27] rechunking raster data after spatial query (#479) * rechunking raster data after spatial query * using xarray chunk() instead of dask rechunk() --- src/spatialdata/_core/query/spatial_query.py | 5 +++++ tests/io/test_readwrite.py | 21 +++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 0aa4f3fc..b6874193 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -563,6 +563,8 @@ def _( if 0 in query_result.shape: return None assert isinstance(query_result, SpatialImage) + # rechunk the data to avoid irregular chunks + image = image.chunk("auto") else: assert isinstance(image, MultiscaleSpatialImage) assert isinstance(query_result, DataTree) @@ -579,6 +581,9 @@ def _( else: d[k] = xdata query_result = MultiscaleSpatialImage.from_dict(d) + # rechunk the data to avoid irregular chunks + for scale in query_result: + query_result[scale]["image"] = query_result[scale]["image"].chunk("auto") query_result = compute_coordinates(query_result) # the bounding box, mapped back to the intrinsic coordinate system is a set of points. The bounding box of these diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index fcbb87ae..81d43864 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -14,7 +14,7 @@ from spatial_image import SpatialImage from spatialdata import SpatialData, read_zarr from spatialdata._io._utils import _are_directories_identical -from spatialdata.models import TableModel +from spatialdata.models import Image2DModel, TableModel from spatialdata.transformations.operations import ( get_transformation, set_transformation, @@ -319,3 +319,22 @@ def test_io_table(shapes): shapes2.table = adata assert shapes2.table is not None assert shapes2.table.shape == (5, 10) + + +def test_bug_rechunking_after_queried_raster(): + # https://github.com/scverse/spatialdata-io/issues/117 + ## + single_scale = Image2DModel.parse(RNG.random((100, 10, 10)), chunks=(5, 5, 5)) + multi_scale = Image2DModel.parse(RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5)) + images = {"single_scale": single_scale, "multi_scale": multi_scale} + sdata = SpatialData(images=images) + queried = sdata.query.bounding_box( + axes=("x", "y"), min_coordinate=[2, 5], max_coordinate=[12, 12], target_coordinate_system="global" + ) + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + queried.write(f) + + ## + + pass From a2970d3219bc6254e3a6a64c661bd82218df2366 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 14 Mar 2024 16:24:45 +0100 Subject: [PATCH 25/27] Test joins with string indices and instance id (#485) * test join strings * fix dtype aggregate --------- Co-authored-by: Luca Marconato --- src/spatialdata/_core/operations/aggregate.py | 9 +++++- src/spatialdata/_core/spatialdata.py | 7 +++++ src/spatialdata/models/models.py | 15 ++++----- .../operations/test_spatialdata_operations.py | 11 ++----- tests/core/query/test_relational_query.py | 31 +++++++++++++++++++ tests/models/test_models.py | 18 +++++------ 6 files changed, 62 insertions(+), 29 deletions(-) diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 46f32b50..e1a13b6b 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -243,7 +243,14 @@ def _create_sdata_from_table_and_shapes( ) -> SpatialData: from spatialdata._core._deepcopy import deepcopy as _deepcopy - table.obs[instance_key] = table.obs_names.copy() + shapes_index_dtype = shapes.index.dtype if isinstance(shapes, GeoDataFrame) else shapes.dtype + try: + table.obs[instance_key] = table.obs_names.copy().astype(shapes_index_dtype) + except ValueError as err: + raise TypeError( + f"Instance key column dtype in table resulting from aggregation cannot be cast to the dtype of" + f"element {shapes_name}.index" + ) from err table.obs[region_key] = shapes_name table = TableModel.parse(table, region=shapes_name, region_key=region_key, instance_key=instance_key) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index b4644c4c..f636afd8 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -199,6 +199,13 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None: else: dtype = element.index.dtype if dtype != table.obs[instance_key].dtype: + if dtype == str or table.obs[instance_key].dtype == str: + raise TypeError( + f"Table instance_key column ({instance_key}) has a dtype " + f"({table.obs[instance_key].dtype}) that does not match the dtype of the indices of " + f"the annotated element ({dtype})." + ) + warnings.warn( ( f"Table instance_key column ({instance_key}) has a dtype " diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index b84bb4cd..c7234175 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -20,7 +20,6 @@ from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from multiscale_spatial_image.to_multiscale.to_multiscale import Methods from pandas import CategoricalDtype -from pandas.errors import IntCastingNaNError from shapely._geometry import GeometryType from shapely.geometry import MultiPolygon, Point, Polygon from shapely.geometry.collection import GeometryCollection @@ -795,6 +794,11 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`.") if attr[self.INSTANCE_KEY] not in data.obs: raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`.") + if (dtype := data.obs[attr[self.INSTANCE_KEY]].dtype) not in [np.int16, np.int32, np.int64, str]: + raise TypeError( + f"Only np.int16, np.int32, np.int64 or string allowed as dtype for " + f"instance_key column in obs. Dtype found to be {dtype}" + ) expected_regions = attr[self.REGION_KEY] if isinstance(attr[self.REGION_KEY], list) else [attr[self.REGION_KEY]] found_regions = data.obs[attr[self.REGION_KEY_KEY]].unique().tolist() if len(set(expected_regions).symmetric_difference(set(found_regions))) > 0: @@ -881,14 +885,6 @@ def parse( adata.obs[region_key] = pd.Categorical(adata.obs[region_key]) if instance_key is None: raise ValueError("`instance_key` must be provided.") - if adata.obs[instance_key].dtype != int: - try: - warnings.warn( - f"Converting `{cls.INSTANCE_KEY}: {instance_key}` to integer dtype.", UserWarning, stacklevel=2 - ) - adata.obs[instance_key] = adata.obs[instance_key].astype(int) - except IntCastingNaNError as exc: - raise ValueError("Values within table.obs[] must be able to be coerced to int dtype.") from exc grouped = adata.obs.groupby(region_key, observed=True) grouped_size = grouped.size() @@ -901,6 +897,7 @@ def parse( attr = {"region": region, "region_key": region_key, "instance_key": instance_key} adata.uns[cls.ATTRS_KEY] = attr + cls().validate(adata) return adata diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 0b39ddce..9e4e235e 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -1,7 +1,6 @@ from __future__ import annotations import math -import warnings import numpy as np import pytest @@ -419,10 +418,7 @@ def test_validate_table_in_spatialdata(full_sdata): region, region_key, _ = get_table_keys(table) assert region == "labels2d" - # no warnings - with warnings.catch_warnings(): - warnings.simplefilter("error") - full_sdata.validate_table_in_spatialdata(table) + full_sdata.validate_table_in_spatialdata(table) # dtype mismatch full_sdata.labels["labels2d"] = Labels2DModel.parse(full_sdata.labels["labels2d"].astype("int16")) @@ -437,10 +433,7 @@ def test_validate_table_in_spatialdata(full_sdata): table.obs[region_key] = "points_0" full_sdata.set_table_annotates_spatialelement("table", region="points_0") - # no warnings - with warnings.catch_warnings(): - warnings.simplefilter("error") - full_sdata.validate_table_in_spatialdata(table) + full_sdata.validate_table_in_spatialdata(table) # dtype mismatch full_sdata.points["points_0"].index = full_sdata.points["points_0"].index.astype("int16") diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index 8f68c448..db4cef3c 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -22,6 +22,37 @@ def test_match_table_to_element(sdata_query_aggregation): # TODO: add tests for labels +def test_join_using_string_instance_id_and_index(sdata_query_aggregation): + sdata_query_aggregation["table"].obs["instance_id"] = [ + f"string_{i}" for i in sdata_query_aggregation["table"].obs["instance_id"] + ] + sdata_query_aggregation["values_circles"].index = pd.Index( + [f"string_{i}" for i in sdata_query_aggregation["values_circles"].index] + ) + sdata_query_aggregation["values_polygons"].index = pd.Index( + [f"string_{i}" for i in sdata_query_aggregation["values_polygons"].index] + ) + + sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] + sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"][:5] + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "inner" + ) + # Note that we started with 21 n_obs. + assert table.n_obs == 10 + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right_exclusive" + ) + assert table.n_obs == 11 + + element_dict, table = join_sdata_spatialelement_table( + sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right" + ) + assert table.n_obs == 21 + + def test_left_inner_right_exclusive_join(sdata_query_aggregation): element_dict, table = join_sdata_spatialelement_table( sdata_query_aggregation, "values_polygons", "table", "right_exclusive" diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 213bde86..116bdbe4 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -318,6 +318,14 @@ def test_table_model( region: str | np.ndarray, ) -> None: region_key = "reg" + obs = pd.DataFrame( + RNG.choice(np.arange(0, 100, dtype=float), size=(10, 3), replace=False), columns=["A", "B", "C"] + ) + obs[region_key] = region + adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) + with pytest.raises(TypeError, match="Only np.int16"): + model.parse(adata, region=region, region_key=region_key, instance_key="A") + obs = pd.DataFrame(RNG.choice(np.arange(0, 100), size=(10, 3), replace=False), columns=["A", "B", "C"]) obs[region_key] = region adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) @@ -332,16 +340,6 @@ def test_table_model( assert TableModel.REGION_KEY_KEY in table.uns[TableModel.ATTRS_KEY] assert table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == region - obs["A"] = obs["A"].astype(str) - adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) - with pytest.warns(UserWarning, match="Converting"): - model.parse(adata, region=region, region_key=region_key, instance_key="A") - - obs["A"] = pd.Series(len([chr(ord("a") + i) for i in range(10)])) - adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) - with pytest.raises(ValueError, match="Values within"): - model.parse(adata, region=region, region_key=region_key, instance_key="A") - @pytest.mark.parametrize("model", [TableModel]) @pytest.mark.parametrize("region", [["sample_1"] * 5 + ["sample_2"] * 5]) def test_table_instance_key_values_not_unique(self, model: TableModel, region: str | np.ndarray): From e67ab4794c732080213b1cfe7eb725b4e0e2c64a Mon Sep 17 00:00:00 2001 From: wmv_hpomen Date: Thu, 14 Mar 2024 17:11:04 +0100 Subject: [PATCH 26/27] cleanup tests --- tests/io/test_multi_table.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index d2382941..fc05d6e0 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -11,9 +11,6 @@ # notes on paths: https://github.com/orgs/scverse/projects/17/views/1?pane=issue&itemId=44066734 test_shapes = _get_shapes() -# shuffle the indices of the dataframe -# np.random.default_rng(0).shuffle(test_shapes["poly"].index) - class TestMultiTable: def test_set_get_tables_from_spatialdata(self, full_sdata: SpatialData, tmp_path: str): @@ -167,22 +164,6 @@ def test_single_table(self, tmp_path: str, region: str): assert isinstance(sdata["shape_annotate"], AnnData) assert_equal(test_sdata["shape_annotate"], sdata["shape_annotate"]) - # note (to keep in the code): these tests here should silmulate the interactions from teh users; if the syntax - # here we are matching the table to the shapes and viceversa (= subset + reordeing) - # there is already a function to do one of these two join operations which is match_table_to_element() - # is too verbose/complex we need to adjust the internals to make it smoother - # # use case example 1 - # # sorting the shapes to match the order of the table - # alternatively, we can have a helper function (join, and simpler ones "match_table_to_element()" - # "match_element_to_table()", "match_annotations_order(...)", "mathc_reference_eleemnt_order??(...)") - # sdata["visium0"][SpatialData.get_instance_key_column(sdata.table['visium0'])] - # assert ... - # # use case example 2 - # # sorting the table to match the order of the shapes - # sdata.table.obs.set_index(keys=["__instance_id__"]) - # sdata.table.obs[sdata["visium0"]] - # assert ... - def test_paired_elements_tables(self, tmp_path: str): tmpdir = Path(tmp_path) / "tmp.zarr" table = _get_table(region="poly") @@ -219,17 +200,6 @@ def test_single_table_multiple_elements(self, tmp_path: str): test_sdata.write(tmpdir) SpatialData.read(tmpdir) - # # use case example 1 - # # sorting the shapes visium0 to match the order of the table - # sdata["visium0"][sdata.table.obs["__instance_id__"][sdata.table.obs["__spatial_element__"] == "visium0"]] - # assert ... - # # use case example 2 - # # subsetting and sorting the table to match the order of the shapes visium0 - # sub_table = sdata.table[sdata.table.obs["__spatial_element"] == "visium0"] - # sub_table.set_index(keys=["__instance_id__"]) - # sub_table.obs[sdata["visium0"]] - # assert ... - def test_multiple_table_without_element(self, tmp_path: str): tmpdir = Path(tmp_path) / "tmp.zarr" table = _get_table(region=None, region_key=None, instance_key=None) From 57e9f613efb5ff019e4a32131f14d69dabb02f89 Mon Sep 17 00:00:00 2001 From: wmv_hpomen Date: Thu, 14 Mar 2024 17:47:04 +0100 Subject: [PATCH 27/27] remove comments --- tests/io/test_multi_table.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index fc05d6e0..50d02677 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -246,29 +246,3 @@ def test_concatenate_sdata_multitables(): assert merged_sdata.tables["table2"].n_obs == 300 assert all(merged_sdata.tables["table"].obs.region.unique() == ["poly_1", "poly_2", "poly_3"]) assert all(merged_sdata.tables["table2"].obs.region.unique() == ["multipoly_1", "multipoly_2", "multipoly_3"]) - - -# The following use cases needs to be put in the tutorial notebook, let's keep the comment here until we have the -# notebook ready. -# # these use cases could be the preferred one for the users; we need to choose one/two preferred ones (either this, -# either helper function, ...) -# # use cases -# # use case example 1 -# # sorting the shapes to match the order of the table -# sdata["visium0"][sdata.table.obs["__instance_id__"]] -# assert ... -# # use case example 2 -# # sorting the table to match the order of the shapes -# sdata.table.obs.set_index(keys=["__instance_id__"]) -# sdata.table.obs[sdata["visium0"]] -# assert ... -# -# We can postpone the implemntation of this test when the functions "match_table_to_element" etc. are ready. -# def test_partial_match(): -# # the function spatialdata._core.query.relational_query.match_table_to_element(no s) needs to be modified (will be -# # simpler), we need also a function match_element_to_table. Maybe we can have just one function doing both the -# things, -# # called match_table_and_elements test that tables and elements do not need to have the same indices -# pass -# # the test would check that we cna call SpatiaLData() on such combinations of mismatching elements and that the -# # match_table_to_element-like functions return the correct subset of the data