Skip to content

Commit

Permalink
debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdavis committed Sep 12, 2023
1 parent 6076745 commit aaa9625
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 57 deletions.
4 changes: 4 additions & 0 deletions src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def _refresh(self) -> None:
"""Refresh the list of available plugins"""
self._plugins: Dict[str, Type[Plugin]] = PLUGIN_REGISTRY
self._categories: Dict[str, List[str]] = PLUGIN_CATEGORY_REGISTRY
print("Refreshing: ", self._plugins, self._categories)

@validate_arguments
def _load_single_plugin_impl(self, plugin_name: str) -> Optional[Type]:
Expand Down Expand Up @@ -665,6 +666,7 @@ def _add_category(self, category: str, name: str) -> "PluginLoader":

def add(self, name: str, cls: Type) -> "PluginLoader":
"""Add a new plugin"""
print("Adding: ", name, cls)
self._refresh()
if name in self._plugins:
log.info(f"Plugin {name} already exists. Overwriting")
Expand Down Expand Up @@ -699,6 +701,7 @@ def get(self, name: str, *args: Any, **kwargs: Any) -> Any:
"""
self._refresh()
if name not in self._plugins and name not in self._available_plugins:
print(self._plugins, self._available_plugins)
raise ValueError(f"Plugin {name} doesn't exist.")

if name not in self._plugins:
Expand Down Expand Up @@ -745,6 +748,7 @@ def __getitem__(self, key: str) -> Any:
return self.get(key)

def reload(self) -> "PluginLoader":
print("Reloading")
global PLUGIN_CATEGORY_REGISTRY
PLUGIN_CATEGORY_REGISTRY = dict()
global PLUGIN_REGISTRY
Expand Down
115 changes: 58 additions & 57 deletions tests/benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import platform
from copy import copy
from pathlib import Path
from typing import Any, List

# third party
import pytest
Expand All @@ -14,15 +13,17 @@
# synthcity absolute
from synthcity.benchmark import Benchmarks
from synthcity.benchmark.utils import get_json_serializable_kwargs
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import (
DataLoader,

# from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import ( # DataLoader,
GenericDataLoader,
SurvivalAnalysisDataLoader,
)
from synthcity.plugins.core.distribution import Distribution
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema

# from typing import Any, List
# from synthcity.plugins.core.distribution import Distribution
# from synthcity.plugins.core.plugin import Plugin
# from synthcity.plugins.core.schema import Schema


def test_benchmark_sanity() -> None:
Expand Down Expand Up @@ -294,53 +295,53 @@ def test_benchmark_workspace_cache() -> None:
assert augment_generator_file.exists()


def test_benchmark_added_plugin() -> None:
X, y = load_iris(return_X_y=True, as_frame=True)
X["target"] = y

class DummyCopyDataPlugin(Plugin):
"""Dummy plugin for debugging."""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@staticmethod
def name() -> str:
return "copy_data"

@staticmethod
def type() -> str:
return "debug"

@staticmethod
def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Distribution]:
return []

def _fit(
self, X: DataLoader, *args: Any, **kwargs: Any
) -> "DummyCopyDataPlugin":
self.features_count = X.shape[1]
self.X = X
return self

def _generate(
self, count: int, syn_schema: Schema, **kwargs: Any
) -> DataLoader:
return self.X.sample(count)

generators = Plugins()
# Add the new plugin to the collection
generators.add("copy_data", DummyCopyDataPlugin)

score = Benchmarks.evaluate(
[
("copy_data", "copy_data", {}),
],
GenericDataLoader(X, target_column="target"),
metrics={
"performance": [
"linear_model",
]
},
)
assert "copy_data" in score
# def test_benchmark_added_plugin() -> None:
# X, y = load_iris(return_X_y=True, as_frame=True)
# X["target"] = y

# class DummyCopyDataPlugin(Plugin):
# """Dummy plugin for debugging."""

# def __init__(self, **kwargs: Any) -> None:
# super().__init__(**kwargs)

# @staticmethod
# def name() -> str:
# return "copy_data"

# @staticmethod
# def type() -> str:
# return "debug"

# @staticmethod
# def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Distribution]:
# return []

# def _fit(
# self, X: DataLoader, *args: Any, **kwargs: Any
# ) -> "DummyCopyDataPlugin":
# self.features_count = X.shape[1]
# self.X = X
# return self

# def _generate(
# self, count: int, syn_schema: Schema, **kwargs: Any
# ) -> DataLoader:
# return self.X.sample(count)

# generators = Plugins()
# # Add the new plugin to the collection
# generators.add("copy_data", DummyCopyDataPlugin)

# score = Benchmarks.evaluate(
# [
# ("copy_data", "copy_data", {}),
# ],
# GenericDataLoader(X, target_column="target"),
# metrics={
# "performance": [
# "linear_model",
# ]
# },
# )
# assert "copy_data" in score

0 comments on commit aaa9625

Please sign in to comment.