diff --git a/pyproject.toml b/pyproject.toml index 6262865..3d7cbc0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ keywords = [ requires-python = ">=3.7" dependencies = [ - "equistore @ https://github.com/lab-cosmo/equistore/archive/a9b9a2a.zip", + "equistore @ https://github.com/lab-cosmo/equistore/archive/c022fde.zip", "numpy", "scipy", "skmatter" diff --git a/tests/equisolve_tests/numpy/feature_selection.py b/tests/equisolve_tests/numpy/feature_selection.py index ded306d..f1a830d 100644 --- a/tests/equisolve_tests/numpy/feature_selection.py +++ b/tests/equisolve_tests/numpy/feature_selection.py @@ -29,7 +29,7 @@ def X(self): def test_fit(self, X, selector_class, skmatter_selector_class): selector = selector_class(n_to_select=2) selector.fit(X) - support = selector.support[0].properties["properties"] + support = selector.support[0].properties skmatter_selector = skmatter_selector_class(n_to_select=2) skmatter_selector.fit(X[0].values) @@ -41,7 +41,7 @@ def test_fit(self, X, selector_class, skmatter_selector_class): ), ) - assert_equal(support, skmatter_support_labels) + assert support == skmatter_support_labels @pytest.mark.parametrize( "selector_class, skmatter_selector_class", diff --git a/tests/equisolve_tests/numpy/sample_selection.py b/tests/equisolve_tests/numpy/sample_selection.py index f5cfda8..3efe657 100644 --- a/tests/equisolve_tests/numpy/sample_selection.py +++ b/tests/equisolve_tests/numpy/sample_selection.py @@ -29,19 +29,20 @@ def X(self): def test_fit(self, X, selector_class, skmatter_selector_class): selector = selector_class(n_to_select=2) selector.fit(X) - support = selector.support[0].samples["structure"] + support = selector.support[0].samples skmatter_selector = skmatter_selector_class(n_to_select=2) skmatter_selector.fit(X[0].values) skmatter_support = skmatter_selector.get_support(indices=True) skmatter_support_labels = Labels( - names=["structure"], + names=["sample", "structure"], values=np.array( - [[support_i] for support_i in skmatter_support], dtype=np.int32 + [[support_i, support_i] for support_i in skmatter_support], + dtype=np.int32, ), ) - assert_equal(support, skmatter_support_labels) + assert support == skmatter_support_labels @pytest.mark.parametrize( "selector_class, skmatter_selector_class",