Skip to content

Commit

Permalink
Enable estimator checks 🚀 (#42)
Browse files Browse the repository at this point in the history
* Enable sklearn estimator checks

* Fix docstring
  • Loading branch information
KarelZe authored Dec 22, 2023
1 parent 8bcd32d commit 2f0f7e9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
27 changes: 19 additions & 8 deletions src/tclf/classical_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def __init__(
str,
str,
]
],
]
| None = None,
*,
features: list[str] | None = None,
random_state: float | None = 42,
Expand All @@ -84,7 +85,7 @@ def __init__(
>>> pred = clf.predict_proba(X)
Args:
layers (List[ tuple[str, str] ]): Layers of classical rule.
layers (List[tuple[str, str]]): Layers of classical rule. Defaults to None, which results in classification by 'strategy' parameter.
features (List[str] | None, optional): List of feature names in order of columns. Required to match columns in feature matrix with label. Can be `None`, if `pd.DataFrame` is passed. Defaults to None.
random_state (float | None, optional): random seed. Defaults to 42.
strategy (Literal["random", "const"], optional): Strategy to fill unclassfied. Randomly with uniform probability or with constant 0. Defaults to "random".
Expand All @@ -94,18 +95,26 @@ def __init__(
self.features = features
self.strategy = strategy

def _more_tags(self) -> dict[str, bool]:
def _more_tags(self) -> dict[str, bool | dict[str, str]]:
"""Set tags for sklearn.
See: https://scikit-learn.org/stable/developers/develop.html#estimator-tags
"""
# FIXME: Try enabling _skip_test again. Skip tests, as prediction is not
# invariant and parameters immutable.
return {
"allow_nan": True,
"binary_only": True,
"_skip_test": True,
"requires_y": False,
"poor_score": True,
"_xfail_checks": {
"check_classifiers_classes": "Disabled due to partly random classification.",
"check_classifiers_train": "No check, as unsupervised classifier.",
"check_classifiers_one_label": "Disabled due to partly random classification.",
"check_methods_subset_invariance": "No check, as unsupervised classifier.",
"check_methods_sample_order_invariance": "No check, as unsupervised classifier.",
"check_supervised_y_no_nan": "No check, as unsupervised classifier.",
"check_supervised_y_2d": "No check, as unsupervised classifier.",
"check_classifiers_regression_target": "No check, as unsupervised classifier.",
},
}

def _tick(self, subset: str) -> npt.NDArray:
Expand Down Expand Up @@ -429,6 +438,7 @@ def fit(

X = self._validate_data(
X,
y="no_validation",
dtype=[np.float64, np.float32],
accept_sparse=False,
force_all_finite=False,
Expand All @@ -445,7 +455,8 @@ def fit(
f"Expected {len(self.columns_)} columns, got {X.shape[1]}."
)

for func_str, _ in self.layers:
self._layers = self.layers if self.layers is not None else []
for func_str, _ in self._layers:
if func_str not in allowed_func_str:
raise ValueError(
f"Unknown function string: {func_str},"
Expand Down Expand Up @@ -476,7 +487,7 @@ def predict(self, X: MatrixLike) -> npt.NDArray:
self.X_ = pd.DataFrame(data=X, columns=self.columns_)
pred = np.full(shape=(X.shape[0],), fill_value=np.nan)

for func_str, subset in self.layers:
for func_str, subset in self._layers:
func = self.func_mapping_[func_str]
pred = np.where(
np.isnan(pred),
Expand Down
17 changes: 10 additions & 7 deletions tests/test_classical_classifier.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Tests for the classical classifier."""

from typing import Callable

import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_allclose
from sklearn.utils.estimator_checks import check_estimator
from sklearn.base import BaseEstimator
from sklearn.utils.estimator_checks import parametrize_with_checks
from sklearn.utils.validation import check_is_fitted

from tclf.classical_classifier import ClassicalClassifier
Expand Down Expand Up @@ -74,14 +77,14 @@ def clf(self, x_train: pd.DataFrame) -> ClassicalClassifier:
Returns:
ClassicalClassifier: fitted clf
"""
return ClassicalClassifier(
layers=[("nan", "ex")],
random_state=7,
).fit(x_train[["ask_best", "bid_best"]])
return ClassicalClassifier().fit(x_train[["ask_best", "bid_best"]])

def test_sklearn_compatibility(self, clf: ClassicalClassifier) -> None:
@parametrize_with_checks([ClassicalClassifier()])
def test_sklearn_compatibility(
self, estimator: BaseEstimator, check: Callable
) -> None:
"""Test, if classifier is compatible with sklearn."""
check_estimator(clf)
check(estimator)

def test_shapes(
self, clf: ClassicalClassifier, x_test: pd.DataFrame, y_test: pd.Series
Expand Down

0 comments on commit 2f0f7e9

Please sign in to comment.