Skip to content

Commit

Permalink
Spark 4 parquet_writer_test.py fixes (NVIDIA#11615)
Browse files Browse the repository at this point in the history
* spark 4 parquet writer test initial fixes

Signed-off-by: Ryan Lee <[email protected]>

* change shim loader approach, deprecate Spark 4 preview release 1

Signed-off-by: Ryan Lee <[email protected]>

* extra space

Signed-off-by: Ryan Lee <[email protected]>

---------

Signed-off-by: Ryan Lee <[email protected]>
  • Loading branch information
rwlee authored Oct 22, 2024
1 parent 732b25b commit 8e2e627
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
4 changes: 2 additions & 2 deletions integration_tests/run_pyspark_from_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ else
[[ ! -x "$(command -v zip)" ]] && { echo "fail to find zip command in $PATH"; exit 1; }
PY4J_TMP=("${SPARK_HOME}"/python/lib/py4j-*-src.zip)
PY4J_FILE=${PY4J_TMP[0]}
# PySpark uses ".dev0" for "-SNAPSHOT", ".dev" for "preview"
# PySpark uses ".dev0" for "-SNAPSHOT" and either ".dev" for "preview" or ".devN" for "previewN"
# https://github.com/apache/spark/blob/66f25e314032d562567620806057fcecc8b71f08/dev/create-release/release-build.sh#L267
VERSION_STRING=$(PYTHONPATH=${SPARK_HOME}/python:${PY4J_FILE} python -c \
"import pyspark, re; print(re.sub('\.dev0?$', '', pyspark.__version__))"
"import pyspark, re; print(re.sub('\.dev[012]?$', '', pyspark.__version__))"
)
SCALA_VERSION=`$SPARK_HOME/bin/pyspark --version 2>&1| grep Scala | awk '{split($4,v,"."); printf "%s.%s", v[1], v[2]}'`

Expand Down
32 changes: 18 additions & 14 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@
reader_opt_confs = [original_parquet_file_reader_conf, multithreaded_parquet_file_reader_conf,
coalesce_parquet_file_reader_conf]
parquet_decimal_struct_gen= StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(decimal_gens)])
writer_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'CORRECTED'}
legacy_parquet_datetimeRebaseModeInWrite='spark.sql.parquet.datetimeRebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite'
legacy_parquet_int96RebaseModeInWrite='spark.sql.parquet.int96RebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInWrite'
legacy_parquet_int96RebaseModeInRead='spark.sql.parquet.int96RebaseModeInRead' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInRead'
writer_confs={legacy_parquet_datetimeRebaseModeInWrite: 'CORRECTED',
legacy_parquet_int96RebaseModeInWrite: 'CORRECTED'}

parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, TimestampGen(), binary_gen]
Expand Down Expand Up @@ -158,8 +161,8 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase):
lambda spark, path: unary_op_df(spark, gen).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase,
conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase,
legacy_parquet_int96RebaseModeInWrite: ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_type})


Expand Down Expand Up @@ -285,8 +288,8 @@ def test_write_sql_save_table(spark_tmp_path, parquet_gens, spark_tmp_table_fact

def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, int96_rebase, datetime_rebase, ts_write):
spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write)
spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', datetime_rebase)
spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', int96_rebase) # for spark 310
spark.conf.set(legacy_parquet_datetimeRebaseModeInWrite, datetime_rebase)
spark.conf.set(legacy_parquet_int96RebaseModeInWrite, int96_rebase) # for spark 310
with pytest.raises(Exception) as e_info:
df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get())
assert e_info.match(r".*SparkUpgradeException.*")
Expand Down Expand Up @@ -544,8 +547,8 @@ def generate_map_with_empty_validity(spark, path):
def test_parquet_write_fails_legacy_datetime(spark_tmp_path, data_gen, ts_write, ts_rebase_write):
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write,
'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write}
legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write,
legacy_parquet_int96RebaseModeInWrite: ts_rebase_write}
def writeParquetCatchException(spark, data_gen, data_path):
with pytest.raises(Exception) as e_info:
unary_op_df(spark, data_gen).coalesce(1).write.parquet(data_path)
Expand All @@ -563,12 +566,12 @@ def test_parquet_write_roundtrip_datetime_with_legacy_rebase(spark_tmp_path, dat
ts_rebase_write, ts_rebase_read):
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write,
'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write[0],
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write[1],
legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write[0],
legacy_parquet_int96RebaseModeInWrite: ts_rebase_write[1],
# The rebase modes in read configs should be ignored and overridden by the same
# modes in write configs, which are retrieved from the written files.
'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read[0],
'spark.sql.legacy.parquet.int96RebaseModeInRead': ts_rebase_read[1]}
legacy_parquet_int96RebaseModeInRead: ts_rebase_read[1]}
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, data_gen).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
Expand Down Expand Up @@ -597,7 +600,8 @@ def test_it(spark):
spark.sql("CREATE TABLE {} LOCATION '{}/ctas' AS SELECT * FROM {}".format(
ctas_with_existing_name, data_path, src_name))
except pyspark.sql.utils.AnalysisException as e:
if allow_non_empty or e.desc.find('non-empty directory') == -1:
description = e._desc if is_spark_400_or_later() else e.desc
if allow_non_empty or description.find('non-empty directory') == -1:
raise e
with_gpu_session(test_it, conf)

Expand Down Expand Up @@ -825,8 +829,8 @@ def write_partitions(spark, table_path):
)

def hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, ts_rebase, func):
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase}
conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase,
legacy_parquet_int96RebaseModeInWrite: ts_rebase}

def create_table(spark, path):
tmp_table = spark_tmp_table_factory.get()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.nvidia.spark.rapids.SparkShimVersion

object SparkShimServiceProvider {
val VERSION = SparkShimVersion(4, 0, 0)
val VERSIONNAMES = Seq(s"$VERSION", s"$VERSION-SNAPSHOT", s"$VERSION-preview1")
val VERSIONNAMES = Seq(s"$VERSION", s"$VERSION-SNAPSHOT", s"$VERSION-preview2")
}

class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider {
Expand Down

0 comments on commit 8e2e627

Please sign in to comment.