Skip to content

Commit

Permalink
Optimized 'SelectFirstTransformer.transform(X)' method
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Nov 1, 2023
1 parent 67c2533 commit c4d2a9d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
14 changes: 8 additions & 6 deletions sklearn2pmml/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,10 +651,11 @@ def fit(self, X, y = None):
X_eval = self._to_evaluation_dataset(X)
mask = numpy.zeros(X.shape[0], dtype = bool)
for name, transformer, predicate in self.steps:
step_mask = eval_expr_rows(X_eval, predicate, dtype = bool)
step_mask[mask] = False
if numpy.sum(step_mask) < 1:
step_mask = numpy.logical_not(mask)
step_mask_eval = eval_expr_rows(X_eval[step_mask], predicate, dtype = bool)
if numpy.sum(step_mask_eval) < 1:
raise ValueError(predicate)
step_mask[step_mask] = step_mask_eval
step_X = X[step_mask]
step_y = y[step_mask] if y is not None else None
transformer.fit(step_X, step_y)
Expand All @@ -666,10 +667,11 @@ def transform(self, X):
X_eval = self._to_evaluation_dataset(X)
mask = numpy.zeros(X.shape[0], dtype = bool)
for name, transformer, predicate in self.steps:
step_mask = eval_expr_rows(X_eval, predicate, dtype = bool)
step_mask[mask] = False
if numpy.sum(step_mask) < 1:
step_mask = numpy.logical_not(mask)
step_mask_eval = eval_expr_rows(X_eval[step_mask], predicate, dtype = bool)
if numpy.sum(step_mask_eval) < 1:
continue
step_mask[step_mask] = step_mask_eval
step_X = X[step_mask]
step_result = transformer.transform(step_X)
step_result = _to_sparse(X, step_mask, step_result)
Expand Down
13 changes: 12 additions & 1 deletion sklearn2pmml/preprocessing/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.impute import SimpleImputer
from sklearn2pmml.decoration import Alias, DateDomain, DateTimeDomain
from sklearn2pmml.preprocessing import Aggregator, CastTransformer, ConcatTransformer, CutTransformer, DataFrameConstructor, DateTimeFormatter, DaysSinceYearTransformer, ExpressionTransformer, FilterLookupTransformer, IdentityTransformer, LookupTransformer, MatchesTransformer, MultiLookupTransformer, NumberFormatter, PMMLLabelBinarizer, PMMLLabelEncoder, PowerFunctionTransformer, ReplaceTransformer, SecondsSinceMidnightTransformer, SecondsSinceYearTransformer, StringNormalizer, SubstringTransformer, WordCountTransformer
from sklearn2pmml.preprocessing import Aggregator, CastTransformer, ConcatTransformer, CutTransformer, DataFrameConstructor, DateTimeFormatter, DaysSinceYearTransformer, ExpressionTransformer, FilterLookupTransformer, IdentityTransformer, LookupTransformer, MatchesTransformer, MultiLookupTransformer, NumberFormatter, PMMLLabelBinarizer, PMMLLabelEncoder, PowerFunctionTransformer, ReplaceTransformer, SecondsSinceMidnightTransformer, SecondsSinceYearTransformer, SelectFirstTransformer, StringNormalizer, SubstringTransformer, WordCountTransformer
from sklearn2pmml.preprocessing.h2o import H2OFrameConstructor, H2OFrameCreator
from sklearn2pmml.preprocessing.lightgbm import make_lightgbm_column_transformer, make_lightgbm_dataframe_mapper
from sklearn2pmml.preprocessing.xgboost import make_xgboost_column_transformer, make_xgboost_dataframe_mapper
Expand Down Expand Up @@ -567,6 +567,17 @@ def test_transform(self):
transformer = WordCountTransformer()
self.assertEqual([[0], [2], [3], [0]], transformer.transform(X).tolist())

class SelectFirstTransformerTest(TestCase):

def test_fit_transform(self):
X = DataFrame([["A", 1.0], ["B", 0], ["A", 3.0], ["C", -1.5]], columns = ["subset", "value"])
transformer = SelectFirstTransformer([
("A", ExpressionTransformer("X['value'] + 1.0"), "X['subset'] == 'A'"),
("B", ExpressionTransformer("X['value'] - 1.0"), "X['subset'] not in ['A', 'C']")
])
Xt = transformer.fit_transform(X)
self.assertEqual([[2.0], [-1.0], [4.0], [None]], Xt.tolist())

class H2OFrameCreatorTest(TestCase):

def test_init(self):
Expand Down

0 comments on commit c4d2a9d

Please sign in to comment.