From 1fa5918b66585f8485d741651246b8d256e678da Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Mon, 14 Oct 2024 13:37:14 +0300 Subject: [PATCH] Added 'MatcherTransformer.re_flavour' and 'ReplaceTransformer.re_flavour' attributes See https://github.com/jpmml/sklearn2pmml/issues/228 --- sklearn2pmml/preprocessing/__init__.py | 10 +++++---- sklearn2pmml/preprocessing/regex.py | 23 +++++++++++++++----- sklearn2pmml/preprocessing/tests/__init__.py | 4 ++-- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/sklearn2pmml/preprocessing/__init__.py b/sklearn2pmml/preprocessing/__init__.py index 971cc7b..50358bc 100644 --- a/sklearn2pmml/preprocessing/__init__.py +++ b/sklearn2pmml/preprocessing/__init__.py @@ -574,8 +574,9 @@ def transform(self, X): class MatchesTransformer(BaseEstimator, TransformerMixin): """Match RE pattern.""" - def __init__(self, pattern): + def __init__(self, pattern, re_flavour = None): self.pattern = pattern + self.re_flavour = re_flavour def fit(self, X, y = None): to_1d(X) @@ -583,7 +584,7 @@ def fit(self, X, y = None): def transform(self, X): X1d = to_1d(X) - regex_engine = make_regex_engine(self.pattern) + regex_engine = make_regex_engine(self.pattern, self.re_flavour) func = lambda x: bool(regex_engine.matches(x)) Xt = eval_rows(X1d, func, shape = X.shape) return Xt @@ -591,9 +592,10 @@ def transform(self, X): class ReplaceTransformer(BaseEstimator, TransformerMixin): """Replace all RE pattern matches.""" - def __init__(self, pattern, replacement): + def __init__(self, pattern, replacement, re_flavour = None): self.pattern = pattern self.replacement = replacement + self.re_flavour = re_flavour def fit(self, X, y = None): to_1d(X) @@ -601,7 +603,7 @@ def fit(self, X, y = None): def transform(self, X): X1d = to_1d(X) - regex_engine = make_regex_engine(self.pattern) + regex_engine = make_regex_engine(self.pattern, self.re_flavour) func = lambda x: regex_engine.replace(self.replacement, x) Xt = eval_rows(X1d, func, shape = X.shape) return Xt diff --git a/sklearn2pmml/preprocessing/regex.py b/sklearn2pmml/preprocessing/regex.py index 74bde44..c2eb6cd 100644 --- a/sklearn2pmml/preprocessing/regex.py +++ b/sklearn2pmml/preprocessing/regex.py @@ -58,9 +58,22 @@ def matches(self, x): def replace(self, replacement, x): return self.pattern_.substitute(replacement, x) -def make_regex_engine(pattern): - try: +def make_regex_engine(pattern, re_flavour): + if re_flavour is None: + try: + import pcre2 + + re_flavour = "pcre2" + except ImportError: + warnings.warn("Perl Compatible Regular Expressions (PCRE) library is not available, falling back to built-in Regular Expressions (RE) library. Transformation results might not be reproducible between Python and PMML environments when using more complex patterns", Warning) + re_flavour = "re" + + if re_flavour == "pcre": return PCREEngine(pattern) - except ImportError: - warnings.warn("Perl Compatible Regular Expressions (PCRE) library is not available, falling back to built-in Regular Expressions (RE) library. Transformation results might not be reproducible between Python and PMML environments when using more complex patterns", Warning) - return REEngine(pattern) \ No newline at end of file + elif re_flavour == "pcre2": + return PCRE2Engine(pattern) + elif re_flavour == "re": + return REEngine(pattern) + else: + re_flavours = ["pcre", "pcre2", "re"] + raise ValueError("Regular Expressions flavour {0} not in {1}".format(re_flavour, re_flavours)) diff --git a/sklearn2pmml/preprocessing/tests/__init__.py b/sklearn2pmml/preprocessing/tests/__init__.py index c9e6fbe..a78704f 100644 --- a/sklearn2pmml/preprocessing/tests/__init__.py +++ b/sklearn2pmml/preprocessing/tests/__init__.py @@ -648,14 +648,14 @@ class MatchesTransformerTest(TestCase): def test_transform(self): X = numpy.asarray(["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]) - transformer = MatchesTransformer("ar?y") + transformer = MatchesTransformer("ar?y", re_flavour = "re") self.assertEqual([True, True, False, False, True, False, False, False, False, False, False, False], transformer.transform(X).tolist()) class ReplaceTransformerTest(TestCase): def test_transform(self): X = numpy.asarray(["A", "B", "BA", "BB", "BAB", "ABBA", "BBBB"]) - transformer = ReplaceTransformer("B+", "c") + transformer = ReplaceTransformer("B+", "c", re_flavour = "re") self.assertEqual(["A", "c", "cA", "c", "cAc", "AcA", "c"], transformer.transform(X).tolist()) vectorizer = CountVectorizer() pipeline = make_pipeline(transformer, vectorizer)