From 8e2e627a7b80c66f4cde0c83675619ee582b0bc8 Mon Sep 17 00:00:00 2001 From: Ryan Lee Date: Tue, 22 Oct 2024 14:00:27 -0700 Subject: [PATCH] Spark 4 parquet_writer_test.py fixes (#11615) * spark 4 parquet writer test initial fixes Signed-off-by: Ryan Lee * change shim loader approach, deprecate Spark 4 preview release 1 Signed-off-by: Ryan Lee * extra space Signed-off-by: Ryan Lee --------- Signed-off-by: Ryan Lee --- integration_tests/run_pyspark_from_build.sh | 4 +-- .../src/main/python/parquet_write_test.py | 32 +++++++++++-------- .../spark400/SparkShimServiceProvider.scala | 2 +- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh index 22a23349791..9bd72b2ada0 100755 --- a/integration_tests/run_pyspark_from_build.sh +++ b/integration_tests/run_pyspark_from_build.sh @@ -28,10 +28,10 @@ else [[ ! -x "$(command -v zip)" ]] && { echo "fail to find zip command in $PATH"; exit 1; } PY4J_TMP=("${SPARK_HOME}"/python/lib/py4j-*-src.zip) PY4J_FILE=${PY4J_TMP[0]} - # PySpark uses ".dev0" for "-SNAPSHOT", ".dev" for "preview" + # PySpark uses ".dev0" for "-SNAPSHOT" and either ".dev" for "preview" or ".devN" for "previewN" # https://github.com/apache/spark/blob/66f25e314032d562567620806057fcecc8b71f08/dev/create-release/release-build.sh#L267 VERSION_STRING=$(PYTHONPATH=${SPARK_HOME}/python:${PY4J_FILE} python -c \ - "import pyspark, re; print(re.sub('\.dev0?$', '', pyspark.__version__))" + "import pyspark, re; print(re.sub('\.dev[012]?$', '', pyspark.__version__))" ) SCALA_VERSION=`$SPARK_HOME/bin/pyspark --version 2>&1| grep Scala | awk '{split($4,v,"."); printf "%s.%s", v[1], v[2]}'` diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 805a0b8137c..2acf3984f64 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -37,8 +37,11 @@ reader_opt_confs = [original_parquet_file_reader_conf, multithreaded_parquet_file_reader_conf, coalesce_parquet_file_reader_conf] parquet_decimal_struct_gen= StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(decimal_gens)]) -writer_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'CORRECTED'} +legacy_parquet_datetimeRebaseModeInWrite='spark.sql.parquet.datetimeRebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite' +legacy_parquet_int96RebaseModeInWrite='spark.sql.parquet.int96RebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInWrite' +legacy_parquet_int96RebaseModeInRead='spark.sql.parquet.int96RebaseModeInRead' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInRead' +writer_confs={legacy_parquet_datetimeRebaseModeInWrite: 'CORRECTED', + legacy_parquet_int96RebaseModeInWrite: 'CORRECTED'} parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, TimestampGen(), binary_gen] @@ -158,8 +161,8 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase): lambda spark, path: unary_op_df(spark, gen).write.parquet(path), lambda spark, path: spark.read.parquet(path), data_path, - conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase, + conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase, + legacy_parquet_int96RebaseModeInWrite: ts_rebase, 'spark.sql.parquet.outputTimestampType': ts_type}) @@ -285,8 +288,8 @@ def test_write_sql_save_table(spark_tmp_path, parquet_gens, spark_tmp_table_fact def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, int96_rebase, datetime_rebase, ts_write): spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write) - spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', datetime_rebase) - spark.conf.set('spark.sql.legacy.parquet.int96RebaseModeInWrite', int96_rebase) # for spark 310 + spark.conf.set(legacy_parquet_datetimeRebaseModeInWrite, datetime_rebase) + spark.conf.set(legacy_parquet_int96RebaseModeInWrite, int96_rebase) # for spark 310 with pytest.raises(Exception) as e_info: df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get()) assert e_info.match(r".*SparkUpgradeException.*") @@ -544,8 +547,8 @@ def generate_map_with_empty_validity(spark, path): def test_parquet_write_fails_legacy_datetime(spark_tmp_path, data_gen, ts_write, ts_rebase_write): data_path = spark_tmp_path + '/PARQUET_DATA' all_confs = {'spark.sql.parquet.outputTimestampType': ts_write, - 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write} + legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write, + legacy_parquet_int96RebaseModeInWrite: ts_rebase_write} def writeParquetCatchException(spark, data_gen, data_path): with pytest.raises(Exception) as e_info: unary_op_df(spark, data_gen).coalesce(1).write.parquet(data_path) @@ -563,12 +566,12 @@ def test_parquet_write_roundtrip_datetime_with_legacy_rebase(spark_tmp_path, dat ts_rebase_write, ts_rebase_read): data_path = spark_tmp_path + '/PARQUET_DATA' all_confs = {'spark.sql.parquet.outputTimestampType': ts_write, - 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write[0], - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write[1], + legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write[0], + legacy_parquet_int96RebaseModeInWrite: ts_rebase_write[1], # The rebase modes in read configs should be ignored and overridden by the same # modes in write configs, which are retrieved from the written files. 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read[0], - 'spark.sql.legacy.parquet.int96RebaseModeInRead': ts_rebase_read[1]} + legacy_parquet_int96RebaseModeInRead: ts_rebase_read[1]} assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: unary_op_df(spark, data_gen).coalesce(1).write.parquet(path), lambda spark, path: spark.read.parquet(path), @@ -597,7 +600,8 @@ def test_it(spark): spark.sql("CREATE TABLE {} LOCATION '{}/ctas' AS SELECT * FROM {}".format( ctas_with_existing_name, data_path, src_name)) except pyspark.sql.utils.AnalysisException as e: - if allow_non_empty or e.desc.find('non-empty directory') == -1: + description = e._desc if is_spark_400_or_later() else e.desc + if allow_non_empty or description.find('non-empty directory') == -1: raise e with_gpu_session(test_it, conf) @@ -825,8 +829,8 @@ def write_partitions(spark, table_path): ) def hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, ts_rebase, func): - conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase} + conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase, + legacy_parquet_int96RebaseModeInWrite: ts_rebase} def create_table(spark, path): tmp_table = spark_tmp_table_factory.get() diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/spark400/SparkShimServiceProvider.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/spark400/SparkShimServiceProvider.scala index 454515db35e..e432c49ee0a 100644 --- a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/spark400/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/spark400/SparkShimServiceProvider.scala @@ -23,7 +23,7 @@ import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(4, 0, 0) - val VERSIONNAMES = Seq(s"$VERSION", s"$VERSION-SNAPSHOT", s"$VERSION-preview1") + val VERSIONNAMES = Seq(s"$VERSION", s"$VERSION-SNAPSHOT", s"$VERSION-preview2") } class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider {