Skip to content

Commit

Permalink
Add simplified type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelZe committed Nov 20, 2023
1 parent 584390d commit 4ddf7aa
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 31 deletions.
46 changes: 22 additions & 24 deletions src/tclf/classical_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import _check_sample_weight, check_is_fitted, check_X_y

from tclf.types import ArrayLike, MatrixLike

allowed_func_str = (
"tick",
"rev_tick",
Expand All @@ -37,34 +39,30 @@ class ClassicalClassifier(ClassifierMixin, BaseEstimator):
"""ClassicalClassifier implements several trade classification rules.
Including:
* Tick test
* Reverse tick test
* Quote rule
* LR algorithm
* LR algorithm with reverse tick test
* EMO algorithm
* EMO algorithm with reverse tick test
* CLNV algorithm
* CLNV algorithm with reverse tick test
* Trade size rule
* Depth rule
* nan
Tick test,
Reverse tick test,
Quote rule,
LR algorithm,
EMO algorithm,
CLNV algorithm,
Trade size rule,
Depth rule,
and nan
Args:
----
ClassifierMixin (_type_): ClassifierMixin
BaseEstimator (_type_): Baseestimator
classifier mixin (ClassifierMixin): mixin for classifier functionality, such as `predict_proba()`
base estimator (BaseEstimator): base estimator for basic functionality, such as `transform()`
"""

def __init__(
self,
*,
layers: list[
tuple[
str,
str,
]
],
*,
features: list[str] | None = None,
random_state: float | None = 42,
strategy: Literal["random", "const"] = "random",
Expand Down Expand Up @@ -383,15 +381,15 @@ def _nan(self, *args: Any) -> npt.NDArray:

def fit(
self,
X: npt.NDArray | pd.DataFrame,
y: npt.NDArray | pd.Series,
X: MatrixLike,
y: ArrayLike,
sample_weight: npt.NDArray | None = None,
) -> ClassicalClassifier:
"""Fit the classifier.
Args:
X (npt.NDArray | pd.DataFrame): features
y (npt.NDArray | pd.Series): ground truth (ignored)
X (MatrixLike): features
y (ArrayLike): ground truth (ignored)
sample_weight (npt.NDArray | None, optional): Sample weights. Defaults to None.
Raises:
Expand Down Expand Up @@ -458,11 +456,11 @@ def fit(

return self

def predict(self, X: npt.NDArray | pd.DataFrame) -> npt.NDArray:
def predict(self, X: MatrixLike) -> npt.NDArray:
"""Perform classification on test vectors `X`.
Args:
X (npt.NDArray | pd.DataFrame): feature matrix.
X (MatrixLike): feature matrix.
Returns:
npt.NDArray: Predicted traget values for X.
Expand Down Expand Up @@ -498,15 +496,15 @@ def predict(self, X: npt.NDArray | pd.DataFrame) -> npt.NDArray:
del self.X_
return pred

def predict_proba(self, X: npt.NDArray | pd.DataFrame) -> npt.NDArray:
def predict_proba(self, X: MatrixLike) -> npt.NDArray:
"""Predict class probabilities for X.
Probabilities are either 0 or 1 depending on the class.
For strategy 'constant' probabilities are (0.5,0.5) for unclassified classes.
Args:
X (npt.NDArray | pd.DataFrame): feature matrix
X (MatrixLike): feature matrix
Returns:
npt.NDArray: probabilities
Expand Down
9 changes: 9 additions & 0 deletions src/tclf/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Common type hints."""

import numpy as np
import numpy.typing as npt
import pandas as pd
from scipy.sparse import spmatrix

MatrixLike = np.ndarray | pd.DataFrame | spmatrix
ArrayLike = npt.ArrayLike | pd.Series
7 changes: 1 addition & 6 deletions tests/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@


class ClassifierMixin:
"""Perform automated tests for Classifiers.
Args:
----
unittest (_type_): unittest module
"""
"""Perform automated tests for Classifiers."""

clf: BaseEstimator
x_test: pd.DataFrame
Expand Down
1 change: 0 additions & 1 deletion tests/test_classical_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TestClassicalClassifier(ClassifierMixin):
"""Perform automated tests for ClassicalClassifier.
Args:
----
unittest (_type_): unittest module
"""

Expand Down

0 comments on commit 4ddf7aa

Please sign in to comment.