Skip to content

Commit

Permalink
Extending get_values() for tables annotating points (#537)
Browse files Browse the repository at this point in the history
* Extending get_values() for tables annotating points

made get_values work for tables that annotate points

* fix docstring

---------

Co-authored-by: Luca Marconato <[email protected]>
  • Loading branch information
dylanclam12 and LucaMarconato authored Jun 14, 2024
1 parent 3c69b26 commit 5034933
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
7 changes: 4 additions & 3 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def get_element_instances(
Returns
-------
pd.Series
The instances (index values) of the SpatialElement.
pd.Series with the instances (index values) of the SpatialElement.
"""
raise ValueError(f"The object type {type(element)} is not supported.")

Expand Down Expand Up @@ -184,6 +183,8 @@ def _filter_table_by_elements(
instances = np.sort(instances)
elif get_model(element) == ShapesModel:
instances = element.index.to_numpy()
elif get_model(element) == PointsModel:
instances = element.compute().index.to_numpy()
else:
continue
indices = ((table.obs[region_key] == name) & (table.obs[instance_key].isin(instances))).to_numpy()
Expand Down Expand Up @@ -785,7 +786,7 @@ def _locate_value(
origins = _get_table_origins(element=el, value_key=value_key, origins=origins)

# adding from the obs columns or var
if model in [ShapesModel, PointsModel, Labels2DModel, Labels3DModel] and sdata is not None:
if model in [PointsModel, ShapesModel, Labels2DModel, Labels3DModel] and sdata is not None:
table = sdata.tables.get(table_name) if table_name is not None else None
if table is not None:
# check if the table is annotating the element
Expand Down
24 changes: 23 additions & 1 deletion tests/core/query/test_relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def _check_location(locations: list[_ValueOrigin], origin: str, is_categorical:
)


def test_get_values_df(sdata_query_aggregation):
def test_get_values_df_shapes(sdata_query_aggregation):
# test with a single value, in the dataframe; using sdata + element_name
v = get_values(
value_key="numerical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles", table_name="table"
Expand Down Expand Up @@ -621,6 +621,28 @@ def test_get_values_df(sdata_query_aggregation):
)


def test_get_values_df_points(points):
# testing get_values() for points, we keep the test more minimalistic than the one for shapes
p = points["points_0"]
p = p.drop("instance_id", axis=1)
p.index.compute()
n = len(p)
obs = pd.DataFrame(index=p.index, data={"region": ["points_0"] * n, "instance_id": range(n)})
obs["region"] = obs["region"].astype("category")
table = TableModel.parse(
AnnData(shape=(n, 0), obs=obs), region="points_0", region_key="region", instance_key="instance_id"
)
points["points_0"] = p
points["table"] = table

assert get_values(value_key="region", element_name="points_0", sdata=points, table_name="table").shape == (300, 1)
get_values(value_key="instance_id", element_name="points_0", sdata=points, table_name="table")
get_values(value_key=["x", "y"], element_name="points_0", sdata=points, table_name="table")
get_values(value_key="genes", element_name="points_0", sdata=points, table_name="table")

pass


def test_get_values_obsm(adata_labels: AnnData):
get_values(value_key="tensor", element=adata_labels)

Expand Down

0 comments on commit 5034933

Please sign in to comment.