diff --git a/setup.cfg b/setup.cfg index a9c6fb5b..842cbbc7 100755 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,7 @@ package_dir = python_requires = >=3.8 install_requires = + importlib-metadata pandas>=1.4,<2 torch>=1.10.0,<2.0 scikit-learn>=1.0 diff --git a/src/synthcity/plugins/__init__.py b/src/synthcity/plugins/__init__.py index f1c64a27..e2057f29 100755 --- a/src/synthcity/plugins/__init__.py +++ b/src/synthcity/plugins/__init__.py @@ -15,6 +15,7 @@ "time_series", "domain_adaptation", "images", + "debug", ] plugins = {} diff --git a/src/synthcity/plugins/core/plugin.py b/src/synthcity/plugins/core/plugin.py index 78b309aa..a57cc3a5 100755 --- a/src/synthcity/plugins/core/plugin.py +++ b/src/synthcity/plugins/core/plugin.py @@ -560,7 +560,7 @@ class PluginLoader: @validate_arguments def __init__(self, plugins: list, expected_type: Type, categories: list) -> None: - self.reload() + # self.reload() global PLUGIN_CATEGORY_REGISTRY PLUGIN_CATEGORY_REGISTRY = {cat: [] for cat in categories} self._refresh() @@ -577,7 +577,6 @@ def _refresh(self) -> None: """Refresh the list of available plugins""" self._plugins: Dict[str, Type[Plugin]] = PLUGIN_REGISTRY self._categories: Dict[str, List[str]] = PLUGIN_CATEGORY_REGISTRY - print("Refreshing: ", self._plugins, self._categories) @validate_arguments def _load_single_plugin_impl(self, plugin_name: str) -> Optional[Type]: @@ -666,7 +665,8 @@ def _add_category(self, category: str, name: str) -> "PluginLoader": def add(self, name: str, cls: Type) -> "PluginLoader": """Add a new plugin""" - print("Adding: ", name, cls) + global PLUGIN_REGISTRY + global PLUGIN_CATEGORY_REGISTRY self._refresh() if name in self._plugins: log.info(f"Plugin {name} already exists. Overwriting") @@ -701,7 +701,6 @@ def get(self, name: str, *args: Any, **kwargs: Any) -> Any: """ self._refresh() if name not in self._plugins and name not in self._available_plugins: - print(self._plugins, self._available_plugins) raise ValueError(f"Plugin {name} doesn't exist.") if name not in self._plugins: @@ -748,9 +747,8 @@ def __getitem__(self, key: str) -> Any: return self.get(key) def reload(self) -> "PluginLoader": - print("Reloading") global PLUGIN_CATEGORY_REGISTRY - PLUGIN_CATEGORY_REGISTRY = dict() global PLUGIN_REGISTRY + PLUGIN_CATEGORY_REGISTRY = dict() PLUGIN_REGISTRY = dict() return self diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index d3f8dc29..d76c0c12 100755 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -1,8 +1,8 @@ # third party import numpy as np -import pkg_resources import pytest from generic_helpers import generate_fixtures +from importlib_metadata import PackageNotFoundError, distribution from sklearn.datasets import load_diabetes, load_iris # synthcity absolute @@ -24,8 +24,13 @@ if not is_missing_goggle_deps: goggle_dependencies = {"dgl", "torch-scatter", "torch-sparse", "torch-geometric"} - installed = {pkg.key for pkg in pkg_resources.working_set} - is_missing_goggle_deps = len(goggle_dependencies - installed) > 0 + missing_deps = [] + for dep in goggle_dependencies: + try: + distribution(dep) + except PackageNotFoundError: + missing_deps.append(dep) + is_missing_goggle_deps = len(missing_deps) > 0 @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") diff --git a/tests/plugins/generic/test_great.py b/tests/plugins/generic/test_great.py index b87d10d9..f2db4670 100755 --- a/tests/plugins/generic/test_great.py +++ b/tests/plugins/generic/test_great.py @@ -34,31 +34,31 @@ } -@pytest.mark.skipif(sys.version_info < (3, 9)) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_sanity(test_plugin: Plugin) -> None: assert test_plugin is not None -@pytest.mark.skipif(sys.version_info < (3, 9)) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_name(test_plugin: Plugin) -> None: assert test_plugin.name() == plugin_name -@pytest.mark.skipif(sys.version_info < (3, 9)) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_type(test_plugin: Plugin) -> None: assert test_plugin.type() == "generic" -@pytest.mark.skipif(sys.version_info < (3, 9)) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_hyperparams(test_plugin: Plugin) -> None: assert len(test_plugin.hyperparameter_space()) == 1 -@pytest.mark.skipif(sys.version_info < (3, 9)) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.parametrize( "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) ) @@ -67,7 +67,7 @@ def test_plugin_fit(test_plugin: Plugin) -> None: test_plugin.fit(GenericDataLoader(X)) -@pytest.mark.skipif(sys.version_info < (3, 9)) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.skipif( IN_GITHUB_ACTIONS, reason="GReaT generate required too much memory to reliably run in GitHub Actions", @@ -104,7 +104,7 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: @pytest.mark.slow -@pytest.mark.skipif(sys.version_info < (3, 9)) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.skipif( IN_GITHUB_ACTIONS, reason="GReaT generate required too much memory to reliably run in GitHub Actions", @@ -147,7 +147,7 @@ def test_sample_hyperparams() -> None: assert plugin(**args) is not None -@pytest.mark.skipif(sys.version_info < (3, 9)) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.skipif( IN_GITHUB_ACTIONS, reason="GReaT generate required too much memory to reliably run in GitHub Actions", @@ -182,7 +182,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d @pytest.mark.slow -@pytest.mark.skipif(sys.version_info < (3, 9)) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.skipif( IN_GITHUB_ACTIONS, reason="GReaT generate required too much memory to reliably run in GitHub Actions", diff --git a/tests/plugins/test_plugin_add.py b/tests/plugins/test_plugin_add.py index c3bc2c1b..d631f8b4 100755 --- a/tests/plugins/test_plugin_add.py +++ b/tests/plugins/test_plugin_add.py @@ -1,6 +1,4 @@ # stdlib -import glob -from pathlib import Path from typing import Any, List # third party @@ -45,15 +43,17 @@ def test_add_dummy_plugin() -> None: # get the list of plugins that are loaded generators = Plugins() - # Get the list of plugins that come with the package - plugins_dir = Path.cwd() / "src/synthcity/plugins" - plugins_list = [] - for plugin_type in plugins_dir.iterdir(): - plugin_paths = glob.glob(str(plugins_dir / plugin_type / "plugin*.py")) - plugins_list.extend([Path(path).stem for path in plugin_paths]) + # # Get the list of plugins that come with the package + # plugins_dir = Path.cwd() / "src/synthcity/plugins" + # plugins_list = [] + # for plugin_type in plugins_dir.iterdir(): + # plugin_paths = glob.glob(str(plugins_dir / plugin_type / "plugin*.py")) + # plugins_list.extend([Path(path).stem for path in plugin_paths]) + available_plugins = Plugins().list() + print(sorted(available_plugins)) # Test that the new plugin is not in the list plugins in the package - assert "copy_data" not in plugins_list + assert "copy_data" not in available_plugins # Add the new plugin generators.add("copy_data", DummyCopyDataPlugin) @@ -71,5 +71,6 @@ def test_add_dummy_plugin() -> None: gen.generate(count=10) # Test that the new plugin is now in the list of available plugins - available_plugins = Plugins().list() + available_plugins = Plugins(categories=["debug"]).list() + print(sorted(available_plugins)) assert "copy_data" in available_plugins