diff --git a/pyproject.toml b/pyproject.toml index b1416b7..d9df207 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,16 +18,16 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development :: Libraries :: Java Libraries" + "Topic :: Software Development :: Libraries :: Java Libraries", ] dependencies = [ "Jpype1==1.4.1", - "pyarrow==14.0.1", + "pyarrow==17.0.0", "matplotlib~=3.6.3", "pandas~=1.5.3", "numpy~=1.24.1", - "jupyter-bokeh~=3.0.5" + "jupyter-bokeh~=3.0.5", ] [project.optional-dependencies] @@ -38,7 +38,6 @@ dev = [ "joblib~=1.2.0", "jupyterlab~=3.5.3", "numpydoc==1.5.0", - "pyarrow==14.0.1", "pylint==2.15.6", "pytest~=7.2.1", "pytest-benchmark==4.0.0", @@ -47,25 +46,21 @@ dev = [ "setuptools", "twine==3.4.2", "wheel~=0.38.4", - "xgboost==1.4.2" -] -extras = [ - "aix360[default,tsice,tslime,tssaliency]==0.3.0" + "xgboost==1.4.2", ] +extras = ["aix360[default,tsice,tslime,tssaliency]==0.3.0"] detoxify = [ "transformers~=4.36.2", "datasets", - "scipy", - "torch", + "scipy~=1.12.0", + "torch~=2.2.1", "iter-tools", "evaluate", - "trl" + "trl", ] -api = [ - "kubernetes" -] +api = ["kubernetes"] [project.urls] homepage = "https://github.com/trustyai-explainability/trustyai-explainability-python" @@ -83,7 +78,7 @@ package-dir = { "" = "src" } log_cli = true addopts = '-m="not block_plots"' markers = [ - "block_plots: Test plots will block execution of subsequent tests until closed" + "block_plots: Test plots will block execution of subsequent tests until closed", ] [tool.setuptools.packages.find] diff --git a/tests/general/test_conversions.py b/tests/general/test_conversions.py index 7893da3..c47118e 100644 --- a/tests/general/test_conversions.py +++ b/tests/general/test_conversions.py @@ -65,7 +65,7 @@ def test_categorical_object_domain_list(): jdomain = feature_domain(domain) assert str(jdomain.getClass().getSimpleName()) == "ObjectFeatureDomain" assert jdomain.getCategories().size() == 2 - assert jdomain.getCategories().containsAll(domain) + assert sorted([o.getObject() for o in jdomain.getCategories()]) == sorted(domain) def test_categorical_object_domain_list_2(): @@ -74,7 +74,7 @@ def test_categorical_object_domain_list_2(): jdomain = feature_domain(domain) assert str(jdomain.getClass().getSimpleName()) == "ObjectFeatureDomain" assert jdomain.getCategories().size() == 2 - assert jdomain.getCategories().containsAll(domain) + assert sorted([o.getObject() for o in jdomain.getCategories()]) == sorted(domain) def test_empty_domain():