diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 090cbca4..c9498a9c 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -97,8 +97,7 @@ def get_element_instances( Returns ------- - pd.Series - The instances (index values) of the SpatialElement. + pd.Series with the instances (index values) of the SpatialElement. """ raise ValueError(f"The object type {type(element)} is not supported.") @@ -184,6 +183,8 @@ def _filter_table_by_elements( instances = np.sort(instances) elif get_model(element) == ShapesModel: instances = element.index.to_numpy() + elif get_model(element) == PointsModel: + instances = element.compute().index.to_numpy() else: continue indices = ((table.obs[region_key] == name) & (table.obs[instance_key].isin(instances))).to_numpy() @@ -785,7 +786,7 @@ def _locate_value( origins = _get_table_origins(element=el, value_key=value_key, origins=origins) # adding from the obs columns or var - if model in [ShapesModel, PointsModel, Labels2DModel, Labels3DModel] and sdata is not None: + if model in [PointsModel, ShapesModel, Labels2DModel, Labels3DModel] and sdata is not None: table = sdata.tables.get(table_name) if table_name is not None else None if table is not None: # check if the table is annotating the element diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index 80cd6d39..60b7bd36 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -520,7 +520,7 @@ def _check_location(locations: list[_ValueOrigin], origin: str, is_categorical: ) -def test_get_values_df(sdata_query_aggregation): +def test_get_values_df_shapes(sdata_query_aggregation): # test with a single value, in the dataframe; using sdata + element_name v = get_values( value_key="numerical_in_gdf", sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" @@ -621,6 +621,28 @@ def test_get_values_df(sdata_query_aggregation): ) +def test_get_values_df_points(points): + # testing get_values() for points, we keep the test more minimalistic than the one for shapes + p = points["points_0"] + p = p.drop("instance_id", axis=1) + p.index.compute() + n = len(p) + obs = pd.DataFrame(index=p.index, data={"region": ["points_0"] * n, "instance_id": range(n)}) + obs["region"] = obs["region"].astype("category") + table = TableModel.parse( + AnnData(shape=(n, 0), obs=obs), region="points_0", region_key="region", instance_key="instance_id" + ) + points["points_0"] = p + points["table"] = table + + assert get_values(value_key="region", element_name="points_0", sdata=points, table_name="table").shape == (300, 1) + get_values(value_key="instance_id", element_name="points_0", sdata=points, table_name="table") + get_values(value_key=["x", "y"], element_name="points_0", sdata=points, table_name="table") + get_values(value_key="genes", element_name="points_0", sdata=points, table_name="table") + + pass + + def test_get_values_obsm(adata_labels: AnnData): get_values(value_key="tensor", element=adata_labels)