Skip to content

Commit

Permalink
Extracted 'sklearn2pmml.util.check_expression()' and 'sklearn2pmml.ut…
Browse files Browse the repository at this point in the history
…il.check_predicate()' utility functions
  • Loading branch information
vruusmann committed Mar 30, 2024
1 parent 6fd22d7 commit d09ffb5
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 43 deletions.
5 changes: 2 additions & 3 deletions sklearn2pmml/ensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sklearn.linear_model._base import LinearClassifierMixin, LinearModel, SparseCoefMixin
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils.metaestimators import _BaseComposition
from sklearn2pmml.util import eval_rows, fqn, to_expr_func, Predicate
from sklearn2pmml.util import check_predicate, eval_rows, fqn, to_expr_func

import copy
import numpy
Expand Down Expand Up @@ -155,8 +155,7 @@ def __init__(self, steps, controller):
if len(step) != 3:
raise TypeError("Step is not a three-element (name, estimator, predicate) tuple")
name, estimator, predicate = step
if not isinstance(predicate, (str, Predicate)):
raise TypeError()
check_predicate(predicate)
self.steps = steps
if controller:
if not hasattr(controller, "transform"):
Expand Down
9 changes: 3 additions & 6 deletions sklearn2pmml/expression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from scipy.special import expit, softmax
from sklearn.preprocessing import normalize
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn2pmml.util import eval_rows, to_expr_func, Expression
from sklearn2pmml.util import check_expression, eval_rows, to_expr_func

import numpy

class ExpressionRegressor(BaseEstimator, RegressorMixin):

def __init__(self, expr, normalization_method):
if not isinstance(expr, Expression):
raise TypeError()
self.expr = expr
self.expr = check_expression(expr)
normalization_methods = ["none", "exp"]
if normalization_method not in normalization_methods:
raise ValueError("Normalization method {0} not in {1}".format(normalization_method, normalization_methods))
Expand Down Expand Up @@ -42,8 +40,7 @@ def __init__(self, class_exprs, normalization_method):
for k, v in class_exprs.items():
if k is None:
raise ValueError()
if not isinstance(v, Expression):
raise TypeError()
check_expression(v)
self.class_exprs = class_exprs
normalization_methods = ["none", "logit", "simplemax", "softmax"]
if normalization_method not in normalization_methods:
Expand Down
8 changes: 4 additions & 4 deletions sklearn2pmml/expression/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_binary_fit_predict(self):
self.assertEqual([[0.0, 1.0], [1.0, 0.0], [0.0, 1.0]], pred_proba.tolist())

class_exprs = {
"yes" : Expression("X[0] ** X[1] + X[2]")
"yes" : "X[0] ** X[1] + X[2]"
}
classifier = ExpressionClassifier(class_exprs, normalization_method = "logit")
X = numpy.asarray([[2, 1/2, 0], [2, 1, 0], [2, 2, -6]])
Expand All @@ -91,7 +91,7 @@ def test_binary_fit_predict(self):
self.assertAlmostEqual(3, numpy.sum(pred_proba))
self.assertEqual([expit(numpy.sqrt(2)), expit(2), expit(-2)], pred_proba[:, 1].tolist())

class_exprs["no"] = Expression("1.0")
class_exprs["no"] = "1.0"
classifier = ExpressionClassifier(class_exprs, normalization_method = "softmax")
classifier.fit(X, y)
self.assertEqual(["no", "yes"], classifier.classes_.tolist())
Expand All @@ -107,8 +107,8 @@ def test_binary_fit_predict(self):

def test_multiclass_fit_predict(self):
class_exprs = {
"0" : Expression("X[0]"),
"1" : Expression("X[1]"),
"0" : "X[0]",
"1" : "X[1]",
}
classifier = ExpressionClassifier(class_exprs, normalization_method = "none")
X = numpy.asarray([[0.6, 0.3], [0.2, 0.2], [0.1, 0.7]])
Expand Down
9 changes: 3 additions & 6 deletions sklearn2pmml/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sklearn.exceptions import NotFittedError
from sklearn.pipeline import Pipeline
from sklearn2pmml import _is_pandas_categorical, _is_proto_pandas_categorical
from sklearn2pmml.util import cast, dt_transform, ensure_def, eval_rows, is_1d, to_1d, to_expr_func, to_numpy, Expression, Predicate, Reshaper
from sklearn2pmml.util import cast, check_expression, check_predicate, dt_transform, ensure_def, eval_rows, is_1d, to_1d, to_expr_func, to_numpy, Reshaper

import numpy
import pandas
Expand Down Expand Up @@ -237,9 +237,7 @@ class ExpressionTransformer(BaseEstimator, TransformerMixin):
"""

def __init__(self, expr, map_missing_to = None, default_value = None, invalid_value_treatment = None, dtype = None):
if not isinstance(expr, (str, Expression)):
raise TypeError()
self.expr = expr
self.expr = check_expression(expr)
self.map_missing_to = map_missing_to
self.default_value = default_value
invalid_value_treatments = ["return_invalid", "as_missing"]
Expand Down Expand Up @@ -705,8 +703,7 @@ def __init__(self, steps, controller = None, eval_rows = True):
if len(step) != 3:
raise TypeError("Step is not a three-element (name, transformer, predicate) tuple")
name, transformer, predicate = step
if not isinstance(predicate, (str, Predicate)):
raise TypeError()
check_predicate(predicate)
self.steps = steps
if controller:
if not hasattr(controller, "transform"):
Expand Down
5 changes: 2 additions & 3 deletions sklearn2pmml/ruleset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn2pmml.util import eval_rows, to_expr_func, Predicate
from sklearn2pmml.util import check_predicate, eval_rows, to_expr_func

class RuleSetClassifier(BaseEstimator, ClassifierMixin):

Expand All @@ -10,8 +10,7 @@ def __init__(self, rules, default_score = None):
if len(rule) != 2:
raise TypeError("Rule is not a two-element (predicate, score) tuple")
predicate, score = rule
if not isinstance(predicate, (str, Predicate)):
raise TypeError()
check_predicate(predicate)
self.rules = rules
self.default_score = default_score

Expand Down
10 changes: 10 additions & 0 deletions sklearn2pmml/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ class Predicate(Evaluatable):
def __init__(self, expr, function_defs = []):
super(Predicate, self).__init__(expr = expr, function_defs = function_defs)

def check_expression(expression):
if not isinstance(expression, (str, Expression)):
raise TypeError("The expression object is not a string nor an instance of {0}".format(Expression.__name__))
return expression

def check_predicate(predicate):
if not isinstance(predicate, (str, Predicate)):
raise TypeError("The predicate object is not a string nor an instance of {0}".format(Predicate.__name__))
return predicate

def to_expr(expr):
if isinstance(expr, str):
return expr
Expand Down
56 changes: 35 additions & 21 deletions sklearn2pmml/util/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pandas import DataFrame
from sklearn2pmml.util import sizeof, deep_sizeof, to_expr, to_expr_func, fqn, Evaluatable, Slicer, Reshaper
from sklearn2pmml.util import check_expression, check_predicate, fqn, sizeof, deep_sizeof, to_expr, to_expr_func, Evaluatable, Expression, Predicate, Slicer, Reshaper
from unittest import TestCase

import inspect
Expand Down Expand Up @@ -55,20 +55,34 @@ def _signum(X):
else:
return 0

def _is_negative(x):
return (_trunc(x) < 0)

def _is_positive(x):
return (_trunc(x) > 0)

def _trunc(x):
return math.trunc(x)

class Dummy:
pass

class FunctionTest(TestCase):

def test_check_expression(self):
self.assertIsNotNone(check_expression("1.0"))
self.assertIsNotNone(check_expression(Expression("1.0")))
with self.assertRaises(TypeError):
check_expression(Predicate(str(True)))

def test_check_predicate(self):
self.assertIsNotNone(check_predicate(str(True)))
self.assertIsNotNone(check_predicate(str(True)))
with self.assertRaises(TypeError):
check_predicate(Expression("1.0"))

def test_fqn(self):
obj = ""
self.assertEqual("builtins.str", fqn(str))
self.assertEqual("builtins.type", fqn(str.__class__))
self.assertEqual("builtins.str", fqn(obj))

obj = Dummy()
self.assertEqual("sklearn2pmml.util.tests.Dummy", fqn(Dummy))
self.assertEqual("builtins.type", fqn(Dummy.__class__))
self.assertEqual("sklearn2pmml.util.tests.Dummy", fqn(obj))

def test_inline_expr(self):
expr = "pandas.isnull(X[0])"
expr = to_expr(expr)
Expand Down Expand Up @@ -105,6 +119,17 @@ def test_inline_def_expr(self):
self.assertEqual(0, expr_func([0]))
self.assertEqual(1, expr_func([1.5]))

def _is_negative(x):
return (_trunc(x) < 0)

def _is_positive(x):
return (_trunc(x) > 0)

def _trunc(x):
return math.trunc(x)

class EvaluatableTest(TestCase):

def test_evaluatable_expr(self):
expr = Evaluatable("_is_negative(X[0])")
expr = to_expr(expr)
Expand Down Expand Up @@ -132,17 +157,6 @@ def test_evaluatable_expr(self):
self.assertEqual(0, expr_func([0]))
self.assertEqual(1, expr_func([1.5]))

def test_fqn(self):
obj = ""
self.assertEqual("builtins.str", fqn(str))
self.assertEqual("builtins.type", fqn(str.__class__))
self.assertEqual("builtins.str", fqn(obj))

obj = Dummy()
self.assertEqual("sklearn2pmml.util.tests.Dummy", fqn(Dummy))
self.assertEqual("builtins.type", fqn(Dummy.__class__))
self.assertEqual("sklearn2pmml.util.tests.Dummy", fqn(obj))

class ReshaperTest(TestCase):

def test_transform(self):
Expand Down

0 comments on commit d09ffb5

Please sign in to comment.