Skip to content

Commit

Permalink
Fix full test errors (#255)
Browse files Browse the repository at this point in the history
* debugging

* working fast tests

* passing tests

* pin lifelines<0.28 as 0.28 does not support python 3.8

* debugging lifelines files error

* lifelines==0.27.7

* fix version pin

* revert to strict pin

* lifelines version constraints as generic as possible

* split core tests into fast and slow and increase timeout

* split slow tests into two

* update version
  • Loading branch information
robsdavis authored Feb 29, 2024
1 parent a7956c8 commit 10dbe78
Show file tree
Hide file tree
Showing 36 changed files with 120 additions and 57 deletions.
15 changes: 13 additions & 2 deletions .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,21 @@ jobs:
run: |
python -m pip install -U pip
pip install -r prereq.txt
- name: Test Core
- name: Test Core - slow part one
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50
pytest -vvvs --durations=50 -m "slow_1"
- name: Test Core - slow part two
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50 -m "slow_2"
- name: Test Core - fast
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50 -m "not slow"
- name: Test GOGGLE
run: |
pip install .[testing,goggle]
Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ install_requires =
scikit-learn>=1.2
nflows>=0.14
numpy>=1.20, <1.24
lifelines>=0.27,!= 0.27.5
lifelines>=0.27,!= 0.27.5, <0.27.8
opacus>=1.3
decaf-synthetic-data>=0.1.6
optuna>=3.1
Expand Down Expand Up @@ -117,6 +117,8 @@ testpaths = tests
# Use pytest markers to select/deselect specific tests
markers =
slow: mark tests as slow (deselect with '-m "not slow"')
slow_1: mark tests as slow (deselect with '-m "not slow_1"')
slow_2: mark tests as slow (deselect with '-m "not slow_1"')

[devpi:upload]
# Options for the devpi: PyPI server and packaging tool
Expand Down
3 changes: 2 additions & 1 deletion src/synthcity/plugins/core/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,12 +928,13 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any:
self.data["observation_times"],
self.data["outcome"],
)

if as_numpy:
longest_observation_seq = max([len(seq) for seq in temporal_data])
return (
np.asarray(static_data),
np.asarray(
pd.concat(temporal_data)
temporal_data
), # TODO: check this works with time series benchmarks
# masked array to handle variable length sequences
ma.vstack(
Expand Down
2 changes: 0 additions & 2 deletions src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,6 @@ class PluginLoader:

@validate_arguments
def __init__(self, plugins: list, expected_type: Type, categories: list) -> None:
# self.reload()
global PLUGIN_CATEGORY_REGISTRY
PLUGIN_CATEGORY_REGISTRY = {cat: [] for cat in categories}
self._refresh()
Expand Down Expand Up @@ -639,7 +638,6 @@ def list(self) -> List[str]:
for plugin in all_plugins:
if self.get_type(plugin).type() in self._categories:
plugins.append(plugin)

return list(set(plugins))

def types(self) -> List[Type]:
Expand Down
2 changes: 2 additions & 0 deletions src/synthcity/plugins/privacy/plugin_dpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class DPGANPlugin(Plugin):
>>>
>>> plugin.generate(50)
Note: There is a known issue with the training step for training GANs with conditionals with dp_enabled set to True, as is the case for DPGAN.
"""

@validate_arguments(config=dict(arbitrary_types_allowed=True))
Expand Down
2 changes: 1 addition & 1 deletion src/synthcity/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.9"
__version__ = "0.2.10"

MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
PATCH_VERSION = __version__.split(".")[-1]
1 change: 1 addition & 0 deletions tests/metrics/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_detect_synth_timeseries(test_plugin: Plugin, evaluator_t: Type) -> None
assert evaluator.direction() == "minimize"


@pytest.mark.slow_1
@pytest.mark.slow
def test_image_support_detection() -> None:
dataset = datasets.MNIST(".", download=True)
Expand Down
6 changes: 6 additions & 0 deletions tests/metrics/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_evaluate_performance_classifier(
@pytest.mark.xfail
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
@pytest.mark.slow_1
@pytest.mark.slow
def test_evaluate_feature_importance_rank_dist_clf(
distance: str, test_plugin: Plugin
Expand Down Expand Up @@ -183,6 +184,7 @@ def test_evaluate_performance_regression(
@pytest.mark.xfail
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
@pytest.mark.slow_1
@pytest.mark.slow
def test_evaluate_feature_importance_rank_dist_reg(
distance: str, test_plugin: Plugin
Expand Down Expand Up @@ -211,6 +213,7 @@ def test_evaluate_feature_importance_rank_dist_reg(
assert score["pvalue"] > 0


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("test_plugin", [Plugins().get("marginal_distributions")])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -296,6 +299,7 @@ def test_evaluate_performance_survival_analysis(
@pytest.mark.xfail
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
@pytest.mark.slow_1
@pytest.mark.slow
def test_evaluate_feature_importance_rank_dist_surv(
distance: str, test_plugin: Plugin
Expand Down Expand Up @@ -362,6 +366,7 @@ def test_evaluate_performance_custom_labels(
assert "syn_ood" in good_score


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("test_plugin", [Plugins().get("timegan")])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -472,6 +477,7 @@ def test_evaluate_performance_time_series_survival(
assert def_score == good_score["syn_id.c_index"] - good_score["syn_id.brier_score"]


@pytest.mark.slow_1
@pytest.mark.slow
def test_image_support_perf() -> None:
dataset = datasets.MNIST(".", download=True)
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/core/models/test_tabular_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_gan_generation_with_early_stopping(patience_metric: Tuple[str, str]) ->
assert generated.shape == (10, X.shape[1])


@pytest.mark.slow_1
@pytest.mark.slow
def test_gan_sampling_adjustment() -> None:
X = get_airfoil_dataset()
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/core/models/test_ts_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def test_ts_gan_generation(source: Any) -> None:
assert observation_times_gen.shape == (10, temporal.shape[1])


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("source", [GoogleStocksDataloader])
def test_ts_gan_generation_schema(source: Any) -> None:
Expand Down
3 changes: 3 additions & 0 deletions tests/plugins/core/models/test_ts_tabular_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_network_config() -> None:
assert net.model.embedding_penalty == 2


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("source", [SineDataloader, GoogleStocksDataloader])
def test_ts_gan_generation(source: Any) -> None:
Expand All @@ -86,6 +87,7 @@ def test_ts_gan_generation(source: Any) -> None:
)


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("source", [GoogleStocksDataloader])
def test_ts_gan_generation_schema(source: Any) -> None:
Expand Down Expand Up @@ -118,6 +120,7 @@ def test_ts_gan_generation_schema(source: Any) -> None:
assert reference_schema.as_constraints().filter(seq_df).sum() > 0


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("source", [SineDataloader, GoogleStocksDataloader])
def test_ts_tabular_gan_conditional(source: Any) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/core/models/test_ts_tabular_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_ts_vae_generation(source: Any) -> None:
)


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("source", [GoogleStocksDataloader])
def test_ts_vae_generation_schema(source: Any) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_train_prediction_coxph(rnn_type: str, output_type: str) -> None:
assert score["clf"]["c_index"][0] > 0.5


@pytest.mark.slow_1
@pytest.mark.slow
def test_hyperparam_search() -> None:
static, temporal, observation_times, outcome = PBCDataloader(as_numpy=True).load()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_train_prediction_dyn_deephit(rnn_type: str, output_type: str) -> None:
assert score["clf"]["c_index"][0] > 0.5


@pytest.mark.slow_1
@pytest.mark.slow
def test_hyperparam_search() -> None:
static, temporal, observation_times, outcome = PBCDataloader(as_numpy=True).load()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def test_train_prediction(emb_rnn_type: str) -> None:
assert score["clf"]["c_index"][0] > 0.5


@pytest.mark.slow_1
@pytest.mark.slow
def test_hyperparam_search() -> None:
static, temporal, observation_times, outcome = PBCDataloader(as_numpy=True).load()
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/domain_adaptation/test_radialgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_1
@pytest.mark.slow
def test_eval_performance_radialgan() -> None:
results = []
Expand Down
2 changes: 2 additions & 0 deletions tests/plugins/generic/test_arf.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("compress_dataset", [True, False])
def test_eval_performance_arf(compress_dataset: bool) -> None:
Expand Down Expand Up @@ -151,6 +152,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d
return start + (end - start) * random.random()


@pytest.mark.slow_1
@pytest.mark.slow
def test_plugin_encoding() -> None:
assert plugin is not None
Expand Down
2 changes: 2 additions & 0 deletions tests/plugins/generic/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("compress_dataset", [True, False])
def test_eval_performance_ctgan(compress_dataset: bool) -> None:
Expand Down Expand Up @@ -169,6 +170,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d
return start + (end - start) * random.random()


@pytest.mark.slow_1
@pytest.mark.slow
def test_plugin_encoding() -> None:
data = [[gen_datetime(), i % 2 == 0, i] for i in range(1000)]
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/generic/test_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_1
@pytest.mark.slow
@pytest.mark.parametrize("compress_dataset", [True, False])
def test_eval_performance_ddpm(compress_dataset: bool) -> None:
Expand Down
41 changes: 18 additions & 23 deletions tests/plugins/generic/test_goggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.datasets import load_diabetes, load_iris

# synthcity absolute
from synthcity.metrics.eval import PerformanceEvaluatorXGB
from synthcity.metrics.eval import AlphaPrecision
from synthcity.plugins import Plugin
from synthcity.plugins.core.constraints import Constraints
from synthcity.plugins.core.dataloader import GenericDataLoader
Expand Down Expand Up @@ -149,39 +149,34 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


# TODO: Known issue goggle seems to have a performance issue.
# Testing fidelity instead. Also need to test more architectures
@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.slow_2
@pytest.mark.slow
@pytest.mark.parametrize(
"compress_dataset,decoder_arch",
[
(True, "het"),
(False, "het"),
(True, "gcn"),
(False, "gcn"),
(True, "sage"),
(False, "sage"),
],
)
def test_eval_performance_goggle(compress_dataset: bool, decoder_arch: str) -> None:
def test_eval_fidelity_goggle(compress_dataset: bool, decoder_arch: str) -> None:
results = []

Xraw, y = load_diabetes(return_X_y=True, as_frame=True)
Xraw, y = load_iris(return_X_y=True, as_frame=True)
Xraw["target"] = y
X = GenericDataLoader(Xraw)

assert plugin is not None
for retry in range(2):
for retry in range(3):
test_plugin = plugin(
n_iter=5000,
compress_dataset=compress_dataset,
decoder_arch=decoder_arch,
encoder_dim=32,
encoder_l=4,
decoder_dim=32,
decoder_l=4,
data_encoder_max_clusters=20,
compress_dataset=False,
decoder_arch="gcn",
random_state=retry,
)
evaluator = PerformanceEvaluatorXGB()
evaluator = AlphaPrecision()

test_plugin.fit(X)
X_syn = test_plugin.generate()

results.append(evaluator.evaluate(X, X_syn)["syn_id"])
X_syn = test_plugin.generate(count=len(X), random_state=retry)
eval_results = evaluator.evaluate(X, X_syn)
results.append(eval_results["authenticity_OC"])

assert np.mean(results) > 0.7
2 changes: 2 additions & 0 deletions tests/plugins/generic/test_great.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
assert (X_gen1.numpy() != X_gen3.numpy()).any()


@pytest.mark.slow_2
@pytest.mark.slow
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
Expand Down Expand Up @@ -185,6 +186,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d
return start + (end - start) * random.random()


@pytest.mark.slow_2
@pytest.mark.slow
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/generic/test_nflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_2
@pytest.mark.slow
@pytest.mark.parametrize("compress_dataset", [True, False])
def test_eval_performance_nflow(compress_dataset: bool) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/generic/test_rtvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_2
@pytest.mark.slow
def test_eval_performance_rtvae() -> None:
results = []
Expand Down
1 change: 1 addition & 0 deletions tests/plugins/generic/test_tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.slow_2
@pytest.mark.slow
def test_eval_performance_tvae() -> None:
results = []
Expand Down
2 changes: 2 additions & 0 deletions tests/plugins/images/test_image_adsgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_plugin_generate() -> None:
assert len(X_gen) == 50


@pytest.mark.slow_2
@pytest.mark.slow
def test_plugin_generate_with_conditional() -> None:
test_plugin = plugin(n_iter=10, n_units_latent=13)
Expand All @@ -71,6 +72,7 @@ def test_plugin_generate_with_conditional() -> None:
assert len(X_gen) == 50


@pytest.mark.slow_2
@pytest.mark.slow
def test_plugin_generate_with_stop_conditional() -> None:
test_plugin = plugin(n_iter=10, n_units_latent=13, n_iter_print=2)
Expand Down
Loading

0 comments on commit 10dbe78

Please sign in to comment.