From e6651fd3444985fd762c8b0fc27a8950e613ac10 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Wed, 19 Jul 2023 18:15:32 +0200 Subject: [PATCH 01/58] _celltype_mapping draft --- src/moscot/base/problems/_mixins.py | 42 +++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 8bfb25deb..945ce016f 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -272,6 +272,48 @@ def _cell_transition_online( forward=forward, ) + + def _celltype_mapping( + self: AnalysisMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + key: Optional[str], + source: K, + target: K, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, + forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic + aggregation_mode: Literal["annotation", "cell"] = "annotation", + other_key: Optional[str] = None, + other_adata: Optional[str] = None, + batch_size: Optional[int] = None, + normalize: bool = True, + scale_by_marginals: bool = True + ): + if mapping_mode == "sum": + return self._cell_transition( + key=key, + source=source, + target=target, + source_groups=source_groups, + target_groups=target_groups, + forward=forward, + aggregation_mode=aggregation_mode, + other_key=other_key, + other_adata=other_adata, + batch_size=batch_size, + normalize=normalize + ) + if mapping_mode == "max": + source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( + self.adata, source_groups + ) + target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition( + self.adata if other_adata is None else other_adata, target_groups + ) + dummy = pd.get_dummies(source_annotations) + out= self.pull(dummy, scale_by_marginals=scale_by_marginals) + return pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) + def _sample_from_tmap( self: AnalysisMixinProtocol[K, B], source: K, From 68c1a011f527c9bc390df85bfdf39552c2e5772d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Jul 2023 16:18:42 +0000 Subject: [PATCH 02/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 57 ++++++++++++++--------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 945ce016f..6757d4bc1 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -272,46 +272,45 @@ def _cell_transition_online( forward=forward, ) - def _celltype_mapping( - self: AnalysisMixinProtocol[K, B], - mapping_mode: Literal["sum", "max"], - key: Optional[str], - source: K, - target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, - forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic - aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: Optional[str] = None, - other_adata: Optional[str] = None, - batch_size: Optional[int] = None, - normalize: bool = True, - scale_by_marginals: bool = True + self: AnalysisMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + key: Optional[str], + source: K, + target: K, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, + forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic + aggregation_mode: Literal["annotation", "cell"] = "annotation", + other_key: Optional[str] = None, + other_adata: Optional[str] = None, + batch_size: Optional[int] = None, + normalize: bool = True, + scale_by_marginals: bool = True, ): if mapping_mode == "sum": return self._cell_transition( - key=key, - source=source, - target=target, - source_groups=source_groups, - target_groups=target_groups, - forward=forward, - aggregation_mode=aggregation_mode, - other_key=other_key, - other_adata=other_adata, - batch_size=batch_size, - normalize=normalize - ) + key=key, + source=source, + target=target, + source_groups=source_groups, + target_groups=target_groups, + forward=forward, + aggregation_mode=aggregation_mode, + other_key=other_key, + other_adata=other_adata, + batch_size=batch_size, + normalize=normalize, + ) if mapping_mode == "max": source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( - self.adata, source_groups + self.adata, source_groups ) target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition( self.adata if other_adata is None else other_adata, target_groups ) dummy = pd.get_dummies(source_annotations) - out= self.pull(dummy, scale_by_marginals=scale_by_marginals) + out = self.pull(dummy, scale_by_marginals=scale_by_marginals) return pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) def _sample_from_tmap( From c1acb13e6eefd3824c9f67149826a8bd21c0b746 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 20 Jul 2023 11:20:36 +0200 Subject: [PATCH 03/58] exposing celltype_mapping in SpatialMapping and SpatialAlignment --- src/moscot/base/problems/_mixins.py | 2 +- src/moscot/problems/space/_mixins.py | 74 ++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 945ce016f..768510c10 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -302,7 +302,7 @@ def _celltype_mapping( other_adata=other_adata, batch_size=batch_size, normalize=normalize - ) + ) if mapping_mode == "max": source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( self.adata, source_groups diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 429f5d0d0..87bd2887b 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -63,6 +63,13 @@ def _cell_transition( ) -> pd.DataFrame: ... + def _celltype_mapping( + self: AnalysisMixinProtocol[K, B], + *args: Any, + **kwargs: Any, + ) -> pd.DataFrame: + ... + class SpatialMappingMixinProtocol(AnalysisMixinProtocol[K, B]): """Protocol class.""" @@ -81,6 +88,9 @@ def _filter_vars( def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... + + def _celltype_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: + ... class SpatialAlignmentMixin(AnalysisMixin[K, B]): @@ -273,6 +283,38 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) + def celltype_mapping( + self: SpatialAlignmentMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + key: Optional[str], + source: K, + target: K, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, + forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic + aggregation_mode: Literal["annotation", "cell"] = "annotation", + other_key: Optional[str] = None, + other_adata: Optional[str] = None, + batch_size: Optional[int] = None, + normalize: bool = True, + scale_by_marginals: bool = True + ) -> pd.DataFrame: + return self._celltype_mapping( + mapping_mode=mapping_mode, + key=key, + source=source, + target=target, + source_groups=source_groups, + target_groups=target_groups, + forward=forward, + aggregation_mode=aggregation_mode, + other_key=other_key, + other_adata=other_adata, + batch_size=batch_size, + normalize=normalize, + scale_by_marginals=scale_by_marginals + ) + @property def spatial_key(self) -> Optional[str]: """Spatial key in :attr:`~anndata.AnnData.obsm`.""" @@ -562,6 +604,38 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) + def celltype_mapping( + self: SpatialMappingMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + key: Optional[str], + source: K, + target: K, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, + forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic + aggregation_mode: Literal["annotation", "cell"] = "annotation", + other_key: Optional[str] = None, + other_adata: Optional[str] = None, + batch_size: Optional[int] = None, + normalize: bool = True, + scale_by_marginals: bool = True + ) -> pd.DataFrame: + return self._celltype_mapping( + mapping_mode=mapping_mode, + key=key, + source=source, + target=target, + source_groups=source_groups, + target_groups=target_groups, + forward=forward, + aggregation_mode=aggregation_mode, + other_key=other_key, + other_adata=other_adata, + batch_size=batch_size, + normalize=normalize, + scale_by_marginals=scale_by_marginals + ) + @property def batch_key(self) -> Optional[str]: """Batch key in :attr:`~anndata.AnnData.obs`.""" From 4324afaa7ab4069e873e192edb75d77be90961e3 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 20 Jul 2023 11:22:06 +0200 Subject: [PATCH 04/58] fix ruff ? --- src/moscot/base/problems/_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 768510c10..6e4c7b96a 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -290,7 +290,7 @@ def _celltype_mapping( scale_by_marginals: bool = True ): if mapping_mode == "sum": - return self._cell_transition( + return self._cell_transition_online( key=key, source=source, target=target, From abcb8fc89aea1e604adb44562f69ff68748a63fe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jul 2023 09:23:54 +0000 Subject: [PATCH 05/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 24 +++++------ src/moscot/problems/space/_mixins.py | 62 ++++++++++++++-------------- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index b20f51e90..84193e5ab 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -290,18 +290,18 @@ def _celltype_mapping( ): if mapping_mode == "sum": return self._cell_transition_online( - key=key, - source=source, - target=target, - source_groups=source_groups, - target_groups=target_groups, - forward=forward, - aggregation_mode=aggregation_mode, - other_key=other_key, - other_adata=other_adata, - batch_size=batch_size, - normalize=normalize - ) + key=key, + source=source, + target=target, + source_groups=source_groups, + target_groups=target_groups, + forward=forward, + aggregation_mode=aggregation_mode, + other_key=other_key, + other_adata=other_adata, + batch_size=batch_size, + normalize=normalize, + ) if mapping_mode == "max": source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( self.adata, source_groups diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 87bd2887b..f1f4ddd84 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -88,7 +88,7 @@ def _filter_vars( def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - + def _celltype_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... @@ -297,23 +297,23 @@ def celltype_mapping( other_adata: Optional[str] = None, batch_size: Optional[int] = None, normalize: bool = True, - scale_by_marginals: bool = True + scale_by_marginals: bool = True, ) -> pd.DataFrame: return self._celltype_mapping( - mapping_mode=mapping_mode, - key=key, - source=source, - target=target, - source_groups=source_groups, - target_groups=target_groups, - forward=forward, - aggregation_mode=aggregation_mode, - other_key=other_key, - other_adata=other_adata, - batch_size=batch_size, - normalize=normalize, - scale_by_marginals=scale_by_marginals - ) + mapping_mode=mapping_mode, + key=key, + source=source, + target=target, + source_groups=source_groups, + target_groups=target_groups, + forward=forward, + aggregation_mode=aggregation_mode, + other_key=other_key, + other_adata=other_adata, + batch_size=batch_size, + normalize=normalize, + scale_by_marginals=scale_by_marginals, + ) @property def spatial_key(self) -> Optional[str]: @@ -618,23 +618,23 @@ def celltype_mapping( other_adata: Optional[str] = None, batch_size: Optional[int] = None, normalize: bool = True, - scale_by_marginals: bool = True + scale_by_marginals: bool = True, ) -> pd.DataFrame: return self._celltype_mapping( - mapping_mode=mapping_mode, - key=key, - source=source, - target=target, - source_groups=source_groups, - target_groups=target_groups, - forward=forward, - aggregation_mode=aggregation_mode, - other_key=other_key, - other_adata=other_adata, - batch_size=batch_size, - normalize=normalize, - scale_by_marginals=scale_by_marginals - ) + mapping_mode=mapping_mode, + key=key, + source=source, + target=target, + source_groups=source_groups, + target_groups=target_groups, + forward=forward, + aggregation_mode=aggregation_mode, + other_key=other_key, + other_adata=other_adata, + batch_size=batch_size, + normalize=normalize, + scale_by_marginals=scale_by_marginals, + ) @property def batch_key(self) -> Optional[str]: From c89236f00f2b7be2409446f68d0fb6ad3fa62c29 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 20 Jul 2023 17:11:06 +0200 Subject: [PATCH 06/58] fixes for mypy and ruff --- src/moscot/base/problems/_mixins.py | 4 +++- src/moscot/problems/space/_mixins.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 84193e5ab..bc250e269 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -302,7 +302,7 @@ def _celltype_mapping( batch_size=batch_size, normalize=normalize, ) - if mapping_mode == "max": + elif mapping_mode == "max": source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( self.adata, source_groups ) @@ -312,6 +312,8 @@ def _celltype_mapping( dummy = pd.get_dummies(source_annotations) out = self.pull(dummy, scale_by_marginals=scale_by_marginals) return pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) + else: + raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") def _sample_from_tmap( self: AnalysisMixinProtocol[K, B], diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index f1f4ddd84..de6a3b76d 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -284,7 +284,7 @@ def cell_transition( # type: ignore[misc] ) def celltype_mapping( - self: SpatialAlignmentMixinProtocol[K, B], + self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], key: Optional[str], source: K, @@ -605,7 +605,7 @@ def cell_transition( # type: ignore[misc] ) def celltype_mapping( - self: SpatialMappingMixinProtocol[K, B], + self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], key: Optional[str], source: K, From 8a9138d22e03193c8917221fbcb930b75b5b4153 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 20 Jul 2023 18:10:17 +0200 Subject: [PATCH 07/58] renamin, adding function to protocol --- src/moscot/base/problems/_mixins.py | 36 +++++++++++++++++++++++++--- src/moscot/problems/space/_mixins.py | 12 +++++----- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index bc250e269..515c33131 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -82,6 +82,18 @@ def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: """Pull distribution.""" ... + def _cell_transition( + self: "AnalysisMixinProtocol[K, B]", + source: K, + target: K, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, + aggregation_mode: Literal["annotation", "cell"] = "annotation", + key_added: Optional[str] = _constants.CELL_TRANSITION, + **kwargs: Any, + ) -> pd.DataFrame: + ... + def _cell_transition_online( self: "AnalysisMixinProtocol[K, B]", key: Optional[str], @@ -98,6 +110,24 @@ def _cell_transition_online( ) -> pd.DataFrame: ... + def _annotation_mapping( + self: "AnalysisMixinProtocol[K, B]", + mapping_mode: Literal["sum", "max"], + key: Optional[str], + source: K, + target: K, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, + forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic + aggregation_mode: Literal["annotation", "cell"] = "annotation", + other_key: Optional[str] = None, + other_adata: Optional[str] = None, + batch_size: Optional[int] = None, + normalize: bool = True, + scale_by_marginals: bool = True, + ) -> pd.DataFrame: + ... + class AnalysisMixin(Generic[K, B]): """Base Analysis Mixin.""" @@ -272,7 +302,7 @@ def _cell_transition_online( forward=forward, ) - def _celltype_mapping( + def _annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], key: Optional[str], @@ -287,7 +317,7 @@ def _celltype_mapping( batch_size: Optional[int] = None, normalize: bool = True, scale_by_marginals: bool = True, - ): + ) -> pd.DataFrame: if mapping_mode == "sum": return self._cell_transition_online( key=key, @@ -467,7 +497,7 @@ def _cell_aggregation_transition( if batch_size is None: batch_size = len(df_2) for batch in range(0, len(df_2), batch_size): - result = func( # TODO(@MUCDK) check how to make compatible with all policies + result = func( # TODO(@MUCDK) check how to make compatiAnalysisMixinProtocolcelltyble with all policies source=source, target=target, data=None, diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index de6a3b76d..5e340640f 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -63,7 +63,7 @@ def _cell_transition( ) -> pd.DataFrame: ... - def _celltype_mapping( + def _annotation_mapping( self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any, @@ -89,7 +89,7 @@ def _filter_vars( def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - def _celltype_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: + def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... @@ -283,7 +283,7 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def celltype_mapping( + def annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], key: Optional[str], @@ -299,7 +299,7 @@ def celltype_mapping( normalize: bool = True, scale_by_marginals: bool = True, ) -> pd.DataFrame: - return self._celltype_mapping( + return self._annotation_mapping( mapping_mode=mapping_mode, key=key, source=source, @@ -604,7 +604,7 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def celltype_mapping( + def annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], key: Optional[str], @@ -620,7 +620,7 @@ def celltype_mapping( normalize: bool = True, scale_by_marginals: bool = True, ) -> pd.DataFrame: - return self._celltype_mapping( + return self._annotation_mapping( mapping_mode=mapping_mode, key=key, source=source, From 453931245757e5358ca189a4385e760064093620 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Fri, 21 Jul 2023 15:29:26 +0200 Subject: [PATCH 08/58] ruff fix? --- src/moscot/base/problems/_mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 515c33131..f21ff8ae3 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -331,8 +331,8 @@ def _annotation_mapping( other_adata=other_adata, batch_size=batch_size, normalize=normalize, - ) - elif mapping_mode == "max": + ) + if mapping_mode == "max": source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( self.adata, source_groups ) From d8f7714df75a647bf61bc7123907454101d0be63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jul 2023 13:30:07 +0000 Subject: [PATCH 09/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index f21ff8ae3..991870aaa 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -331,7 +331,7 @@ def _annotation_mapping( other_adata=other_adata, batch_size=batch_size, normalize=normalize, - ) + ) if mapping_mode == "max": source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( self.adata, source_groups From 4944db6ae9730865d9b98b66501567e80c841e52 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 1 Aug 2023 00:43:48 +0200 Subject: [PATCH 10/58] adding cell_transition_kwargs --- src/moscot/base/problems/_mixins.py | 37 ++++++---------------------- src/moscot/problems/space/_mixins.py | 36 +++------------------------ 2 files changed, 12 insertions(+), 61 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 991870aaa..c4574d580 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -1,3 +1,4 @@ +import types from typing import ( TYPE_CHECKING, Any, @@ -11,6 +12,7 @@ Sequence, Tuple, Union, + Mapping, ) import numpy as np @@ -113,18 +115,11 @@ def _cell_transition_online( def _annotation_mapping( self: "AnalysisMixinProtocol[K, B]", mapping_mode: Literal["sum", "max"], - key: Optional[str], - source: K, - target: K, source_groups: Str_Dict_t, target_groups: Str_Dict_t, - forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic - aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: Optional[str] = None, other_adata: Optional[str] = None, - batch_size: Optional[int] = None, - normalize: bool = True, scale_by_marginals: bool = True, + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: ... @@ -305,34 +300,18 @@ def _cell_transition_online( def _annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], - key: Optional[str], - source: K, - target: K, source_groups: Str_Dict_t, target_groups: Str_Dict_t, - forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic - aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: Optional[str] = None, other_adata: Optional[str] = None, - batch_size: Optional[int] = None, - normalize: bool = True, scale_by_marginals: bool = True, + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: if mapping_mode == "sum": - return self._cell_transition_online( - key=key, - source=source, - target=target, - source_groups=source_groups, - target_groups=target_groups, - forward=forward, - aggregation_mode=aggregation_mode, - other_key=other_key, - other_adata=other_adata, - batch_size=batch_size, - normalize=normalize, + return self._cell_transition( + **cell_transition_kwargs ) if mapping_mode == "max": + assert (not cell_transition_kwargs), "cell_transition_kwargs is not empty, although cell_transition is not used." source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( self.adata, source_groups ) @@ -340,7 +319,7 @@ def _annotation_mapping( self.adata if other_adata is None else other_adata, target_groups ) dummy = pd.get_dummies(source_annotations) - out = self.pull(dummy, scale_by_marginals=scale_by_marginals) + out:ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None return pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) else: raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 5e340640f..75aa335d2 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -286,33 +286,19 @@ def cell_transition( # type: ignore[misc] def annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], - key: Optional[str], - source: K, - target: K, source_groups: Str_Dict_t, target_groups: Str_Dict_t, - forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic - aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: Optional[str] = None, other_adata: Optional[str] = None, - batch_size: Optional[int] = None, - normalize: bool = True, scale_by_marginals: bool = True, + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: return self._annotation_mapping( mapping_mode=mapping_mode, - key=key, - source=source, - target=target, source_groups=source_groups, target_groups=target_groups, - forward=forward, - aggregation_mode=aggregation_mode, - other_key=other_key, other_adata=other_adata, - batch_size=batch_size, - normalize=normalize, scale_by_marginals=scale_by_marginals, + cell_transition_kwargs=cell_transition_kwargs, ) @property @@ -607,33 +593,19 @@ def cell_transition( # type: ignore[misc] def annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], - key: Optional[str], - source: K, - target: K, source_groups: Str_Dict_t, target_groups: Str_Dict_t, - forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic - aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: Optional[str] = None, other_adata: Optional[str] = None, - batch_size: Optional[int] = None, - normalize: bool = True, scale_by_marginals: bool = True, + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: return self._annotation_mapping( mapping_mode=mapping_mode, - key=key, - source=source, - target=target, source_groups=source_groups, target_groups=target_groups, - forward=forward, - aggregation_mode=aggregation_mode, - other_key=other_key, other_adata=other_adata, - batch_size=batch_size, - normalize=normalize, scale_by_marginals=scale_by_marginals, + cell_transition_kwargs=cell_transition_kwargs, ) @property From c2a2466f8c11a03e8021238a0257090e2706e2de Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 31 Jul 2023 22:44:40 +0000 Subject: [PATCH 11/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index c4574d580..bffbfe487 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -7,12 +7,12 @@ Iterable, List, Literal, + Mapping, Optional, Protocol, Sequence, Tuple, Union, - Mapping, ) import numpy as np @@ -307,11 +307,11 @@ def _annotation_mapping( cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: if mapping_mode == "sum": - return self._cell_transition( - **cell_transition_kwargs - ) + return self._cell_transition(**cell_transition_kwargs) if mapping_mode == "max": - assert (not cell_transition_kwargs), "cell_transition_kwargs is not empty, although cell_transition is not used." + assert ( + not cell_transition_kwargs + ), "cell_transition_kwargs is not empty, although cell_transition is not used." source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( self.adata, source_groups ) @@ -319,7 +319,7 @@ def _annotation_mapping( self.adata if other_adata is None else other_adata, target_groups ) dummy = pd.get_dummies(source_annotations) - out:ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None + out: ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None return pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) else: raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") From d10f82f4fb00e4cce8d4664e94a60f61f3705bf1 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 1 Aug 2023 00:48:14 +0200 Subject: [PATCH 12/58] ruff and mypy fix ? --- src/moscot/base/problems/_mixins.py | 3 +-- src/moscot/problems/space/_mixins.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index bffbfe487..2494783d9 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -321,8 +321,7 @@ def _annotation_mapping( dummy = pd.get_dummies(source_annotations) out: ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None return pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) - else: - raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") + raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") def _sample_from_tmap( self: AnalysisMixinProtocol[K, B], diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 75aa335d2..f8c8a6530 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -1,4 +1,5 @@ import itertools +import types from typing import ( TYPE_CHECKING, Any, From 17d48b0774eb40088028cbe39a9f7e4d6f257da1 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Fri, 15 Sep 2023 18:29:09 +0200 Subject: [PATCH 13/58] anno_mapping changes --- src/moscot/base/problems/_mixins.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 2494783d9..a4b77e1e1 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -300,26 +300,31 @@ def _cell_transition_online( def _annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, + #source_groups: Str_Dict_t, + #target_groups: Str_Dict_t, + label: str, + forward: bool, other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: if mapping_mode == "sum": - return self._cell_transition(**cell_transition_kwargs) + return self._cell_transition(**cell_transition_kwargs) #aggregation mode should set to cell if mapping_mode == "max": assert ( not cell_transition_kwargs ), "cell_transition_kwargs is not empty, although cell_transition is not used." - source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( - self.adata, source_groups - ) - target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition( - self.adata if other_adata is None else other_adata, target_groups - ) + if forward: + source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( + self.adata, label + ) + elif not forward: + target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition( + self.adata if other_adata is None else other_adata, label + ) dummy = pd.get_dummies(source_annotations) out: ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None + # return pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") From 713e8e0fb30581ee249ef0664012213db952942e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:02:53 +0000 Subject: [PATCH 14/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index a4b77e1e1..824b46903 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -300,8 +300,8 @@ def _cell_transition_online( def _annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], - #source_groups: Str_Dict_t, - #target_groups: Str_Dict_t, + # source_groups: Str_Dict_t, + # target_groups: Str_Dict_t, label: str, forward: bool, other_adata: Optional[str] = None, @@ -309,7 +309,7 @@ def _annotation_mapping( cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: if mapping_mode == "sum": - return self._cell_transition(**cell_transition_kwargs) #aggregation mode should set to cell + return self._cell_transition(**cell_transition_kwargs) # aggregation mode should set to cell if mapping_mode == "max": assert ( not cell_transition_kwargs @@ -324,7 +324,7 @@ def _annotation_mapping( ) dummy = pd.get_dummies(source_annotations) out: ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None - # + # return pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") From b2a97c21dea8e6d1d04ece0f57f048e5ceece79d Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 19 Oct 2023 23:33:26 +0200 Subject: [PATCH 15/58] removed source and target groups --- src/moscot/base/problems/_mixins.py | 16 +++++++--------- src/moscot/problems/space/_mixins.py | 16 ++++++++-------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index a4b77e1e1..28bf880d1 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -115,8 +115,8 @@ def _cell_transition_online( def _annotation_mapping( self: "AnalysisMixinProtocol[K, B]", mapping_mode: Literal["sum", "max"], - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, + annotation_label: str, + forward: bool, other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), @@ -300,9 +300,7 @@ def _cell_transition_online( def _annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], - #source_groups: Str_Dict_t, - #target_groups: Str_Dict_t, - label: str, + annotation_label: str, forward: bool, other_adata: Optional[str] = None, scale_by_marginals: bool = True, @@ -316,15 +314,15 @@ def _annotation_mapping( ), "cell_transition_kwargs is not empty, although cell_transition is not used." if forward: source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( - self.adata, label + self.adata, annotation_label ) + dummy = pd.get_dummies(source_annotations) elif not forward: target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition( - self.adata if other_adata is None else other_adata, label + self.adata if other_adata is None else other_adata, annotation_label ) - dummy = pd.get_dummies(source_annotations) + dummy = pd.get_dummies(target_annotations) out: ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None - # return pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index f8c8a6530..f9d45c26d 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -287,16 +287,16 @@ def cell_transition( # type: ignore[misc] def annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, + annotation_label: str, + forward: bool, other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: return self._annotation_mapping( mapping_mode=mapping_mode, - source_groups=source_groups, - target_groups=target_groups, + annotation_label=annotation_label, + forward=forward, other_adata=other_adata, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, @@ -594,16 +594,16 @@ def cell_transition( # type: ignore[misc] def annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, + annotation_label: str, + forward: bool, other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: return self._annotation_mapping( mapping_mode=mapping_mode, - source_groups=source_groups, - target_groups=target_groups, + annotation_label=annotation_label, + forward=forward, other_adata=other_adata, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, From 189d468c70957e2513367c54a8f573b4c233f1c9 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 2 Nov 2023 18:28:51 +0100 Subject: [PATCH 16/58] key_added logic --- src/moscot/base/problems/_mixins.py | 5 ++++- src/moscot/problems/space/_mixins.py | 27 ++++++++++++++++++++++----- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 9febcac4b..c47cfba07 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -301,10 +301,13 @@ def _annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, - forward: bool, + forward: bool = True, + #source_label: Optional[str] = "adata", + #target_label: Optional[str] = "adata", other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + #key_added: Optional[str] = None, ) -> pd.DataFrame: if mapping_mode == "sum": return self._cell_transition(**cell_transition_kwargs) # aggregation mode should set to cell diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index f9d45c26d..b77d48f75 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -596,18 +596,35 @@ def annotation_mapping( mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, - other_adata: Optional[str] = None, + source: str = "src", + target: str = "tgt", scale_by_marginals: bool = True, + key_added: str | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: - return self._annotation_mapping( + # mp[("batch_0","tgt"), ("batch_1","tgt")] # from sc tgt to spatial batch + annotation = self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, - forward=forward, - other_adata=other_adata, + forward=not forward, # inverted for MappingProblem + other_adata=self.adata_sc, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, - ) + ) + if key_added is None: + return annotation + if key_added not in self.adata.obs: + self.adata.obs[key_added] = pd.empty(len(self.adata)) + + # if forward add in self.adata; forward false - self.adata_sc + if forward: + idx = self.adata[self.adata.obs[self.batch_key] == source] + self.adata[idx].obs[key_added] = annotation + else: + idx = self.adata_sc[self.adata_sc.obs[self.batch_key] == target] # is target correct here? + self.adata_sc[idx].obs[key_added] = annotation + + @property def batch_key(self) -> Optional[str]: From 47b8ec5e18517297a464d40972dd09816826daa9 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 2 Nov 2023 18:34:47 +0100 Subject: [PATCH 17/58] key_added logic --- src/moscot/base/problems/_mixins.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index c47cfba07..e3a54a4e2 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -310,6 +310,7 @@ def _annotation_mapping( #key_added: Optional[str] = None, ) -> pd.DataFrame: if mapping_mode == "sum": + cell_transition_kwargs.setdefault("aggregation_mode", "cell") return self._cell_transition(**cell_transition_kwargs) # aggregation mode should set to cell if mapping_mode == "max": assert ( From 7ac7e2619de720107c35586609c6304f8c84561a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Nov 2023 17:36:19 +0000 Subject: [PATCH 18/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 6 +++--- src/moscot/problems/space/_mixins.py | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index e3a54a4e2..4bcbc8b8d 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -302,12 +302,12 @@ def _annotation_mapping( mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool = True, - #source_label: Optional[str] = "adata", - #target_label: Optional[str] = "adata", + # source_label: Optional[str] = "adata", + # target_label: Optional[str] = "adata", other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - #key_added: Optional[str] = None, + # key_added: Optional[str] = None, ) -> pd.DataFrame: if mapping_mode == "sum": cell_transition_kwargs.setdefault("aggregation_mode", "cell") diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index b77d48f75..6e683ca8b 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -606,11 +606,11 @@ def annotation_mapping( annotation = self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, - forward=not forward, # inverted for MappingProblem + forward=not forward, # inverted for MappingProblem other_adata=self.adata_sc, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, - ) + ) if key_added is None: return annotation if key_added not in self.adata.obs: @@ -621,10 +621,8 @@ def annotation_mapping( idx = self.adata[self.adata.obs[self.batch_key] == source] self.adata[idx].obs[key_added] = annotation else: - idx = self.adata_sc[self.adata_sc.obs[self.batch_key] == target] # is target correct here? + idx = self.adata_sc[self.adata_sc.obs[self.batch_key] == target] # is target correct here? self.adata_sc[idx].obs[key_added] = annotation - - @property def batch_key(self) -> Optional[str]: From 5f5bf26f3e8d90606e57a93080f4a4f4222b4766 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 30 Nov 2023 16:41:57 +0100 Subject: [PATCH 19/58] anno_map progress --- src/moscot/base/problems/_mixins.py | 44 +++++++++++++++++++--------- src/moscot/problems/space/_mixins.py | 37 +++++++++++++---------- 2 files changed, 51 insertions(+), 30 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 4bcbc8b8d..f951cb09c 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -301,6 +301,7 @@ def _annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, + source: K, forward: bool = True, # source_label: Optional[str] = "adata", # target_label: Optional[str] = "adata", @@ -309,26 +310,41 @@ def _annotation_mapping( cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), # key_added: Optional[str] = None, ) -> pd.DataFrame: + if not forward: + source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( + self.adata, annotation_label + ) + #source_df = _get_df_cell_transition(source_annotations, source) + print("source anno", source_annotations) + dummy = pd.get_dummies(source_annotations) + elif forward: + target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition( + self.adata if other_adata is None else other_adata, annotation_label + ) + dummy = pd.get_dummies(target_annotations) if mapping_mode == "sum": + cell_transition_kwargs = dict(cell_transition_kwargs) cell_transition_kwargs.setdefault("aggregation_mode", "cell") - return self._cell_transition(**cell_transition_kwargs) # aggregation mode should set to cell - if mapping_mode == "max": + cell_transition_kwargs.setdefault("key", annotation_label) + cell_transition_kwargs.setdefault("source", source) + cell_transition_kwargs.setdefault("target", 'tgt') # target always tgt + cell_transition_kwargs.setdefault("other_adata", other_adata) + out: ArrayLike = self._cell_transition(**cell_transition_kwargs) + #return pd.Categorical(out.argmax(1)) + #return self._cell_transition(**cell_transition_kwargs) # aggregation mode should set to cell + elif mapping_mode == "max": assert ( not cell_transition_kwargs ), "cell_transition_kwargs is not empty, although cell_transition is not used." - if forward: - source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( - self.adata, annotation_label - ) - dummy = pd.get_dummies(source_annotations) - elif not forward: - target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition( - self.adata if other_adata is None else other_adata, annotation_label - ) - dummy = pd.get_dummies(target_annotations) out: ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None - return pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) - raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") + #return pd.Categorical([dummy.columns[i] for i in out.argmax(1)]) + else: + raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") + # return for both modes + print("forward in problems:", forward) + print(dummy) + print("res", pd.Categorical([dummy.columns[i] for i in np.argmax(out, axis=1)])) + return pd.Categorical([dummy.columns[i] for i in np.argmax(out, axis=1)]) def _sample_from_tmap( self: AnalysisMixinProtocol[K, B], diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 6e683ca8b..5056c94fd 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -597,32 +597,37 @@ def annotation_mapping( annotation_label: str, forward: bool, source: str = "src", - target: str = "tgt", + # target: str = "tgt", scale_by_marginals: bool = True, key_added: str | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: # mp[("batch_0","tgt"), ("batch_1","tgt")] # from sc tgt to spatial batch - annotation = self._annotation_mapping( - mapping_mode=mapping_mode, - annotation_label=annotation_label, - forward=not forward, # inverted for MappingProblem - other_adata=self.adata_sc, - scale_by_marginals=scale_by_marginals, - cell_transition_kwargs=cell_transition_kwargs, - ) + if cell_transition_kwargs: + annotation = self._annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + source=source, + forward=not forward, # inverted for MappingProblem + other_adata=self.adata_sc, + scale_by_marginals=scale_by_marginals, + cell_transition_kwargs=cell_transition_kwargs, + ) if key_added is None: return annotation - if key_added not in self.adata.obs: - self.adata.obs[key_added] = pd.empty(len(self.adata)) + + if key_added not in list(self.adata.obs): + self.adata.obs[key_added] = np.empty(len(self.adata)) # if forward add in self.adata; forward false - self.adata_sc if forward: - idx = self.adata[self.adata.obs[self.batch_key] == source] - self.adata[idx].obs[key_added] = annotation - else: - idx = self.adata_sc[self.adata_sc.obs[self.batch_key] == target] # is target correct here? - self.adata_sc[idx].obs[key_added] = annotation + if source != "src": + idx = self.adata[self.adata.obs[self.batch_key] == source] + self.adata[idx].obs[key_added] = annotation + else: + self.adata.obs[key_added] = annotation + else: # target is always 'tgt' + self.adata_sc.obs[key_added] = annotation @property def batch_key(self) -> Optional[str]: From f7accd47e2f1f84154c08f23de71135c666fff38 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Nov 2023 15:45:30 +0000 Subject: [PATCH 20/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 10 +++++----- src/moscot/problems/space/_mixins.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index f951cb09c..37142fd5f 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -314,7 +314,7 @@ def _annotation_mapping( source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( self.adata, annotation_label ) - #source_df = _get_df_cell_transition(source_annotations, source) + # source_df = _get_df_cell_transition(source_annotations, source) print("source anno", source_annotations) dummy = pd.get_dummies(source_annotations) elif forward: @@ -327,17 +327,17 @@ def _annotation_mapping( cell_transition_kwargs.setdefault("aggregation_mode", "cell") cell_transition_kwargs.setdefault("key", annotation_label) cell_transition_kwargs.setdefault("source", source) - cell_transition_kwargs.setdefault("target", 'tgt') # target always tgt + cell_transition_kwargs.setdefault("target", "tgt") # target always tgt cell_transition_kwargs.setdefault("other_adata", other_adata) out: ArrayLike = self._cell_transition(**cell_transition_kwargs) - #return pd.Categorical(out.argmax(1)) - #return self._cell_transition(**cell_transition_kwargs) # aggregation mode should set to cell + # return pd.Categorical(out.argmax(1)) + # return self._cell_transition(**cell_transition_kwargs) # aggregation mode should set to cell elif mapping_mode == "max": assert ( not cell_transition_kwargs ), "cell_transition_kwargs is not empty, although cell_transition is not used." out: ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None - #return pd.Categorical([dummy.columns[i] for i in out.argmax(1)]) + # return pd.Categorical([dummy.columns[i] for i in out.argmax(1)]) else: raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") # return for both modes diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 5056c94fd..b806127e7 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -626,8 +626,8 @@ def annotation_mapping( self.adata[idx].obs[key_added] = annotation else: self.adata.obs[key_added] = annotation - else: # target is always 'tgt' - self.adata_sc.obs[key_added] = annotation + else: # target is always 'tgt' + self.adata_sc.obs[key_added] = annotation @property def batch_key(self) -> Optional[str]: From 69c3e76f2519b6dc6f8d7972dd20edd689e70795 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 7 Dec 2023 23:52:29 +0100 Subject: [PATCH 21/58] mp fix, ap and tp added --- src/moscot/base/problems/_mixins.py | 42 +++++++++++------------ src/moscot/problems/space/_mixins.py | 47 +++++++++++++++++++------ src/moscot/problems/time/_mixins.py | 51 ++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 32 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 37142fd5f..9d1c91aba 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -302,6 +302,7 @@ def _annotation_mapping( mapping_mode: Literal["sum", "max"], annotation_label: str, source: K, + target: K, forward: bool = True, # source_label: Optional[str] = "adata", # target_label: Optional[str] = "adata", @@ -310,41 +311,40 @@ def _annotation_mapping( cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), # key_added: Optional[str] = None, ) -> pd.DataFrame: - if not forward: - source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( - self.adata, annotation_label - ) - # source_df = _get_df_cell_transition(source_annotations, source) - print("source anno", source_annotations) - dummy = pd.get_dummies(source_annotations) - elif forward: - target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition( - self.adata if other_adata is None else other_adata, annotation_label + if forward: + source_df = _get_df_cell_transition( + self.adata, + annotation_keys=[annotation_label], + filter_key=self.batch_key, + filter_value=source, ) - dummy = pd.get_dummies(target_annotations) + dummy = pd.get_dummies(source_df) + elif not forward: + target_df = _get_df_cell_transition( + self.adata if other_adata is None else other_adata, + annotation_keys=[annotation_label], + filter_key=self.batch_key, + filter_value=target, + ) + dummy = pd.get_dummies(target_df) if mapping_mode == "sum": cell_transition_kwargs = dict(cell_transition_kwargs) - cell_transition_kwargs.setdefault("aggregation_mode", "cell") - cell_transition_kwargs.setdefault("key", annotation_label) + cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell + #cell_transition_kwargs.setdefault("key", annotation_label) + cell_transition_kwargs.setdefault("key", self.batch_key) cell_transition_kwargs.setdefault("source", source) - cell_transition_kwargs.setdefault("target", "tgt") # target always tgt + cell_transition_kwargs.setdefault("target", target) cell_transition_kwargs.setdefault("other_adata", other_adata) out: ArrayLike = self._cell_transition(**cell_transition_kwargs) - # return pd.Categorical(out.argmax(1)) - # return self._cell_transition(**cell_transition_kwargs) # aggregation mode should set to cell elif mapping_mode == "max": assert ( not cell_transition_kwargs ), "cell_transition_kwargs is not empty, although cell_transition is not used." out: ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None - # return pd.Categorical([dummy.columns[i] for i in out.argmax(1)]) else: raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") # return for both modes - print("forward in problems:", forward) - print(dummy) - print("res", pd.Categorical([dummy.columns[i] for i in np.argmax(out, axis=1)])) - return pd.Categorical([dummy.columns[i] for i in np.argmax(out, axis=1)]) + return pd.Categorical(out.idxmax(axis="columns")) def _sample_from_tmap( self: AnalysisMixinProtocol[K, B], diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index b806127e7..2ad05a0a4 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -289,19 +289,43 @@ def annotation_mapping( mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, + source: str = "src", + target: str = "tgt", other_adata: Optional[str] = None, scale_by_marginals: bool = True, + key_added: str | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: - return self._annotation_mapping( + annotation = self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, + source=source, + target=target, forward=forward, other_adata=other_adata, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) + if key_added is None: + return annotation + + if key_added not in list(self.adata.obs): + self.adata.obs[key_added] = np.empty(len(self.adata)) + + if forward: + if source != "src": + idx = self.adata[self.adata.obs[self.batch_key] == source] + self.adata[idx].obs[key_added] = annotation + else: + self.adata.obs[key_added] = annotation + else: + if target != "tgt": + idx = self.adata[self.adata.obs[self.batch_key] == target] + self.adata[idx].obs[key_added] = annotation + else: + self.adata.obs[key_added] = annotation + @property def spatial_key(self) -> Optional[str]: """Spatial key in :attr:`~anndata.AnnData.obsm`.""" @@ -603,16 +627,17 @@ def annotation_mapping( cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: # mp[("batch_0","tgt"), ("batch_1","tgt")] # from sc tgt to spatial batch - if cell_transition_kwargs: - annotation = self._annotation_mapping( - mapping_mode=mapping_mode, - annotation_label=annotation_label, - source=source, - forward=not forward, # inverted for MappingProblem - other_adata=self.adata_sc, - scale_by_marginals=scale_by_marginals, - cell_transition_kwargs=cell_transition_kwargs, - ) + annotation = self._annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + source=source, + target="tgt", # target always 'tgt' + forward=not forward, # inverted for MappingProblem + other_adata=self.adata_sc, + scale_by_marginals=scale_by_marginals, + cell_transition_kwargs=cell_transition_kwargs, + ) + if key_added is None: return annotation diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 55b94c890..091b7834d 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -1,5 +1,6 @@ import itertools import pathlib +import types from typing import ( TYPE_CHECKING, Any, @@ -8,6 +9,7 @@ Iterator, List, Literal, + Mapping, Optional, Protocol, Sequence, @@ -65,6 +67,13 @@ def _cell_transition( ) -> pd.DataFrame: ... + def _annotation_mapping( + self: AnalysisMixinProtocol[K, B], + *args: Any, + **kwargs: Any, + ) -> pd.DataFrame: + ... + def _sample_from_tmap( self: "TemporalMixinProtocol[K, B]", source: K, @@ -232,6 +241,48 @@ def cell_transition( key_added=key_added, ) + def annotation_mapping( + self: AnalysisMixinProtocol[K, B], + mapping_mode: Literal["sum", "max"], + annotation_label: str, + forward: bool, + source: str = "src", + target: str = "tgt", + scale_by_marginals: bool = True, + other_adata: Optional[str] = None, + key_added: str | None = None, + cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + ) -> pd.DataFrame: + annotation = self._annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + source=source, + target=target, + forward=forward, + other_adata=other_adata, + scale_by_marginals=scale_by_marginals, + cell_transition_kwargs=cell_transition_kwargs, + ) + + if key_added is None: + return annotation + + if key_added not in list(self.adata.obs): + self.adata.obs[key_added] = np.empty(len(self.adata)) + + if forward: + if source != "src": + idx = self.adata[self.adata.obs[self.batch_key] == source] + self.adata[idx].obs[key_added] = annotation + else: + self.adata.obs[key_added] = annotation + else: + if target != "tgt": + idx = self.adata[self.adata.obs[self.batch_key] == target] + self.adata[idx].obs[key_added] = annotation + else: + self.adata.obs[key_added] = annotation + def sankey( self: "TemporalMixinProtocol[K, B]", source: K, From 878a79e0bfc10e10ab714c659abd15fa53bb92ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Dec 2023 22:53:26 +0000 Subject: [PATCH 22/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 9d1c91aba..d9e2d7286 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -325,12 +325,12 @@ def _annotation_mapping( annotation_keys=[annotation_label], filter_key=self.batch_key, filter_value=target, - ) + ) dummy = pd.get_dummies(target_df) if mapping_mode == "sum": cell_transition_kwargs = dict(cell_transition_kwargs) - cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell - #cell_transition_kwargs.setdefault("key", annotation_label) + cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell + # cell_transition_kwargs.setdefault("key", annotation_label) cell_transition_kwargs.setdefault("key", self.batch_key) cell_transition_kwargs.setdefault("source", source) cell_transition_kwargs.setdefault("target", target) From 2366a324aef266aee51b987e6c26bfc5d8af9f86 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 11 Dec 2023 16:43:50 +0100 Subject: [PATCH 23/58] make sum for temporal problem work --- src/moscot/base/problems/_mixins.py | 54 ++++++++++++++++++----------- src/moscot/problems/time/_mixins.py | 54 +++++++++-------------------- 2 files changed, 49 insertions(+), 59 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index d9e2d7286..8155c988d 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import types from typing import ( TYPE_CHECKING, @@ -62,14 +64,14 @@ def _apply( ... def _interpolate_transport( - self: "AnalysisMixinProtocol[K, B]", + self: AnalysisMixinProtocol[K, B], path: Sequence[Tuple[K, K]], scale_by_marginals: bool = True, ) -> LinearOperator: ... def _flatten( - self: "AnalysisMixinProtocol[K, B]", + self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str], @@ -85,7 +87,7 @@ def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: ... def _cell_transition( - self: "AnalysisMixinProtocol[K, B]", + self: AnalysisMixinProtocol[K, B], source: K, target: K, source_groups: Str_Dict_t, @@ -97,7 +99,7 @@ def _cell_transition( ... def _cell_transition_online( - self: "AnalysisMixinProtocol[K, B]", + self: AnalysisMixinProtocol[K, B], key: Optional[str], source: K, target: K, @@ -113,10 +115,13 @@ def _cell_transition_online( ... def _annotation_mapping( - self: "AnalysisMixinProtocol[K, B]", + self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, + source: K, + target: K, + key: str, other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), @@ -303,6 +308,7 @@ def _annotation_mapping( annotation_label: str, source: K, target: K, + key: str | None = None, forward: bool = True, # source_label: Optional[str] = "adata", # target_label: Optional[str] = "adata", @@ -311,40 +317,46 @@ def _annotation_mapping( cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), # key_added: Optional[str] = None, ) -> pd.DataFrame: + batch_key = getattr(self, "batch_key", None) + cell_transition_kwargs = dict(cell_transition_kwargs) + cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell + cell_transition_kwargs.setdefault("key", key) + cell_transition_kwargs.setdefault("source", source) + cell_transition_kwargs.setdefault("target", target) + cell_transition_kwargs.setdefault("other_adata", other_adata) + cell_transition_kwargs.setdefault("forward", forward) if forward: source_df = _get_df_cell_transition( self.adata, annotation_keys=[annotation_label], - filter_key=self.batch_key, + filter_key=batch_key, filter_value=source, ) dummy = pd.get_dummies(source_df) + axis = 1 # columns + cell_transition_kwargs.setdefault("source_groups", None) + cell_transition_kwargs.setdefault("target_groups", annotation_label) elif not forward: target_df = _get_df_cell_transition( self.adata if other_adata is None else other_adata, annotation_keys=[annotation_label], - filter_key=self.batch_key, + filter_key=batch_key, filter_value=target, ) dummy = pd.get_dummies(target_df) + axis = 0 # rows + cell_transition_kwargs.setdefault("source_groups", annotation_label) + cell_transition_kwargs.setdefault("target_groups", None) if mapping_mode == "sum": - cell_transition_kwargs = dict(cell_transition_kwargs) - cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell - # cell_transition_kwargs.setdefault("key", annotation_label) - cell_transition_kwargs.setdefault("key", self.batch_key) - cell_transition_kwargs.setdefault("source", source) - cell_transition_kwargs.setdefault("target", target) - cell_transition_kwargs.setdefault("other_adata", other_adata) - out: ArrayLike = self._cell_transition(**cell_transition_kwargs) + out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs) elif mapping_mode == "max": - assert ( - not cell_transition_kwargs - ), "cell_transition_kwargs is not empty, although cell_transition is not used." - out: ArrayLike = self.pull(dummy, scale_by_marginals=scale_by_marginals) # or assert out is not None - else: + if forward: + print(self[(source, target)]) + out: ArrayLike = self[(source, target)].push(dummy, scale_by_marginals=scale_by_marginals) + return out raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") # return for both modes - return pd.Categorical(out.idxmax(axis="columns")) + return out.idxmax(axis=axis).to_frame(name=annotation_label) def _sample_from_tmap( self: AnalysisMixinProtocol[K, B], diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 091b7834d..5d1fb9409 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import pathlib import types @@ -41,7 +43,7 @@ class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): # typ _temporal_key: Optional[str] def cell_transition( # noqa: D102 - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], source: K, target: K, source_groups: Str_Dict_t, @@ -67,15 +69,8 @@ def _cell_transition( ) -> pd.DataFrame: ... - def _annotation_mapping( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: - ... - def _sample_from_tmap( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], source: K, target: K, n_samples: int, @@ -89,7 +84,7 @@ def _sample_from_tmap( ... def _compute_wasserstein_distance( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], point_cloud_1: ArrayLike, point_cloud_2: ArrayLike, a: Optional[ArrayLike] = None, @@ -100,7 +95,7 @@ def _compute_wasserstein_distance( ... def _interpolate_gex_with_ot( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], number_cells: int, source_data: ArrayLike, target_data: ArrayLike, @@ -114,7 +109,7 @@ def _interpolate_gex_with_ot( ... def _get_data( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], source: K, intermediate: Optional[K] = None, target: Optional[K] = None, @@ -125,7 +120,7 @@ def _get_data( ... def _interpolate_gex_randomly( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], number_cells: int, source_data: ArrayLike, target_data: ArrayLike, @@ -136,7 +131,7 @@ def _interpolate_gex_randomly( ... def _plot_temporal( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], data: Dict[K, ArrayLike], source: K, target: K, @@ -242,49 +237,32 @@ def cell_transition( ) def annotation_mapping( - self: AnalysisMixinProtocol[K, B], + self: TemporalMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, - source: str = "src", - target: str = "tgt", + source: K, + target: K, scale_by_marginals: bool = True, other_adata: Optional[str] = None, - key_added: str | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: - annotation = self._annotation_mapping( + annotation: pd.DataFrame = self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, source=source, target=target, + key=self._temporal_key, forward=forward, other_adata=other_adata, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) - if key_added is None: - return annotation - - if key_added not in list(self.adata.obs): - self.adata.obs[key_added] = np.empty(len(self.adata)) - - if forward: - if source != "src": - idx = self.adata[self.adata.obs[self.batch_key] == source] - self.adata[idx].obs[key_added] = annotation - else: - self.adata.obs[key_added] = annotation - else: - if target != "tgt": - idx = self.adata[self.adata.obs[self.batch_key] == target] - self.adata[idx].obs[key_added] = annotation - else: - self.adata.obs[key_added] = annotation + return annotation def sankey( - self: "TemporalMixinProtocol[K, B]", + self: TemporalMixinProtocol[K, B], source: K, target: K, source_groups: Str_Dict_t, From 5bd669673a4457cef1698d89f2d09f9a38e7f1ec Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 11 Dec 2023 16:52:17 +0100 Subject: [PATCH 24/58] fix max for temporal --- src/moscot/base/problems/_mixins.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 8155c988d..520c2e9a1 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -329,7 +329,7 @@ def _annotation_mapping( source_df = _get_df_cell_transition( self.adata, annotation_keys=[annotation_label], - filter_key=batch_key, + filter_key=key, filter_value=source, ) dummy = pd.get_dummies(source_df) @@ -340,7 +340,7 @@ def _annotation_mapping( target_df = _get_df_cell_transition( self.adata if other_adata is None else other_adata, annotation_keys=[annotation_label], - filter_key=batch_key, + filter_key=key, filter_value=target, ) dummy = pd.get_dummies(target_df) @@ -349,14 +349,15 @@ def _annotation_mapping( cell_transition_kwargs.setdefault("target_groups", None) if mapping_mode == "sum": out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs) + return out.idxmax(axis=axis).to_frame(name=annotation_label) elif mapping_mode == "max": if forward: - print(self[(source, target)]) out: ArrayLike = self[(source, target)].push(dummy, scale_by_marginals=scale_by_marginals) - return out + else: + out: ArrayLike = self[(source, target)].pull(dummy, scale_by_marginals=scale_by_marginals) + return pd.DataFrame(out.argmax(1), columns=[annotation_label]) + else: raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") - # return for both modes - return out.idxmax(axis=axis).to_frame(name=annotation_label) def _sample_from_tmap( self: AnalysisMixinProtocol[K, B], From 5b3d50875b11dd59c6857b2163b179f46cb2f69c Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Mon, 11 Dec 2023 21:02:15 +0100 Subject: [PATCH 25/58] before merge --- src/moscot/base/problems/_mixins.py | 1 + src/moscot/problems/space/_mixins.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 9d1c91aba..8dd9d462d 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -331,6 +331,7 @@ def _annotation_mapping( cell_transition_kwargs = dict(cell_transition_kwargs) cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell #cell_transition_kwargs.setdefault("key", annotation_label) + cell_transition_kwargs.setdefault("forward", forward) cell_transition_kwargs.setdefault("key", self.batch_key) cell_transition_kwargs.setdefault("source", source) cell_transition_kwargs.setdefault("target", target) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 2ad05a0a4..525ddfa57 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -315,7 +315,7 @@ def annotation_mapping( if forward: if source != "src": - idx = self.adata[self.adata.obs[self.batch_key] == source] + idx = self.adata.obs[self.batch_key] == source self.adata[idx].obs[key_added] = annotation else: self.adata.obs[key_added] = annotation From 338e75a1f9dd45c04f4717fd999c5fb91913202e Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 12 Dec 2023 08:37:23 +0100 Subject: [PATCH 26/58] passing batch key ap --- src/moscot/problems/space/_mixins.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 525ddfa57..601d32455 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -301,6 +301,7 @@ def annotation_mapping( annotation_label=annotation_label, source=source, target=target, + key=self._batch_key, forward=forward, other_adata=other_adata, scale_by_marginals=scale_by_marginals, From 5a6d538c24ad50312ef2bcfa18cd41fb8adccc38 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 14 Dec 2023 14:34:02 +0100 Subject: [PATCH 27/58] cleaned arguments, forward in mp --- src/moscot/base/problems/_mixins.py | 3 -- src/moscot/problems/cross_modality/_mixins.py | 4 +- src/moscot/problems/space/_mixins.py | 53 +++---------------- src/moscot/problems/time/_mixins.py | 3 +- 4 files changed, 11 insertions(+), 52 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 520c2e9a1..5d03d669c 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -310,12 +310,9 @@ def _annotation_mapping( target: K, key: str | None = None, forward: bool = True, - # source_label: Optional[str] = "adata", - # target_label: Optional[str] = "adata", other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - # key_added: Optional[str] = None, ) -> pd.DataFrame: batch_key = getattr(self, "batch_key", None) cell_transition_kwargs = dict(cell_transition_kwargs) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 8ea454e8e..1099c5a32 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -189,8 +189,8 @@ def annotation_mapping( mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, - source: K, - target: K, + source: K = "src", + target: K = "tgt", scale_by_marginals: bool = True, other_adata: Optional[str] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 601d32455..7dfb2d4da 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -291,9 +291,7 @@ def annotation_mapping( forward: bool, source: str = "src", target: str = "tgt", - other_adata: Optional[str] = None, scale_by_marginals: bool = True, - key_added: str | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: annotation = self._annotation_mapping( @@ -303,29 +301,10 @@ def annotation_mapping( target=target, key=self._batch_key, forward=forward, - other_adata=other_adata, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) - - if key_added is None: - return annotation - - if key_added not in list(self.adata.obs): - self.adata.obs[key_added] = np.empty(len(self.adata)) - - if forward: - if source != "src": - idx = self.adata.obs[self.batch_key] == source - self.adata[idx].obs[key_added] = annotation - else: - self.adata.obs[key_added] = annotation - else: - if target != "tgt": - idx = self.adata[self.adata.obs[self.batch_key] == target] - self.adata[idx].obs[key_added] = annotation - else: - self.adata.obs[key_added] = annotation + return annotation @property def spatial_key(self) -> Optional[str]: @@ -620,40 +599,24 @@ def annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, - forward: bool, - source: str = "src", - # target: str = "tgt", + source: str, + target: str = "tgt", + forward: bool = False, scale_by_marginals: bool = True, - key_added: str | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: - # mp[("batch_0","tgt"), ("batch_1","tgt")] # from sc tgt to spatial batch annotation = self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, source=source, - target="tgt", # target always 'tgt' - forward=not forward, # inverted for MappingProblem + target=target, + forward=forward, + key=self.batch_key, other_adata=self.adata_sc, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) - - if key_added is None: - return annotation - - if key_added not in list(self.adata.obs): - self.adata.obs[key_added] = np.empty(len(self.adata)) - - # if forward add in self.adata; forward false - self.adata_sc - if forward: - if source != "src": - idx = self.adata[self.adata.obs[self.batch_key] == source] - self.adata[idx].obs[key_added] = annotation - else: - self.adata.obs[key_added] = annotation - else: # target is always 'tgt' - self.adata_sc.obs[key_added] = annotation + return annotation @property def batch_key(self) -> Optional[str]: diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 5d1fb9409..0e4ea9219 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -244,7 +244,6 @@ def annotation_mapping( source: K, target: K, scale_by_marginals: bool = True, - other_adata: Optional[str] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: annotation: pd.DataFrame = self._annotation_mapping( @@ -254,7 +253,7 @@ def annotation_mapping( target=target, key=self._temporal_key, forward=forward, - other_adata=other_adata, + other_adata=None, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) From da8eac1d194612936af4e322adf5892edb4adfb5 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 19 Dec 2023 09:13:53 +0100 Subject: [PATCH 28/58] fix for cross_modality and general label handling in max --- src/moscot/base/problems/_mixins.py | 3 ++- src/moscot/problems/cross_modality/_mixins.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 5d03d669c..bbc39ad46 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -352,7 +352,8 @@ def _annotation_mapping( out: ArrayLike = self[(source, target)].push(dummy, scale_by_marginals=scale_by_marginals) else: out: ArrayLike = self[(source, target)].pull(dummy, scale_by_marginals=scale_by_marginals) - return pd.DataFrame(out.argmax(1), columns=[annotation_label]) + categories = pd.Categorical([(dummy.columns[i]).split("_")[-1] for i in np.array(out.argmax(1))]) + return pd.DataFrame(categories, columns=[annotation_label]) else: raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 1099c5a32..367c19625 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -202,7 +202,7 @@ def annotation_mapping( target=target, key=self.batch_key, forward=forward, - other_adata=other_adata, + other_adata=self.adata_tgt if forward else self.adata_src, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) From 10a79fe79c827f513006bb3e03a43cba1d025c11 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 19 Dec 2023 11:57:59 +0100 Subject: [PATCH 29/58] handling annottion labels --- src/moscot/base/problems/_mixins.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index bbc39ad46..1edb8ba7b 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -314,7 +314,6 @@ def _annotation_mapping( scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: - batch_key = getattr(self, "batch_key", None) cell_transition_kwargs = dict(cell_transition_kwargs) cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell cell_transition_kwargs.setdefault("key", key) @@ -329,7 +328,7 @@ def _annotation_mapping( filter_key=key, filter_value=source, ) - dummy = pd.get_dummies(source_df) + dummy = pd.get_dummies(source_df, prefix="", prefix_sep="") axis = 1 # columns cell_transition_kwargs.setdefault("source_groups", None) cell_transition_kwargs.setdefault("target_groups", annotation_label) @@ -340,7 +339,7 @@ def _annotation_mapping( filter_key=key, filter_value=target, ) - dummy = pd.get_dummies(target_df) + dummy = pd.get_dummies(target_df, prefix="", prefix_sep="") axis = 0 # rows cell_transition_kwargs.setdefault("source_groups", annotation_label) cell_transition_kwargs.setdefault("target_groups", None) @@ -352,7 +351,7 @@ def _annotation_mapping( out: ArrayLike = self[(source, target)].push(dummy, scale_by_marginals=scale_by_marginals) else: out: ArrayLike = self[(source, target)].pull(dummy, scale_by_marginals=scale_by_marginals) - categories = pd.Categorical([(dummy.columns[i]).split("_")[-1] for i in np.array(out.argmax(1))]) + categories = pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) return pd.DataFrame(categories, columns=[annotation_label]) else: raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") From d3f435b67beb854bb433d5bf81e403feeef97eb0 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 19 Dec 2023 14:27:08 +0100 Subject: [PATCH 30/58] update --- src/moscot/base/problems/_mixins.py | 13 ++++++------- src/moscot/base/problems/_utils.py | 2 +- src/moscot/problems/space/_mixins.py | 9 ++++++++- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 1edb8ba7b..2f3752df4 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -151,7 +151,6 @@ def _cell_transition( ) if aggregation_mode == "cell" and source_groups is None and target_groups is None: raise ValueError("At least one of `source_groups` and `target_group` must be specified.") - _check_argument_compatibility_cell_transition( source_annotation=source_groups, target_annotation=target_groups, @@ -208,13 +207,13 @@ def _cell_transition_online( ) df_source = _get_df_cell_transition( self.adata, - [source_annotation_key, target_annotation_key], + [source_annotation_key], key, source, ) df_target = _get_df_cell_transition( self.adata if other_adata is None else other_adata, - [source_annotation_key, target_annotation_key], + [target_annotation_key], key if other_adata is None else other_key, target, ) @@ -323,7 +322,7 @@ def _annotation_mapping( cell_transition_kwargs.setdefault("forward", forward) if forward: source_df = _get_df_cell_transition( - self.adata, + self.adata if other_adata is None else other_adata, annotation_keys=[annotation_label], filter_key=key, filter_value=source, @@ -332,9 +331,9 @@ def _annotation_mapping( axis = 1 # columns cell_transition_kwargs.setdefault("source_groups", None) cell_transition_kwargs.setdefault("target_groups", annotation_label) - elif not forward: + else: target_df = _get_df_cell_transition( - self.adata if other_adata is None else other_adata, + self.adata, annotation_keys=[annotation_label], filter_key=key, filter_value=target, @@ -346,7 +345,7 @@ def _annotation_mapping( if mapping_mode == "sum": out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs) return out.idxmax(axis=axis).to_frame(name=annotation_label) - elif mapping_mode == "max": + if mapping_mode == "max": if forward: out: ArrayLike = self[(source, target)].push(dummy, scale_by_marginals=scale_by_marginals) else: diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index 6c05faa65..3fae860db 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -107,7 +107,7 @@ def _check_argument_compatibility_cell_transition( raise ValueError("Unable to infer distributions, missing `adata` and `key`.") if forward and target_annotation is None: raise ValueError("No target annotation provided.") - if not forward and source_annotation is None: + if aggregation_mode == "annotation" and (not forward and source_annotation is None): raise ValueError("No source annotation provided.") if (aggregation_mode == "annotation") and (source_annotation is None or target_annotation is None): raise ValueError( diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 7dfb2d4da..ab16b932b 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -605,12 +605,19 @@ def annotation_mapping( scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: + cell_transition_kwargs = dict(cell_transition_kwargs) + if forward: + cell_transition_kwargs.setdefault("source_groups", None) + cell_transition_kwargs.setdefault("target_groups", annotation_label) + else: + cell_transition_kwargs.setdefault("source_groups", None) + cell_transition_kwargs.setdefault("target_groups", annotation_label) annotation = self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, source=source, target=target, - forward=forward, + forward=not forward, key=self.batch_key, other_adata=self.adata_sc, scale_by_marginals=scale_by_marginals, From 3a3776200386a152e3043f183576f265b6dd367a Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 19 Dec 2023 14:29:29 +0100 Subject: [PATCH 31/58] update --- src/moscot/problems/space/_mixins.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index ab16b932b..61c321635 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -605,6 +605,13 @@ def annotation_mapping( scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: + """ + + Notes + ----- + If forward is True, it means that the annotation columns (annotation label) needs to be in the target adata, + If forward is False, it means that the annotation column (annotation label) needs to be in the source adata. + """ cell_transition_kwargs = dict(cell_transition_kwargs) if forward: cell_transition_kwargs.setdefault("source_groups", None) From 72e545d8de1d10d59941fc40b476027b93a27f80 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 19 Dec 2023 17:28:25 +0100 Subject: [PATCH 32/58] update --- src/moscot/problems/space/_mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 61c321635..5f9eb619e 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -614,8 +614,8 @@ def annotation_mapping( """ cell_transition_kwargs = dict(cell_transition_kwargs) if forward: - cell_transition_kwargs.setdefault("source_groups", None) - cell_transition_kwargs.setdefault("target_groups", annotation_label) + cell_transition_kwargs.setdefault("source_groups", annotation_label) + cell_transition_kwargs.setdefault("target_groups", None) else: cell_transition_kwargs.setdefault("source_groups", None) cell_transition_kwargs.setdefault("target_groups", annotation_label) From 284c83dc57ac6467fb5661f29f3ab494a5391b3c Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Wed, 20 Dec 2023 13:45:41 +0100 Subject: [PATCH 33/58] update --- src/moscot/base/problems/_mixins.py | 4 ++-- src/moscot/problems/space/_mixins.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 2f3752df4..d69b2052e 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -322,7 +322,7 @@ def _annotation_mapping( cell_transition_kwargs.setdefault("forward", forward) if forward: source_df = _get_df_cell_transition( - self.adata if other_adata is None else other_adata, + self.adata if (other_adata is None or mapping_mode == "max") else other_adata, annotation_keys=[annotation_label], filter_key=key, filter_value=source, @@ -333,7 +333,7 @@ def _annotation_mapping( cell_transition_kwargs.setdefault("target_groups", annotation_label) else: target_df = _get_df_cell_transition( - self.adata, + self.adata if (other_adata is None or mapping_mode == "sum") else other_adata, annotation_keys=[annotation_label], filter_key=key, filter_value=target, diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 5f9eb619e..1acc37231 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -624,7 +624,7 @@ def annotation_mapping( annotation_label=annotation_label, source=source, target=target, - forward=not forward, + forward=not forward if mapping_mode == "sum" else forward, key=self.batch_key, other_adata=self.adata_sc, scale_by_marginals=scale_by_marginals, From fc7aba0d3cb7da460669a9243d7aa661768e94d0 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 20 Dec 2023 14:43:33 +0100 Subject: [PATCH 34/58] update --- src/moscot/base/problems/_mixins.py | 58 ++++++++++++++--------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index d69b2052e..da5cd03a1 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -313,42 +313,42 @@ def _annotation_mapping( scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: - cell_transition_kwargs = dict(cell_transition_kwargs) - cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell - cell_transition_kwargs.setdefault("key", key) - cell_transition_kwargs.setdefault("source", source) - cell_transition_kwargs.setdefault("target", target) - cell_transition_kwargs.setdefault("other_adata", other_adata) - cell_transition_kwargs.setdefault("forward", forward) - if forward: - source_df = _get_df_cell_transition( - self.adata if (other_adata is None or mapping_mode == "max") else other_adata, - annotation_keys=[annotation_label], - filter_key=key, - filter_value=source, - ) - dummy = pd.get_dummies(source_df, prefix="", prefix_sep="") - axis = 1 # columns - cell_transition_kwargs.setdefault("source_groups", None) - cell_transition_kwargs.setdefault("target_groups", annotation_label) - else: - target_df = _get_df_cell_transition( - self.adata if (other_adata is None or mapping_mode == "sum") else other_adata, - annotation_keys=[annotation_label], - filter_key=key, - filter_value=target, - ) - dummy = pd.get_dummies(target_df, prefix="", prefix_sep="") - axis = 0 # rows - cell_transition_kwargs.setdefault("source_groups", annotation_label) - cell_transition_kwargs.setdefault("target_groups", None) if mapping_mode == "sum": + cell_transition_kwargs = dict(cell_transition_kwargs) + cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell + cell_transition_kwargs.setdefault("key", key) + cell_transition_kwargs.setdefault("source", source) + cell_transition_kwargs.setdefault("target", target) + cell_transition_kwargs.setdefault("other_adata", other_adata) + cell_transition_kwargs.setdefault("forward", forward) + if forward: + cell_transition_kwargs.setdefault("source_groups", None) + cell_transition_kwargs.setdefault("target_groups", annotation_label) + axis = 1 # columns + else: + cell_transition_kwargs.setdefault("source_groups", annotation_label) + cell_transition_kwargs.setdefault("target_groups", None) + axis = 0 # rows out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs) return out.idxmax(axis=axis).to_frame(name=annotation_label) if mapping_mode == "max": if forward: + source_df = _get_df_cell_transition( + self.adata, + annotation_keys=[annotation_label], + filter_key=key, + filter_value=source, + ) + dummy = pd.get_dummies(source_df, prefix="", prefix_sep="") out: ArrayLike = self[(source, target)].push(dummy, scale_by_marginals=scale_by_marginals) else: + target_df = _get_df_cell_transition( + other_adata, + annotation_keys=[annotation_label], + filter_key=key, + filter_value=target, + ) + dummy = pd.get_dummies(target_df, prefix="", prefix_sep="") out: ArrayLike = self[(source, target)].pull(dummy, scale_by_marginals=scale_by_marginals) categories = pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) return pd.DataFrame(categories, columns=[annotation_label]) From e26266d56ee82d78baa4d96bf5cacfac8e343deb Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 26 Dec 2023 18:17:27 +0100 Subject: [PATCH 35/58] shorter returns --- src/moscot/base/problems/_mixins.py | 3 +-- src/moscot/problems/cross_modality/_mixins.py | 3 +-- src/moscot/problems/space/_mixins.py | 6 ++---- src/moscot/problems/time/_mixins.py | 3 +-- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index da5cd03a1..3659d33d8 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -352,8 +352,7 @@ def _annotation_mapping( out: ArrayLike = self[(source, target)].pull(dummy, scale_by_marginals=scale_by_marginals) categories = pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) return pd.DataFrame(categories, columns=[annotation_label]) - else: - raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") + raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") def _sample_from_tmap( self: AnalysisMixinProtocol[K, B], diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 367c19625..c00858449 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -195,7 +195,7 @@ def annotation_mapping( other_adata: Optional[str] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: - annotation: pd.DataFrame = self._annotation_mapping( + return self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, source=source, @@ -207,7 +207,6 @@ def annotation_mapping( cell_transition_kwargs=cell_transition_kwargs, ) - return annotation @property def batch_key(self) -> Optional[str]: diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 1acc37231..a8cd33b65 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -294,7 +294,7 @@ def annotation_mapping( scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: - annotation = self._annotation_mapping( + return self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, source=source, @@ -304,7 +304,6 @@ def annotation_mapping( scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) - return annotation @property def spatial_key(self) -> Optional[str]: @@ -619,7 +618,7 @@ def annotation_mapping( else: cell_transition_kwargs.setdefault("source_groups", None) cell_transition_kwargs.setdefault("target_groups", annotation_label) - annotation = self._annotation_mapping( + return self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, source=source, @@ -630,7 +629,6 @@ def annotation_mapping( scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) - return annotation @property def batch_key(self) -> Optional[str]: diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 0e4ea9219..ed60413a6 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -246,7 +246,7 @@ def annotation_mapping( scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: - annotation: pd.DataFrame = self._annotation_mapping( + return self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, source=source, @@ -258,7 +258,6 @@ def annotation_mapping( cell_transition_kwargs=cell_transition_kwargs, ) - return annotation def sankey( self: TemporalMixinProtocol[K, B], From 96dd1d5866075149cb088fe95d8a51825a63bdf4 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Sun, 7 Jan 2024 15:23:54 +0100 Subject: [PATCH 36/58] fix for temporal cell_transition --- src/moscot/base/problems/_mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 3659d33d8..175e570ce 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -207,13 +207,13 @@ def _cell_transition_online( ) df_source = _get_df_cell_transition( self.adata, - [source_annotation_key], + [source_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key], key, source, ) df_target = _get_df_cell_transition( self.adata if other_adata is None else other_adata, - [target_annotation_key], + [target_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key], key if other_adata is None else other_key, target, ) From 039ce1ed2caac2807db83369a314894deb236386 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 7 Jan 2024 14:24:31 +0000 Subject: [PATCH 37/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/problems/cross_modality/_mixins.py | 1 - src/moscot/problems/time/_mixins.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index c00858449..29f38f797 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -207,7 +207,6 @@ def annotation_mapping( cell_transition_kwargs=cell_transition_kwargs, ) - @property def batch_key(self) -> Optional[str]: """Batch key in :attr:`~anndata.AnnData.obs`.""" diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index ed60413a6..0c6f1de2c 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -258,7 +258,6 @@ def annotation_mapping( cell_transition_kwargs=cell_transition_kwargs, ) - def sankey( self: TemporalMixinProtocol[K, B], source: K, From 4c694ee74f2d181dfec222c6e8d55fdc2acd5800 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Tue, 9 Jan 2024 09:14:56 +0100 Subject: [PATCH 38/58] gt for annotation tests --- tests/conftest.py | 50 +++++++++++++++++++- tests/problems/cross_modality/test_mixins.py | 12 +++++ tests/problems/space/test_mixins.py | 11 +++++ tests/problems/time/test_mixins.py | 20 ++++++++ 4 files changed, 92 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6acc3c4b4..4fd2a8c45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ from math import cos, sin -from typing import Optional, Tuple +from typing import Optional, Tuple, Literal import pytest @@ -207,3 +207,51 @@ def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]: adata_src.obsm["emb_src"] = rng.normal(size=(adata_src.shape[0], 5)) adata_tgt.obsm["emb_tgt"] = rng.normal(size=(adata_tgt.shape[0], 15)) return adata_src, adata_tgt + +@pytest.fixture() +def adata_anno( + problem_kind: Literal["temporal", "cross_modality", "alignment", "mapping"], + #forward: bool + ) -> AnnData | Tuple[AnnData, AnnData]: + rng = np.random.RandomState(31) + adata_src = AnnData(X=csr_matrix(rng.normal(size=(10, 60)))) + adata_src.obs["celltype"] = _gt_source_annotation + adata_src.obs["celltype"] = adata_src.obs["celltype"].astype("category") + adata_src.uns["expected_max"] = _gt_target_max_annotation + adata_src.uns["expected_sum"] = _gt_target_sum_annotation + adata_tgt = AnnData(X=csr_matrix(rng.normal(size=(15, 60)))) + if problem_kind == "cross_modality": + adata_src.obs["batch"] = "0" + adata_tgt.obs["batch"] = "1" + adata_src.obsm["emb_src"] = rng.normal(size=(adata_src.shape[0], 5)) + adata_tgt.obsm["emb_tgt"] = rng.normal(size=(adata_tgt.shape[0], 15)) + sc.pp.pca(adata_src) + sc.pp.pca(adata_tgt) + return adata_src, adata_tgt + if problem_kind in ["alignment","mapping"]: + adata_src.obsm["spatial"] = rng.normal(size=(adata_src.n_obs, 2)) + adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2)) + key = "day" if problem_kind == "temporal" else "batch" + adatas = [adata_src, adata_tgt] # if forward else [adata_tgt, adata_src] + adata = ad.concat(adatas, join="outer", label=key, index_unique="-", uns_merge="unique") + adata.obs[key] = (pd.to_numeric(adata.obs[key]) if key == "day" else adata.obs[key]).astype("category") + adata.layers["counts"] = adata.X.A + sc.pp.pca(adata) + return adata + +_gt_source_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A"], dtype="U1") + +_gt_target_max_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) + +_gt_target_sum_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A", "A", "A", "A", "A", "A"]) + +@pytest.fixture() +def gt_tm_annotation() -> np.ndarray: + tm = np.zeros((10, 15)) + for i in range(10): + tm[i][i] = 1 + for i in range(10, 15): + tm[0][i] = 0.3 + tm[1][i] = 0.3 + tm[2][i] = 0.4 + return tm diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index 33c3cb918..d2918c601 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -106,3 +106,15 @@ def test_cell_transition_pipeline( assert result2.shape == (3, 3) with pytest.raises(AssertionError): pd.testing.assert_frame_equal(result1, result2) + + @pytest.mark.fast() + @pytest.mark.parametrize("forward", [True])#, False]) + @pytest.mark.parametrize("mapping_mode", ["max"])#, "sum"]) + @pytest.mark.parametrize("problem_kind", ["cross_modality"]) + def test_annotation_mapping( + self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation + ): + rng = np.random.RandomState(0) + adata_src, adata_tgt = adata_anno + tp = TranslationProblem(adata_src, adata_tgt) + tp = tp.prepare(batch_key="batch", src_attr="emb_src", tgt_attr="emb_tgt", joint_attr="X_pca") diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index e208d7763..7506d4171 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -175,3 +175,14 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n assert isinstance(result, pd.DataFrame) assert result.shape == (3, 4) + + @pytest.mark.fast() + @pytest.mark.parametrize("forward", [True])#, False]) + @pytest.mark.parametrize("mapping_mode", ["max"])#, "sum"]) + @pytest.mark.parametrize("problem_kind", ["mapping"]) + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + rng = np.random.RandomState(0) + adataref, adatasp = _adata_spatial_split(adata_anno) + mp = MappingProblem(adataref, adatasp) + mp = mp.prepare(batch_key="batch", sc_attr={"attr": "obsm", "key": "X_pca"}) + diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index 612211879..3c2db975c 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -50,6 +50,26 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward present_cell_type_marginal = marginal[marginal > 0] np.testing.assert_allclose(present_cell_type_marginal, 1.0) + @pytest.mark.fast() + @pytest.mark.parametrize("forward", [True])#, False]) + @pytest.mark.parametrize("mapping_mode", ["max"])#, "sum"]) + @pytest.mark.parametrize("problem_kind", ["temporal"]) + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + problem = TemporalProblem(adata_anno) + problem_keys = (0, 1) + problem = problem.prepare(time_key="day", joint_attr="X_pca") + assert set(problem.problems.keys()) == {problem_keys} + problem[problem_keys]._solution = MockSolverOutput(gt_tm_annotation) + result = problem.annotation_mapping( + mapping_mode=mapping_mode, + annotation_label="celltype", + forward=forward, + source=0, + target=1 + ) + expected_result = (adata_anno.uns["expected_max"] if mapping_mode == "max" else adata_anno.uns["expected_sum"]) + assert (result["celltype"] == expected_result).all() + @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) def test_cell_transition_different_groups(self, gt_temporal_adata: AnnData, forward: bool): From d6968f53b375d18daa2b236b606efb10b1bf3662 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Jan 2024 08:15:38 +0000 Subject: [PATCH 39/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/conftest.py | 13 ++++++++----- tests/problems/cross_modality/test_mixins.py | 4 ++-- tests/problems/space/test_mixins.py | 5 ++--- tests/problems/time/test_mixins.py | 12 ++++-------- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4fd2a8c45..8476c35a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ from math import cos, sin -from typing import Optional, Tuple, Literal +from typing import Literal, Optional, Tuple import pytest @@ -208,11 +208,12 @@ def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]: adata_tgt.obsm["emb_tgt"] = rng.normal(size=(adata_tgt.shape[0], 15)) return adata_src, adata_tgt + @pytest.fixture() def adata_anno( problem_kind: Literal["temporal", "cross_modality", "alignment", "mapping"], - #forward: bool - ) -> AnnData | Tuple[AnnData, AnnData]: + # forward: bool +) -> AnnData | Tuple[AnnData, AnnData]: rng = np.random.RandomState(31) adata_src = AnnData(X=csr_matrix(rng.normal(size=(10, 60)))) adata_src.obs["celltype"] = _gt_source_annotation @@ -228,23 +229,25 @@ def adata_anno( sc.pp.pca(adata_src) sc.pp.pca(adata_tgt) return adata_src, adata_tgt - if problem_kind in ["alignment","mapping"]: + if problem_kind in ["alignment", "mapping"]: adata_src.obsm["spatial"] = rng.normal(size=(adata_src.n_obs, 2)) adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2)) key = "day" if problem_kind == "temporal" else "batch" - adatas = [adata_src, adata_tgt] # if forward else [adata_tgt, adata_src] + adatas = [adata_src, adata_tgt] # if forward else [adata_tgt, adata_src] adata = ad.concat(adatas, join="outer", label=key, index_unique="-", uns_merge="unique") adata.obs[key] = (pd.to_numeric(adata.obs[key]) if key == "day" else adata.obs[key]).astype("category") adata.layers["counts"] = adata.X.A sc.pp.pca(adata) return adata + _gt_source_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A"], dtype="U1") _gt_target_max_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) _gt_target_sum_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A", "A", "A", "A", "A", "A"]) + @pytest.fixture() def gt_tm_annotation() -> np.ndarray: tm = np.zeros((10, 15)) diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index d2918c601..891ae41d3 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -108,8 +108,8 @@ def test_cell_transition_pipeline( pd.testing.assert_frame_equal(result1, result2) @pytest.mark.fast() - @pytest.mark.parametrize("forward", [True])#, False]) - @pytest.mark.parametrize("mapping_mode", ["max"])#, "sum"]) + @pytest.mark.parametrize("forward", [True]) # , False]) + @pytest.mark.parametrize("mapping_mode", ["max"]) # , "sum"]) @pytest.mark.parametrize("problem_kind", ["cross_modality"]) def test_annotation_mapping( self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index 7506d4171..161ae95ac 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -177,12 +177,11 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n assert result.shape == (3, 4) @pytest.mark.fast() - @pytest.mark.parametrize("forward", [True])#, False]) - @pytest.mark.parametrize("mapping_mode", ["max"])#, "sum"]) + @pytest.mark.parametrize("forward", [True]) # , False]) + @pytest.mark.parametrize("mapping_mode", ["max"]) # , "sum"]) @pytest.mark.parametrize("problem_kind", ["mapping"]) def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): rng = np.random.RandomState(0) adataref, adatasp = _adata_spatial_split(adata_anno) mp = MappingProblem(adataref, adatasp) mp = mp.prepare(batch_key="batch", sc_attr={"attr": "obsm", "key": "X_pca"}) - diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index 3c2db975c..19f7563ed 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -51,8 +51,8 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward np.testing.assert_allclose(present_cell_type_marginal, 1.0) @pytest.mark.fast() - @pytest.mark.parametrize("forward", [True])#, False]) - @pytest.mark.parametrize("mapping_mode", ["max"])#, "sum"]) + @pytest.mark.parametrize("forward", [True]) # , False]) + @pytest.mark.parametrize("mapping_mode", ["max"]) # , "sum"]) @pytest.mark.parametrize("problem_kind", ["temporal"]) def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): problem = TemporalProblem(adata_anno) @@ -61,13 +61,9 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo assert set(problem.problems.keys()) == {problem_keys} problem[problem_keys]._solution = MockSolverOutput(gt_tm_annotation) result = problem.annotation_mapping( - mapping_mode=mapping_mode, - annotation_label="celltype", - forward=forward, - source=0, - target=1 + mapping_mode=mapping_mode, annotation_label="celltype", forward=forward, source=0, target=1 ) - expected_result = (adata_anno.uns["expected_max"] if mapping_mode == "max" else adata_anno.uns["expected_sum"]) + expected_result = adata_anno.uns["expected_max"] if mapping_mode == "max" else adata_anno.uns["expected_sum"] assert (result["celltype"] == expected_result).all() @pytest.mark.fast() From 112af8eb3866c68cf4bd107be356e0e10aec81e7 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Wed, 10 Jan 2024 12:33:20 +0100 Subject: [PATCH 40/58] only passing tests --- src/moscot/base/problems/_mixins.py | 2 +- src/moscot/utils/subset_policy.py | 3 ++- tests/problems/cross_modality/test_mixins.py | 15 +++++++++++++-- tests/problems/space/test_mixins.py | 18 +++++++++++++++--- tests/problems/time/test_mixins.py | 2 +- 5 files changed, 32 insertions(+), 8 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 175e570ce..b7f548840 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -343,7 +343,7 @@ def _annotation_mapping( out: ArrayLike = self[(source, target)].push(dummy, scale_by_marginals=scale_by_marginals) else: target_df = _get_df_cell_transition( - other_adata, + self.adata if other_adata is None else other_adata, annotation_keys=[annotation_label], filter_key=key, filter_value=target, diff --git a/src/moscot/utils/subset_policy.py b/src/moscot/utils/subset_policy.py index 72c5cdcc1..5917e1f15 100644 --- a/src/moscot/utils/subset_policy.py +++ b/src/moscot/utils/subset_policy.py @@ -84,7 +84,8 @@ def __init__( self._subset_key: Optional[str] = key if verify_integrity and len(self._cat) < 2: - raise ValueError(f"Policy must contain at least `2` different values, found `{len(self._cat)}`.") + raise ValueError(f"Policy must contain at least `2` different values, found `{len(self._cat)}`.\n" + "Is it possible that there is only one `batch` in `batch_key`?") @abc.abstractmethod def _create_graph(self, **kwargs: Any) -> Set[Tuple[K, K]]: diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index d2918c601..a6dd492d5 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -109,7 +109,7 @@ def test_cell_transition_pipeline( @pytest.mark.fast() @pytest.mark.parametrize("forward", [True])#, False]) - @pytest.mark.parametrize("mapping_mode", ["max"])#, "sum"]) + @pytest.mark.parametrize("mapping_mode",["max",])# "sum"]) @pytest.mark.parametrize("problem_kind", ["cross_modality"]) def test_annotation_mapping( self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation @@ -117,4 +117,15 @@ def test_annotation_mapping( rng = np.random.RandomState(0) adata_src, adata_tgt = adata_anno tp = TranslationProblem(adata_src, adata_tgt) - tp = tp.prepare(batch_key="batch", src_attr="emb_src", tgt_attr="emb_tgt", joint_attr="X_pca") + tp = tp.prepare(src_attr="emb_src", tgt_attr="emb_tgt") + problem_keys = ("src", "tgt") + assert set(tp.problems.keys()) == {problem_keys} + tp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation), overwrite=True) + + result = tp.annotation_mapping( + mapping_mode=mapping_mode, + annotation_label="celltype", + forward=forward, + ) + expected_result = (adata_src.uns["expected_max"] if mapping_mode == "max" else adata_src.uns["expected_sum"]) + assert (result["celltype"] == expected_result).all() diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index 7506d4171..eacdcd641 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -177,12 +177,24 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n assert result.shape == (3, 4) @pytest.mark.fast() - @pytest.mark.parametrize("forward", [True])#, False]) - @pytest.mark.parametrize("mapping_mode", ["max"])#, "sum"]) + @pytest.mark.parametrize("forward", [False,])# True]) + @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) @pytest.mark.parametrize("problem_kind", ["mapping"]) def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): rng = np.random.RandomState(0) adataref, adatasp = _adata_spatial_split(adata_anno) mp = MappingProblem(adataref, adatasp) - mp = mp.prepare(batch_key="batch", sc_attr={"attr": "obsm", "key": "X_pca"}) + mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, joint_attr={"attr": "X"}) + problem_keys = ("src", "tgt") + assert set(mp.problems.keys()) == {problem_keys} + mp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation.T)) + + result = mp.annotation_mapping( + mapping_mode=mapping_mode, + annotation_label="celltype", + source="src", + forward=forward, + ) + expected_result = (adataref.uns["expected_max"] if mapping_mode == "max" else adataref.uns["expected_sum"]) + assert (result["celltype"] == expected_result).all() diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index 3c2db975c..80b25f911 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -51,7 +51,7 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward np.testing.assert_allclose(present_cell_type_marginal, 1.0) @pytest.mark.fast() - @pytest.mark.parametrize("forward", [True])#, False]) + @pytest.mark.parametrize("forward", [True,])# False]) @pytest.mark.parametrize("mapping_mode", ["max"])#, "sum"]) @pytest.mark.parametrize("problem_kind", ["temporal"]) def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): From e0efc2ef5c376d2fec80ccac0b27510e69679e58 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jan 2024 11:39:57 +0000 Subject: [PATCH 41/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/utils/subset_policy.py | 6 ++++-- tests/problems/cross_modality/test_mixins.py | 11 ++++++++--- tests/problems/space/test_mixins.py | 10 +++++++--- tests/problems/time/test_mixins.py | 9 +++++++-- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/moscot/utils/subset_policy.py b/src/moscot/utils/subset_policy.py index 5917e1f15..3f86851dd 100644 --- a/src/moscot/utils/subset_policy.py +++ b/src/moscot/utils/subset_policy.py @@ -84,8 +84,10 @@ def __init__( self._subset_key: Optional[str] = key if verify_integrity and len(self._cat) < 2: - raise ValueError(f"Policy must contain at least `2` different values, found `{len(self._cat)}`.\n" - "Is it possible that there is only one `batch` in `batch_key`?") + raise ValueError( + f"Policy must contain at least `2` different values, found `{len(self._cat)}`.\n" + "Is it possible that there is only one `batch` in `batch_key`?" + ) @abc.abstractmethod def _create_graph(self, **kwargs: Any) -> Set[Tuple[K, K]]: diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index a6dd492d5..97e6b08fd 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -108,8 +108,13 @@ def test_cell_transition_pipeline( pd.testing.assert_frame_equal(result1, result2) @pytest.mark.fast() - @pytest.mark.parametrize("forward", [True])#, False]) - @pytest.mark.parametrize("mapping_mode",["max",])# "sum"]) + @pytest.mark.parametrize("forward", [True]) # , False]) + @pytest.mark.parametrize( + "mapping_mode", + [ + "max", + ], + ) # "sum"]) @pytest.mark.parametrize("problem_kind", ["cross_modality"]) def test_annotation_mapping( self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation @@ -127,5 +132,5 @@ def test_annotation_mapping( annotation_label="celltype", forward=forward, ) - expected_result = (adata_src.uns["expected_max"] if mapping_mode == "max" else adata_src.uns["expected_sum"]) + expected_result = adata_src.uns["expected_max"] if mapping_mode == "max" else adata_src.uns["expected_sum"] assert (result["celltype"] == expected_result).all() diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index eacdcd641..dc813219c 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -177,7 +177,12 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n assert result.shape == (3, 4) @pytest.mark.fast() - @pytest.mark.parametrize("forward", [False,])# True]) + @pytest.mark.parametrize( + "forward", + [ + False, + ], + ) # True]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) @pytest.mark.parametrize("problem_kind", ["mapping"]) def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): @@ -195,6 +200,5 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo source="src", forward=forward, ) - expected_result = (adataref.uns["expected_max"] if mapping_mode == "max" else adataref.uns["expected_sum"]) + expected_result = adataref.uns["expected_max"] if mapping_mode == "max" else adataref.uns["expected_sum"] assert (result["celltype"] == expected_result).all() - diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index ad27dda7a..dfe2c2d62 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -51,8 +51,13 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward np.testing.assert_allclose(present_cell_type_marginal, 1.0) @pytest.mark.fast() - @pytest.mark.parametrize("forward", [True,])# False]) - @pytest.mark.parametrize("mapping_mode", ["max"])#, "sum"]) + @pytest.mark.parametrize( + "forward", + [ + True, + ], + ) # False]) + @pytest.mark.parametrize("mapping_mode", ["max"]) # , "sum"]) @pytest.mark.parametrize("problem_kind", ["temporal"]) def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): problem = TemporalProblem(adata_anno) From 01637c880a531d11e4f9bf0bbd539dc73bd76cb1 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Wed, 10 Jan 2024 13:17:21 +0100 Subject: [PATCH 42/58] ruff type annotation --- src/moscot/base/problems/_mixins.py | 104 ++++++++++++++-------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 40523d22e..7b08a2429 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -4,16 +4,16 @@ from typing import ( TYPE_CHECKING, Any, - Dict, + dict, Generic, Iterable, - List, + list, Literal, Mapping, Optional, Protocol, Sequence, - Tuple, + tuple, Union, ) @@ -25,7 +25,7 @@ from anndata import AnnData from moscot import _constants -from moscot._types import ArrayLike, Numeric_t, Str_Dict_t +from moscot._types import ArrayLike, Numeric_t, Str_dict_t from moscot.base.output import BaseSolverOutput from moscot.base.problems._utils import ( _check_argument_compatibility_cell_transition, @@ -49,14 +49,14 @@ class AnalysisMixinProtocol(Protocol[K, B]): adata: AnnData _policy: SubsetPolicy[K] - solutions: Dict[Tuple[K, K], BaseSolverOutput] - problems: Dict[Tuple[K, K], B] + solutions: dict[tuple[K, K], BaseSolverOutput] + problems: dict[tuple[K, K], B] def _apply( self, - data: Optional[Union[str, ArrayLike]] = None, - source: Optional[K] = None, - target: Optional[K] = None, + data: None | Union[str, ArrayLike] = None, + source: None | K = None, + target: None | K = None, forward: bool = True, return_all: bool = False, scale_by_marginals: bool = False, @@ -66,24 +66,24 @@ def _apply( def _interpolate_transport( self: AnalysisMixinProtocol[K, B], - path: Sequence[Tuple[K, K]], + path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, ) -> LinearOperator: ... def _flatten( self: AnalysisMixinProtocol[K, B], - data: Dict[K, ArrayLike], + data: dict[K, ArrayLike], *, - key: Optional[str], + key: None | str, ) -> ArrayLike: ... - def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: + def push(self, *args: Any, **kwargs: Any) -> None | ApplyOutput_t[K]: """Push distribution.""" ... - def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: + def pull(self, *args: Any, **kwargs: Any) -> None | ApplyOutput_t[K]: """Pull distribution.""" ... @@ -91,26 +91,26 @@ def _cell_transition( self: AnalysisMixinProtocol[K, B], source: K, target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, + source_groups: Str_dict_t, + target_groups: Str_dict_t, aggregation_mode: Literal["annotation", "cell"] = "annotation", - key_added: Optional[str] = _constants.CELL_TRANSITION, + key_added: None | str = _constants.CELL_TRANSITION, **kwargs: Any, ) -> pd.DataFrame: ... def _cell_transition_online( self: AnalysisMixinProtocol[K, B], - key: Optional[str], + key: None | str, source: K, target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, + source_groups: Str_dict_t, + target_groups: Str_dict_t, forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: Optional[str] = None, - other_adata: Optional[str] = None, - batch_size: Optional[int] = None, + other_key: None | str = None, + other_adata: None | str = None, + batch_size: None | int = None, normalize: bool = True, ) -> pd.DataFrame: ... @@ -123,7 +123,7 @@ def _annotation_mapping( source: K, target: K, key: str, - other_adata: Optional[str] = None, + other_adata: None | str = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: @@ -140,10 +140,10 @@ def _cell_transition( self: AnalysisMixinProtocol[K, B], source: K, target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, + source_groups: Str_dict_t, + target_groups: Str_dict_t, aggregation_mode: Literal["annotation", "cell"] = "annotation", - key_added: Optional[str] = _constants.CELL_TRANSITION, + key_added: None | str = _constants.CELL_TRANSITION, **kwargs: Any, ) -> pd.DataFrame: if aggregation_mode == "annotation" and (source_groups is None or target_groups is None): @@ -187,16 +187,16 @@ def _cell_transition( def _cell_transition_online( self: AnalysisMixinProtocol[K, B], - key: Optional[str], + key: None | str, source: K, target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, + source_groups: Str_dict_t, + target_groups: Str_dict_t, forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: Optional[str] = None, - other_adata: Optional[str] = None, - batch_size: Optional[int] = None, + other_key: None | str = None, + other_adata: None | str = None, + batch_size: None | int = None, normalize: bool = True, **_: Any, ) -> pd.DataFrame: @@ -310,7 +310,7 @@ def _annotation_mapping( target: K, key: str | None = None, forward: bool = True, - other_adata: Optional[str] = None, + other_adata: None | str = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: @@ -364,9 +364,9 @@ def _sample_from_tmap( target_dim: int, batch_size: int = 256, account_for_unbalancedness: bool = False, - interpolation_parameter: Optional[Numeric_t] = None, - seed: Optional[int] = None, - ) -> Tuple[List[Any], List[ArrayLike]]: + interpolation_parameter: None | Numeric_t = None, + seed: None | int = None, + ) -> tuple[list[Any], list[ArrayLike]]: rng = np.random.RandomState(seed) if account_for_unbalancedness and interpolation_parameter is None: raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.") @@ -403,7 +403,7 @@ def _sample_from_tmap( rows_sampled = rng.choice(source_dim, p=row_probability / row_probability.sum(), size=n_samples) rows, counts = np.unique(rows_sampled, return_counts=True) - all_cols_sampled: List[str] = [] + all_cols_sampled: list[str] = [] for batch in range(0, len(rows), batch_size): rows_batch = rows[batch : batch + batch_size] counts_batch = counts[batch : batch + batch_size] @@ -436,7 +436,7 @@ def _sample_from_tmap( def _interpolate_transport( self: AnalysisMixinProtocol[K, B], # TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key) - path: Sequence[Tuple[K, K]], + path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, **_: Any, ) -> LinearOperator: @@ -447,7 +447,7 @@ def _interpolate_transport( fst, *rest = path return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals) - def _flatten(self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: + def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: None | str) -> ArrayLike: tmp = np.full(len(self.adata), np.nan) for k, v in data.items(): mask = self.adata.obs[key] == k @@ -459,8 +459,8 @@ def _annotation_aggregation_transition( source: K, target: K, annotation_key: str, - annotations_1: List[Any], - annotations_2: List[Any], + annotations_1: list[Any], + annotations_2: list[Any], df: pd.DataFrame, tm: pd.DataFrame, forward: bool, @@ -495,12 +495,12 @@ def _cell_aggregation_transition( target: str, annotation_key: str, # TODO(MUCDK): unused variables, del below - annotations_1: List[Any], - annotations_2: List[Any], + annotations_1: list[Any], + annotations_2: list[Any], df_1: pd.DataFrame, df_2: pd.DataFrame, tm: pd.DataFrame, - batch_size: Optional[int], + batch_size: None | int, forward: bool, ) -> pd.DataFrame: func = self.push if forward else self.pull @@ -532,12 +532,12 @@ def compute_feature_correlation( obs_key: str, corr_method: Literal["pearson", "spearman"] = "pearson", significance_method: Literal["fisher", "perm_test"] = "fisher", - annotation: Optional[Dict[str, Iterable[str]]] = None, - layer: Optional[str] = None, - features: Optional[Union[List[str], Literal["human", "mouse", "drosophila"]]] = None, + annotation: None | dict[str, Iterable[str]] = None, + layer: None | str = None, + features: None | Union[list[str], Literal["human", "mouse", "drosophila"]] = None, confidence_level: float = 0.95, n_perms: int = 1000, - seed: Optional[int] = None, + seed: None | int = None, **kwargs: Any, ) -> pd.DataFrame: """Compute correlation of push-forward or pull-back distribution with features. @@ -641,9 +641,9 @@ def compute_entropy( source: K, target: K, forward: bool = True, - key_added: Optional[str] = "conditional_entropy", - batch_size: Optional[int] = None, - ) -> Optional[pd.DataFrame]: + key_added: None | str = "conditional_entropy", + batch_size: None | int = None, + ) -> None | pd.DataFrame: """Compute the conditional entropy per cell. The conditional entropy reflects the uncertainty of the mapping of a single cell. From 2d9b73e4426c252d9de8561f800be8ceabe71ff0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jan 2024 12:18:03 +0000 Subject: [PATCH 43/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 7b08a2429..1e9c6bf89 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -4,17 +4,16 @@ from typing import ( TYPE_CHECKING, Any, - dict, Generic, Iterable, - list, Literal, Mapping, - Optional, Protocol, Sequence, - tuple, Union, + dict, + list, + tuple, ) import numpy as np From 6a5ad0b01f878f218292c1ec2509a73b60f64beb Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Wed, 10 Jan 2024 13:29:18 +0100 Subject: [PATCH 44/58] Revert "ruff type annotation" This reverts commit 01637c880a531d11e4f9bf0bbd539dc73bd76cb1. --- src/moscot/base/problems/_mixins.py | 104 ++++++++++++++-------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 1e9c6bf89..8bd8baf52 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -4,16 +4,16 @@ from typing import ( TYPE_CHECKING, Any, + Dict, Generic, Iterable, + List, Literal, Mapping, Protocol, Sequence, + Tuple, Union, - dict, - list, - tuple, ) import numpy as np @@ -24,7 +24,7 @@ from anndata import AnnData from moscot import _constants -from moscot._types import ArrayLike, Numeric_t, Str_dict_t +from moscot._types import ArrayLike, Numeric_t, Str_Dict_t from moscot.base.output import BaseSolverOutput from moscot.base.problems._utils import ( _check_argument_compatibility_cell_transition, @@ -48,14 +48,14 @@ class AnalysisMixinProtocol(Protocol[K, B]): adata: AnnData _policy: SubsetPolicy[K] - solutions: dict[tuple[K, K], BaseSolverOutput] - problems: dict[tuple[K, K], B] + solutions: Dict[Tuple[K, K], BaseSolverOutput] + problems: Dict[Tuple[K, K], B] def _apply( self, - data: None | Union[str, ArrayLike] = None, - source: None | K = None, - target: None | K = None, + data: Optional[Union[str, ArrayLike]] = None, + source: Optional[K] = None, + target: Optional[K] = None, forward: bool = True, return_all: bool = False, scale_by_marginals: bool = False, @@ -65,24 +65,24 @@ def _apply( def _interpolate_transport( self: AnalysisMixinProtocol[K, B], - path: Sequence[tuple[K, K]], + path: Sequence[Tuple[K, K]], scale_by_marginals: bool = True, ) -> LinearOperator: ... def _flatten( self: AnalysisMixinProtocol[K, B], - data: dict[K, ArrayLike], + data: Dict[K, ArrayLike], *, - key: None | str, + key: Optional[str], ) -> ArrayLike: ... - def push(self, *args: Any, **kwargs: Any) -> None | ApplyOutput_t[K]: + def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: """Push distribution.""" ... - def pull(self, *args: Any, **kwargs: Any) -> None | ApplyOutput_t[K]: + def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: """Pull distribution.""" ... @@ -90,26 +90,26 @@ def _cell_transition( self: AnalysisMixinProtocol[K, B], source: K, target: K, - source_groups: Str_dict_t, - target_groups: Str_dict_t, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, aggregation_mode: Literal["annotation", "cell"] = "annotation", - key_added: None | str = _constants.CELL_TRANSITION, + key_added: Optional[str] = _constants.CELL_TRANSITION, **kwargs: Any, ) -> pd.DataFrame: ... def _cell_transition_online( self: AnalysisMixinProtocol[K, B], - key: None | str, + key: Optional[str], source: K, target: K, - source_groups: Str_dict_t, - target_groups: Str_dict_t, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: None | str = None, - other_adata: None | str = None, - batch_size: None | int = None, + other_key: Optional[str] = None, + other_adata: Optional[str] = None, + batch_size: Optional[int] = None, normalize: bool = True, ) -> pd.DataFrame: ... @@ -122,7 +122,7 @@ def _annotation_mapping( source: K, target: K, key: str, - other_adata: None | str = None, + other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: @@ -139,10 +139,10 @@ def _cell_transition( self: AnalysisMixinProtocol[K, B], source: K, target: K, - source_groups: Str_dict_t, - target_groups: Str_dict_t, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, aggregation_mode: Literal["annotation", "cell"] = "annotation", - key_added: None | str = _constants.CELL_TRANSITION, + key_added: Optional[str] = _constants.CELL_TRANSITION, **kwargs: Any, ) -> pd.DataFrame: if aggregation_mode == "annotation" and (source_groups is None or target_groups is None): @@ -186,16 +186,16 @@ def _cell_transition( def _cell_transition_online( self: AnalysisMixinProtocol[K, B], - key: None | str, + key: Optional[str], source: K, target: K, - source_groups: Str_dict_t, - target_groups: Str_dict_t, + source_groups: Str_Dict_t, + target_groups: Str_Dict_t, forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: None | str = None, - other_adata: None | str = None, - batch_size: None | int = None, + other_key: Optional[str] = None, + other_adata: Optional[str] = None, + batch_size: Optional[int] = None, normalize: bool = True, **_: Any, ) -> pd.DataFrame: @@ -309,7 +309,7 @@ def _annotation_mapping( target: K, key: str | None = None, forward: bool = True, - other_adata: None | str = None, + other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: @@ -363,9 +363,9 @@ def _sample_from_tmap( target_dim: int, batch_size: int = 256, account_for_unbalancedness: bool = False, - interpolation_parameter: None | Numeric_t = None, - seed: None | int = None, - ) -> tuple[list[Any], list[ArrayLike]]: + interpolation_parameter: Optional[Numeric_t] = None, + seed: Optional[int] = None, + ) -> Tuple[List[Any], List[ArrayLike]]: rng = np.random.RandomState(seed) if account_for_unbalancedness and interpolation_parameter is None: raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.") @@ -402,7 +402,7 @@ def _sample_from_tmap( rows_sampled = rng.choice(source_dim, p=row_probability / row_probability.sum(), size=n_samples) rows, counts = np.unique(rows_sampled, return_counts=True) - all_cols_sampled: list[str] = [] + all_cols_sampled: List[str] = [] for batch in range(0, len(rows), batch_size): rows_batch = rows[batch : batch + batch_size] counts_batch = counts[batch : batch + batch_size] @@ -435,7 +435,7 @@ def _sample_from_tmap( def _interpolate_transport( self: AnalysisMixinProtocol[K, B], # TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key) - path: Sequence[tuple[K, K]], + path: Sequence[Tuple[K, K]], scale_by_marginals: bool = True, **_: Any, ) -> LinearOperator: @@ -446,7 +446,7 @@ def _interpolate_transport( fst, *rest = path return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals) - def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: None | str) -> ArrayLike: + def _flatten(self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: tmp = np.full(len(self.adata), np.nan) for k, v in data.items(): mask = self.adata.obs[key] == k @@ -458,8 +458,8 @@ def _annotation_aggregation_transition( source: K, target: K, annotation_key: str, - annotations_1: list[Any], - annotations_2: list[Any], + annotations_1: List[Any], + annotations_2: List[Any], df: pd.DataFrame, tm: pd.DataFrame, forward: bool, @@ -494,12 +494,12 @@ def _cell_aggregation_transition( target: str, annotation_key: str, # TODO(MUCDK): unused variables, del below - annotations_1: list[Any], - annotations_2: list[Any], + annotations_1: List[Any], + annotations_2: List[Any], df_1: pd.DataFrame, df_2: pd.DataFrame, tm: pd.DataFrame, - batch_size: None | int, + batch_size: Optional[int], forward: bool, ) -> pd.DataFrame: func = self.push if forward else self.pull @@ -531,12 +531,12 @@ def compute_feature_correlation( obs_key: str, corr_method: Literal["pearson", "spearman"] = "pearson", significance_method: Literal["fisher", "perm_test"] = "fisher", - annotation: None | dict[str, Iterable[str]] = None, - layer: None | str = None, - features: None | Union[list[str], Literal["human", "mouse", "drosophila"]] = None, + annotation: Optional[Dict[str, Iterable[str]]] = None, + layer: Optional[str] = None, + features: Optional[Union[List[str], Literal["human", "mouse", "drosophila"]]] = None, confidence_level: float = 0.95, n_perms: int = 1000, - seed: None | int = None, + seed: Optional[int] = None, **kwargs: Any, ) -> pd.DataFrame: """Compute correlation of push-forward or pull-back distribution with features. @@ -640,9 +640,9 @@ def compute_entropy( source: K, target: K, forward: bool = True, - key_added: None | str = "conditional_entropy", - batch_size: None | int = None, - ) -> None | pd.DataFrame: + key_added: Optional[str] = "conditional_entropy", + batch_size: Optional[int] = None, + ) -> Optional[pd.DataFrame]: """Compute the conditional entropy per cell. The conditional entropy reflects the uncertainty of the mapping of a single cell. From 1cb646da77030386b13149af047762db9dcd0c26 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Wed, 10 Jan 2024 13:30:55 +0100 Subject: [PATCH 45/58] revert --- src/moscot/base/problems/_mixins.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 8bd8baf52..c322fee11 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -14,6 +14,7 @@ Sequence, Tuple, Union, + Optional ) import numpy as np From 93b2e8de159e3e80fbb6d0bd052bf6ba6eaebdaf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jan 2024 12:31:31 +0000 Subject: [PATCH 46/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index c322fee11..40523d22e 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -10,11 +10,11 @@ List, Literal, Mapping, + Optional, Protocol, Sequence, Tuple, Union, - Optional ) import numpy as np From 4d3c8dbef8a958b7a624065ca82634b7635095d9 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Wed, 17 Jan 2024 17:54:47 +0100 Subject: [PATCH 47/58] fully passing tests --- src/moscot/base/problems/_mixins.py | 55 +++++++++++++++---- src/moscot/problems/cross_modality/_mixins.py | 2 +- src/moscot/problems/space/_mixins.py | 16 +----- tests/conftest.py | 46 +++++++++------- tests/problems/cross_modality/test_mixins.py | 23 ++++---- tests/problems/space/test_mixins.py | 44 +++++++++++---- tests/problems/time/test_mixins.py | 21 ++++--- 7 files changed, 125 insertions(+), 82 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 40523d22e..36863551a 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -310,8 +310,9 @@ def _annotation_mapping( target: K, key: str | None = None, forward: bool = True, - other_adata: Optional[str] = None, + other_adata: str | None = None, scale_by_marginals: bool = True, + batch_size: int | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: if mapping_mode == "sum": @@ -321,18 +322,19 @@ def _annotation_mapping( cell_transition_kwargs.setdefault("source", source) cell_transition_kwargs.setdefault("target", target) cell_transition_kwargs.setdefault("other_adata", other_adata) - cell_transition_kwargs.setdefault("forward", forward) + cell_transition_kwargs.setdefault("forward", not forward) if forward: - cell_transition_kwargs.setdefault("source_groups", None) - cell_transition_kwargs.setdefault("target_groups", annotation_label) - axis = 1 # columns - else: cell_transition_kwargs.setdefault("source_groups", annotation_label) cell_transition_kwargs.setdefault("target_groups", None) axis = 0 # rows + else: + cell_transition_kwargs.setdefault("source_groups", None) + cell_transition_kwargs.setdefault("target_groups", annotation_label) + axis = 1 # columns out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs) return out.idxmax(axis=axis).to_frame(name=annotation_label) if mapping_mode == "max": + out = [] if forward: source_df = _get_df_cell_transition( self.adata, @@ -340,8 +342,23 @@ def _annotation_mapping( filter_key=key, filter_value=source, ) - dummy = pd.get_dummies(source_df, prefix="", prefix_sep="") - out: ArrayLike = self[(source, target)].push(dummy, scale_by_marginals=scale_by_marginals) + out_len = self[(source, target)].solution.shape[1] + batch_size = batch_size if batch_size is not None else out_len + for batch in range(0, out_len, batch_size): + tm_batch = self.push( + source=source, + target=target, + data=None, + subset=(batch, batch_size), + normalize=True, + return_all=False, + scale_by_marginals=scale_by_marginals, + split_mass=True, + key_added=None, + ) + v = np.array(tm_batch.argmax(1)) + out.extend(source_df[annotation_label][v[i]] for i in range(len(v))) + else: target_df = _get_df_cell_transition( self.adata if other_adata is None else other_adata, @@ -349,9 +366,23 @@ def _annotation_mapping( filter_key=key, filter_value=target, ) - dummy = pd.get_dummies(target_df, prefix="", prefix_sep="") - out: ArrayLike = self[(source, target)].pull(dummy, scale_by_marginals=scale_by_marginals) - categories = pd.Categorical([dummy.columns[i] for i in np.array(out.argmax(1))]) + out_len = self[(source, target)].solution.shape[0] + batch_size = batch_size if batch_size is not None else out_len + for batch in range(0, out_len, batch_size): + tm_batch = self.pull( + source=source, + target=target, + data=None, + subset=(batch, batch_size), + normalize=True, + return_all=False, + scale_by_marginals=scale_by_marginals, + split_mass=True, + key_added=None, + ) + v = np.array(tm_batch.argmax(1)) + out.extend(target_df[annotation_label][v[i]] for i in range(len(v))) + categories = pd.Categorical(out) return pd.DataFrame(categories, columns=[annotation_label]) raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") @@ -507,7 +538,7 @@ def _cell_aggregation_transition( if batch_size is None: batch_size = len(df_2) for batch in range(0, len(df_2), batch_size): - result = func( # TODO(@MUCDK) check how to make compatiAnalysisMixinProtocolcelltyble with all policies + result = func( # TODO(@MUCDK) check how to make compatible with all policies source=source, target=target, data=None, diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index d7f4c6bfc..a42ecb14b 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -202,7 +202,7 @@ def annotation_mapping( target=target, key=self.batch_key, forward=forward, - other_adata=self.adata_tgt if forward else self.adata_src, + other_adata=self.adata_tgt, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index a8cd33b65..3ee462985 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -604,26 +604,12 @@ def annotation_mapping( scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: - """ - - Notes - ----- - If forward is True, it means that the annotation columns (annotation label) needs to be in the target adata, - If forward is False, it means that the annotation column (annotation label) needs to be in the source adata. - """ - cell_transition_kwargs = dict(cell_transition_kwargs) - if forward: - cell_transition_kwargs.setdefault("source_groups", annotation_label) - cell_transition_kwargs.setdefault("target_groups", None) - else: - cell_transition_kwargs.setdefault("source_groups", None) - cell_transition_kwargs.setdefault("target_groups", annotation_label) return self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, source=source, target=target, - forward=not forward if mapping_mode == "sum" else forward, + forward=forward, key=self.batch_key, other_adata=self.adata_sc, scale_by_marginals=scale_by_marginals, diff --git a/tests/conftest.py b/tests/conftest.py index c10d4d043..8a2a2b6d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ from math import cos, sin -from typing import Literal, Optional, Tuple +from typing import Literal, Optional, Tuple, Union import pytest @@ -211,15 +211,22 @@ def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]: @pytest.fixture() def adata_anno( problem_kind: Literal["temporal", "cross_modality", "alignment", "mapping"], - # forward: bool -) -> AnnData | Tuple[AnnData, AnnData]: +) -> Union[AnnData, Tuple[AnnData, AnnData]]: rng = np.random.RandomState(31) adata_src = AnnData(X=csr_matrix(rng.normal(size=(10, 60)))) - adata_src.obs["celltype"] = _gt_source_annotation - adata_src.obs["celltype"] = adata_src.obs["celltype"].astype("category") - adata_src.uns["expected_max"] = _gt_target_max_annotation - adata_src.uns["expected_sum"] = _gt_target_sum_annotation + rng_src = rng.choice(["A", "B", "C"], size=5).tolist() + adata_src.obs["celltype1"] = ["C", "C", "A", "B", "B"] + rng_src + adata_src.obs["celltype1"] = adata_src.obs["celltype1"].astype("category") + adata_src.uns["expected_max1"] = ["C", "C", "A", "B", "B"] + rng_src + rng_src + adata_src.uns["expected_sum1"] = ["C", "C", "B", "B", "B"] + rng_src + rng_src + adata_tgt = AnnData(X=csr_matrix(rng.normal(size=(15, 60)))) + rng_tgt = rng.choice(["A", "B", "C"], size=5).tolist() + adata_tgt.obs["celltype2"] = ["C", "C", "A", "B", "B"] + rng_tgt + rng_tgt + adata_tgt.obs["celltype2"] = adata_tgt.obs["celltype2"].astype("category") + adata_tgt.uns["expected_max2"] = ["C", "C", "A", "B", "B"] + rng_tgt + adata_tgt.uns["expected_sum2"] = ["C", "C", "B", "B", "B"] + rng_tgt + if problem_kind == "cross_modality": adata_src.obs["batch"] = "0" adata_tgt.obs["batch"] = "1" @@ -228,11 +235,18 @@ def adata_anno( sc.pp.pca(adata_src) sc.pp.pca(adata_tgt) return adata_src, adata_tgt - if problem_kind in ["alignment", "mapping"]: + if problem_kind == "mapping": + adata_src.obs["batch"] = "0" + adata_tgt.obs["batch"] = "1" + sc.pp.pca(adata_src) + sc.pp.pca(adata_tgt) + adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2)) + return adata_src, adata_tgt + if problem_kind == "alignment": adata_src.obsm["spatial"] = rng.normal(size=(adata_src.n_obs, 2)) adata_tgt.obsm["spatial"] = rng.normal(size=(adata_tgt.n_obs, 2)) key = "day" if problem_kind == "temporal" else "batch" - adatas = [adata_src, adata_tgt] # if forward else [adata_tgt, adata_src] + adatas = [adata_src, adata_tgt] adata = ad.concat(adatas, join="outer", label=key, index_unique="-", uns_merge="unique") adata.obs[key] = (pd.to_numeric(adata.obs[key]) if key == "day" else adata.obs[key]).astype("category") adata.layers["counts"] = adata.X.A @@ -240,20 +254,14 @@ def adata_anno( return adata -_gt_source_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A"], dtype="U1") - -_gt_target_max_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) - -_gt_target_sum_annotation = np.array(["A", "A", "B", "A", "B", "C", "A", "A", "A", "A", "A", "A", "A", "A", "A"]) - - @pytest.fixture() def gt_tm_annotation() -> np.ndarray: tm = np.zeros((10, 15)) for i in range(10): tm[i][i] = 1 for i in range(10, 15): - tm[0][i] = 0.3 - tm[1][i] = 0.3 - tm[2][i] = 0.4 + tm[i-5][i] = 1 + for j in range(2,5): + for i in range(2,5): + tm[i][j] = 0.3 if i != j else 0.4 return tm diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index 97e6b08fd..7721bf877 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -108,29 +108,28 @@ def test_cell_transition_pipeline( pd.testing.assert_frame_equal(result1, result2) @pytest.mark.fast() - @pytest.mark.parametrize("forward", [True]) # , False]) - @pytest.mark.parametrize( - "mapping_mode", - [ - "max", - ], - ) # "sum"]) + @pytest.mark.parametrize("forward", [True, False]) + @pytest.mark.parametrize("mapping_mode",["max", "sum"]) @pytest.mark.parametrize("problem_kind", ["cross_modality"]) def test_annotation_mapping( self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation ): - rng = np.random.RandomState(0) adata_src, adata_tgt = adata_anno tp = TranslationProblem(adata_src, adata_tgt) tp = tp.prepare(src_attr="emb_src", tgt_attr="emb_tgt") problem_keys = ("src", "tgt") assert set(tp.problems.keys()) == {problem_keys} tp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation), overwrite=True) - + annotation_label = "celltype1" if forward else "celltype2" result = tp.annotation_mapping( mapping_mode=mapping_mode, - annotation_label="celltype", + annotation_label=annotation_label, forward=forward, + source="src", + target="tgt" ) - expected_result = adata_src.uns["expected_max"] if mapping_mode == "max" else adata_src.uns["expected_sum"] - assert (result["celltype"] == expected_result).all() + if forward: + expected_result = adata_src.uns["expected_max1"] if mapping_mode == "max" else adata_src.uns["expected_sum1"] + else: + expected_result = adata_tgt.uns["expected_max2"] if mapping_mode == "max" else adata_tgt.uns["expected_sum2"] + assert (result[annotation_label] == expected_result).all() diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index dc813219c..ec337f592 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -93,6 +93,29 @@ def test_cell_transition_pipeline(self, adata_space_rotate: AnnData, forward: bo assert isinstance(result, pd.DataFrame) assert result.shape == (3, 3) + @pytest.mark.fast() + @pytest.mark.parametrize("forward", [True, False]) + @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("problem_kind", ["alignment"]) + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + ap = AlignmentProblem(adata=adata_anno) + ap = ap.prepare(batch_key="batch", joint_attr={"attr": "X"}) + problem_keys = ("0", "1") + assert set(ap.problems.keys()) == {problem_keys} + ap[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation)) + annotation_label = "celltype1" if forward else "celltype2" + result = ap.annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + source="0", + target="1", + forward=forward, + ) + if forward: + expected_result = adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"] + else: + expected_result = adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"] + assert (result[annotation_label] == expected_result).all() class TestSpatialMappingAnalysisMixin: @pytest.mark.parametrize("sc_attr", [{"attr": "X"}, {"attr": "obsm", "key": "X_pca"}]) @@ -177,28 +200,25 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n assert result.shape == (3, 4) @pytest.mark.fast() - @pytest.mark.parametrize( - "forward", - [ - False, - ], - ) # True]) + @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) @pytest.mark.parametrize("problem_kind", ["mapping"]) def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): - rng = np.random.RandomState(0) - adataref, adatasp = _adata_spatial_split(adata_anno) + adataref, adatasp = adata_anno mp = MappingProblem(adataref, adatasp) mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, joint_attr={"attr": "X"}) problem_keys = ("src", "tgt") assert set(mp.problems.keys()) == {problem_keys} mp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation.T)) - + annotation_label = "celltype1" if not forward else "celltype2" result = mp.annotation_mapping( mapping_mode=mapping_mode, - annotation_label="celltype", + annotation_label=annotation_label, source="src", forward=forward, ) - expected_result = adataref.uns["expected_max"] if mapping_mode == "max" else adataref.uns["expected_sum"] - assert (result["celltype"] == expected_result).all() + if not forward: + expected_result = adataref.uns["expected_max1"] if mapping_mode == "max" else adataref.uns["expected_sum1"] + else: + expected_result = adatasp.uns["expected_max2"] if mapping_mode == "max" else adatasp.uns["expected_sum2"] + assert (result[annotation_label] == expected_result).all() diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index dfe2c2d62..6fcb58baf 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -51,13 +51,8 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward np.testing.assert_allclose(present_cell_type_marginal, 1.0) @pytest.mark.fast() - @pytest.mark.parametrize( - "forward", - [ - True, - ], - ) # False]) - @pytest.mark.parametrize("mapping_mode", ["max"]) # , "sum"]) + @pytest.mark.parametrize("forward",[True, False]) + @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) @pytest.mark.parametrize("problem_kind", ["temporal"]) def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): problem = TemporalProblem(adata_anno) @@ -65,11 +60,15 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo problem = problem.prepare(time_key="day", joint_attr="X_pca") assert set(problem.problems.keys()) == {problem_keys} problem[problem_keys]._solution = MockSolverOutput(gt_tm_annotation) + annotation_label = "celltype1" if forward else "celltype2" result = problem.annotation_mapping( - mapping_mode=mapping_mode, annotation_label="celltype", forward=forward, source=0, target=1 - ) - expected_result = adata_anno.uns["expected_max"] if mapping_mode == "max" else adata_anno.uns["expected_sum"] - assert (result["celltype"] == expected_result).all() + mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source=0, target=1 + ) + if forward: + expected_result = adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"] + else: + expected_result = adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"] + assert (result[annotation_label] == expected_result).all() @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) From e2b9c5bf98e710074d485b4cd2100810a07f3561 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jan 2024 16:55:21 +0000 Subject: [PATCH 48/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/conftest.py | 6 +++--- tests/problems/cross_modality/test_mixins.py | 16 ++++++++-------- tests/problems/space/test_mixins.py | 9 +++++++-- tests/problems/time/test_mixins.py | 12 ++++++++---- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 8a2a2b6d4..f86b26d3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -260,8 +260,8 @@ def gt_tm_annotation() -> np.ndarray: for i in range(10): tm[i][i] = 1 for i in range(10, 15): - tm[i-5][i] = 1 - for j in range(2,5): - for i in range(2,5): + tm[i - 5][i] = 1 + for j in range(2, 5): + for i in range(2, 5): tm[i][j] = 0.3 if i != j else 0.4 return tm diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index 7721bf877..079e153a4 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -109,7 +109,7 @@ def test_cell_transition_pipeline( @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) - @pytest.mark.parametrize("mapping_mode",["max", "sum"]) + @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) @pytest.mark.parametrize("problem_kind", ["cross_modality"]) def test_annotation_mapping( self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation @@ -122,14 +122,14 @@ def test_annotation_mapping( tp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation), overwrite=True) annotation_label = "celltype1" if forward else "celltype2" result = tp.annotation_mapping( - mapping_mode=mapping_mode, - annotation_label=annotation_label, - forward=forward, - source="src", - target="tgt" + mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source="src", target="tgt" ) if forward: - expected_result = adata_src.uns["expected_max1"] if mapping_mode == "max" else adata_src.uns["expected_sum1"] + expected_result = ( + adata_src.uns["expected_max1"] if mapping_mode == "max" else adata_src.uns["expected_sum1"] + ) else: - expected_result = adata_tgt.uns["expected_max2"] if mapping_mode == "max" else adata_tgt.uns["expected_sum2"] + expected_result = ( + adata_tgt.uns["expected_max2"] if mapping_mode == "max" else adata_tgt.uns["expected_sum2"] + ) assert (result[annotation_label] == expected_result).all() diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index ec337f592..a6b70031c 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -112,11 +112,16 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo forward=forward, ) if forward: - expected_result = adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"] + expected_result = ( + adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"] + ) else: - expected_result = adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"] + expected_result = ( + adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"] + ) assert (result[annotation_label] == expected_result).all() + class TestSpatialMappingAnalysisMixin: @pytest.mark.parametrize("sc_attr", [{"attr": "X"}, {"attr": "obsm", "key": "X_pca"}]) @pytest.mark.parametrize("var_names", ["0", [str(i) for i in range(20)]]) diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index 6fcb58baf..cb2d9ea2a 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -51,7 +51,7 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward np.testing.assert_allclose(present_cell_type_marginal, 1.0) @pytest.mark.fast() - @pytest.mark.parametrize("forward",[True, False]) + @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) @pytest.mark.parametrize("problem_kind", ["temporal"]) def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): @@ -63,11 +63,15 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo annotation_label = "celltype1" if forward else "celltype2" result = problem.annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source=0, target=1 - ) + ) if forward: - expected_result = adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"] + expected_result = ( + adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"] + ) else: - expected_result = adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"] + expected_result = ( + adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"] + ) assert (result[annotation_label] == expected_result).all() @pytest.mark.fast() From e1608b0667f9d4dc2a03fc2905ad70844d23f264 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 18 Jan 2024 14:54:04 +0100 Subject: [PATCH 49/58] some mypy fixes --- pyproject.toml | 2 ++ src/moscot/base/problems/_mixins.py | 10 +++++----- src/moscot/problems/cross_modality/_mixins.py | 10 ++++++---- src/moscot/problems/space/_mixins.py | 14 +++++++------- src/moscot/problems/time/_mixins.py | 7 +++++++ 5 files changed, 27 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ef1fc1ac5..10102d26b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,8 @@ ignore = [ "D107", # Missing docstring in magic method "D105", + # Use `X | Y` for type annotations + "UP007", ] line-length = 120 select = [ diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 36863551a..43ee8f248 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -122,7 +122,7 @@ def _annotation_mapping( forward: bool, source: K, target: K, - key: str, + key: str | None = None, other_adata: Optional[str] = None, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), @@ -342,10 +342,10 @@ def _annotation_mapping( filter_key=key, filter_value=source, ) - out_len = self[(source, target)].solution.shape[1] + out_len = self.solutions[(source, target)].shape[1] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch = self.push( + tm_batch : ArrayLike = self.push( # type: ignore[attr-defined] source=source, target=target, data=None, @@ -366,10 +366,10 @@ def _annotation_mapping( filter_key=key, filter_value=target, ) - out_len = self[(source, target)].solution.shape[0] + out_len = self.solutions[(source, target)].shape[0] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch = self.pull( + tm_batch : ArrayLike = self.pull( # type: ignore[attr-defined] source=source, target=target, data=None, diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index a42ecb14b..46367c9e0 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -25,6 +25,8 @@ class CrossModalityTranslationMixinProtocol(AnalysisMixinProtocol[K, B]): def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... + def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: + ... class CrossModalityTranslationMixin(AnalysisMixin[K, B]): """Cross modality translation analysis mixin class.""" @@ -184,13 +186,13 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( + def annotation_mapping( # type: ignore[misc] self: CrossModalityTranslationMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, - source: K = "src", - target: K = "tgt", + source: str = "src", + target: str = "tgt", scale_by_marginals: bool = True, other_adata: Optional[str] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), @@ -202,7 +204,7 @@ def annotation_mapping( target=target, key=self.batch_key, forward=forward, - other_adata=self.adata_tgt, + other_adata=self.adata_tgt if other_adata is None else other_adata, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 3ee462985..cd101c19d 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -284,8 +284,8 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( - self: AnalysisMixinProtocol[K, B], + def annotation_mapping( # type: ignore[misc] + self: SpatialAlignmentMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, @@ -299,7 +299,7 @@ def annotation_mapping( annotation_label=annotation_label, source=source, target=target, - key=self._batch_key, + key=self.batch_key, forward=forward, scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, @@ -594,12 +594,12 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( - self: AnalysisMixinProtocol[K, B], + def annotation_mapping( # type: ignore[misc] + self: SpatialMappingMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, - source: str, - target: str = "tgt", + source: K, + target: K | str = "tgt", forward: bool = False, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 0c6f1de2c..d43f49005 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -69,6 +69,13 @@ def _cell_transition( ) -> pd.DataFrame: ... + def _annotation_mapping( + self: AnalysisMixinProtocol[K, B], + *args: Any, + **kwargs: Any, + ) -> pd.DataFrame: + ... + def _sample_from_tmap( self: TemporalMixinProtocol[K, B], source: K, From ab89a42747f2250d55ced4533c38b184c33e0abc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jan 2024 13:54:58 +0000 Subject: [PATCH 50/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 4 ++-- src/moscot/problems/cross_modality/_mixins.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 43ee8f248..ae5855ff2 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -345,7 +345,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[1] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch : ArrayLike = self.push( # type: ignore[attr-defined] + tm_batch: ArrayLike = self.push( # type: ignore[attr-defined] source=source, target=target, data=None, @@ -369,7 +369,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[0] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch : ArrayLike = self.pull( # type: ignore[attr-defined] + tm_batch: ArrayLike = self.pull( # type: ignore[attr-defined] source=source, target=target, data=None, diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 46367c9e0..3e2818248 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -28,6 +28,7 @@ def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: An def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... + class CrossModalityTranslationMixin(AnalysisMixin[K, B]): """Cross modality translation analysis mixin class.""" From 57c2649da3e1cd8cfd66dca8a6139a2953d9321f Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 18 Jan 2024 15:11:43 +0100 Subject: [PATCH 51/58] ruff typing --- src/moscot/base/problems/_mixins.py | 32 ++++++++++++++-------------- src/moscot/problems/space/_mixins.py | 2 +- src/moscot/problems/time/_mixins.py | 22 +++++++++---------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index ae5855ff2..75570df12 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -49,8 +49,8 @@ class AnalysisMixinProtocol(Protocol[K, B]): adata: AnnData _policy: SubsetPolicy[K] - solutions: Dict[Tuple[K, K], BaseSolverOutput] - problems: Dict[Tuple[K, K], B] + solutions: dict[tuple[K, K], BaseSolverOutput] + problems: dict[tuple[K, K], B] def _apply( self, @@ -66,14 +66,14 @@ def _apply( def _interpolate_transport( self: AnalysisMixinProtocol[K, B], - path: Sequence[Tuple[K, K]], + path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, ) -> LinearOperator: ... def _flatten( self: AnalysisMixinProtocol[K, B], - data: Dict[K, ArrayLike], + data: dict[K, ArrayLike], *, key: Optional[str], ) -> ArrayLike: @@ -345,7 +345,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[1] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch: ArrayLike = self.push( # type: ignore[attr-defined] + tm_batch: ArrayLike = self.push( # type: ignore[no-redef] source=source, target=target, data=None, @@ -369,7 +369,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[0] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch: ArrayLike = self.pull( # type: ignore[attr-defined] + tm_batch: ArrayLike = self.pull( # type: ignore[no-redef] source=source, target=target, data=None, @@ -397,7 +397,7 @@ def _sample_from_tmap( account_for_unbalancedness: bool = False, interpolation_parameter: Optional[Numeric_t] = None, seed: Optional[int] = None, - ) -> Tuple[List[Any], List[ArrayLike]]: + ) -> tuple[list[Any], list[ArrayLike]]: rng = np.random.RandomState(seed) if account_for_unbalancedness and interpolation_parameter is None: raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.") @@ -434,7 +434,7 @@ def _sample_from_tmap( rows_sampled = rng.choice(source_dim, p=row_probability / row_probability.sum(), size=n_samples) rows, counts = np.unique(rows_sampled, return_counts=True) - all_cols_sampled: List[str] = [] + all_cols_sampled: list[str] = [] for batch in range(0, len(rows), batch_size): rows_batch = rows[batch : batch + batch_size] counts_batch = counts[batch : batch + batch_size] @@ -467,7 +467,7 @@ def _sample_from_tmap( def _interpolate_transport( self: AnalysisMixinProtocol[K, B], # TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key) - path: Sequence[Tuple[K, K]], + path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, **_: Any, ) -> LinearOperator: @@ -478,7 +478,7 @@ def _interpolate_transport( fst, *rest = path return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals) - def _flatten(self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: + def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: tmp = np.full(len(self.adata), np.nan) for k, v in data.items(): mask = self.adata.obs[key] == k @@ -490,8 +490,8 @@ def _annotation_aggregation_transition( source: K, target: K, annotation_key: str, - annotations_1: List[Any], - annotations_2: List[Any], + annotations_1: list[Any], + annotations_2: list[Any], df: pd.DataFrame, tm: pd.DataFrame, forward: bool, @@ -526,8 +526,8 @@ def _cell_aggregation_transition( target: str, annotation_key: str, # TODO(MUCDK): unused variables, del below - annotations_1: List[Any], - annotations_2: List[Any], + annotations_1: list[Any], + annotations_2: list[Any], df_1: pd.DataFrame, df_2: pd.DataFrame, tm: pd.DataFrame, @@ -563,9 +563,9 @@ def compute_feature_correlation( obs_key: str, corr_method: Literal["pearson", "spearman"] = "pearson", significance_method: Literal["fisher", "perm_test"] = "fisher", - annotation: Optional[Dict[str, Iterable[str]]] = None, + annotation: Optional[dict[str, Iterable[str]]] = None, layer: Optional[str] = None, - features: Optional[Union[List[str], Literal["human", "mouse", "drosophila"]]] = None, + features: Optional[Union[list[str], Literal["human", "mouse", "drosophila"]]] = None, confidence_level: float = 0.95, n_perms: int = 1000, seed: Optional[int] = None, diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index cd101c19d..f4604bb8c 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -599,7 +599,7 @@ def annotation_mapping( # type: ignore[misc] mapping_mode: Literal["sum", "max"], annotation_label: str, source: K, - target: K | str = "tgt", + target: Union[K, str] = "tgt", forward: bool = False, scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index d43f49005..f2ee7068f 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -38,7 +38,7 @@ class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): # type: ignore[misc] adata: AnnData - problems: Dict[Tuple[K, K], BirthDeathProblem] + problems: dict[tuple[K, K], BirthDeathProblem] temporal_key: Optional[str] _temporal_key: Optional[str] @@ -87,7 +87,7 @@ def _sample_from_tmap( account_for_unbalancedness: bool = False, interpolation_parameter: Optional[float] = None, seed: Optional[int] = None, - ) -> Tuple[List[Any], List[ArrayLike]]: + ) -> tuple[list[Any], list[ArrayLike]]: ... def _compute_wasserstein_distance( @@ -123,7 +123,7 @@ def _get_data( posterior_marginals: bool = True, *, only_start: bool = False, - ) -> Union[Tuple[ArrayLike, AnnData], Tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: + ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: ... def _interpolate_gex_randomly( @@ -139,7 +139,7 @@ def _interpolate_gex_randomly( def _plot_temporal( self: TemporalMixinProtocol[K, B], - data: Dict[K, ArrayLike], + data: dict[K, ArrayLike], source: K, target: K, time_points: Optional[Iterable[K]] = None, @@ -157,7 +157,7 @@ def _get_interp_param( ) -> float: ... - def __iter__(self) -> Iterator[Tuple[K, K]]: + def __iter__(self) -> Iterator[tuple[K, K]]: ... @@ -278,7 +278,7 @@ def sankey( order_annotations: Optional[Sequence[str]] = None, key_added: Optional[str] = _constants.SANKEY, **kwargs: Any, - ) -> Optional[List[pd.DataFrame]]: + ) -> Optional[list[pd.DataFrame]]: """Compute a `Sankey diagram `_ between cells across time points. .. seealso:: @@ -392,7 +392,7 @@ def push( source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, - subset: Optional[Union[str, List[str], Tuple[int, int]]] = None, + subset: Optional[Union[str, list[str], tuple[int, int]]] = None, scale_by_marginals: bool = True, key_added: Optional[str] = _constants.PUSH, return_all: bool = False, @@ -459,7 +459,7 @@ def pull( source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, - subset: Optional[Union[str, List[str], Tuple[int, int]]] = None, + subset: Optional[Union[str, list[str], tuple[int, int]]] = None, scale_by_marginals: bool = True, key_added: Optional[str] = _constants.PULL, return_all: bool = False, @@ -614,7 +614,7 @@ def _get_data( posterior_marginals: bool = True, *, only_start: bool = False, - ) -> Union[Tuple[ArrayLike, AnnData], Tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: + ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: # TODO: use .items() for src, tgt in self.problems: tag = self.problems[src, tgt].xy.tag # type: ignore[union-attr] @@ -821,7 +821,7 @@ def compute_time_point_distances( posterior_marginals: bool = True, backend: Literal["ott"] = "ott", **kwargs: Any, - ) -> Tuple[float, float]: + ) -> tuple[float, float]: """Compute `Wasserstein distance `_ between time points. .. seealso:: @@ -904,7 +904,7 @@ def compute_batch_distances( if len(data) != len(adata): raise ValueError(f"Expected the `data` to have length `{len(adata)}`, found `{len(data)}`.") - dist: List[float] = [] + dist: list[float] = [] for batch_1, batch_2 in itertools.combinations(adata.obs[batch_key].unique(), 2): dist.append( self._compute_wasserstein_distance( From 2b8aa48221bc5cc89378570da426550011c250e1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jan 2024 14:15:34 +0000 Subject: [PATCH 52/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/problems/_mixins.py | 3 --- src/moscot/problems/time/_mixins.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 75570df12..d1788c028 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -4,16 +4,13 @@ from typing import ( TYPE_CHECKING, Any, - Dict, Generic, Iterable, - List, Literal, Mapping, Optional, Protocol, Sequence, - Tuple, Union, ) diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index f2ee7068f..b10417660 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -6,16 +6,13 @@ from typing import ( TYPE_CHECKING, Any, - Dict, Iterable, Iterator, - List, Literal, Mapping, Optional, Protocol, Sequence, - Tuple, Union, ) From 158679426242e6eebded9d7186de105877fb9109 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 18 Jan 2024 19:34:36 +0100 Subject: [PATCH 53/58] docstrings --- src/moscot/base/problems/_mixins.py | 2 +- src/moscot/problems/cross_modality/_mixins.py | 31 ++++++++++ src/moscot/problems/space/_mixins.py | 58 +++++++++++++++++++ src/moscot/problems/time/_mixins.py | 29 ++++++++++ 4 files changed, 119 insertions(+), 1 deletion(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index d1788c028..cd09cd9b9 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -342,7 +342,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[1] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch: ArrayLike = self.push( # type: ignore[no-redef] + tm_batch: ArrayLike = self.push( source=source, target=target, data=None, diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 3e2818248..7006d580b 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -198,6 +198,37 @@ def annotation_mapping( # type: ignore[misc] other_adata: Optional[str] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: + """Transfer annotations between distributions. + + This function transfers annotation labels (e.g. cell types) between groups of cells. + + Parameters + ---------- + mapping_mode + How to decide which label to transfer. Valid options are: + + - ``'max'`` - pick the label of the annotated cell with the highest mapping weight. + - ``'sum'`` - compute :meth:`cell_transition` of annotation labels to target cells and + pick the label with the highest transition weight. + annotation_label + Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. + forward + If :obj:`True`, transfer the annotations from ``source`` to ``target``. + source + Key identifying the source distribution. + target + Key identifying the target distribution. + scale_by_marginals + Whether to scale by the source :term:`marginals`. + other_adata + The second :obj:`anndata.AnnData` if present. + cell_transition_kwargs + Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + + Returns + ------- + :obj:`pd.DataFrame` - returns the DataFrame of transferred annotations. + """ return self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index f4604bb8c..24d2bfce8 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -294,6 +294,35 @@ def annotation_mapping( # type: ignore[misc] scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: + """Transfer annotations between distributions. + + This function transfers annotation labels (e.g. cell types) between groups of cells. + + Parameters + ---------- + mapping_mode + How to decide which label to transfer. Valid options are: + + - ``'max'`` - pick the label of the annotated cell with the highest mapping weight. + - ``'sum'`` - compute :meth:`cell_transition` of annotation labels to target cells and + pick the label with the highest transition weight. + annotation_label + Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. + forward + If :obj:`True`, transfer the annotations from ``source`` to ``target``. + source + Key identifying the source distribution. + target + Key identifying the target distribution. + scale_by_marginals + Whether to scale by the source :term:`marginals`. + cell_transition_kwargs + Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + + Returns + ------- + :obj:`pd.DataFrame` - returns the DataFrame of transferred annotations. + """ return self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, @@ -604,6 +633,35 @@ def annotation_mapping( # type: ignore[misc] scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: + """Transfer annotations between distributions. + + This function transfers annotation labels (e.g. cell types) between groups of cells. + + Parameters + ---------- + mapping_mode + How to decide which label to transfer. Valid options are: + + - ``'max'`` - pick the label of the annotated cell with the highest mapping weight. + - ``'sum'`` - compute :meth:`cell_transition` of annotation labels to target cells and + pick the label with the highest transition weight. + annotation_label + Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. + forward + If :obj:`True`, transfer the annotations from ``source`` to ``target``. + source + Key identifying the source distribution. + target + Key identifying the target distribution. + scale_by_marginals + Whether to scale by the source :term:`marginals`. + cell_transition_kwargs + Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + + Returns + ------- + :obj:`pd.DataFrame` - returns the DataFrame of transferred annotations. + """ return self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index b10417660..38a1d5bf5 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -250,6 +250,35 @@ def annotation_mapping( scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: + """Transfer annotations between distributions. + + This function transfers annotation labels (e.g. cell types) between groups of cells. + + Parameters + ---------- + mapping_mode + How to decide which label to transfer. Valid options are: + + - ``'max'`` - pick the label of the annotated cell with the highest mapping weight. + - ``'sum'`` - compute :meth:`cell_transition` of annotation labels to target cells and + pick the label with the highest transition weight. + annotation_label + Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. + forward + If :obj:`True`, transfer the annotations from ``source`` to ``target``. + source + Key identifying the source distribution. + target + Key identifying the target distribution. + scale_by_marginals + Whether to scale by the source :term:`marginals`. + cell_transition_kwargs + Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + + Returns + ------- + :obj:`pd.DataFrame` - returns the DataFrame of transferred annotations. + """ return self._annotation_mapping( mapping_mode=mapping_mode, annotation_label=annotation_label, From e13ef406b924a7ce4eb5f51392776da313d49cc9 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Thu, 18 Jan 2024 19:44:18 +0100 Subject: [PATCH 54/58] lint docstrings --- src/moscot/problems/cross_modality/_mixins.py | 4 ++-- src/moscot/problems/space/_mixins.py | 4 ++-- src/moscot/problems/time/_mixins.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 7006d580b..b428324f6 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -221,13 +221,13 @@ def annotation_mapping( # type: ignore[misc] scale_by_marginals Whether to scale by the source :term:`marginals`. other_adata - The second :obj:`anndata.AnnData` if present. + The second :class:`anndata.AnnData` if present. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. Returns ------- - :obj:`pd.DataFrame` - returns the DataFrame of transferred annotations. + :class:`~pandas.DataFrame`. - returns the DataFrame of transferred annotations. """ return self._annotation_mapping( mapping_mode=mapping_mode, diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 24d2bfce8..651e5b733 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -321,7 +321,7 @@ def annotation_mapping( # type: ignore[misc] Returns ------- - :obj:`pd.DataFrame` - returns the DataFrame of transferred annotations. + :class:`~pandas.DataFrame` - returns the DataFrame of transferred annotations. """ return self._annotation_mapping( mapping_mode=mapping_mode, @@ -660,7 +660,7 @@ def annotation_mapping( # type: ignore[misc] Returns ------- - :obj:`pd.DataFrame` - returns the DataFrame of transferred annotations. + :class:`~pandas.DataFrame` - returns the DataFrame of transferred annotations. """ return self._annotation_mapping( mapping_mode=mapping_mode, diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 38a1d5bf5..b45c732d6 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -277,7 +277,7 @@ def annotation_mapping( Returns ------- - :obj:`pd.DataFrame` - returns the DataFrame of transferred annotations. + :class:`~pandas.DataFrame` - returns the DataFrame of transferred annotations. """ return self._annotation_mapping( mapping_mode=mapping_mode, From a775e701ed7dd22b31968f45c55efc2dc809643c Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Fri, 19 Jan 2024 10:29:39 +0100 Subject: [PATCH 55/58] unexpose scale_by_marginals and edits --- src/moscot/problems/cross_modality/_mixins.py | 17 ++++--------- src/moscot/problems/space/_mixins.py | 24 +++++++------------ src/moscot/problems/time/_mixins.py | 12 ++++------ 3 files changed, 17 insertions(+), 36 deletions(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index b428324f6..893b71abc 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -194,22 +194,20 @@ def annotation_mapping( # type: ignore[misc] forward: bool, source: str = "src", target: str = "tgt", - scale_by_marginals: bool = True, - other_adata: Optional[str] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: """Transfer annotations between distributions. - This function transfers annotation labels (e.g. cell types) between groups of cells. + This function transfers annotations (e.g. cell type labels) between distributions of cells. Parameters ---------- mapping_mode How to decide which label to transfer. Valid options are: - - ``'max'`` - pick the label of the annotated cell with the highest mapping weight. - - ``'sum'`` - compute :meth:`cell_transition` of annotation labels to target cells and - pick the label with the highest transition weight. + - ``'max'`` - pick the label of the annotated cell with the highest matching probability. + - ``'sum'`` - aggregate the annotated cells by label then + pick the label with the highest total matching probability. annotation_label Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. forward @@ -218,10 +216,6 @@ def annotation_mapping( # type: ignore[misc] Key identifying the source distribution. target Key identifying the target distribution. - scale_by_marginals - Whether to scale by the source :term:`marginals`. - other_adata - The second :class:`anndata.AnnData` if present. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -236,8 +230,7 @@ def annotation_mapping( # type: ignore[misc] target=target, key=self.batch_key, forward=forward, - other_adata=self.adata_tgt if other_adata is None else other_adata, - scale_by_marginals=scale_by_marginals, + other_adata=self.adata_tgt, cell_transition_kwargs=cell_transition_kwargs, ) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 651e5b733..5328d2c5a 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -291,21 +291,20 @@ def annotation_mapping( # type: ignore[misc] forward: bool, source: str = "src", target: str = "tgt", - scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: """Transfer annotations between distributions. - This function transfers annotation labels (e.g. cell types) between groups of cells. + This function transfers annotations (e.g. cell type labels) between distributions of cells. Parameters ---------- mapping_mode How to decide which label to transfer. Valid options are: - - ``'max'`` - pick the label of the annotated cell with the highest mapping weight. - - ``'sum'`` - compute :meth:`cell_transition` of annotation labels to target cells and - pick the label with the highest transition weight. + - ``'max'`` - pick the label of the annotated cell with the highest matching probability. + - ``'sum'`` - aggregate the annotated cells by label then + pick the label with the highest total matching probability. annotation_label Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. forward @@ -314,8 +313,6 @@ def annotation_mapping( # type: ignore[misc] Key identifying the source distribution. target Key identifying the target distribution. - scale_by_marginals - Whether to scale by the source :term:`marginals`. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -330,7 +327,6 @@ def annotation_mapping( # type: ignore[misc] target=target, key=self.batch_key, forward=forward, - scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) @@ -630,21 +626,20 @@ def annotation_mapping( # type: ignore[misc] source: K, target: Union[K, str] = "tgt", forward: bool = False, - scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: """Transfer annotations between distributions. - This function transfers annotation labels (e.g. cell types) between groups of cells. + This function transfers annotations (e.g. cell type labels) between distributions of cells. Parameters ---------- mapping_mode How to decide which label to transfer. Valid options are: - - ``'max'`` - pick the label of the annotated cell with the highest mapping weight. - - ``'sum'`` - compute :meth:`cell_transition` of annotation labels to target cells and - pick the label with the highest transition weight. + - ``'max'`` - pick the label of the annotated cell with the highest matching probability. + - ``'sum'`` - aggregate the annotated cells by label then + pick the label with the highest total matching probability. annotation_label Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. forward @@ -653,8 +648,6 @@ def annotation_mapping( # type: ignore[misc] Key identifying the source distribution. target Key identifying the target distribution. - scale_by_marginals - Whether to scale by the source :term:`marginals`. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -670,7 +663,6 @@ def annotation_mapping( # type: ignore[misc] forward=forward, key=self.batch_key, other_adata=self.adata_sc, - scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index b45c732d6..5752a87da 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -247,21 +247,20 @@ def annotation_mapping( forward: bool, source: K, target: K, - scale_by_marginals: bool = True, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: """Transfer annotations between distributions. - This function transfers annotation labels (e.g. cell types) between groups of cells. + This function transfers annotations (e.g. cell type labels) between distributions of cells. Parameters ---------- mapping_mode How to decide which label to transfer. Valid options are: - - ``'max'`` - pick the label of the annotated cell with the highest mapping weight. - - ``'sum'`` - compute :meth:`cell_transition` of annotation labels to target cells and - pick the label with the highest transition weight. + - ``'max'`` - pick the label of the annotated cell with the highest matching probability. + - ``'sum'`` - aggregate the annotated cells by label then + pick the label with the highest total matching probability. annotation_label Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. forward @@ -270,8 +269,6 @@ def annotation_mapping( Key identifying the source distribution. target Key identifying the target distribution. - scale_by_marginals - Whether to scale by the source :term:`marginals`. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -287,7 +284,6 @@ def annotation_mapping( key=self._temporal_key, forward=forward, other_adata=None, - scale_by_marginals=scale_by_marginals, cell_transition_kwargs=cell_transition_kwargs, ) From 93b561fd6b5107ce27d6b656e0cd58c986c0f166 Mon Sep 17 00:00:00 2001 From: ArinaDanilina <98481272+ArinaDanilina@users.noreply.github.com> Date: Fri, 19 Jan 2024 10:49:43 +0100 Subject: [PATCH 56/58] Update src/moscot/problems/cross_modality/_mixins.py Co-authored-by: Giovanni Palla <25887487+giovp@users.noreply.github.com> --- src/moscot/problems/cross_modality/_mixins.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 893b71abc..18e896e5f 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -208,6 +208,7 @@ def annotation_mapping( # type: ignore[misc] - ``'max'`` - pick the label of the annotated cell with the highest matching probability. - ``'sum'`` - aggregate the annotated cells by label then pick the label with the highest total matching probability. + annotation_label Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. forward From 8fb37769cc7b082f9b02da0a0d874516b98fb250 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jan 2024 09:50:16 +0000 Subject: [PATCH 57/58] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/problems/cross_modality/_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 18e896e5f..cc4079ab8 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -208,7 +208,7 @@ def annotation_mapping( # type: ignore[misc] - ``'max'`` - pick the label of the annotated cell with the highest matching probability. - ``'sum'`` - aggregate the annotated cells by label then pick the label with the highest total matching probability. - + annotation_label Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. forward From 557139d5c3b99239e6c0b5185f17df841d9774c4 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Fri, 19 Jan 2024 10:59:30 +0100 Subject: [PATCH 58/58] returns -> Returns --- src/moscot/problems/cross_modality/_mixins.py | 3 +-- src/moscot/problems/space/_mixins.py | 4 ++-- src/moscot/problems/time/_mixins.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 18e896e5f..ce58f84a4 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -208,7 +208,6 @@ def annotation_mapping( # type: ignore[misc] - ``'max'`` - pick the label of the annotated cell with the highest matching probability. - ``'sum'`` - aggregate the annotated cells by label then pick the label with the highest total matching probability. - annotation_label Key in :attr:`~anndata.AnnData.obs` where the annotation is stored. forward @@ -222,7 +221,7 @@ def annotation_mapping( # type: ignore[misc] Returns ------- - :class:`~pandas.DataFrame`. - returns the DataFrame of transferred annotations. + :class:`~pandas.DataFrame` - Returns the DataFrame of transferred annotations. """ return self._annotation_mapping( mapping_mode=mapping_mode, diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 5328d2c5a..0aa84c326 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -318,7 +318,7 @@ def annotation_mapping( # type: ignore[misc] Returns ------- - :class:`~pandas.DataFrame` - returns the DataFrame of transferred annotations. + :class:`~pandas.DataFrame` - Returns the DataFrame of transferred annotations. """ return self._annotation_mapping( mapping_mode=mapping_mode, @@ -653,7 +653,7 @@ def annotation_mapping( # type: ignore[misc] Returns ------- - :class:`~pandas.DataFrame` - returns the DataFrame of transferred annotations. + :class:`~pandas.DataFrame` - Returns the DataFrame of transferred annotations. """ return self._annotation_mapping( mapping_mode=mapping_mode, diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 5752a87da..597757901 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -274,7 +274,7 @@ def annotation_mapping( Returns ------- - :class:`~pandas.DataFrame` - returns the DataFrame of transferred annotations. + :class:`~pandas.DataFrame` - Returns the DataFrame of transferred annotations. """ return self._annotation_mapping( mapping_mode=mapping_mode,