Skip to content

Commit

Permalink
Update to_json to be more generic and fix some bugs (NVIDIA#11642)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Oct 23, 2024
1 parent 8e2e627 commit a071efe
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 106 deletions.
108 changes: 88 additions & 20 deletions integration_tests/src/main/python/json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ def test_read_case_col_name(spark_tmp_path, v1_enabled_list, col_name):
conf=all_confs)


@pytest.mark.parametrize('data_gen', [byte_gen,
_to_json_datagens=[byte_gen,
boolean_gen,
short_gen,
int_gen,
Expand All @@ -1033,36 +1033,84 @@ def test_read_case_col_name(spark_tmp_path, v1_enabled_list, col_name):
.with_special_case('\\\'a\\\''),
pytest.param(StringGen('\u001a', nullable=True), marks=pytest.mark.xfail(
reason='https://github.com/NVIDIA/spark-rapids/issues/9705'))
], ids=idfn)
]

@pytest.mark.parametrize('data_gen', _to_json_datagens, ids=idfn)
@pytest.mark.parametrize('ignore_null_fields', [True, False])
@pytest.mark.parametrize('pretty', [
pytest.param(True, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/9517')),
False
])
@pytest.mark.parametrize('timezone', [
'UTC',
'Etc/UTC',
pytest.param('UTC+07:00', marks=pytest.mark.allow_non_gpu('ProjectExec')),
'Etc/UTC'
])
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_structs_to_json(spark_tmp_path, data_gen, ignore_null_fields, pretty, timezone):
@allow_non_gpu(*non_utc_project_allow)
def test_structs_to_json(spark_tmp_path, data_gen, ignore_null_fields, timezone):
struct_gen = StructGen([
('a', data_gen),
("b", StructGen([('child', data_gen)], nullable=True)),
("c", ArrayGen(StructGen([('child', data_gen)], nullable=True))),
("d", MapGen(LongGen(nullable=False), data_gen)),
("d", MapGen(StringGen('[A-Za-z0-9]{0,10}', nullable=False), data_gen)),
("e", ArrayGen(MapGen(LongGen(nullable=False), data_gen), nullable=True)),
("e", ArrayGen(MapGen(StringGen('[A-Z]{5}', nullable=False), data_gen), nullable=True)),
], nullable=False)
gen = StructGen([('my_struct', struct_gen)], nullable=False)

options = { 'ignoreNullFields': ignore_null_fields,
'pretty': pretty,
'timeZone': timezone}

def struct_to_json(spark):
df = gen_df(spark, gen)
return df.withColumn("my_json", f.to_json("my_struct", options)).drop("my_struct")
return df.select(
f.to_json("my_struct", options).alias("ms"))

conf = copy_and_update(_enable_all_types_conf,
{ 'spark.rapids.sql.expression.StructsToJson': True })

assert_gpu_and_cpu_are_equal_collect(
lambda spark : struct_to_json(spark),
conf=conf)

@pytest.mark.parametrize('data_gen', _to_json_datagens, ids=idfn)
@pytest.mark.parametrize('ignore_null_fields', [True, False])
@pytest.mark.parametrize('timezone', [
'UTC',
'Etc/UTC'
])
@allow_non_gpu(*non_utc_project_allow)
def test_arrays_to_json(spark_tmp_path, data_gen, ignore_null_fields, timezone):
array_gen = ArrayGen(data_gen, nullable=True)
gen = StructGen([("my_array", array_gen)], nullable=False)

options = { 'ignoreNullFields': ignore_null_fields,
'timeZone': timezone}

def struct_to_json(spark):
df = gen_df(spark, gen)
return df.select(
f.to_json("my_array", options).alias("ma"))

conf = copy_and_update(_enable_all_types_conf,
{ 'spark.rapids.sql.expression.StructsToJson': True })

assert_gpu_and_cpu_are_equal_collect(
lambda spark : struct_to_json(spark),
conf=conf)

@pytest.mark.parametrize('data_gen', _to_json_datagens, ids=idfn)
@pytest.mark.parametrize('ignore_null_fields', [True, False])
@pytest.mark.parametrize('timezone', [
'UTC',
'Etc/UTC'
])
@allow_non_gpu(*non_utc_project_allow)
def test_maps_to_json(spark_tmp_path, data_gen, ignore_null_fields, timezone):
map_gen = MapGen(StringGen('[A-Z]{1,10}', nullable=False), data_gen, nullable=True)
gen = StructGen([("my_map", map_gen)], nullable=False)

options = { 'ignoreNullFields': ignore_null_fields,
'timeZone': timezone}

def struct_to_json(spark):
df = gen_df(spark, gen)
return df.select(
f.to_json("my_map", options).alias("mm"))

conf = copy_and_update(_enable_all_types_conf,
{ 'spark.rapids.sql.expression.StructsToJson': True })
Expand All @@ -1073,16 +1121,13 @@ def struct_to_json(spark):

@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
@pytest.mark.parametrize('timestamp_format', [
'yyyy-MM-dd\'T\'HH:mm:ss[.SSS][XXX]',
pytest.param('yyyy-MM-dd\'T\'HH:mm:ss.SSSXXX', marks=pytest.mark.allow_non_gpu('ProjectExec')),
pytest.param('dd/MM/yyyy\'T\'HH:mm:ss[.SSS][XXX]', marks=pytest.mark.allow_non_gpu('ProjectExec')),
'yyyy-MM-dd\'T\'HH:mm:ss[.SSS][XXX]'
])
@pytest.mark.parametrize('timezone', [
'UTC',
'Etc/UTC',
pytest.param('UTC+07:00', marks=pytest.mark.allow_non_gpu('ProjectExec')),
'Etc/UTC'
])
@pytest.mark.skipif(is_not_utc(), reason='Duplicated as original test case designed which it is parameterized by timezone. https://github.com/NVIDIA/spark-rapids/issues/9653.')
@allow_non_gpu(*non_utc_project_allow)
def test_structs_to_json_timestamp(spark_tmp_path, data_gen, timestamp_format, timezone):
struct_gen = StructGen([
("b", StructGen([('child', data_gen)], nullable=True)),
Expand Down Expand Up @@ -1211,6 +1256,29 @@ def struct_to_json(spark):
conf=conf)


@allow_non_gpu('ProjectExec')
def test_structs_to_json_fallback_pretty(spark_tmp_path):
struct_gen = StructGen([
('a', long_gen),
("b", byte_gen),
("c", ArrayGen(short_gen))
], nullable=False)
gen = StructGen([('my_struct', struct_gen)], nullable=False)

options = { 'pretty': True }

def struct_to_json(spark):
df = gen_df(spark, gen)
return df.withColumn("my_json", f.to_json("my_struct", options)).drop("my_struct")

conf = copy_and_update(_enable_all_types_conf,
{ 'spark.rapids.sql.expression.StructsToJson': True })

assert_gpu_fallback_collect(
lambda spark : struct_to_json(spark),
'ProjectExec',
conf=conf)

#####################################################
# Some from_json tests ported over from Apache Spark
#####################################################
Expand Down
124 changes: 51 additions & 73 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ object GpuCast {
fromDataType: DataType,
toDataType: DataType,
options: CastOptions = CastOptions.DEFAULT_CAST_OPTIONS): ColumnVector = {
if (options.castToJsonString && fromDataType == StringType && toDataType == StringType) {
// Special case because they are structurally equal
return escapeAndQuoteJsonString(input)
}
if (DataType.equalsStructurally(fromDataType, toDataType)) {
return input.copyToColumnVector()
}
Expand Down Expand Up @@ -707,7 +711,9 @@ object GpuCast {
def castToString(
input: ColumnView,
fromDataType: DataType, options: CastOptions): ColumnVector = fromDataType match {
case StringType if options.castToJsonString => escapeAndQuoteJsonString(input)
case StringType => input.copyToColumnVector()
case DateType if options.castToJsonString => castDateToJson(input)
case DateType => input.asStrings("%Y-%m-%d")
case TimestampType if options.castToJsonString => castTimestampToJson(input)
case TimestampType => castTimestampToString(input)
Expand Down Expand Up @@ -753,12 +759,22 @@ object GpuCast {
}
}

private def castDateToJson(input: ColumnView): ColumnVector = {
// We need to quote and escape the result.
withResource(input.asStrings("%Y-%m-%d")) { tmp =>
escapeAndQuoteJsonString(tmp)
}
}

private def castTimestampToJson(input: ColumnView): ColumnVector = {
// we fall back to CPU if the JSON timezone is not UTC, so it is safe
// to hard-code `Z` here for now, but we should really add a timestamp
// format to CastOptions when we add support for custom formats in
// https://github.com/NVIDIA/spark-rapids/issues/9602
input.asStrings("%Y-%m-%dT%H:%M:%S.%3fZ")
// We also need to quote and escape the result.
withResource(input.asStrings("%Y-%m-%dT%H:%M:%S.%3fZ")) { tmp =>
escapeAndQuoteJsonString(tmp)
}
}

/**
Expand Down Expand Up @@ -887,48 +903,17 @@ object GpuCast {

val numRows = input.getRowCount.toInt

/**
* Create a new column with quotes around the supplied string column. Caller
* is responsible for closing `column`.
*/
def addQuotes(column: ColumnVector, rowCount: Int): ColumnVector = {
withResource(ArrayBuffer.empty[ColumnVector]) { columns =>
withResource(Scalar.fromString("\"")) { quote =>
withResource(ColumnVector.fromScalar(quote, rowCount)) {
quoteScalar =>
columns += quoteScalar.incRefCount()
columns += escapeJsonString(column)
columns += quoteScalar.incRefCount()
}
}
withResource(Scalar.fromString("")) { emptyScalar =>
ColumnVector.stringConcatenate(emptyScalar, emptyScalar, columns.toArray)
}
}
}

// cast the key column and value column to string columns
val (strKey, strValue) = withResource(input.getChildColumnView(0)) { kvStructColumn =>
if (options.castToJsonString) {
// keys must have quotes around them in JSON mode
val strKey: ColumnVector = withResource(kvStructColumn.getChildColumnView(0)) { keyColumn =>
withResource(castToString(keyColumn, from.keyType, options)) { key =>
addQuotes(key, keyColumn.getRowCount.toInt)
}
// For JSON only Strings are supported as keys so they should already come back quoted
castToString(keyColumn, from.keyType, options)
}
// string values must have quotes around them in JSON mode, and null values need
// to be represented by the string literal `null`
// null values need to be represented by the string literal `null`
val strValue = closeOnExcept(strKey) { _ =>
withResource(kvStructColumn.getChildColumnView(1)) { valueColumn =>
val dt = valueColumn.getType
val valueStr = if (dt == DType.STRING || dt.isDurationType || dt.isTimestampType) {
withResource(castToString(valueColumn, from.valueType, options)) { valueStr =>
addQuotes(valueStr, valueColumn.getRowCount.toInt)
}
} else {
castToString(valueColumn, from.valueType, options)
}
withResource(valueStr) { _ =>
withResource(castToString(valueColumn, from.valueType, options)) { valueStr =>
withResource(Scalar.fromString("null")) { nullScalar =>
withResource(valueColumn.isNull) { isNull =>
isNull.ifElse(nullScalar, valueStr)
Expand Down Expand Up @@ -1088,12 +1073,8 @@ object GpuCast {
val rowCount = input.getRowCount.toInt

def castToJsonAttribute(fieldIndex: Int,
colon: ColumnVector,
quote: ColumnVector): ColumnVector = {
colon: ColumnVector): ColumnVector = {
val jsonName = StringEscapeUtils.escapeJson(inputSchema(fieldIndex).name)
val dt = inputSchema(fieldIndex).dataType
val needsQuoting = dt == DataTypes.StringType || dt == DataTypes.DateType ||
dt == DataTypes.TimestampType
withResource(input.getChildColumnView(fieldIndex)) { cv =>
withResource(ArrayBuffer.empty[ColumnVector]) { attrColumns =>
// prefix with quoted column name followed by colon
Expand All @@ -1105,13 +1086,7 @@ object GpuCast {
// write the value
withResource(castToString(cv, inputSchema(fieldIndex).dataType, options)) {
attrValue =>
if (needsQuoting) {
attrColumns += quote.incRefCount()
attrColumns += escapeJsonString(attrValue)
attrColumns += quote.incRefCount()
} else {
attrColumns += attrValue.incRefCount()
}
attrColumns += attrValue.incRefCount()
}
// now concatenate
val jsonAttr = withResource(Scalar.fromString("")) { emptyString =>
Expand All @@ -1126,23 +1101,9 @@ object GpuCast {
}
}
} else {
val jsonAttr = withResource(ArrayBuffer.empty[ColumnVector]) { attrValues =>
withResource(castToString(cv, inputSchema(fieldIndex).dataType, options)) {
attrValue =>
if (needsQuoting) {
attrValues += quote.incRefCount()
attrValues += escapeJsonString(attrValue)
attrValues += quote.incRefCount()
withResource(Scalar.fromString("")) { emptyString =>
ColumnVector.stringConcatenate(emptyString, emptyString, attrValues.toArray)
}
} else {
attrValue.incRefCount()
}
}
}
// add attribute value, or null literal string if value is null
attrColumns += withResource(jsonAttr) { _ =>
attrColumns += withResource(castToString(cv,
inputSchema(fieldIndex).dataType, options)) { jsonAttr =>
withResource(cv.isNull) { isNull =>
withResource(Scalar.fromString("null")) { nullScalar =>
isNull.ifElse(nullScalar, jsonAttr)
Expand All @@ -1158,18 +1119,18 @@ object GpuCast {
}
}

withResource(Seq("", ",", ":", "\"", "{", "}").safeMap(Scalar.fromString)) {
withResource(Seq("", ",", ":", "{", "}").safeMap(Scalar.fromString)) {
case Seq(emptyScalar, commaScalar, columnScalars@_*) =>
withResource(columnScalars.safeMap(s => ColumnVector.fromScalar(s, rowCount))) {
case Seq(colon, quote, leftBrace, rightBrace) =>
case Seq(colon, leftBrace, rightBrace) =>
val jsonAttrs = withResource(ArrayBuffer.empty[ColumnVector]) { columns =>
// create one column per attribute, which will either be in the form `"name":value` or
// empty string for rows that have null values
if (input.getNumChildren == 1) {
castToJsonAttribute(0, colon, quote)
castToJsonAttribute(0, colon)
} else {
for (i <- 0 until input.getNumChildren) {
columns += castToJsonAttribute(i, colon, quote)
columns += castToJsonAttribute(i, colon)
}
// concatenate the columns into one string
withResource(ColumnVector.stringConcatenate(commaScalar,
Expand All @@ -1195,14 +1156,31 @@ object GpuCast {
}

/**
* Escape quotes and newlines in a string column. Caller is responsible for closing `cv`.
* Add quotes to and escape quotes and newlines in a string column.
* Caller is responsible for closing `cv`.
*/
private def escapeJsonString(cv: ColumnVector): ColumnVector = {
private def escapeAndQuoteJsonString(cv: ColumnView): ColumnVector = {
val rowCount = cv.getRowCount.toInt
val chars = Seq("\r", "\n", "\\", "\"")
val escaped = chars.map(StringEscapeUtils.escapeJava)
withResource(ColumnVector.fromStrings(chars: _*)) { search =>
withResource(ColumnVector.fromStrings(escaped: _*)) { replace =>
cv.stringReplace(search, replace)
withResource(ArrayBuffer.empty[ColumnVector]) { columns =>
withResource(Scalar.fromString("\"")) { quote =>
withResource(ColumnVector.fromScalar(quote, rowCount)) {
quoteScalar =>
columns += quoteScalar.incRefCount()

withResource(ColumnVector.fromStrings(chars: _*)) { search =>
withResource(ColumnVector.fromStrings(escaped: _*)) { replace =>
columns += cv.stringReplace(search, replace)
}
}
columns += quoteScalar.incRefCount()
}
}
withResource(Scalar.fromString("")) { emptyScalar =>
withResource(Scalar.fromNull(DType.STRING)) { nullScalar =>
ColumnVector.stringConcatenate(emptyScalar, nullScalar, columns.toArray)
}
}
}
}
Expand Down
Loading

0 comments on commit a071efe

Please sign in to comment.