Skip to content

Commit

Permalink
update transpiler
Browse files Browse the repository at this point in the history
Signed-off-by: Suraj Aralihalli <[email protected]>
  • Loading branch information
SurajAralihalli committed Oct 23, 2024
1 parent a071efe commit 17ee0d2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 38 deletions.
39 changes: 16 additions & 23 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -847,23 +847,12 @@ class CudfRegexTranspiler(mode: RegexMode) {
// the end of a line of an input character sequence.
// this method produces a RegexAST which outputs a regular expression to match any possible
// combination of line terminators
private def lineTerminatorMatcher(exclude: Set[Char], excludeCRLF: Boolean,
capture: Boolean): RegexAST = {
val terminatorChars = new ListBuffer[RegexCharacterClassComponent]()
terminatorChars ++= lineTerminatorChars.filter(!exclude.contains(_)).map(RegexChar)

if (terminatorChars.size == 0 && excludeCRLF) {
private def lineTerminatorMatcher(excludeCRLF: Boolean, capture: Boolean): RegexAST = {
if (excludeCRLF) {
RegexEmpty()
} else if (terminatorChars.size == 0) {
} else {
RegexGroup(capture = capture, RegexSequence(ListBuffer(RegexChar('\r'), RegexChar('\n'))),
None)
} else if (excludeCRLF) {
RegexGroup(capture = capture,
RegexCharacterClass(negated = false, characters = terminatorChars),
None
)
} else {
RegexGroup(capture = capture, RegexParser.parse("\r|\u0085|\u2028|\u2029|\r\n"), None)
}
}

Expand Down Expand Up @@ -1104,8 +1093,8 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}
RegexSequence(ListBuffer(
RegexRepetition(lineTerminatorMatcher(Set(ch), true,
mode == RegexReplaceMode), SimpleQuantifier('?')),
RegexRepetition(lineTerminatorMatcher(excludeCRLF = true,
capture = mode == RegexReplaceMode), SimpleQuantifier('?')),
RegexChar('$')))
case Some(RegexEscaped('b')) | Some(RegexEscaped('B')) =>
throw new RegexUnsupportedException(
Expand All @@ -1119,8 +1108,8 @@ class CudfRegexTranspiler(mode: RegexMode) {
}
}
RegexSequence(ListBuffer(
RegexRepetition(lineTerminatorMatcher(Set.empty, false,
mode == RegexReplaceMode), SimpleQuantifier('?')),
RegexRepetition(lineTerminatorMatcher(excludeCRLF = false,
capture = mode == RegexReplaceMode), SimpleQuantifier('?')),
RegexChar('$')))
}
case '^' if mode == RegexSplitMode =>
Expand Down Expand Up @@ -1367,18 +1356,21 @@ class CudfRegexTranspiler(mode: RegexMode) {
case RegexGroup(capture, RegexSequence(
ListBuffer(RegexCharacterClass(true, parts))), _)
if parts.forall(!isBeginOrEndLineAnchor(_)) =>
r(j) = RegexSequence(ListBuffer(lineTerminatorMatcher(Set.empty, true, capture),
r(j) = RegexSequence(
ListBuffer(lineTerminatorMatcher(excludeCRLF = true, capture = capture),
RegexChar('$')))
popBackrefIfNecessary(capture)
case RegexGroup(capture, RegexCharacterClass(true, parts), _)
if parts.forall(!isBeginOrEndLineAnchor(_)) =>
r(j) = RegexSequence(ListBuffer(lineTerminatorMatcher(Set.empty, true, capture),
r(j) = RegexSequence(ListBuffer(
lineTerminatorMatcher(excludeCRLF = true, capture = capture),
RegexChar('$')))
popBackrefIfNecessary(capture)
case RegexCharacterClass(true, parts)
if parts.forall(!isBeginOrEndLineAnchor(_)) =>
r(j) = RegexSequence(
ListBuffer(lineTerminatorMatcher(Set.empty, true, false), RegexChar('$')))
r(j) = RegexSequence(ListBuffer(
lineTerminatorMatcher(excludeCRLF = true, capture = false),
RegexChar('$')))
popBackrefIfNecessary(false)
case RegexChar(ch) if ch == '\n' =>
// what's really needed here is negative lookahead, but that is not
Expand All @@ -1391,7 +1383,8 @@ class CudfRegexTranspiler(mode: RegexMode) {
ListBuffer(
rewrite(part, replacement, None, flags),
RegexSequence(ListBuffer(
RegexRepetition(lineTerminatorMatcher(Set(ch), true, false),
RegexRepetition(
lineTerminatorMatcher(excludeCRLF = true, capture = false),
SimpleQuantifier('?')), RegexChar('$')))))
popBackrefIfNecessary(false)
case RegexEscaped('z') =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@ package org.apache.spark.sql.rapids

import java.nio.charset.Charset
import java.text.DecimalFormatSymbols
import java.util.{Locale, Optional}

import java.util.{EnumSet, Locale, Optional}
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexProgram, RoundMode, Scalar}
import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexFlag, RegexProgram, RoundMode, Scalar}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.CastStrings
import com.nvidia.spark.rapids.jni.GpuSubstringIndexUtils
import com.nvidia.spark.rapids.jni.RegexRewriteUtils
import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -1114,7 +1111,8 @@ case class GpuRLike(left: Expression, right: Expression, pattern: String)
override def toString: String = s"$left gpurlike $right"

override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
lhs.getBase.containsRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE))
lhs.getBase.containsRe(new RegexProgram(pattern,
EnumSet.of(RegexFlag.EXT_NEWLINE), CaptureGroups.NON_CAPTURE))
}

override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = {
Expand Down Expand Up @@ -1231,7 +1229,8 @@ case class GpuRegExpReplace(
throw new IllegalStateException("Need a replace")
}
case _ =>
val prog = new RegexProgram(cudfRegexPattern, CaptureGroups.NON_CAPTURE)
val prog = new RegexProgram(cudfRegexPattern,
EnumSet.of(RegexFlag.EXT_NEWLINE), CaptureGroups.NON_CAPTURE)
if (SparkShimImpl.reproduceEmptyStringBug &&
GpuRegExpUtils.isEmptyRepetition(javaRegexpPattern)) {
val isEmpty = withResource(strExpr.getBase.getCharLengths) { len =>
Expand Down Expand Up @@ -1275,7 +1274,7 @@ case class GpuRegExpReplaceWithBackref(
override def dataType: DataType = StringType

override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
val prog = new RegexProgram(cudfRegexPattern)
val prog = new RegexProgram(cudfRegexPattern, EnumSet.of(RegexFlag.EXT_NEWLINE))
if (SparkShimImpl.reproduceEmptyStringBug &&
GpuRegExpUtils.isEmptyRepetition(javaRegexpPattern)) {
val isEmpty = withResource(input.getBase.getCharLengths) { len =>
Expand Down Expand Up @@ -1416,7 +1415,8 @@ case class GpuRegExpExtract(
// | 'a1a' | '1' | '1' |
// | '1a1' | '' | NULL |

withResource(str.getBase.extractRe(new RegexProgram(extractPattern))) { extract =>
withResource(str.getBase.extractRe(new RegexProgram(extractPattern,
EnumSet.of(RegexFlag.EXT_NEWLINE)))) { extract =>
withResource(GpuScalar.from("", DataTypes.StringType)) { emptyString =>
val outputNullAndInputNotNull =
withResource(extract.getColumn(groupIndex).isNull) { outputNull =>
Expand Down Expand Up @@ -1514,7 +1514,8 @@ case class GpuRegExpExtractAll(
idx: GpuScalar): ColumnVector = {
idx.getValue.asInstanceOf[Int] match {
case 0 =>
val prog = new RegexProgram(cudfRegexPattern, CaptureGroups.NON_CAPTURE)
val prog = new RegexProgram(cudfRegexPattern,
EnumSet.of(RegexFlag.EXT_NEWLINE), CaptureGroups.NON_CAPTURE)
str.getBase.extractAllRecord(prog, 0)
case _ =>
// Extract matches corresponding to idx. cuDF's extract_all_record does not support
Expand All @@ -1529,7 +1530,7 @@ case class GpuRegExpExtractAll(
// 2nd element afterwards from the cuDF list

val rowCount = str.getRowCount
val prog = new RegexProgram(cudfRegexPattern)
val prog = new RegexProgram(cudfRegexPattern, EnumSet.of(RegexFlag.EXT_NEWLINE))

val extractedWithNulls = withResource(
// Now the index is always 1 because we have transpiled all the capture groups to the
Expand Down Expand Up @@ -1795,7 +1796,8 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression,
case 0 =>
// Same as splitting as many times as possible
if (isRegExp) {
str.getBase.stringSplitRecord(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), -1)
str.getBase.stringSplitRecord(new RegexProgram(pattern,
EnumSet.of(RegexFlag.EXT_NEWLINE) ,CaptureGroups.NON_CAPTURE), -1)
} else {
str.getBase.stringSplitRecord(pattern, -1)
}
Expand All @@ -1810,7 +1812,8 @@ case class GpuStringSplit(str: Expression, regex: Expression, limit: Expression,
}
case n =>
if (isRegExp) {
str.getBase.stringSplitRecord(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), n)
str.getBase.stringSplitRecord(new RegexProgram(pattern,
EnumSet.of(RegexFlag.EXT_NEWLINE) ,CaptureGroups.NON_CAPTURE), n)
} else {
str.getBase.stringSplitRecord(pattern, n)
}
Expand Down Expand Up @@ -1923,7 +1926,8 @@ case class GpuStringToMap(strExpr: Expression,
private def toMap(str: GpuColumnVector): GpuColumnVector = {
// Firstly, split the input strings into lists of strings.
val listsOfStrings = if (isPairDelimRegExp) {
str.getBase.stringSplitRecord(new RegexProgram(pairDelim, CaptureGroups.NON_CAPTURE))
str.getBase.stringSplitRecord(new RegexProgram(pairDelim,
EnumSet.of(RegexFlag.EXT_NEWLINE), CaptureGroups.NON_CAPTURE))
} else {
str.getBase.stringSplitRecord(pairDelim)
}
Expand All @@ -1932,7 +1936,8 @@ case class GpuStringToMap(strExpr: Expression,
withResource(listsOfStrings.getChildColumnView(0)) { stringsCol =>
// Split the key-value strings into pairs of strings of key-value (using limit = 2).
val keysValuesTable = if (isKeyValueDelimRegExp) {
stringsCol.stringSplit(new RegexProgram(keyValueDelim, CaptureGroups.NON_CAPTURE), 2)
stringsCol.stringSplit(new RegexProgram(keyValueDelim,
EnumSet.of(RegexFlag.EXT_NEWLINE), CaptureGroups.NON_CAPTURE), 2)
} else {
stringsCol.stringSplit(keyValueDelim, 2)
}
Expand Down

0 comments on commit 17ee0d2

Please sign in to comment.