From 732b25bccd0e6839b574b4b9bf48e85f7268d946 Mon Sep 17 00:00:00 2001 From: MithunR Date: Tue, 22 Oct 2024 13:52:42 -0700 Subject: [PATCH] Fix `collection_ops_test` for [databricks] 14.3 (#11623) * Fix `collection_ops_test` for Databricks 14.3 Fixes #11532. This commit introduces a RapidsErrorUtils shim for Databricks 14.3, to deal with the new error messages thrown for large array/sequences. This should fix the failure in `collection_ops_test.py::test_sequence_too_long_sequence` on Databricks 14.3. Signed-off-by: MithunR --- .../src/main/python/collection_ops_test.py | 8 ++--- .../sql/rapids/shims/RapidsErrorUtils.scala | 10 ++----- .../shims/RapidsErrorUtils341DBPlusBase.scala | 30 +++++++++++++++++++ .../sql/rapids/shims/RapidsErrorUtils.scala | 23 ++++++++++++++ ...equenceSizeExceededLimitErrorBuilder.scala | 1 + 5 files changed, 60 insertions(+), 12 deletions(-) create mode 100644 sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils341DBPlusBase.scala create mode 100644 sql-plugin/src/main/spark350db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala rename sql-plugin/src/main/{spark400 => spark350db}/scala/org/apache/spark/sql/rapids/shims/SequenceSizeExceededLimitErrorBuilder.scala (98%) diff --git a/integration_tests/src/main/python/collection_ops_test.py b/integration_tests/src/main/python/collection_ops_test.py index 813f1a77c94..4aef35b0b59 100644 --- a/integration_tests/src/main/python/collection_ops_test.py +++ b/integration_tests/src/main/python/collection_ops_test.py @@ -18,11 +18,10 @@ from data_gen import * from pyspark.sql.types import * -from spark_session import is_before_spark_400 from string_test import mk_str_gen import pyspark.sql.functions as f import pyspark.sql.utils -from spark_session import with_cpu_session, with_gpu_session, is_before_spark_334, is_before_spark_351, is_before_spark_342, is_before_spark_340, is_spark_350 +from spark_session import with_cpu_session, with_gpu_session, is_before_spark_334, is_before_spark_342, is_before_spark_340, is_databricks_version_or_later, is_spark_350, is_spark_400_or_later from conftest import get_datagen_seed from marks import allow_non_gpu @@ -330,8 +329,9 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen): def test_sequence_too_long_sequence(stop_gen): msg = "Too long sequence" if is_before_spark_334() \ or (not is_before_spark_340() and is_before_spark_342()) \ - or is_spark_350() \ - else "Can't create array" if not is_before_spark_400() \ + or (is_spark_350() and not is_databricks_version_or_later(14, 3)) \ + else "Can't create array" if ((is_databricks_version_or_later(14, 3)) + or is_spark_400_or_later()) \ else "Unsuccessful try to create array with" assert_gpu_and_cpu_error( # To avoid OOM, reduce the row number to 1, it is enough to verify this case. diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index f3aa56d5f4d..78813c8c0b0 100644 --- a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -19,11 +19,5 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims -import org.apache.spark.sql.errors.QueryExecutionErrors - -object RapidsErrorUtils extends RapidsErrorUtilsBase - with RapidsQueryErrorUtils with SequenceSizeTooLongErrorBuilder { - def sqlArrayIndexNotStartAtOneError(): RuntimeException = { - QueryExecutionErrors.invalidIndexOfZeroError(context = null) - } -} +object RapidsErrorUtils extends RapidsErrorUtils341DBPlusBase + with SequenceSizeTooLongErrorBuilder diff --git a/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils341DBPlusBase.scala b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils341DBPlusBase.scala new file mode 100644 index 00000000000..3e668708d03 --- /dev/null +++ b/sql-plugin/src/main/spark341db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils341DBPlusBase.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "341db"} +{"spark": "350db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +import org.apache.spark.sql.errors.QueryExecutionErrors + +trait RapidsErrorUtils341DBPlusBase extends RapidsErrorUtilsBase + with RapidsQueryErrorUtils { + def sqlArrayIndexNotStartAtOneError(): RuntimeException = { + QueryExecutionErrors.invalidIndexOfZeroError(context = null) + } +} diff --git a/sql-plugin/src/main/spark350db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark350db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala new file mode 100644 index 00000000000..518fd2bf133 --- /dev/null +++ b/sql-plugin/src/main/spark350db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "350db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.rapids.shims + +object RapidsErrorUtils extends RapidsErrorUtils341DBPlusBase + with SequenceSizeExceededLimitErrorBuilder diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeExceededLimitErrorBuilder.scala b/sql-plugin/src/main/spark350db/scala/org/apache/spark/sql/rapids/shims/SequenceSizeExceededLimitErrorBuilder.scala similarity index 98% rename from sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeExceededLimitErrorBuilder.scala rename to sql-plugin/src/main/spark350db/scala/org/apache/spark/sql/rapids/shims/SequenceSizeExceededLimitErrorBuilder.scala index 741634aea3f..81ba52f4665 100644 --- a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SequenceSizeExceededLimitErrorBuilder.scala +++ b/sql-plugin/src/main/spark350db/scala/org/apache/spark/sql/rapids/shims/SequenceSizeExceededLimitErrorBuilder.scala @@ -15,6 +15,7 @@ */ /*** spark-rapids-shim-json-lines +{"spark": "350db"} {"spark": "400"} spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims