diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b2c64a8242de2..3ecfeb5f69aac 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -995,6 +995,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_connect_basic", "pyspark.sql.tests.connect.test_connect_function", "pyspark.sql.tests.connect.test_connect_column", + "pyspark.sql.tests.connect.test_connect_session", "pyspark.sql.tests.connect.test_parity_arrow", "pyspark.sql.tests.connect.test_parity_arrow_python_udf", "pyspark.sql.tests.connect.test_parity_datasources", diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index cb7e286c5371c..b55875b8577da 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -23,19 +23,14 @@ import shutil import string import tempfile -import uuid -from collections import defaultdict from pyspark.errors import ( PySparkAttributeError, PySparkTypeError, - PySparkException, PySparkValueError, - RetriesExceeded, ) from pyspark.errors.exceptions.base import SessionNotSameException from pyspark.sql import SparkSession as PySparkSession, Row -from pyspark.sql.connect.client.retries import RetryPolicy from pyspark.sql.types import ( StructType, StructField, @@ -57,30 +52,25 @@ from pyspark.testing.connectutils import ( should_test_connect, ReusedConnectTestCase, - connect_requirement_message, ) from pyspark.testing.pandasutils import PandasOnSparkTestUtils from pyspark.errors.exceptions.connect import ( AnalysisException, ParseException, SparkConnectException, - SparkUpgradeException, ) if should_test_connect: - import grpc import pandas as pd import numpy as np from pyspark.sql.connect.proto import Expression as ProtoExpression from pyspark.sql.connect.session import SparkSession as RemoteSparkSession - from pyspark.sql.connect.client import DefaultChannelBuilder, ChannelBuilder from pyspark.sql.connect.column import Column from pyspark.sql.connect.readwriter import DataFrameWriterV2 from pyspark.sql.dataframe import DataFrame from pyspark.sql.connect.dataframe import DataFrame as CDataFrame from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF - from pyspark.sql.connect.client.core import Retrying, SparkConnectClient class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils): @@ -3410,435 +3400,6 @@ def test_df_caache(self): self.assertTrue(df.is_cached) -class SparkConnectSessionTests(ReusedConnectTestCase): - def setUp(self) -> None: - self.spark = ( - PySparkSession.builder.config(conf=self.conf()) - .appName(self.__class__.__name__) - .remote("local[4]") - .getOrCreate() - ) - - def tearDown(self): - self.spark.stop() - - def _check_no_active_session_error(self, e: PySparkException): - self.check_error(exception=e, error_class="NO_ACTIVE_SESSION", message_parameters=dict()) - - def test_stop_session(self): - df = self.spark.sql("select 1 as a, 2 as b") - catalog = self.spark.catalog - self.spark.stop() - - # _execute_and_fetch - with self.assertRaises(SparkConnectException) as e: - self.spark.sql("select 1") - self._check_no_active_session_error(e.exception) - - with self.assertRaises(SparkConnectException) as e: - catalog.tableExists("table") - self._check_no_active_session_error(e.exception) - - # _execute - with self.assertRaises(SparkConnectException) as e: - self.spark.udf.register("test_func", lambda x: x + 1) - self._check_no_active_session_error(e.exception) - - # _analyze - with self.assertRaises(SparkConnectException) as e: - df._explain_string(extended=True) - self._check_no_active_session_error(e.exception) - - # Config - with self.assertRaises(SparkConnectException) as e: - self.spark.conf.get("some.conf") - self._check_no_active_session_error(e.exception) - - def test_error_enrichment_message(self): - with self.sql_conf( - { - "spark.sql.connect.enrichError.enabled": True, - "spark.sql.connect.serverStacktrace.enabled": False, - "spark.sql.pyspark.jvmStacktrace.enabled": False, - } - ): - name = "test" * 10000 - with self.assertRaises(AnalysisException) as e: - self.spark.sql("select " + name).collect() - self.assertTrue(name in e.exception._message) - self.assertFalse("JVM stacktrace" in e.exception._message) - - def test_error_enrichment_jvm_stacktrace(self): - with self.sql_conf( - { - "spark.sql.connect.enrichError.enabled": True, - "spark.sql.pyspark.jvmStacktrace.enabled": False, - } - ): - with self.sql_conf({"spark.sql.connect.serverStacktrace.enabled": False}): - with self.assertRaises(SparkUpgradeException) as e: - self.spark.sql( - """select from_json( - '{"d": "02-29"}', 'd date', map('dateFormat', 'MM-dd'))""" - ).collect() - self.assertFalse("JVM stacktrace" in e.exception._message) - - with self.sql_conf({"spark.sql.connect.serverStacktrace.enabled": True}): - with self.assertRaises(SparkUpgradeException) as e: - self.spark.sql( - """select from_json( - '{"d": "02-29"}', 'd date', map('dateFormat', 'MM-dd'))""" - ).collect() - self.assertTrue("JVM stacktrace" in str(e.exception)) - self.assertTrue("org.apache.spark.SparkUpgradeException" in str(e.exception)) - self.assertTrue( - "at org.apache.spark.sql.errors.ExecutionErrors" - ".failToParseDateTimeInNewParserError" in str(e.exception) - ) - self.assertTrue("Caused by: java.time.DateTimeException:" in str(e.exception)) - - def test_not_hitting_netty_header_limit(self): - with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): - with self.assertRaises(AnalysisException): - self.spark.sql("select " + "test" * 1).collect() - - def test_error_stack_trace(self): - with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}): - with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): - with self.assertRaises(AnalysisException) as e: - self.spark.sql("select x").collect() - self.assertTrue("JVM stacktrace" in str(e.exception)) - self.assertIsNotNone(e.exception.getStackTrace()) - self.assertTrue( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) - ) - - with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}): - with self.assertRaises(AnalysisException) as e: - self.spark.sql("select x").collect() - self.assertFalse("JVM stacktrace" in str(e.exception)) - self.assertIsNone(e.exception.getStackTrace()) - self.assertFalse( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) - ) - - # Create a new session with a different stack trace size. - self.spark.stop() - spark = ( - PySparkSession.builder.config(conf=self.conf()) - .config("spark.connect.grpc.maxMetadataSize", 128) - .remote("local[4]") - .getOrCreate() - ) - spark.conf.set("spark.sql.connect.enrichError.enabled", False) - spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", True) - with self.assertRaises(AnalysisException) as e: - spark.sql("select x").collect() - self.assertTrue("JVM stacktrace" in str(e.exception)) - self.assertIsNotNone(e.exception.getStackTrace()) - self.assertFalse( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) - ) - spark.stop() - - def test_can_create_multiple_sessions_to_different_remotes(self): - self.spark.stop() - self.assertIsNotNone(self.spark._client) - # Creates a new remote session. - other = PySparkSession.builder.remote("sc://other.remote:114/").create() - self.assertNotEqual(self.spark, other) - - # Gets currently active session. - same = PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate() - self.assertEqual(other, same) - same.release_session_on_close = False # avoid sending release to dummy connection - same.stop() - - # Make sure the environment is clean. - self.spark.stop() - with self.assertRaises(RuntimeError) as e: - PySparkSession.builder.create() - self.assertIn("Create a new SparkSession is only supported with SparkConnect.", str(e)) - - def test_get_message_parameters_without_enriched_error(self): - with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}): - exception = None - try: - self.spark.sql("""SELECT a""") - except AnalysisException as e: - exception = e - - self.assertIsNotNone(exception) - self.assertEqual(exception.getMessageParameters(), {"objectName": "`a`"}) - - def test_custom_channel_builder(self): - # Access self.spark's DefaultChannelBuilder to reuse same endpoint - endpoint = self.spark._client._builder.endpoint - - class CustomChannelBuilder(ChannelBuilder): - def toChannel(self): - return self._insecure_channel(endpoint) - - session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create() - session.sql("select 1 + 1") - - -class SparkConnectSessionWithOptionsTest(unittest.TestCase): - def setUp(self) -> None: - self.spark = ( - PySparkSession.builder.config("string", "foo") - .config("integer", 1) - .config("boolean", False) - .appName(self.__class__.__name__) - .remote("local[4]") - .getOrCreate() - ) - - def tearDown(self): - self.spark.stop() - - def test_config(self): - # Config - self.assertEqual(self.spark.conf.get("string"), "foo") - self.assertEqual(self.spark.conf.get("boolean"), "false") - self.assertEqual(self.spark.conf.get("integer"), "1") - - -class TestError(grpc.RpcError, Exception): - def __init__(self, code: grpc.StatusCode): - self._code = code - - def code(self): - return self._code - - -class TestPolicy(RetryPolicy): - # Put a small value for initial backoff so that tests don't spend - # Time waiting - def __init__(self, initial_backoff=10, **kwargs): - super().__init__(initial_backoff=initial_backoff, **kwargs) - - def can_retry(self, exception: BaseException): - return isinstance(exception, TestError) - - -class TestPolicySpecificError(TestPolicy): - def __init__(self, specific_code: grpc.StatusCode, **kwargs): - super().__init__(**kwargs) - self.specific_code = specific_code - - def can_retry(self, exception: BaseException): - return exception.code() == self.specific_code - - -@unittest.skipIf(not should_test_connect, connect_requirement_message) -class RetryTests(unittest.TestCase): - def setUp(self) -> None: - self.call_wrap = defaultdict(int) - - def stub(self, retries, code): - self.call_wrap["attempts"] += 1 - if self.call_wrap["attempts"] < retries: - self.call_wrap["raised"] += 1 - raise TestError(code) - - def test_simple(self): - # Check that max_retries 1 is only one retry so two attempts. - for attempt in Retrying(TestPolicy(max_retries=1)): - with attempt: - self.stub(2, grpc.StatusCode.INTERNAL) - - self.assertEqual(2, self.call_wrap["attempts"]) - self.assertEqual(1, self.call_wrap["raised"]) - - def test_below_limit(self): - # Check that if we have less than 4 retries all is ok. - for attempt in Retrying(TestPolicy(max_retries=4)): - with attempt: - self.stub(2, grpc.StatusCode.INTERNAL) - - self.assertLess(self.call_wrap["attempts"], 4) - self.assertEqual(self.call_wrap["raised"], 1) - - def test_exceed_retries(self): - # Exceed the retries. - with self.assertRaises(RetriesExceeded): - for attempt in Retrying(TestPolicy(max_retries=2)): - with attempt: - self.stub(5, grpc.StatusCode.INTERNAL) - - self.assertLess(self.call_wrap["attempts"], 5) - self.assertEqual(self.call_wrap["raised"], 3) - - def test_throw_not_retriable_error(self): - with self.assertRaises(ValueError): - for attempt in Retrying(TestPolicy(max_retries=2)): - with attempt: - raise ValueError - - def test_specific_exception(self): - # Check that only specific exceptions are retried. - # Check that if we have less than 4 retries all is ok. - policy = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.UNAVAILABLE) - - for attempt in Retrying(policy): - with attempt: - self.stub(2, grpc.StatusCode.UNAVAILABLE) - - self.assertLess(self.call_wrap["attempts"], 4) - self.assertEqual(self.call_wrap["raised"], 1) - - def test_specific_exception_exceed_retries(self): - # Exceed the retries. - policy = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.UNAVAILABLE) - with self.assertRaises(RetriesExceeded): - for attempt in Retrying(policy): - with attempt: - self.stub(5, grpc.StatusCode.UNAVAILABLE) - - self.assertLess(self.call_wrap["attempts"], 4) - self.assertEqual(self.call_wrap["raised"], 3) - - def test_rejected_by_policy(self): - # Test that another error is always thrown. - policy = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.UNAVAILABLE) - - with self.assertRaises(TestError): - for attempt in Retrying(policy): - with attempt: - self.stub(5, grpc.StatusCode.INTERNAL) - - self.assertEqual(self.call_wrap["attempts"], 1) - self.assertEqual(self.call_wrap["raised"], 1) - - def test_multiple_policies(self): - policy1 = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.UNAVAILABLE) - policy2 = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.INTERNAL) - - # Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors - - error_suply = iter([grpc.StatusCode.UNAVAILABLE] * 2 + [grpc.StatusCode.INTERNAL] * 4) - - for attempt in Retrying([policy1, policy2]): - with attempt: - error = next(error_suply, None) - if error: - raise TestError(error) - - self.assertEqual(next(error_suply, None), None) - - def test_multiple_policies_exceed(self): - policy1 = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.INTERNAL) - policy2 = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.INTERNAL) - - with self.assertRaises(RetriesExceeded): - for attempt in Retrying([policy1, policy2]): - with attempt: - self.stub(10, grpc.StatusCode.INTERNAL) - - self.assertEqual(self.call_wrap["attempts"], 7) - self.assertEqual(self.call_wrap["raised"], 7) - - -@unittest.skipIf(not should_test_connect, connect_requirement_message) -class ChannelBuilderTests(unittest.TestCase): - def test_invalid_connection_strings(self): - invalid = [ - "scc://host:12", - "http://host", - "sc:/host:1234/path", - "sc://host/path", - "sc://host/;parm1;param2", - ] - for i in invalid: - self.assertRaises(PySparkValueError, DefaultChannelBuilder, i) - - def test_sensible_defaults(self): - chan = DefaultChannelBuilder("sc://host") - self.assertFalse(chan.secure, "Default URL is not secure") - - chan = DefaultChannelBuilder("sc://host/;token=abcs") - self.assertTrue(chan.secure, "specifying a token must set the channel to secure") - self.assertRegex( - chan.userAgent, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$" - ) - chan = DefaultChannelBuilder("sc://host/;use_ssl=abcs") - self.assertFalse(chan.secure, "Garbage in, false out") - - def test_user_agent(self): - chan = DefaultChannelBuilder("sc://host/;user_agent=Agent123%20%2F3.4") - self.assertIn("Agent123 /3.4", chan.userAgent) - - def test_user_agent_len(self): - user_agent = "x" * 2049 - chan = DefaultChannelBuilder(f"sc://host/;user_agent={user_agent}") - with self.assertRaises(SparkConnectException) as err: - chan.userAgent - self.assertRegex(err.exception._message, "'user_agent' parameter should not exceed") - - user_agent = "%C3%A4" * 341 # "%C3%A4" -> "ä"; (341 * 6 = 2046) < 2048 - expected = "ä" * 341 - chan = DefaultChannelBuilder(f"sc://host/;user_agent={user_agent}") - self.assertIn(expected, chan.userAgent) - - def test_valid_channel_creation(self): - chan = DefaultChannelBuilder("sc://host").toChannel() - self.assertIsInstance(chan, grpc.Channel) - - # Sets up a channel without tokens because ssl is not used. - chan = DefaultChannelBuilder("sc://host/;use_ssl=true;token=abc").toChannel() - self.assertIsInstance(chan, grpc.Channel) - - chan = DefaultChannelBuilder("sc://host/;use_ssl=true").toChannel() - self.assertIsInstance(chan, grpc.Channel) - - def test_channel_properties(self): - chan = DefaultChannelBuilder( - "sc://host/;use_ssl=true;token=abc;user_agent=foo;param1=120%2021" - ) - self.assertEqual("host:15002", chan.endpoint) - self.assertIn("foo", chan.userAgent.split(" ")) - self.assertEqual(True, chan.secure) - self.assertEqual("120 21", chan.get("param1")) - - def test_metadata(self): - chan = DefaultChannelBuilder( - "sc://host/;use_ssl=true;token=abc;param1=120%2021;x-my-header=abcd" - ) - md = chan.metadata() - self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md) - - def test_metadata(self): - id = str(uuid.uuid4()) - chan = DefaultChannelBuilder(f"sc://host/;session_id={id}") - self.assertEqual(id, chan.session_id) - - chan = DefaultChannelBuilder( - f"sc://host/;session_id={id};user_agent=acbd;token=abcd;use_ssl=true" - ) - md = chan.metadata() - for kv in md: - self.assertNotIn( - kv[0], - [ - ChannelBuilder.PARAM_SESSION_ID, - ChannelBuilder.PARAM_TOKEN, - ChannelBuilder.PARAM_USER_ID, - ChannelBuilder.PARAM_USER_AGENT, - ChannelBuilder.PARAM_USE_SSL, - ], - "Metadata must not contain fixed params", - ) - - with self.assertRaises(ValueError) as ve: - chan = DefaultChannelBuilder("sc://host/;session_id=abcd") - SparkConnectClient(chan) - self.assertIn("Parameter value session_id must be a valid UUID format", str(ve.exception)) - - chan = DefaultChannelBuilder("sc://host/") - self.assertIsNone(chan.session_id) - - if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py new file mode 100644 index 0000000000000..9dc4d2ee9e497 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -0,0 +1,488 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import uuid +from collections import defaultdict + + +from pyspark.errors import ( + PySparkException, + PySparkValueError, + RetriesExceeded, +) +from pyspark.sql import SparkSession as PySparkSession +from pyspark.sql.connect.client.retries import RetryPolicy + +from pyspark.testing.connectutils import ( + should_test_connect, + ReusedConnectTestCase, + connect_requirement_message, +) +from pyspark.errors.exceptions.connect import ( + AnalysisException, + SparkConnectException, + SparkUpgradeException, +) + +if should_test_connect: + import grpc + from pyspark.sql.connect.session import SparkSession as RemoteSparkSession + from pyspark.sql.connect.client import DefaultChannelBuilder, ChannelBuilder + from pyspark.sql.connect.client.core import Retrying, SparkConnectClient + + +class SparkConnectSessionTests(ReusedConnectTestCase): + def setUp(self) -> None: + self.spark = ( + PySparkSession.builder.config(conf=self.conf()) + .appName(self.__class__.__name__) + .remote("local[4]") + .getOrCreate() + ) + + def tearDown(self): + self.spark.stop() + + def _check_no_active_session_error(self, e: PySparkException): + self.check_error(exception=e, error_class="NO_ACTIVE_SESSION", message_parameters=dict()) + + def test_stop_session(self): + df = self.spark.sql("select 1 as a, 2 as b") + catalog = self.spark.catalog + self.spark.stop() + + # _execute_and_fetch + with self.assertRaises(SparkConnectException) as e: + self.spark.sql("select 1") + self._check_no_active_session_error(e.exception) + + with self.assertRaises(SparkConnectException) as e: + catalog.tableExists("table") + self._check_no_active_session_error(e.exception) + + # _execute + with self.assertRaises(SparkConnectException) as e: + self.spark.udf.register("test_func", lambda x: x + 1) + self._check_no_active_session_error(e.exception) + + # _analyze + with self.assertRaises(SparkConnectException) as e: + df._explain_string(extended=True) + self._check_no_active_session_error(e.exception) + + # Config + with self.assertRaises(SparkConnectException) as e: + self.spark.conf.get("some.conf") + self._check_no_active_session_error(e.exception) + + def test_error_enrichment_message(self): + with self.sql_conf( + { + "spark.sql.connect.enrichError.enabled": True, + "spark.sql.connect.serverStacktrace.enabled": False, + "spark.sql.pyspark.jvmStacktrace.enabled": False, + } + ): + name = "test" * 10000 + with self.assertRaises(AnalysisException) as e: + self.spark.sql("select " + name).collect() + self.assertTrue(name in e.exception._message) + self.assertFalse("JVM stacktrace" in e.exception._message) + + def test_error_enrichment_jvm_stacktrace(self): + with self.sql_conf( + { + "spark.sql.connect.enrichError.enabled": True, + "spark.sql.pyspark.jvmStacktrace.enabled": False, + } + ): + with self.sql_conf({"spark.sql.connect.serverStacktrace.enabled": False}): + with self.assertRaises(SparkUpgradeException) as e: + self.spark.sql( + """select from_json( + '{"d": "02-29"}', 'd date', map('dateFormat', 'MM-dd'))""" + ).collect() + self.assertFalse("JVM stacktrace" in e.exception._message) + + with self.sql_conf({"spark.sql.connect.serverStacktrace.enabled": True}): + with self.assertRaises(SparkUpgradeException) as e: + self.spark.sql( + """select from_json( + '{"d": "02-29"}', 'd date', map('dateFormat', 'MM-dd'))""" + ).collect() + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertTrue("org.apache.spark.SparkUpgradeException" in str(e.exception)) + self.assertTrue( + "at org.apache.spark.sql.errors.ExecutionErrors" + ".failToParseDateTimeInNewParserError" in str(e.exception) + ) + self.assertTrue("Caused by: java.time.DateTimeException:" in str(e.exception)) + + def test_not_hitting_netty_header_limit(self): + with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): + with self.assertRaises(AnalysisException): + self.spark.sql("select " + "test" * 1).collect() + + def test_error_stack_trace(self): + with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}): + with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): + with self.assertRaises(AnalysisException) as e: + self.spark.sql("select x").collect() + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertIsNotNone(e.exception.getStackTrace()) + self.assertTrue( + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) + ) + + with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}): + with self.assertRaises(AnalysisException) as e: + self.spark.sql("select x").collect() + self.assertFalse("JVM stacktrace" in str(e.exception)) + self.assertIsNone(e.exception.getStackTrace()) + self.assertFalse( + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) + ) + + # Create a new session with a different stack trace size. + self.spark.stop() + spark = ( + PySparkSession.builder.config(conf=self.conf()) + .config("spark.connect.grpc.maxMetadataSize", 128) + .remote("local[4]") + .getOrCreate() + ) + spark.conf.set("spark.sql.connect.enrichError.enabled", False) + spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", True) + with self.assertRaises(AnalysisException) as e: + spark.sql("select x").collect() + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertIsNotNone(e.exception.getStackTrace()) + self.assertFalse( + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) + ) + spark.stop() + + def test_can_create_multiple_sessions_to_different_remotes(self): + self.spark.stop() + self.assertIsNotNone(self.spark._client) + # Creates a new remote session. + other = PySparkSession.builder.remote("sc://other.remote:114/").create() + self.assertNotEqual(self.spark, other) + + # Gets currently active session. + same = PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate() + self.assertEqual(other, same) + same.release_session_on_close = False # avoid sending release to dummy connection + same.stop() + + # Make sure the environment is clean. + self.spark.stop() + with self.assertRaises(RuntimeError) as e: + PySparkSession.builder.create() + self.assertIn("Create a new SparkSession is only supported with SparkConnect.", str(e)) + + def test_get_message_parameters_without_enriched_error(self): + with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}): + exception = None + try: + self.spark.sql("""SELECT a""") + except AnalysisException as e: + exception = e + + self.assertIsNotNone(exception) + self.assertEqual(exception.getMessageParameters(), {"objectName": "`a`"}) + + def test_custom_channel_builder(self): + # Access self.spark's DefaultChannelBuilder to reuse same endpoint + endpoint = self.spark._client._builder.endpoint + + class CustomChannelBuilder(ChannelBuilder): + def toChannel(self): + return self._insecure_channel(endpoint) + + session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create() + session.sql("select 1 + 1") + + +class SparkConnectSessionWithOptionsTest(unittest.TestCase): + def setUp(self) -> None: + self.spark = ( + PySparkSession.builder.config("string", "foo") + .config("integer", 1) + .config("boolean", False) + .appName(self.__class__.__name__) + .remote("local[4]") + .getOrCreate() + ) + + def tearDown(self): + self.spark.stop() + + def test_config(self): + # Config + self.assertEqual(self.spark.conf.get("string"), "foo") + self.assertEqual(self.spark.conf.get("boolean"), "false") + self.assertEqual(self.spark.conf.get("integer"), "1") + + +class TestError(grpc.RpcError, Exception): + def __init__(self, code: grpc.StatusCode): + self._code = code + + def code(self): + return self._code + + +class TestPolicy(RetryPolicy): + # Put a small value for initial backoff so that tests don't spend + # Time waiting + def __init__(self, initial_backoff=10, **kwargs): + super().__init__(initial_backoff=initial_backoff, **kwargs) + + def can_retry(self, exception: BaseException): + return isinstance(exception, TestError) + + +class TestPolicySpecificError(TestPolicy): + def __init__(self, specific_code: grpc.StatusCode, **kwargs): + super().__init__(**kwargs) + self.specific_code = specific_code + + def can_retry(self, exception: BaseException): + return exception.code() == self.specific_code + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class RetryTests(unittest.TestCase): + def setUp(self) -> None: + self.call_wrap = defaultdict(int) + + def stub(self, retries, code): + self.call_wrap["attempts"] += 1 + if self.call_wrap["attempts"] < retries: + self.call_wrap["raised"] += 1 + raise TestError(code) + + def test_simple(self): + # Check that max_retries 1 is only one retry so two attempts. + for attempt in Retrying(TestPolicy(max_retries=1)): + with attempt: + self.stub(2, grpc.StatusCode.INTERNAL) + + self.assertEqual(2, self.call_wrap["attempts"]) + self.assertEqual(1, self.call_wrap["raised"]) + + def test_below_limit(self): + # Check that if we have less than 4 retries all is ok. + for attempt in Retrying(TestPolicy(max_retries=4)): + with attempt: + self.stub(2, grpc.StatusCode.INTERNAL) + + self.assertLess(self.call_wrap["attempts"], 4) + self.assertEqual(self.call_wrap["raised"], 1) + + def test_exceed_retries(self): + # Exceed the retries. + with self.assertRaises(RetriesExceeded): + for attempt in Retrying(TestPolicy(max_retries=2)): + with attempt: + self.stub(5, grpc.StatusCode.INTERNAL) + + self.assertLess(self.call_wrap["attempts"], 5) + self.assertEqual(self.call_wrap["raised"], 3) + + def test_throw_not_retriable_error(self): + with self.assertRaises(ValueError): + for attempt in Retrying(TestPolicy(max_retries=2)): + with attempt: + raise ValueError + + def test_specific_exception(self): + # Check that only specific exceptions are retried. + # Check that if we have less than 4 retries all is ok. + policy = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.UNAVAILABLE) + + for attempt in Retrying(policy): + with attempt: + self.stub(2, grpc.StatusCode.UNAVAILABLE) + + self.assertLess(self.call_wrap["attempts"], 4) + self.assertEqual(self.call_wrap["raised"], 1) + + def test_specific_exception_exceed_retries(self): + # Exceed the retries. + policy = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.UNAVAILABLE) + with self.assertRaises(RetriesExceeded): + for attempt in Retrying(policy): + with attempt: + self.stub(5, grpc.StatusCode.UNAVAILABLE) + + self.assertLess(self.call_wrap["attempts"], 4) + self.assertEqual(self.call_wrap["raised"], 3) + + def test_rejected_by_policy(self): + # Test that another error is always thrown. + policy = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.UNAVAILABLE) + + with self.assertRaises(TestError): + for attempt in Retrying(policy): + with attempt: + self.stub(5, grpc.StatusCode.INTERNAL) + + self.assertEqual(self.call_wrap["attempts"], 1) + self.assertEqual(self.call_wrap["raised"], 1) + + def test_multiple_policies(self): + policy1 = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.UNAVAILABLE) + policy2 = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.INTERNAL) + + # Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors + + error_suply = iter([grpc.StatusCode.UNAVAILABLE] * 2 + [grpc.StatusCode.INTERNAL] * 4) + + for attempt in Retrying([policy1, policy2]): + with attempt: + error = next(error_suply, None) + if error: + raise TestError(error) + + self.assertEqual(next(error_suply, None), None) + + def test_multiple_policies_exceed(self): + policy1 = TestPolicySpecificError(max_retries=2, specific_code=grpc.StatusCode.INTERNAL) + policy2 = TestPolicySpecificError(max_retries=4, specific_code=grpc.StatusCode.INTERNAL) + + with self.assertRaises(RetriesExceeded): + for attempt in Retrying([policy1, policy2]): + with attempt: + self.stub(10, grpc.StatusCode.INTERNAL) + + self.assertEqual(self.call_wrap["attempts"], 7) + self.assertEqual(self.call_wrap["raised"], 7) + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class ChannelBuilderTests(unittest.TestCase): + def test_invalid_connection_strings(self): + invalid = [ + "scc://host:12", + "http://host", + "sc:/host:1234/path", + "sc://host/path", + "sc://host/;parm1;param2", + ] + for i in invalid: + self.assertRaises(PySparkValueError, DefaultChannelBuilder, i) + + def test_sensible_defaults(self): + chan = DefaultChannelBuilder("sc://host") + self.assertFalse(chan.secure, "Default URL is not secure") + + chan = DefaultChannelBuilder("sc://host/;token=abcs") + self.assertTrue(chan.secure, "specifying a token must set the channel to secure") + self.assertRegex( + chan.userAgent, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$" + ) + chan = DefaultChannelBuilder("sc://host/;use_ssl=abcs") + self.assertFalse(chan.secure, "Garbage in, false out") + + def test_user_agent(self): + chan = DefaultChannelBuilder("sc://host/;user_agent=Agent123%20%2F3.4") + self.assertIn("Agent123 /3.4", chan.userAgent) + + def test_user_agent_len(self): + user_agent = "x" * 2049 + chan = DefaultChannelBuilder(f"sc://host/;user_agent={user_agent}") + with self.assertRaises(SparkConnectException) as err: + chan.userAgent + self.assertRegex(err.exception._message, "'user_agent' parameter should not exceed") + + user_agent = "%C3%A4" * 341 # "%C3%A4" -> "ä"; (341 * 6 = 2046) < 2048 + expected = "ä" * 341 + chan = DefaultChannelBuilder(f"sc://host/;user_agent={user_agent}") + self.assertIn(expected, chan.userAgent) + + def test_valid_channel_creation(self): + chan = DefaultChannelBuilder("sc://host").toChannel() + self.assertIsInstance(chan, grpc.Channel) + + # Sets up a channel without tokens because ssl is not used. + chan = DefaultChannelBuilder("sc://host/;use_ssl=true;token=abc").toChannel() + self.assertIsInstance(chan, grpc.Channel) + + chan = DefaultChannelBuilder("sc://host/;use_ssl=true").toChannel() + self.assertIsInstance(chan, grpc.Channel) + + def test_channel_properties(self): + chan = DefaultChannelBuilder( + "sc://host/;use_ssl=true;token=abc;user_agent=foo;param1=120%2021" + ) + self.assertEqual("host:15002", chan.endpoint) + self.assertIn("foo", chan.userAgent.split(" ")) + self.assertEqual(True, chan.secure) + self.assertEqual("120 21", chan.get("param1")) + + def test_metadata(self): + chan = DefaultChannelBuilder( + "sc://host/;use_ssl=true;token=abc;param1=120%2021;x-my-header=abcd" + ) + md = chan.metadata() + self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md) + + def test_metadata(self): + id = str(uuid.uuid4()) + chan = DefaultChannelBuilder(f"sc://host/;session_id={id}") + self.assertEqual(id, chan.session_id) + + chan = DefaultChannelBuilder( + f"sc://host/;session_id={id};user_agent=acbd;token=abcd;use_ssl=true" + ) + md = chan.metadata() + for kv in md: + self.assertNotIn( + kv[0], + [ + ChannelBuilder.PARAM_SESSION_ID, + ChannelBuilder.PARAM_TOKEN, + ChannelBuilder.PARAM_USER_ID, + ChannelBuilder.PARAM_USER_AGENT, + ChannelBuilder.PARAM_USE_SSL, + ], + "Metadata must not contain fixed params", + ) + + with self.assertRaises(ValueError) as ve: + chan = DefaultChannelBuilder("sc://host/;session_id=abcd") + SparkConnectClient(chan) + self.assertIn("Parameter value session_id must be a valid UUID format", str(ve.exception)) + + chan = DefaultChannelBuilder("sc://host/") + self.assertIsNone(chan.session_id) + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_connect_session import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + + unittest.main(testRunner=testRunner, verbosity=2)