From 4ddf7aa8fe5e399207c0fdbbe01ccaa3fae173fd Mon Sep 17 00:00:00 2001 From: Markus Bilz Date: Mon, 20 Nov 2023 20:37:02 +0100 Subject: [PATCH] Add simplified type hints --- src/tclf/classical_classifier.py | 46 ++++++++++++++---------------- src/tclf/types.py | 9 ++++++ tests/templates.py | 7 +---- tests/test_classical_classifier.py | 1 - 4 files changed, 32 insertions(+), 31 deletions(-) create mode 100644 src/tclf/types.py diff --git a/src/tclf/classical_classifier.py b/src/tclf/classical_classifier.py index ca681fd..61fac3f 100644 --- a/src/tclf/classical_classifier.py +++ b/src/tclf/classical_classifier.py @@ -15,6 +15,8 @@ from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.validation import _check_sample_weight, check_is_fitted, check_X_y +from tclf.types import ArrayLike, MatrixLike + allowed_func_str = ( "tick", "rev_tick", @@ -37,34 +39,30 @@ class ClassicalClassifier(ClassifierMixin, BaseEstimator): """ClassicalClassifier implements several trade classification rules. Including: - * Tick test - * Reverse tick test - * Quote rule - * LR algorithm - * LR algorithm with reverse tick test - * EMO algorithm - * EMO algorithm with reverse tick test - * CLNV algorithm - * CLNV algorithm with reverse tick test - * Trade size rule - * Depth rule - * nan + Tick test, + Reverse tick test, + Quote rule, + LR algorithm, + EMO algorithm, + CLNV algorithm, + Trade size rule, + Depth rule, + and nan Args: - ---- - ClassifierMixin (_type_): ClassifierMixin - BaseEstimator (_type_): Baseestimator + classifier mixin (ClassifierMixin): mixin for classifier functionality, such as `predict_proba()` + base estimator (BaseEstimator): base estimator for basic functionality, such as `transform()` """ def __init__( self, - *, layers: list[ tuple[ str, str, ] ], + *, features: list[str] | None = None, random_state: float | None = 42, strategy: Literal["random", "const"] = "random", @@ -383,15 +381,15 @@ def _nan(self, *args: Any) -> npt.NDArray: def fit( self, - X: npt.NDArray | pd.DataFrame, - y: npt.NDArray | pd.Series, + X: MatrixLike, + y: ArrayLike, sample_weight: npt.NDArray | None = None, ) -> ClassicalClassifier: """Fit the classifier. Args: - X (npt.NDArray | pd.DataFrame): features - y (npt.NDArray | pd.Series): ground truth (ignored) + X (MatrixLike): features + y (ArrayLike): ground truth (ignored) sample_weight (npt.NDArray | None, optional): Sample weights. Defaults to None. Raises: @@ -458,11 +456,11 @@ def fit( return self - def predict(self, X: npt.NDArray | pd.DataFrame) -> npt.NDArray: + def predict(self, X: MatrixLike) -> npt.NDArray: """Perform classification on test vectors `X`. Args: - X (npt.NDArray | pd.DataFrame): feature matrix. + X (MatrixLike): feature matrix. Returns: npt.NDArray: Predicted traget values for X. @@ -498,7 +496,7 @@ def predict(self, X: npt.NDArray | pd.DataFrame) -> npt.NDArray: del self.X_ return pred - def predict_proba(self, X: npt.NDArray | pd.DataFrame) -> npt.NDArray: + def predict_proba(self, X: MatrixLike) -> npt.NDArray: """Predict class probabilities for X. Probabilities are either 0 or 1 depending on the class. @@ -506,7 +504,7 @@ def predict_proba(self, X: npt.NDArray | pd.DataFrame) -> npt.NDArray: For strategy 'constant' probabilities are (0.5,0.5) for unclassified classes. Args: - X (npt.NDArray | pd.DataFrame): feature matrix + X (MatrixLike): feature matrix Returns: npt.NDArray: probabilities diff --git a/src/tclf/types.py b/src/tclf/types.py new file mode 100644 index 0000000..2fdaa74 --- /dev/null +++ b/src/tclf/types.py @@ -0,0 +1,9 @@ +"""Common type hints.""" + +import numpy as np +import numpy.typing as npt +import pandas as pd +from scipy.sparse import spmatrix + +MatrixLike = np.ndarray | pd.DataFrame | spmatrix +ArrayLike = npt.ArrayLike | pd.Series diff --git a/tests/templates.py b/tests/templates.py index 596acbb..669cf87 100644 --- a/tests/templates.py +++ b/tests/templates.py @@ -12,12 +12,7 @@ class ClassifierMixin: - """Perform automated tests for Classifiers. - - Args: - ---- - unittest (_type_): unittest module - """ + """Perform automated tests for Classifiers.""" clf: BaseEstimator x_test: pd.DataFrame diff --git a/tests/test_classical_classifier.py b/tests/test_classical_classifier.py index 3674d9c..4b84837 100644 --- a/tests/test_classical_classifier.py +++ b/tests/test_classical_classifier.py @@ -16,7 +16,6 @@ class TestClassicalClassifier(ClassifierMixin): """Perform automated tests for ClassicalClassifier. Args: - ---- unittest (_type_): unittest module """