Skip to content

Commit

Permalink
Cost Tests (#637)
Browse files Browse the repository at this point in the history
* added tests for cost utils

* add barcode distance tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix naming

* refactor for precommit

* added tests for leaf_distance

* fix docstring formatting

* fixed the tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dominik Klein <[email protected]>
Co-authored-by: Dominik Klein <[email protected]>
  • Loading branch information
4 people authored Jan 9, 2024
1 parent b2bcf4b commit 924c78a
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 8 deletions.
7 changes: 5 additions & 2 deletions src/moscot/costs/_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def barcodes(self) -> ArrayLike:
class LeafDistance(BaseCost):
"""`Shortest path <https://en.wikipedia.org/wiki/Shortest_path_problem>`_ distance on a weighted tree.
.. note::
This class ignores `attr` which is always set to `uns`.
.. seealso::
- See :doc:`../notebooks/examples/problems/600_leaf_distance` on how to use this cost
in the :class:`~moscot.problems.time.LineageProblem`.
Expand All @@ -83,7 +86,7 @@ class LeafDistance(BaseCost):
def __init__(
self, adata: AnnData, weight: Union[str, Callable[[Any, Any, Dict[Any, Any]], float]] = "weight", **kwargs: Any
):
kwargs["attr"] = "uns"
kwargs["attr"] = "uns" # TODO: maybe document that attr is ignored
super().__init__(adata, **kwargs)
self._weight = weight

Expand Down Expand Up @@ -132,7 +135,7 @@ def _scaled_hamming_dist(x: ArrayLike, y: ArrayLike) -> float:

# there may not be any sites where both were measured
if not len(b1):
return np.nan
raise ValueError("No shared indices.")
b2 = y[shared_indices]

differences = b1 != b2
Expand Down
19 changes: 15 additions & 4 deletions src/moscot/costs/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ def get_available_costs(backend: Optional[str] = None) -> Dict[str, Tuple[str, .
-------
Dictionary with keys as backend names and values as registered cost functions.
"""
groups: Dict[str, List[str]] = collections.defaultdict(list)
for key in _REGISTRY:
back, *name = key.split(_SEP)
groups[back].append(_SEP.join(name))
groups: Dict[str, List[str]] = _get_available_backends_and_costs()

if backend is None:
return {k: tuple(v) for k, v in groups.items()}
Expand All @@ -46,3 +43,17 @@ def get_available_costs(backend: Optional[str] = None) -> Dict[str, Tuple[str, .
def register_cost(name: str, *, backend: str) -> Any:
"""Register cost function for a specific backend."""
return _REGISTRY.register(f"{backend}{_SEP}{name}")


def _get_available_backends_and_costs():
"""Return a dictionary of available backends with their corresponding list of costs.
Returns
-------
Default dictionary with keys as backend names and values as registered cost functions.
"""
groups: Dict[str, List[str]] = collections.defaultdict(list)
for key in _REGISTRY:
back, *name = key.split(_SEP)
groups[back].append(_SEP.join(name))
return groups
120 changes: 118 additions & 2 deletions tests/costs/test_cost.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,118 @@
class TestCost: # TODO
pass
import pytest

import numpy as np
import pandas as pd

import anndata as ad

from moscot.costs._costs import BarcodeDistance, _scaled_hamming_dist
from moscot.costs._utils import get_cost


class TestBarcodeDistance:
RNG = np.random.RandomState(0)

@staticmethod
def test_barcode_distance_init():
adata = ad.AnnData(TestBarcodeDistance.RNG.rand(3, 3), obsm={"barcodes": TestBarcodeDistance.RNG.rand(3, 3)})
# initialization failure when no adata is provided
with pytest.raises(TypeError):
get_cost("barcode_distance", backend="moscot")
# initialization failure when invalid key is provided
with pytest.raises(KeyError):
get_cost("barcode_distance", backend="moscot", adata=adata, key="invalid_key", attr="obsm")
# initialization failure when invalid attr
with pytest.raises(AttributeError):
get_cost("barcode_distance", backend="moscot", adata=adata, key="barcodes", attr="invalid_attr")
# check if not None
cost_fn: BarcodeDistance = get_cost(
"barcode_distance", backend="moscot", adata=adata, key="barcodes", attr="obsm"
)
assert cost_fn is not None

@staticmethod
def test_scaled_hamming_dist_with_sample_inputs():
# Sample input arrays
x = np.array([1, -1, 0, 1])
y = np.array([0, 1, 1, 1])

# Expected output
expected_distance = 2.0 / 3

# Compute the scaled Hamming distance
computed_distance = _scaled_hamming_dist(x, y)

# Check if the computed distance matches the expected distance
np.testing.assert_almost_equal(computed_distance, expected_distance, decimal=4)

@staticmethod
def test_scaled_hamming_dist_if_nan():
# Sample input arrays with no shared indices
x = np.array([-1, -1, 0, 1])
y = np.array([0, 1, -1, -1])

with pytest.raises(ValueError, match="No shared indices."):
_scaled_hamming_dist(x, y)

@staticmethod
def test_barcode_distance_with_sample_input():
# Example barcodes
barcodes = np.array([[1, 0, 1], [1, 1, 0], [0, 1, 1]])

# Create a dummy AnnData object with the example barcodes
adata = ad.AnnData(TestBarcodeDistance.RNG.rand(3, 3))
adata.obsm["barcodes"] = barcodes

# Initialize BarcodeDistance
cost_fn: BarcodeDistance = get_cost(
"barcode_distance", backend="moscot", adata=adata, key="barcodes", attr="obsm"
)

# Compute distances
computed_distances = cost_fn()

# Expected distance matrix
expected_distances = np.array([[0.0, 2.0, 2.0], [2.0, 0.0, 2.0], [2.0, 2.0, 0.0]]) / 3.0

# Check if the computed distances match the expected distances
np.testing.assert_almost_equal(computed_distances, expected_distances, decimal=4)


class TestLeafDistance:
@staticmethod
def create_dummy_adata_leaf():
import networkx as nx

adata: ad.AnnData = ad.AnnData(
X=np.ones((10, 10)),
obs=pd.DataFrame(data={"day": [0, 0, 0, 1, 1, 1, 2, 2, 2, 2]}),
)
g: nx.DiGraph = nx.DiGraph()
g.add_nodes_from([str(i) for i in range(3)] + ["root"])
g.add_edges_from([("root", str(i)) for i in range(3)])
adata.uns["tree"] = {0: g}
return adata

@staticmethod
def test_leaf_distance_init():
adata = TestLeafDistance.create_dummy_adata_leaf()
# initialization failure when no adata is provided
with pytest.raises(TypeError):
get_cost("leaf_distance", backend="moscot")
# initialization failure when invalid key is provided
with pytest.raises(KeyError, match="Unable to find tree in"):
get_cost("leaf_distance", backend="moscot", adata=adata, key="invalid_key", attr="uns", dist_key=0)
# initialization failure when invalid dist_key is provided
with pytest.raises(KeyError, match="Unable to find tree in"):
get_cost("leaf_distance", backend="moscot", adata=adata, key="tree", attr="uns")
# when leaves do not match adata.obs_names
with pytest.raises(ValueError, match="Leaves do not match"):
get_cost("leaf_distance", backend="moscot", adata=adata, key="tree", attr="uns", dist_key=0)()
# now giving valid input
adata0 = adata[adata.obs.day == 0]
cost_fn = get_cost("leaf_distance", backend="moscot", adata=adata0, key="tree", attr="uns", dist_key=0)
np.testing.assert_equal(cost_fn(), np.array([[0, 2, 2], [2, 0, 2], [2, 2, 0]]))
# when tree is not a networkx.Graph
adata0.uns["tree"] = {0: 1}
with pytest.raises(TypeError, match="networkx.Graph"):
get_cost("leaf_distance", backend="moscot", adata=adata0, key="tree", attr="uns", dist_key=0)
52 changes: 52 additions & 0 deletions tests/costs/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest

from moscot.costs._utils import (
_get_available_backends_and_costs,
get_available_costs,
get_cost,
)


class TestCostUtils:
ALL_BACKENDS_N_COSTS = {
"moscot": ("barcode_distance", "leaf_distance"),
"ott": (
"euclidean",
"sq_euclidean",
"cosine",
"pnorm_p",
"sq_pnorm",
"elastic_l1",
"elastic_l2",
"elastic_stvs",
"elastic_sqk_overlap",
),
}

@staticmethod
def test_get_available_backends_n_costs():
assert dict(_get_available_backends_and_costs()) == {
k: list(v) for k, v in _get_available_backends_and_costs().items()
}

@staticmethod
def test_get_available_costs():
assert get_available_costs() == TestCostUtils.ALL_BACKENDS_N_COSTS
assert get_available_costs("moscot") == {"moscot": (TestCostUtils.ALL_BACKENDS_N_COSTS["moscot"])}
assert get_available_costs("ott") == {"ott": TestCostUtils.ALL_BACKENDS_N_COSTS["ott"]}
with pytest.raises(KeyError):
get_available_costs("foo")

@staticmethod
def test_get_cost_fails():
invalid_cost = "foo"
invalid_backend = "bar"
with pytest.raises(
ValueError, match=f"Cost `{invalid_cost!r}` is not available for backend `{invalid_backend!r}`."
):
get_cost(invalid_cost, backend=invalid_backend)
for backend in TestCostUtils.ALL_BACKENDS_N_COSTS:
with pytest.raises(
ValueError, match=f"Cost `{invalid_cost!r}` is not available for backend `{backend!r}`."
):
get_cost(invalid_cost, backend=backend)

0 comments on commit 924c78a

Please sign in to comment.