Skip to content

Commit

Permalink
fixing plugin registry and goggle dependancy check
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdavis committed Oct 9, 2023
1 parent e7834e2 commit 14c57cc
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 28 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ package_dir =
python_requires = >=3.8

install_requires =
importlib-metadata
pandas>=1.4,<2
torch>=1.10.0,<2.0
scikit-learn>=1.0
Expand Down
1 change: 1 addition & 0 deletions src/synthcity/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"time_series",
"domain_adaptation",
"images",
"debug",
]
plugins = {}

Expand Down
10 changes: 4 additions & 6 deletions src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ class PluginLoader:

@validate_arguments
def __init__(self, plugins: list, expected_type: Type, categories: list) -> None:
self.reload()
# self.reload()
global PLUGIN_CATEGORY_REGISTRY
PLUGIN_CATEGORY_REGISTRY = {cat: [] for cat in categories}
self._refresh()
Expand All @@ -577,7 +577,6 @@ def _refresh(self) -> None:
"""Refresh the list of available plugins"""
self._plugins: Dict[str, Type[Plugin]] = PLUGIN_REGISTRY
self._categories: Dict[str, List[str]] = PLUGIN_CATEGORY_REGISTRY
print("Refreshing: ", self._plugins, self._categories)

@validate_arguments
def _load_single_plugin_impl(self, plugin_name: str) -> Optional[Type]:
Expand Down Expand Up @@ -666,7 +665,8 @@ def _add_category(self, category: str, name: str) -> "PluginLoader":

def add(self, name: str, cls: Type) -> "PluginLoader":
"""Add a new plugin"""
print("Adding: ", name, cls)
global PLUGIN_REGISTRY
global PLUGIN_CATEGORY_REGISTRY
self._refresh()
if name in self._plugins:
log.info(f"Plugin {name} already exists. Overwriting")
Expand Down Expand Up @@ -701,7 +701,6 @@ def get(self, name: str, *args: Any, **kwargs: Any) -> Any:
"""
self._refresh()
if name not in self._plugins and name not in self._available_plugins:
print(self._plugins, self._available_plugins)
raise ValueError(f"Plugin {name} doesn't exist.")

if name not in self._plugins:
Expand Down Expand Up @@ -748,9 +747,8 @@ def __getitem__(self, key: str) -> Any:
return self.get(key)

def reload(self) -> "PluginLoader":
print("Reloading")
global PLUGIN_CATEGORY_REGISTRY
PLUGIN_CATEGORY_REGISTRY = dict()
global PLUGIN_REGISTRY
PLUGIN_CATEGORY_REGISTRY = dict()
PLUGIN_REGISTRY = dict()
return self
11 changes: 8 additions & 3 deletions tests/plugins/generic/test_goggle.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# third party
import numpy as np
import pkg_resources
import pytest
from generic_helpers import generate_fixtures
from importlib_metadata import PackageNotFoundError, distribution
from sklearn.datasets import load_diabetes, load_iris

# synthcity absolute
Expand All @@ -24,8 +24,13 @@

if not is_missing_goggle_deps:
goggle_dependencies = {"dgl", "torch-scatter", "torch-sparse", "torch-geometric"}
installed = {pkg.key for pkg in pkg_resources.working_set}
is_missing_goggle_deps = len(goggle_dependencies - installed) > 0
missing_deps = []
for dep in goggle_dependencies:
try:
distribution(dep)
except PackageNotFoundError:
missing_deps.append(dep)
is_missing_goggle_deps = len(missing_deps) > 0


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
Expand Down
18 changes: 9 additions & 9 deletions tests/plugins/generic/test_great.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,31 @@
}


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_sanity(test_plugin: Plugin) -> None:
assert test_plugin is not None


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_name(test_plugin: Plugin) -> None:
assert test_plugin.name() == plugin_name


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_type(test_plugin: Plugin) -> None:
assert test_plugin.type() == "generic"


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_hyperparams(test_plugin: Plugin) -> None:
assert len(test_plugin.hyperparameter_space()) == 1


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
Expand All @@ -67,7 +67,7 @@ def test_plugin_fit(test_plugin: Plugin) -> None:
test_plugin.fit(GenericDataLoader(X))


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:


@pytest.mark.slow
@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down Expand Up @@ -182,7 +182,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d


@pytest.mark.slow
@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down
21 changes: 11 additions & 10 deletions tests/plugins/test_plugin_add.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# stdlib
import glob
from pathlib import Path
from typing import Any, List

# third party
Expand Down Expand Up @@ -45,15 +43,17 @@ def test_add_dummy_plugin() -> None:
# get the list of plugins that are loaded
generators = Plugins()

# Get the list of plugins that come with the package
plugins_dir = Path.cwd() / "src/synthcity/plugins"
plugins_list = []
for plugin_type in plugins_dir.iterdir():
plugin_paths = glob.glob(str(plugins_dir / plugin_type / "plugin*.py"))
plugins_list.extend([Path(path).stem for path in plugin_paths])
# # Get the list of plugins that come with the package
# plugins_dir = Path.cwd() / "src/synthcity/plugins"
# plugins_list = []
# for plugin_type in plugins_dir.iterdir():
# plugin_paths = glob.glob(str(plugins_dir / plugin_type / "plugin*.py"))
# plugins_list.extend([Path(path).stem for path in plugin_paths])
available_plugins = Plugins().list()
print(sorted(available_plugins))

# Test that the new plugin is not in the list plugins in the package
assert "copy_data" not in plugins_list
assert "copy_data" not in available_plugins

# Add the new plugin
generators.add("copy_data", DummyCopyDataPlugin)
Expand All @@ -71,5 +71,6 @@ def test_add_dummy_plugin() -> None:
gen.generate(count=10)

# Test that the new plugin is now in the list of available plugins
available_plugins = Plugins().list()
available_plugins = Plugins(categories=["debug"]).list()
print(sorted(available_plugins))
assert "copy_data" in available_plugins

0 comments on commit 14c57cc

Please sign in to comment.