diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index bcfa3eea..6bbcbfec 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -42,7 +42,6 @@ jobs: - uses: actions/checkout@v2 with: submodules: true - - uses: gautamkrishnar/keepalive-workflow@v1 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: diff --git a/setup.cfg b/setup.cfg index 974f8198..f9e6652c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,7 @@ package_dir = python_requires = >=3.8 install_requires = + importlib-metadata pandas>=1.4,<2 torch>=1.10.0,<2.0 scikit-learn>=1.0 diff --git a/src/synthcity/metrics/eval_performance.py b/src/synthcity/metrics/eval_performance.py index 633a44b0..91a3a5f2 100644 --- a/src/synthcity/metrics/eval_performance.py +++ b/src/synthcity/metrics/eval_performance.py @@ -425,12 +425,10 @@ def ts_eval_cbk( temporal_train_data = id_temporal_gt[train_idx] observation_times_train_data = id_observation_times_gt[train_idx] outcome_train_data = id_outcome_gt[train_idx] - static_test_data = id_static_gt[test_idx] temporal_test_data = id_temporal_gt[test_idx] observation_times_test_data = id_observation_times_gt[test_idx] outcome_test_data = id_outcome_gt[test_idx] - real_score = ts_eval_cbk( static_train_data, temporal_train_data, diff --git a/src/synthcity/plugins/__init__.py b/src/synthcity/plugins/__init__.py index f1c64a27..e2057f29 100644 --- a/src/synthcity/plugins/__init__.py +++ b/src/synthcity/plugins/__init__.py @@ -15,6 +15,7 @@ "time_series", "domain_adaptation", "images", + "debug", ] plugins = {} diff --git a/src/synthcity/plugins/core/dataloader.py b/src/synthcity/plugins/core/dataloader.py index c01cede7..7615e9a4 100644 --- a/src/synthcity/plugins/core/dataloader.py +++ b/src/synthcity/plugins/core/dataloader.py @@ -932,7 +932,9 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any: longest_observation_seq = max([len(seq) for seq in temporal_data]) return ( np.asarray(static_data), - np.asarray(temporal_data), + np.asarray( + pd.concat(temporal_data) + ), # TODO: check this works with time series benchmarks # masked array to handle variable length sequences ma.vstack( [ diff --git a/src/synthcity/plugins/core/models/aim.py b/src/synthcity/plugins/core/models/aim.py index e5165ad9..cc2adc15 100644 --- a/src/synthcity/plugins/core/models/aim.py +++ b/src/synthcity/plugins/core/models/aim.py @@ -126,12 +126,10 @@ def __init__(self, epsilon: float, delta: float): Base class for a mechanism. :param epsilon: privacy parameter :param delta: privacy parameter - :param prng: pseudo random number generator """ self.epsilon = epsilon self.delta = delta self.rho = 0 if delta == 0 else cdp_rho(epsilon, delta) - self.prng = np.random def run(self, dataset: Dataset, workload: List[Tuple]) -> Any: pass @@ -204,7 +202,7 @@ def exponential_mechanism( else: p = softmax(0.5 * epsilon / sensitivity * q + base_measure) - return keys[self.prng.choice(p.size, p=p)] + return keys[np.random.choice(p.size, p=p)] # def gaussian_noise_scale(self, l2_sensitivity, epsilon, delta): # """Return the Gaussian noise necessary to attain (epsilon, delta)-DP""" @@ -223,11 +221,11 @@ def exponential_mechanism( def gaussian_noise(self, sigma: float, size: Union[int, Tuple]) -> np.ndarray: """Generate iid Gaussian noise of a given scale and size""" - return self.prng.normal(0, sigma, size) + return np.random.normal(0, sigma, size) # def laplace_noise(self, b, size): # """Generate iid Laplace noise of a given scale and size""" - # return self.prng.laplace(0, b, size) + # return np.random.laplace(0, b, size) # def best_noise_distribution(self, l1_sensitivity, l2_sensitivity, epsilon, delta): # """Adaptively determine if Laplace or Gaussian noise will be better, and diff --git a/src/synthcity/plugins/core/models/mbi/clique_vector.py b/src/synthcity/plugins/core/models/mbi/clique_vector.py index 8005a83f..d2b96c4b 100644 --- a/src/synthcity/plugins/core/models/mbi/clique_vector.py +++ b/src/synthcity/plugins/core/models/mbi/clique_vector.py @@ -36,21 +36,21 @@ def uniform(domain, cliques): return CliqueVector({cl: Factor.uniform(domain.project(cl)) for cl in cliques}) @staticmethod - def random(domain, cliques, prng=np.random): + def random(domain, cliques): # synthcity relative from .factor import Factor return CliqueVector( - {cl: Factor.random(domain.project(cl), prng) for cl in cliques} + {cl: Factor.random(domain.project(cl), np.random) for cl in cliques} ) @staticmethod - def normal(domain, cliques, prng=np.random): + def normal(domain, cliques): # synthcity relative from .factor import Factor return CliqueVector( - {cl: Factor.normal(domain.project(cl), prng) for cl in cliques} + {cl: Factor.normal(domain.project(cl), np.random) for cl in cliques} ) @staticmethod diff --git a/src/synthcity/plugins/core/models/tabular_aim.py b/src/synthcity/plugins/core/models/tabular_aim.py index 1b2c68cd..f3c031f6 100644 --- a/src/synthcity/plugins/core/models/tabular_aim.py +++ b/src/synthcity/plugins/core/models/tabular_aim.py @@ -1,5 +1,6 @@ # stdlib import itertools +from abc import ABCMeta from typing import Any, Optional, Union # third party @@ -17,7 +18,7 @@ from .mbi.domain import Domain -class TabularAIM: +class TabularAIM(metaclass=ABCMeta): """ .. inheritance-diagram:: synthcity.plugins.core.models.tabular_aim.TabularAIM :parts: 1 @@ -68,7 +69,6 @@ def __init__( self.degree = degree self.num_marginals = num_marginals self.max_cells = max_cells - self.prng = np.random @validate_arguments(config=dict(arbitrary_types_allowed=True)) def fit( @@ -101,7 +101,7 @@ def fit( if self.num_marginals is not None: workload = [ workload[i] - for i in self.prng.choice( + for i in np.random.choice( len(workload), self.num_marginals, replace=False ) ] diff --git a/src/synthcity/plugins/core/models/tabular_arf.py b/src/synthcity/plugins/core/models/tabular_arf.py index 30aaba56..97d844a6 100644 --- a/src/synthcity/plugins/core/models/tabular_arf.py +++ b/src/synthcity/plugins/core/models/tabular_arf.py @@ -1,4 +1,5 @@ # stdlib +from abc import ABCMeta from typing import Any, Union # third party @@ -21,7 +22,7 @@ from synthcity.utils.constants import DEVICE -class TabularARF: +class TabularARF(metaclass=ABCMeta): def __init__( self, # ARF parameters diff --git a/src/synthcity/plugins/core/models/tabular_goggle.py b/src/synthcity/plugins/core/models/tabular_goggle.py index bcd8e54f..2a04fb9f 100644 --- a/src/synthcity/plugins/core/models/tabular_goggle.py +++ b/src/synthcity/plugins/core/models/tabular_goggle.py @@ -1,4 +1,5 @@ # stdlib +from abc import ABCMeta from typing import Any, Optional, Union # third party @@ -15,7 +16,7 @@ from .tabular_encoder import TabularEncoder -class TabularGoggle: +class TabularGoggle(metaclass=ABCMeta): def __init__( self, X: pd.DataFrame, diff --git a/src/synthcity/plugins/core/models/tabular_great.py b/src/synthcity/plugins/core/models/tabular_great.py index d79a4681..8d868d89 100644 --- a/src/synthcity/plugins/core/models/tabular_great.py +++ b/src/synthcity/plugins/core/models/tabular_great.py @@ -1,4 +1,5 @@ # stdlib +from abc import ABCMeta from typing import Any, Dict, Optional, Union # third party @@ -20,7 +21,7 @@ from synthcity.utils.constants import DEVICE -class TabularGReaT: +class TabularGReaT(metaclass=ABCMeta): """ .. inheritance-diagram:: synthcity.plugins.core.models.tabular_great.TabularGReaT :parts: 1 diff --git a/src/synthcity/plugins/core/models/ts_model.py b/src/synthcity/plugins/core/models/ts_model.py index e9fbde4f..d4bf51cd 100644 --- a/src/synthcity/plugins/core/models/ts_model.py +++ b/src/synthcity/plugins/core/models/ts_model.py @@ -255,6 +255,7 @@ def forward( raise ValueError("NaNs detected in the temporal horizons") if self.use_horizon_condition: + # TODO: ADD error handling for len(temporal_data.shape) != 3 or len(observation_times.shape) != 2 temporal_data_merged = torch.cat( [temporal_data, observation_times.unsqueeze(2)], dim=2 ) diff --git a/src/synthcity/plugins/core/plugin.py b/src/synthcity/plugins/core/plugin.py index 4ee6f6ca..a57cc3a5 100644 --- a/src/synthcity/plugins/core/plugin.py +++ b/src/synthcity/plugins/core/plugin.py @@ -560,6 +560,9 @@ class PluginLoader: @validate_arguments def __init__(self, plugins: list, expected_type: Type, categories: list) -> None: + # self.reload() + global PLUGIN_CATEGORY_REGISTRY + PLUGIN_CATEGORY_REGISTRY = {cat: [] for cat in categories} self._refresh() self._available_plugins = {} for plugin in plugins: @@ -662,6 +665,8 @@ def _add_category(self, category: str, name: str) -> "PluginLoader": def add(self, name: str, cls: Type) -> "PluginLoader": """Add a new plugin""" + global PLUGIN_REGISTRY + global PLUGIN_CATEGORY_REGISTRY self._refresh() if name in self._plugins: log.info(f"Plugin {name} already exists. Overwriting") @@ -742,5 +747,8 @@ def __getitem__(self, key: str) -> Any: return self.get(key) def reload(self) -> "PluginLoader": - self._plugins = {} + global PLUGIN_CATEGORY_REGISTRY + global PLUGIN_REGISTRY + PLUGIN_CATEGORY_REGISTRY = dict() + PLUGIN_REGISTRY = dict() return self diff --git a/src/synthcity/plugins/privacy/plugin_aim.py b/src/synthcity/plugins/privacy/plugin_aim.py index c32d6bf9..6a10067a 100644 --- a/src/synthcity/plugins/privacy/plugin_aim.py +++ b/src/synthcity/plugins/privacy/plugin_aim.py @@ -110,7 +110,7 @@ def name() -> str: @staticmethod def type() -> str: - return "generic" + return "privacy" @staticmethod def hyperparameter_space(**kwargs: Any) -> List[Distribution]: diff --git a/src/synthcity/utils/datasets/time_series/google_stocks.py b/src/synthcity/utils/datasets/time_series/google_stocks.py index ee90380c..5e731ac8 100644 --- a/src/synthcity/utils/datasets/time_series/google_stocks.py +++ b/src/synthcity/utils/datasets/time_series/google_stocks.py @@ -74,7 +74,6 @@ def load( np.asarray(observation_times), np.asarray(outcome, dtype=np.float32), ) - return ( pd.DataFrame(np.zeros((len(temporal_data), 0))), temporal_data, diff --git a/tests/metrics/test_performance.py b/tests/metrics/test_performance.py index 30d743bd..51782fe5 100644 --- a/tests/metrics/test_performance.py +++ b/tests/metrics/test_performance.py @@ -363,7 +363,7 @@ def test_evaluate_performance_custom_labels( @pytest.mark.slow -@pytest.mark.parametrize("test_plugin", [Plugins().get("marginal_distributions")]) +@pytest.mark.parametrize("test_plugin", [Plugins().get("timegan")]) @pytest.mark.parametrize( "evaluator_t", [ diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index d3f8dc29..d76c0c12 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -1,8 +1,8 @@ # third party import numpy as np -import pkg_resources import pytest from generic_helpers import generate_fixtures +from importlib_metadata import PackageNotFoundError, distribution from sklearn.datasets import load_diabetes, load_iris # synthcity absolute @@ -24,8 +24,13 @@ if not is_missing_goggle_deps: goggle_dependencies = {"dgl", "torch-scatter", "torch-sparse", "torch-geometric"} - installed = {pkg.key for pkg in pkg_resources.working_set} - is_missing_goggle_deps = len(goggle_dependencies - installed) > 0 + missing_deps = [] + for dep in goggle_dependencies: + try: + distribution(dep) + except PackageNotFoundError: + missing_deps.append(dep) + is_missing_goggle_deps = len(missing_deps) > 0 @pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") diff --git a/tests/plugins/generic/test_great.py b/tests/plugins/generic/test_great.py index 26b97633..acaddc6f 100644 --- a/tests/plugins/generic/test_great.py +++ b/tests/plugins/generic/test_great.py @@ -1,6 +1,7 @@ # stdlib import os import random +import sys from datetime import datetime, timedelta # third party @@ -15,9 +16,14 @@ from synthcity.plugins import Plugin from synthcity.plugins.core.constraints import Constraints from synthcity.plugins.core.dataloader import GenericDataLoader -from synthcity.plugins.generic.plugin_great import plugin from synthcity.utils.serialization import load, save +if sys.version_info >= (3, 9): + # synthcity absolute + from synthcity.plugins.generic.plugin_great import plugin +else: + plugin = None + IN_GITHUB_ACTIONS: bool = os.getenv("GITHUB_ACTIONS") == "true" plugin_name = "great" @@ -28,34 +34,44 @@ } +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_sanity(test_plugin: Plugin) -> None: assert test_plugin is not None +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_name(test_plugin: Plugin) -> None: assert test_plugin.name() == plugin_name +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_type(test_plugin: Plugin) -> None: assert test_plugin.type() == "generic" +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_hyperparams(test_plugin: Plugin) -> None: assert len(test_plugin.hyperparameter_space()) == 1 +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.parametrize( "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) ) +@pytest.mark.skipif( + IN_GITHUB_ACTIONS, + reason="GReaT generate required too much memory to reliably run in GitHub Actions", +) def test_plugin_fit(test_plugin: Plugin) -> None: X, _ = load_iris(as_frame=True, return_X_y=True) test_plugin.fit(GenericDataLoader(X)) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.skipif( IN_GITHUB_ACTIONS, reason="GReaT generate required too much memory to reliably run in GitHub Actions", @@ -92,6 +108,7 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: @pytest.mark.slow +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.skipif( IN_GITHUB_ACTIONS, reason="GReaT generate required too much memory to reliably run in GitHub Actions", @@ -102,7 +119,7 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: def test_plugin_generate_constraints_great(test_plugin: Plugin) -> None: X, y = load_iris(as_frame=True, return_X_y=True) X["target"] = y - test_plugin.fit(GenericDataLoader(X)) + test_plugin.fit(GenericDataLoader(X), device="cpu") constraints = Constraints( rules=[ @@ -134,6 +151,7 @@ def test_sample_hyperparams() -> None: assert plugin(**args) is not None +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.skipif( IN_GITHUB_ACTIONS, reason="GReaT generate required too much memory to reliably run in GitHub Actions", @@ -168,6 +186,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d @pytest.mark.slow +@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+") @pytest.mark.skipif( IN_GITHUB_ACTIONS, reason="GReaT generate required too much memory to reliably run in GitHub Actions", diff --git a/tests/plugins/privacy/test_aim.py b/tests/plugins/privacy/test_aim.py index a950712f..cf1cad83 100644 --- a/tests/plugins/privacy/test_aim.py +++ b/tests/plugins/privacy/test_aim.py @@ -43,7 +43,7 @@ def test_plugin_name(test_plugin: Plugin) -> None: @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_type(test_plugin: Plugin) -> None: - assert test_plugin.type() == "generic" + assert test_plugin.type() == "privacy" @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) diff --git a/tests/plugins/test_plugin_add.py b/tests/plugins/test_plugin_add.py index c3bc2c1b..494c54fe 100644 --- a/tests/plugins/test_plugin_add.py +++ b/tests/plugins/test_plugin_add.py @@ -1,6 +1,4 @@ # stdlib -import glob -from pathlib import Path from typing import Any, List # third party @@ -43,17 +41,12 @@ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader def test_add_dummy_plugin() -> None: # get the list of plugins that are loaded - generators = Plugins() + generators = Plugins().reload() - # Get the list of plugins that come with the package - plugins_dir = Path.cwd() / "src/synthcity/plugins" - plugins_list = [] - for plugin_type in plugins_dir.iterdir(): - plugin_paths = glob.glob(str(plugins_dir / plugin_type / "plugin*.py")) - plugins_list.extend([Path(path).stem for path in plugin_paths]) + available_plugins = Plugins().list() # Test that the new plugin is not in the list plugins in the package - assert "copy_data" not in plugins_list + assert "copy_data" not in available_plugins # Add the new plugin generators.add("copy_data", DummyCopyDataPlugin) @@ -71,5 +64,5 @@ def test_add_dummy_plugin() -> None: gen.generate(count=10) # Test that the new plugin is now in the list of available plugins - available_plugins = Plugins().list() + available_plugins = Plugins(categories=["debug"]).list() assert "copy_data" in available_plugins diff --git a/tests/plugins/test_plugin_serialization.py b/tests/plugins/test_plugin_serialization.py index 34140785..a00bb969 100644 --- a/tests/plugins/test_plugin_serialization.py +++ b/tests/plugins/test_plugin_serialization.py @@ -1,4 +1,5 @@ # stdlib +import inspect from typing import Any # third party @@ -17,6 +18,10 @@ from synthcity.utils.serialization import load, save from synthcity.version import MAJOR_VERSION +generic_plugins = Plugins(categories=["generic"]).list() +privacy_plugins = Plugins(categories=["privacy"]).list() +time_series_plugins = Plugins(categories=["time_series"]).list() + def test_version() -> None: for plugin in Plugins().list(): @@ -74,7 +79,7 @@ def test_serialization_sanity() -> None: verify_serialization(syn_model, generate=True) -@pytest.mark.parametrize("plugin", Plugins(categories=["privacy"]).list()) +@pytest.mark.parametrize("plugin", privacy_plugins) @pytest.mark.slow def test_serialization_privacy_plugins(plugin: str) -> None: generic_data = pd.DataFrame(load_iris()["data"]) @@ -89,7 +94,7 @@ def test_serialization_privacy_plugins(plugin: str) -> None: verify_serialization(syn_model, generate=True) -@pytest.mark.parametrize("plugin", Plugins(categories=["generic"]).list()) +@pytest.mark.parametrize("plugin", generic_plugins) @pytest.mark.slow def test_serialization_generic_plugins(plugin: str) -> None: generic_data = pd.DataFrame(load_iris()["data"]) @@ -104,7 +109,7 @@ def test_serialization_generic_plugins(plugin: str) -> None: verify_serialization(syn_model, generate=True) -@pytest.mark.parametrize("plugin", Plugins(categories=["time_series"]).list()) +@pytest.mark.parametrize("plugin", time_series_plugins) @pytest.mark.slow def test_serialization_ts_plugins(plugin: str) -> None: ( @@ -121,7 +126,14 @@ def test_serialization_ts_plugins(plugin: str) -> None: ) ts_plugins = Plugins(categories=["time_series"]) - syn_model = ts_plugins.get(plugin, n_iter=10, strict=False) + + # Use n_iter to limit the number of iterations for testing purposes, if possible + # TODO: consider removing this filter step and add n_iter to all models even if it's not used + test_params = {"n_iter": 10, "strict": False} + accepted_params = inspect.signature(ts_plugins.get).parameters + filtered_kwargs = {k: v for k, v in test_params.items() if k in accepted_params} + + syn_model = ts_plugins.get(plugin, **filtered_kwargs) # pre-training verify_serialization(syn_model)