Skip to content

Commit

Permalink
Remove checks for subsets
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelZe committed Nov 26, 2023
1 parent b696120 commit 1f8c20b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 21 deletions.
10 changes: 2 additions & 8 deletions src/tclf/classical_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
"nan",
)

allowed_subsets = ("all", "ex", "best")


class ClassicalClassifier(ClassifierMixin, BaseEstimator):
"""ClassicalClassifier implements several trade classification rules.
Expand Down Expand Up @@ -88,7 +86,7 @@ def __init__(
>>> pred = clf.predict_proba(X)
Args:
layers (List[ tuple[ str, str, ] ]): Layers of classical rule.
layers (List[ tuple[str, str] ]): Layers of classical rule.
features (List[str] | None, optional): List of feature names in order of columns. Required to match columns in feature matrix with label. Can be `None`, if `pd.DataFrame` is passed. Defaults to None.
random_state (float | None, optional): random seed. Defaults to 42.
strategy (Literal["random", "const"], optional): Strategy to fill unclassfied. Randomly with uniform probability or with constant 0. Defaults to "random".
Expand Down Expand Up @@ -450,11 +448,7 @@ def fit(
f"Expected {len(self.columns_)} columns, got {X.shape[1]}."
)

for func_str, subset in self.layers:
if subset not in allowed_subsets:
raise ValueError(
f"Unknown subset: {subset}, expected one of {allowed_subsets}."
)
for func_str, _ in self.layers:
if func_str not in allowed_func_str:
raise ValueError(
f"Unknown function string: {func_str},"
Expand Down
13 changes: 0 additions & 13 deletions tests/test_classical_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,6 @@ def test_invalid_func(self) -> None:
with pytest.raises(ValueError, match=r"Unknown function string"):
classifier.fit(self.x_train, self.y_train)

def test_invalid_subset(self) -> None:
"""Test, if only valid subset strings can be passed.
An exception should be raised for invalid subsets.
Test for 'bar', which is no valid subset.
"""
classifier = ClassicalClassifier(
layers=[("tick", "bar")],
random_state=42,
)
with pytest.raises(ValueError, match=r"Unknown subset"):
classifier.fit(self.x_train, self.y_train)

def test_invalid_col_length(self) -> None:
"""Test, if only valid column length can be passed.
Expand Down

0 comments on commit 1f8c20b

Please sign in to comment.