From 17ee0d2b77179881e57a3ba0976b760fd8343af8 Mon Sep 17 00:00:00 2001 From: Suraj Aralihalli Date: Wed, 23 Oct 2024 15:02:24 -0700 Subject: [PATCH] update transpiler Signed-off-by: Suraj Aralihalli --- .../com/nvidia/spark/rapids/RegexParser.scala | 39 ++++++++----------- .../spark/sql/rapids/stringFunctions.scala | 35 ++++++++++------- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 362a9cce293..a9f3b10d929 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -847,23 +847,12 @@ class CudfRegexTranspiler(mode: RegexMode) { // the end of a line of an input character sequence. // this method produces a RegexAST which outputs a regular expression to match any possible // combination of line terminators - private def lineTerminatorMatcher(exclude: Set[Char], excludeCRLF: Boolean, - capture: Boolean): RegexAST = { - val terminatorChars = new ListBuffer[RegexCharacterClassComponent]() - terminatorChars ++= lineTerminatorChars.filter(!exclude.contains(_)).map(RegexChar) - - if (terminatorChars.size == 0 && excludeCRLF) { + private def lineTerminatorMatcher(excludeCRLF: Boolean, capture: Boolean): RegexAST = { + if (excludeCRLF) { RegexEmpty() - } else if (terminatorChars.size == 0) { + } else { RegexGroup(capture = capture, RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n'))), None) - } else if (excludeCRLF) { - RegexGroup(capture = capture, - RegexCharacterClass(negated = false, characters = terminatorChars), - None - ) - } else { - RegexGroup(capture = capture, RegexParser.parse("\r|\u0085|\u2028|\u2029|\r\n"), None) } } @@ -1104,8 +1093,8 @@ class CudfRegexTranspiler(mode: RegexMode) { } } RegexSequence(ListBuffer( - RegexRepetition(lineTerminatorMatcher(Set(ch), true, - mode == RegexReplaceMode), SimpleQuantifier('?')), + RegexRepetition(lineTerminatorMatcher(excludeCRLF = true, + capture = mode == RegexReplaceMode), SimpleQuantifier('?')), RegexChar('$'))) case Some(RegexEscaped('b')) | Some(RegexEscaped('B')) => throw new RegexUnsupportedException( @@ -1119,8 +1108,8 @@ class CudfRegexTranspiler(mode: RegexMode) { } } RegexSequence(ListBuffer( - RegexRepetition(lineTerminatorMatcher(Set.empty, false, - mode == RegexReplaceMode), SimpleQuantifier('?')), + RegexRepetition(lineTerminatorMatcher(excludeCRLF = false, + capture = mode == RegexReplaceMode), SimpleQuantifier('?')), RegexChar('$'))) } case '^' if mode == RegexSplitMode => @@ -1367,18 +1356,21 @@ class CudfRegexTranspiler(mode: RegexMode) { case RegexGroup(capture, RegexSequence( ListBuffer(RegexCharacterClass(true, parts))), _) if parts.forall(!isBeginOrEndLineAnchor(_)) => - r(j) = RegexSequence(ListBuffer(lineTerminatorMatcher(Set.empty, true, capture), + r(j) = RegexSequence( + ListBuffer(lineTerminatorMatcher(excludeCRLF = true, capture = capture), RegexChar('$'))) popBackrefIfNecessary(capture) case RegexGroup(capture, RegexCharacterClass(true, parts), _) if parts.forall(!isBeginOrEndLineAnchor(_)) => - r(j) = RegexSequence(ListBuffer(lineTerminatorMatcher(Set.empty, true, capture), + r(j) = RegexSequence(ListBuffer( + lineTerminatorMatcher(excludeCRLF = true, capture = capture), RegexChar('$'))) popBackrefIfNecessary(capture) case RegexCharacterClass(true, parts) if parts.forall(!isBeginOrEndLineAnchor(_)) => - r(j) = RegexSequence( - ListBuffer(lineTerminatorMatcher(Set.empty, true, false), RegexChar('$'))) + r(j) = RegexSequence(ListBuffer( + lineTerminatorMatcher(excludeCRLF = true, capture = false), + RegexChar('$'))) popBackrefIfNecessary(false) case RegexChar(ch) if ch == '\n' => // what's really needed here is negative lookahead, but that is not @@ -1391,7 +1383,8 @@ class CudfRegexTranspiler(mode: RegexMode) { ListBuffer( rewrite(part, replacement, None, flags), RegexSequence(ListBuffer( - RegexRepetition(lineTerminatorMatcher(Set(ch), true, false), + RegexRepetition( + lineTerminatorMatcher(excludeCRLF = true, capture = false), SimpleQuantifier('?')), RegexChar('$'))))) popBackrefIfNecessary(false) case RegexEscaped('z') => diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index c8a90dc80ad..f51d12e7145 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.rapids import java.nio.charset.Charset import java.text.DecimalFormatSymbols -import java.util.{Locale, Optional} - +import java.util.{EnumSet, Locale, Optional} import scala.collection.mutable.ArrayBuffer - -import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexProgram, RoundMode, Scalar} +import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -30,7 +28,6 @@ import com.nvidia.spark.rapids.jni.CastStrings import com.nvidia.spark.rapids.jni.GpuSubstringIndexUtils import com.nvidia.spark.rapids.jni.RegexRewriteUtils import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl} - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -1114,7 +1111,8 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String) override def toString: String = s"$left gpurlike $right" override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { - lhs.getBase.containsRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE)) + lhs.getBase.containsRe(new RegexProgram(pattern, + EnumSet.of(RegexFlag.EXT_NEWLINE), CaptureGroups.NON_CAPTURE)) } override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = { @@ -1231,7 +1229,8 @@ case class GpuRegExpReplace( throw new IllegalStateException("Need a replace") } case _ => - val prog = new RegexProgram(cudfRegexPattern, CaptureGroups.NON_CAPTURE) + val prog = new RegexProgram(cudfRegexPattern, + EnumSet.of(RegexFlag.EXT_NEWLINE), CaptureGroups.NON_CAPTURE) if (SparkShimImpl.reproduceEmptyStringBug && GpuRegExpUtils.isEmptyRepetition(javaRegexpPattern)) { val isEmpty = withResource(strExpr.getBase.getCharLengths) { len => @@ -1275,7 +1274,7 @@ case class GpuRegExpReplaceWithBackref( override def dataType: DataType = StringType override protected def doColumnar(input: GpuColumnVector): ColumnVector = { - val prog = new RegexProgram(cudfRegexPattern) + val prog = new RegexProgram(cudfRegexPattern, EnumSet.of(RegexFlag.EXT_NEWLINE)) if (SparkShimImpl.reproduceEmptyStringBug && GpuRegExpUtils.isEmptyRepetition(javaRegexpPattern)) { val isEmpty = withResource(input.getBase.getCharLengths) { len => @@ -1416,7 +1415,8 @@ case class GpuRegExpExtract( // | 'a1a' | '1' | '1' | // | '1a1' | '' | NULL | - withResource(str.getBase.extractRe(new RegexProgram(extractPattern))) { extract => + withResource(str.getBase.extractRe(new RegexProgram(extractPattern, + EnumSet.of(RegexFlag.EXT_NEWLINE)))) { extract => withResource(GpuScalar.from("", DataTypes.StringType)) { emptyString => val outputNullAndInputNotNull = withResource(extract.getColumn(groupIndex).isNull) { outputNull => @@ -1514,7 +1514,8 @@ case class GpuRegExpExtractAll( idx: GpuScalar): ColumnVector = { idx.getValue.asInstanceOf[Int] match { case 0 => - val prog = new RegexProgram(cudfRegexPattern, CaptureGroups.NON_CAPTURE) + val prog = new RegexProgram(cudfRegexPattern, + EnumSet.of(RegexFlag.EXT_NEWLINE), CaptureGroups.NON_CAPTURE) str.getBase.extractAllRecord(prog, 0) case _ => // Extract matches corresponding to idx. cuDF's extract_all_record does not support @@ -1529,7 +1530,7 @@ case class GpuRegExpExtractAll( // 2nd element afterwards from the cuDF list val rowCount = str.getRowCount - val prog = new RegexProgram(cudfRegexPattern) + val prog = new RegexProgram(cudfRegexPattern, EnumSet.of(RegexFlag.EXT_NEWLINE)) val extractedWithNulls = withResource( // Now the index is always 1 because we have transpiled all the capture groups to the @@ -1795,7 +1796,8 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression, case 0 => // Same as splitting as many times as possible if (isRegExp) { - str.getBase.stringSplitRecord(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), -1) + str.getBase.stringSplitRecord(new RegexProgram(pattern, + EnumSet.of(RegexFlag.EXT_NEWLINE) ,CaptureGroups.NON_CAPTURE), -1) } else { str.getBase.stringSplitRecord(pattern, -1) } @@ -1810,7 +1812,8 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression, } case n => if (isRegExp) { - str.getBase.stringSplitRecord(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), n) + str.getBase.stringSplitRecord(new RegexProgram(pattern, + EnumSet.of(RegexFlag.EXT_NEWLINE) ,CaptureGroups.NON_CAPTURE), n) } else { str.getBase.stringSplitRecord(pattern, n) } @@ -1923,7 +1926,8 @@ case class GpuStringToMap(strExpr: Expression, private def toMap(str: GpuColumnVector): GpuColumnVector = { // Firstly, split the input strings into lists of strings. val listsOfStrings = if (isPairDelimRegExp) { - str.getBase.stringSplitRecord(new RegexProgram(pairDelim, CaptureGroups.NON_CAPTURE)) + str.getBase.stringSplitRecord(new RegexProgram(pairDelim, + EnumSet.of(RegexFlag.EXT_NEWLINE), CaptureGroups.NON_CAPTURE)) } else { str.getBase.stringSplitRecord(pairDelim) } @@ -1932,7 +1936,8 @@ case class GpuStringToMap(strExpr: Expression, withResource(listsOfStrings.getChildColumnView(0)) { stringsCol => // Split the key-value strings into pairs of strings of key-value (using limit = 2). val keysValuesTable = if (isKeyValueDelimRegExp) { - stringsCol.stringSplit(new RegexProgram(keyValueDelim, CaptureGroups.NON_CAPTURE), 2) + stringsCol.stringSplit(new RegexProgram(keyValueDelim, + EnumSet.of(RegexFlag.EXT_NEWLINE), CaptureGroups.NON_CAPTURE), 2) } else { stringsCol.stringSplit(keyValueDelim, 2) }