Skip to content

Commit

Permalink
tests: update resource setup in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelZe committed Dec 2, 2024
1 parent 4fb3b95 commit bd295f3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 27 deletions.
15 changes: 8 additions & 7 deletions tests/test_classical_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,24 @@ class TestClassicalClassifier(ClassifierMixin):
unittest (_type_): unittest module
"""

def setup(self) -> None:
@classmethod
def setup_class(cls):
"""Set up basic classifier and data.
Prepares inputs and expected outputs for testing.
"""
self.x_train = pd.DataFrame(
cls.x_train = pd.DataFrame(
[[1, 2], [3, 4], [1, 2], [3, 4]], columns=["BEST_ASK", "BEST_BID"]
)
self.y_train = pd.Series([1, 1, -1, -1])
self.x_test = pd.DataFrame(
cls.y_train = pd.Series([1, 1, -1, -1])
cls.x_test = pd.DataFrame(
[[1, 2], [3, 4], [1, 2], [3, 4]], columns=["BEST_ASK", "BEST_BID"]
)
self.y_test = pd.Series([1, -1, 1, -1])
self.clf = ClassicalClassifier(
cls.y_test = pd.Series([1, -1, 1, -1])
cls.clf = ClassicalClassifier(
layers=[("nan", "ex")],
random_state=7,
).fit(self.x_train, self.y_train)
).fit(cls.x_train, cls.y_train)

def test_random_state(self) -> None:
"""Test, if random state is correctly set.
Expand Down
27 changes: 14 additions & 13 deletions tests/test_fttransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,36 @@ class TestFTTransformer(NeuralNetTestsMixin):
NeuralNetTestsMixin (neural net mixin): mixin
"""

def setup(self) -> None:
@classmethod
def setup_class(cls):
"""Set up basic network and data.
Prepares inputs and expected outputs for testing.
"""
self.num_features_cont = 5
self.num_features_cat = 1
self.cat_cardinalities = [2]
self.batch_size = 64
cls.num_features_cont = 5
cls.num_features_cat = 1
cls.cat_cardinalities = [2]
cls.batch_size = 64

set_seed()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.x_cat = torch.randint(0, 1, (self.batch_size, self.num_features_cat)).to(
cls.x_cat = torch.randint(0, 1, (cls.batch_size, cls.num_features_cat)).to(
device
)
self.x_cont = (
torch.randn(self.batch_size, self.num_features_cont).float().to(device)
cls.x_cont = (
torch.randn(cls.batch_size, cls.num_features_cont).float().to(device)
)
self.expected_outputs = (
torch.randint(0, 1, (self.batch_size, 1)).float().to(device)
cls.expected_outputs = (
torch.randint(0, 1, (cls.batch_size, 1)).float().to(device)
)

# https://github.com/Yura52/rtdl/blob/main/rtdl/modules.py

params_feature_tokenizer: Dict[str, Any] = {
"num_continous": self.num_features_cont,
"cat_cardinalities": self.cat_cardinalities,
"num_continous": cls.num_features_cont,
"cat_cardinalities": cls.cat_cardinalities,
"d_token": 96,
}
feature_tokenizer = FeatureTokenizer(**params_feature_tokenizer)
Expand Down Expand Up @@ -89,7 +90,7 @@ def setup(self) -> None:

transformer = Transformer(**params_transformer)

self.net = FTTransformer(feature_tokenizer, transformer).to(device)
cls.net = FTTransformer(feature_tokenizer, transformer).to(device)

def test_numerical_feature_tokenizer(self) -> None:
"""Test numerical feature tokenizer.
Expand Down
15 changes: 8 additions & 7 deletions tests/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,24 @@ class TestObjectives:
metaclass (_type_, optional): parent. Defaults to abc.ABCMeta.
"""

def setup(self) -> None:
@classmethod
def setup_class(cls):
"""Set up basic data.
Construct feature matrix and target.
"""
self._old_cwd = Path.cwd()
cls._old_cwd = Path.cwd()
start = dt.datetime(2020, 1, 1).replace(tzinfo=dt.timezone.utc)
end = dt.datetime(2021, 12, 31).replace(tzinfo=dt.timezone.utc)
index = pd.date_range(start=start, end=end, freq="15min")

# make 1 const feature and 1 non-const feature, as catboost requires non-const
self._x_train = pd.DataFrame(data={"feature_1": 1}, index=index)
cls._x_train = pd.DataFrame(data={"feature_1": 1}, index=index)
rng = np.random.default_rng()
self._x_train["feature_2"] = rng.integers(1, 6, self._x_train.shape[0])
self._y_train = self._x_train["feature_2"]
self._x_val = self._x_train.copy()
self._y_val = self._y_train.copy()
cls._x_train["feature_2"] = rng.integers(1, 6, cls._x_train.shape[0])
cls._y_train = cls._x_train["feature_2"]
cls._x_val = cls._x_train.copy()
cls._y_val = cls._y_train.copy()

def test_classical_objective(self) -> None:
"""Test if classical objective returns a valid value.
Expand Down

0 comments on commit bd295f3

Please sign in to comment.