Skip to content

Commit

Permalink
Fix collection_ops_test for [databricks] 14.3 (NVIDIA#11623)
Browse files Browse the repository at this point in the history
* Fix `collection_ops_test` for Databricks 14.3

Fixes NVIDIA#11532.

This commit introduces a RapidsErrorUtils shim for Databricks 14.3, to
deal with the new error messages thrown for large array/sequences.

This should fix the failure in `collection_ops_test.py::test_sequence_too_long_sequence`
on Databricks 14.3.

Signed-off-by: MithunR <[email protected]>
  • Loading branch information
mythrocks authored Oct 22, 2024
1 parent b9a1a49 commit 732b25b
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 12 deletions.
8 changes: 4 additions & 4 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
from data_gen import *
from pyspark.sql.types import *

from spark_session import is_before_spark_400
from string_test import mk_str_gen
import pyspark.sql.functions as f
import pyspark.sql.utils
from spark_session import with_cpu_session, with_gpu_session, is_before_spark_334, is_before_spark_351, is_before_spark_342, is_before_spark_340, is_spark_350
from spark_session import with_cpu_session, with_gpu_session, is_before_spark_334, is_before_spark_342, is_before_spark_340, is_databricks_version_or_later, is_spark_350, is_spark_400_or_later
from conftest import get_datagen_seed
from marks import allow_non_gpu

Expand Down Expand Up @@ -330,8 +329,9 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
def test_sequence_too_long_sequence(stop_gen):
msg = "Too long sequence" if is_before_spark_334() \
or (not is_before_spark_340() and is_before_spark_342()) \
or is_spark_350() \
else "Can't create array" if not is_before_spark_400() \
or (is_spark_350() and not is_databricks_version_or_later(14, 3)) \
else "Can't create array" if ((is_databricks_version_or_later(14, 3))
or is_spark_400_or_later()) \
else "Unsuccessful try to create array with"
assert_gpu_and_cpu_error(
# To avoid OOM, reduce the row number to 1, it is enough to verify this case.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,5 @@
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.errors.QueryExecutionErrors

object RapidsErrorUtils extends RapidsErrorUtilsBase
with RapidsQueryErrorUtils with SequenceSizeTooLongErrorBuilder {
def sqlArrayIndexNotStartAtOneError(): RuntimeException = {
QueryExecutionErrors.invalidIndexOfZeroError(context = null)
}
}
object RapidsErrorUtils extends RapidsErrorUtils341DBPlusBase
with SequenceSizeTooLongErrorBuilder
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "341db"}
{"spark": "350db"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.errors.QueryExecutionErrors

trait RapidsErrorUtils341DBPlusBase extends RapidsErrorUtilsBase
with RapidsQueryErrorUtils {
def sqlArrayIndexNotStartAtOneError(): RuntimeException = {
QueryExecutionErrors.invalidIndexOfZeroError(context = null)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "350db"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

object RapidsErrorUtils extends RapidsErrorUtils341DBPlusBase
with SequenceSizeExceededLimitErrorBuilder
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

/*** spark-rapids-shim-json-lines
{"spark": "350db"}
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims
Expand Down

0 comments on commit 732b25b

Please sign in to comment.