Skip to content

Commit

Permalink
skip great tests on python<3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdavis committed Sep 11, 2023
1 parent 41d1e19 commit 02eb8a0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 31 deletions.
15 changes: 0 additions & 15 deletions src/synthcity/plugins/core/models/ts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,24 +256,9 @@ def forward(

if self.use_horizon_condition:
# TODO: ADD error handling for len(temporal_data.shape) != 3 or len(observation_times.shape) != 2
# try:
temporal_data_merged = torch.cat(
[temporal_data, observation_times.unsqueeze(2)], dim=2
)
# print(
# 3333333333333333,
# temporal_data.shape, # when passing, = 3d tensor
# observation_times.shape, # when passing, 2d tensor
# )
# except Exception as e:
# print(temporal_data.shape, observation_times.shape)
# print(temporal_data, observation_times)
# print(
# 3333333333333333,
# temporal_data.shape,
# observation_times.shape,
# )
# raise e
else:
temporal_data_merged = temporal_data

Expand Down
15 changes: 0 additions & 15 deletions tests/metrics/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,25 +385,10 @@ def test_evaluate_performance_time_series(
static_data=static_data,
outcome=outcome,
)
print(
111111111111, "temp: ", np.asarray(data.data.get("temporal_data")).shape
) # correct shape
print(
111111111111, "obs: ", np.asarray(data.data.get("observation_times")).shape
) # correct shape

test_plugin.fit(data)
data_gen = test_plugin.generate(100)

print(
222222222222, "temp", np.asarray(data_gen.data.get("temporal_data")).shape
) # correct shape
print(
222222222222,
"obs: ",
np.asarray(data_gen.data.get("observation_times")).shape, # correct shape
)

evaluator = evaluator_t(
task_type="time_series",
use_cache=False,
Expand Down
17 changes: 16 additions & 1 deletion tests/plugins/generic/test_great.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
import os
import random
import sys
from datetime import datetime, timedelta

# third party
Expand All @@ -15,9 +16,14 @@
from synthcity.plugins import Plugin
from synthcity.plugins.core.constraints import Constraints
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.plugins.generic.plugin_great import plugin
from synthcity.utils.serialization import load, save

if sys.version_info >= (3, 9):
# synthcity absolute
from synthcity.plugins.generic.plugin_great import plugin
else:
plugin = None

IN_GITHUB_ACTIONS: bool = os.getenv("GITHUB_ACTIONS") == "true"

plugin_name = "great"
Expand All @@ -28,26 +34,31 @@
}


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_sanity(test_plugin: Plugin) -> None:
assert test_plugin is not None


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_name(test_plugin: Plugin) -> None:
assert test_plugin.name() == plugin_name


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_type(test_plugin: Plugin) -> None:
assert test_plugin.type() == "generic"


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_hyperparams(test_plugin: Plugin) -> None:
assert len(test_plugin.hyperparameter_space()) == 1


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
Expand All @@ -56,6 +67,7 @@ def test_plugin_fit(test_plugin: Plugin) -> None:
test_plugin.fit(GenericDataLoader(X))


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down Expand Up @@ -92,6 +104,7 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:


@pytest.mark.slow
@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down Expand Up @@ -134,6 +147,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down Expand Up @@ -168,6 +182,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d


@pytest.mark.slow
@pytest.mark.skipif(sys.version_info < (3, 9))
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down

0 comments on commit 02eb8a0

Please sign in to comment.