From 00b2bf98568004347fdd30acf8c5150514d7042c Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 7 Feb 2024 16:52:35 +0800 Subject: [PATCH] chunk-read-impl Signed-off-by: Weichen Xu --- .../scala/org/apache/spark/SparkContext.scala | 20 +++ .../scala/org/apache/spark/SparkEnv.scala | 4 + .../api/python/CachedArrowBatchServer.scala | 147 ++++++++++++++++ .../api/python/PythonWorkerFactory.scala | 18 ++ .../org/apache/spark/executor/Executor.scala | 22 +++ .../apache/spark/internal/config/Python.scala | 8 + .../org/apache/spark/storage/BlockId.scala | 9 + dev/sparktestsupport/modules.py | 1 + python/pyspark/context.py | 9 + python/pyspark/serializers.py | 5 + python/pyspark/sql/chunk.py | 159 +++++++++++++++++ .../pyspark/sql/tests/test_chunk_read_api.py | 124 ++++++++++++++ .../spark/sql/api/python/ChunkReadUtils.scala | 161 ++++++++++++++++++ .../sql/execution/arrow/ArrowConverters.scala | 4 + 14 files changed, 691 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/api/python/CachedArrowBatchServer.scala create mode 100644 python/pyspark/sql/chunk.py create mode 100644 python/pyspark/sql/tests/test_chunk_read_api.py create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/api/python/ChunkReadUtils.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index da37fa83254bc..2d3e5e4bbac4c 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -42,6 +42,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.python.CachedArrowBatchServer import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.errors.SparkCoreErrors @@ -379,6 +380,14 @@ class SparkContext(config: SparkConf) extends Logging { override protected def initialValue(): Properties = new Properties() } + private[spark] def cachedArrowBatchServerPort: Option[Int] = { + _cachedArrowBatchServer.map(_.serverSocket.getLocalPort) + } + + private[spark] def cachedArrowBatchServerSecret: Option[String] = { + _cachedArrowBatchServer.map(_.authHelper.secret) + } + /* ------------------------------------------------------------------------------------- * | Initialization. This code initializes the context in a manner that is exception-safe. | | All internal fields holding state are initialized here, and any error prompts the | @@ -401,6 +410,8 @@ class SparkContext(config: SparkConf) extends Logging { } } + private var _cachedArrowBatchServer: Option[CachedArrowBatchServer] = None + try { _conf = config.clone() _conf.get(SPARK_LOG_LEVEL).foreach { level => @@ -486,6 +497,12 @@ class SparkContext(config: SparkConf) extends Logging { _env = createSparkEnv(_conf, isLocal, listenerBus) SparkEnv.set(_env) + if (SparkEnv.get.conf.get(Python.PYTHON_DATAFRAME_CHUNK_READ_ENABLED)) { + val server = new CachedArrowBatchServer(SparkEnv.get.conf, SparkEnv.get.blockManager) + server.start() + _cachedArrowBatchServer = Some(server) + } + // If running the REPL, register the repl's output dir with the file server. _conf.getOption("spark.repl.class.outputDir").foreach { path => val replUri = _env.rpcEnv.fileServer.addDirectory("/classes", new File(path)) @@ -2333,6 +2350,9 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _progressBar.foreach(_.stop()) } + Utils.tryLogNonFatalError { + _cachedArrowBatchServer.foreach(_.stop()) + } _taskScheduler = null // TODO: Cache.stop()? if (_env != null) { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 94a4debd0263c..ef7901a555cd7 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -99,6 +99,10 @@ class SparkEnv ( private[spark] var executorBackend: Option[ExecutorBackend] = None + private[spark] var cachedArrowBatchServerPort: Option[Int] = None + + private[spark] var cachedArrowBatchServerSecret: Option[String] = None + private[spark] def stop(): Unit = { if (!isStopped) { diff --git a/core/src/main/scala/org/apache/spark/api/python/CachedArrowBatchServer.scala b/core/src/main/scala/org/apache/spark/api/python/CachedArrowBatchServer.scala new file mode 100644 index 0000000000000..2d85cf50214bc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/CachedArrowBatchServer.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream} +import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.nio.charset.StandardCharsets.UTF_8 + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.security.SocketAuthHelper +import org.apache.spark.storage.{ArrowBatchBlockId, BlockId, BlockManager} + + +/** + * A helper class to run cached arrow batch server. + * Cached arrow batch server is for serving chunk data, when user calls + * `pyspark.sql.chunk.readChunk` API, it creates connection to this server, + * then sends chunk_id to the server, then the server responses chunk binary data + * to client. + * The server queries chuck data using chunk_id from block manager, + * for chunk data storage logic, please refer to + * `PersistDataFrameAsArrowBatchChunksPartitionEvaluator` class. + */ +class CachedArrowBatchServer( + val sparkConf: SparkConf, + val blockManager: BlockManager +) extends Logging { + + val authHelper = new SocketAuthHelper(sparkConf) + + val serverSocket = new ServerSocket( + 0, 1, InetAddress.getLoopbackAddress() + ) + + protected def readUtf8(s: Socket): String = { + val din = new DataInputStream(s.getInputStream()) + val len = din.readInt() + val bytes = new Array[Byte](len) + din.readFully(bytes) + new String(bytes, UTF_8) + } + + protected def writeUtf8(str: String, s: Socket): Unit = { + val bytes = str.getBytes(UTF_8) + val dout = new DataOutputStream(s.getOutputStream()) + dout.writeInt(bytes.length) + dout.write(bytes, 0, bytes.length) + dout.flush() + } + + private def handleConnection(sock: Socket): Unit = { + val blockId = BlockId(readUtf8(sock)) + assert(blockId.isInstanceOf[ArrowBatchBlockId]) + + var errMessage = "ok" + var blockDataOpt: Option[Array[Byte]] = None + + try { + val blockResult = blockManager.get[Array[Byte]](blockId) + if (blockResult.isDefined) { + blockDataOpt = Some(blockResult.get.data.next().asInstanceOf[Array[Byte]]) + } else { + errMessage = s"The chunk $blockId data cache does not exist or has been removed" + } + } catch { + case e: Exception => + errMessage = e.getMessage + } + + writeUtf8(errMessage, sock) + + if (blockDataOpt.isDefined) { + val out = new BufferedOutputStream(sock.getOutputStream()) + out.write(blockDataOpt.get) + out.flush() + } + } + + def createConnectionThread(sock: Socket, threadName: String): Unit = { + new Thread(threadName) { + setDaemon(true) + + override def run(): Unit = { + try { + authHelper.authClient(sock) + handleConnection(sock) + } finally { + JavaUtils.closeQuietly(sock) + } + } + }.start() + } + + def start(): Unit = { + logTrace("Creating listening socket") + + new Thread("CachedArrowBatchServer-listener") { + setDaemon(true) + + override def run(): Unit = { + var sock: Socket = null + + var connectionCount = 0 + try { + while (true) { + sock = serverSocket.accept() + connectionCount += 1 + createConnectionThread( + sock, s"CachedArrowBatchServer-connection-$connectionCount" + ) + } + } catch { + case e: SocketException => + // if serverSocket is closed, it means the server is shut down. + // swallow the exception. + if (!serverSocket.isClosed) { + throw e + } + } finally { + logTrace("Closing server") + JavaUtils.closeQuietly(serverSocket) + } + } + }.start() + } + + def stop(): Unit = { + JavaUtils.closeQuietly(serverSocket) + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index e385bc685bfed..bf98ad7609276 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -107,6 +107,22 @@ private[spark] class PythonWorkerFactory( } } + private def appendCachedArrowBatchServerEnvVars( + workerEnv: java.util.Map[String, String] + ): Unit = { + val env = SparkEnv.get + if (env.conf.get(PYTHON_DATAFRAME_CHUNK_READ_ENABLED)) { + workerEnv.put( + "PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_PORT", + env.cachedArrowBatchServerPort.get.toString + ) + workerEnv.put( + "PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_SECRET", + env.cachedArrowBatchServerSecret.get + ) + } + } + /** * Connect to a worker launched through pyspark/daemon.py (by default), which forks python * processes itself to avoid the high cost of forking from Java. This currently only works @@ -170,6 +186,7 @@ private[spark] class PythonWorkerFactory( } val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) + appendCachedArrowBatchServerEnvVars(workerEnv) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") @@ -247,6 +264,7 @@ private[spark] class PythonWorkerFactory( } val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) + appendCachedArrowBatchServerEnvVars(workerEnv) workerEnv.put("PYTHONPATH", pythonPath) workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) if (Utils.preferIPv6) { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index f2a65aab1ba48..5df14c62601c8 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -40,6 +40,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.slf4j.MDC import org.apache.spark._ +import org.apache.spark.api.python.CachedArrowBatchServer import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -347,6 +348,21 @@ private[spark] class Executor( metricsPoller.start() + val cachedArrowBatchServer: Option[CachedArrowBatchServer] = if ( + SparkEnv.get.conf.get(Python.PYTHON_DATAFRAME_CHUNK_READ_ENABLED) + ) { + val server = new CachedArrowBatchServer(env.conf, env.blockManager) + + server.start() + + env.cachedArrowBatchServerPort = Some(server.serverSocket.getLocalPort) + env.cachedArrowBatchServerSecret = Some(server.authHelper.secret) + + Some(server) + } else { + None + } + private[executor] def numRunningTasks: Int = runningTasks.size() /** @@ -418,6 +434,12 @@ private[spark] class Executor( if (!executorShutdown.getAndSet(true)) { ShutdownHookManager.removeShutdownHook(stopHookReference) env.metricsSystem.report() + try { + cachedArrowBatchServer.foreach(_.stop()) + } catch { + case NonFatal(e) => + logWarning("Unable to stop arrow batch server", e) + } try { if (metricsPoller != null) { metricsPoller.stop() diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala index 4f71e7a9e9be9..11bf204f0c836 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala @@ -69,4 +69,12 @@ private[spark] object Python { .version("3.2.0") .booleanConf .createWithDefault(false) + + val PYTHON_DATAFRAME_CHUNK_READ_ENABLED = + ConfigBuilder("spark.python.dataFrameChunkRead.enabled") + .doc("When true, driver and executors launch local cached arrow batch servers for serving " + + "PySpark DataFrame 'pyspark.sql.chunk.read_chunk' API requests.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 585d9a886b473..bf7e108b8e66d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -185,6 +185,11 @@ private[spark] case class TestBlockId(id: String) extends BlockId { override def name: String = "test_" + id } +private[spark] case class ArrowBatchBlockId(id: UUID) extends BlockId { + override def name: String = "arrow_batch_" + id +} + + @DeveloperApi class UnrecognizedBlockId(name: String) extends SparkException(s"Failed to parse $name into a block ID") @@ -215,6 +220,7 @@ object BlockId { val STREAM = "input-([0-9]+)-([0-9]+)".r val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r + val ARROW_BATCH = "arrow_batch_([-A-Fa-f0-9]+)".r val TEST = "test_(.*)".r def apply(name: String): BlockId = name match { @@ -254,8 +260,11 @@ object BlockId { TempLocalBlockId(UUID.fromString(uuid)) case TEMP_SHUFFLE(uuid) => TempShuffleBlockId(UUID.fromString(uuid)) + case ARROW_BATCH(uuid) => + ArrowBatchBlockId(UUID.fromString(uuid)) case TEST(value) => TestBlockId(value) + case _ => throw SparkCoreErrors.unrecognizedBlockIdError(name) } } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index a97e6afdc356d..9b83599593d2c 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -516,6 +516,7 @@ def __hash__(self): "pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations", "pyspark.sql.tests.pandas.test_pandas_udf_window", "pyspark.sql.tests.pandas.test_converter", + "pyspark.sql.tests.test_chunk_read_api", "pyspark.sql.tests.test_pandas_sqlmetrics", "pyspark.sql.tests.test_python_datasource", "pyspark.sql.tests.test_readwriter", diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 19d3608c3825b..d429896a68f3a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -294,9 +294,18 @@ def _do_init( # Create the Java SparkContext through Py4J self._jsc = jsc or self._initialize_context(self._conf._jconf) + # Reset the SparkConf to the one actually used by the SparkContext in JVM. self._conf = SparkConf(_jconf=self._jsc.sc().conf()) + if self.getConf().get("spark.python.dataFrameChunkRead.enabled", "false").lower() == "true": + os.environ["PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_PORT"] = str( + self._jsc.sc().cachedArrowBatchServerPort().get() + ) + os.environ["PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_SECRET"] = ( + self._jsc.sc().cachedArrowBatchServerSecret().get() + ) + # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server assert self._gateway is not None diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d269d55653cfb..a4db38b9955a1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -613,6 +613,11 @@ def write_with_length(obj, stream): stream.write(obj) +def read_with_length(stream): + length = read_int(stream) + return stream.read(length) + + class ChunkedStream: """ diff --git a/python/pyspark/sql/chunk.py b/python/pyspark/sql/chunk.py new file mode 100644 index 0000000000000..f87e8202888c2 --- /dev/null +++ b/python/pyspark/sql/chunk.py @@ -0,0 +1,159 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from collections import namedtuple + +import pyarrow as pa + +from pyspark.rdd import _create_local_socket +from pyspark.sql import DataFrame +from pyspark.sql import SparkSession +from pyspark.serializers import read_with_length, write_with_length +from pyspark.sql.pandas.serializers import ArrowStreamSerializer +from pyspark.sql.pandas.utils import require_minimum_pyarrow_version +from pyspark.errors import PySparkRuntimeError + + +ChunkMeta = namedtuple("ChunkMeta", ["id", "row_count", "byte_count"]) + +require_minimum_pyarrow_version() + + +def persistDataFrameAsChunks(dataframe: DataFrame, max_records_per_chunk: int) -> list[ChunkMeta]: + """Persist and materialize the spark dataframe as chunks, each chunk is an arrow batch. + It tries to persist data to spark worker memory firstly, if memory is not sufficient, + then it fallbacks to persist spilled data to spark worker local disk. + Return the list of tuple (chunk_id, chunk_row_count, chunk_byte_count). + This function is only available when it is called from spark driver process. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + dataframe : DataFrame + the spark DataFrame to be persisted as chunks + max_records_per_chunk : int + an integer representing max records per chunk + + Notes + ----- + This API is a developer API. + """ + spark = dataframe.sparkSession + if spark is None: + raise PySparkRuntimeError("Active spark session is required.") + + sc = spark.sparkContext + if sc.getConf().get("spark.python.dataFrameChunkRead.enabled", "false").lower() != "true": + raise PySparkRuntimeError( + "In order to use 'persistDataFrameAsChunks' API, you must set spark " + "cluster config 'spark.python.dataFrameChunkRead.enabled' to 'true'." + ) + + python_api = sc._jvm.org.apache.spark.sql.api.python # type: ignore[union-attr] + + chunk_meta_list = list( + python_api.ChunkReadUtils.persistDataFrameAsArrowBatchChunks( + dataframe._jdf, max_records_per_chunk + ) + ) + return [ + ChunkMeta(java_chunk_meta.id(), java_chunk_meta.rowCount(), java_chunk_meta.byteCount()) + for java_chunk_meta in chunk_meta_list + ] + + +def unpersistChunks(chunk_ids: list[str]) -> None: + """Unpersist chunks by chunk ids. + This function is only available when it is called from spark driver process. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + chunk_ids : list[str] + A list of chunk ids + + Notes + ----- + This API is a developer API. + """ + sc = SparkSession.getActiveSession().sparkContext # type: ignore[union-attr] + python_api = sc._jvm.org.apache.spark.sql.api.python # type: ignore[union-attr] + python_api.ChunkReadUtils.unpersistChunks(chunk_ids) + + +def readChunk(chunk_id: str) -> pa.Table: + """Read chunk by id, return this chunk as an arrow table. + You can call this function from spark driver, spark python UDF python, + descendant process of spark driver, or descendant process of spark python UDF worker. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + chunk_id : str + a string of chunk id + + Notes + ----- + This API is a developer API. + """ + + if "PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_PORT" not in os.environ: + raise PySparkRuntimeError( + "In order to use dataframe chunk read API, you must set spark " + "cluster config 'spark.python.dataFrameChunkRead.enabled' to 'true'," + "and you must call 'readChunk' API in pyspark driver, pyspark UDF," + "descendant process of pyspark driver, or descendant process of pyspark " + "UDF worker." + ) + + port = int(os.environ["PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_PORT"]) + auth_secret = os.environ["PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_SECRET"] + + sockfile = _create_local_socket((port, auth_secret)) + + try: + write_with_length(chunk_id.encode("utf-8"), sockfile) + sockfile.flush() + err_message = read_with_length(sockfile).decode("utf-8") + + if err_message != "ok": + raise PySparkRuntimeError(f"Read chunk '{chunk_id}' failed (error: {err_message}).") + + arrow_serializer = ArrowStreamSerializer() + + batch_stream = arrow_serializer.load_stream(sockfile) + + arrow_batch = list(batch_stream)[0] + + arrow_batch = pa.RecordBatch.from_arrays( + [ + # This call actually reallocates the array + pa.concat_arrays([array]) + for array in arrow_batch + ], + schema=arrow_batch.schema, + ) + + arrow_table = pa.Table.from_batches([arrow_batch]) + + return arrow_table + finally: + sockfile.close() diff --git a/python/pyspark/sql/tests/test_chunk_read_api.py b/python/pyspark/sql/tests/test_chunk_read_api.py new file mode 100644 index 0000000000000..5603a70bd9827 --- /dev/null +++ b/python/pyspark/sql/tests/test_chunk_read_api.py @@ -0,0 +1,124 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import pickle +import sys +import subprocess +import tempfile +import time +import unittest +from pyspark.errors import PySparkRuntimeError +from pyspark.sql import SparkSession +from pyspark.sql.chunk import persistDataFrameAsChunks, readChunk, unpersistChunks + + +class ChunkReadApiTests(unittest.TestCase): + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + self.spark = ( + SparkSession.builder.master("local-cluster[2, 1, 1024]") + .appName(class_name) + .config("spark.python.dataFrameChunkRead.enabled", "true") + .config("spark.task.maxFailures", "1") + .getOrCreate() + ) + self.sc = self.spark.sparkContext + + self.test_df = self.spark.range(0, 16, 1, 2) + self.chunks = persistDataFrameAsChunks(self.test_df, 3) + self.chunk_ids = [chunk.id for chunk in self.chunks] + self.expected_chunk_data_list = [ + [0, 1, 2], + [3, 4, 5], + [6, 7], + [8, 9, 10], + [11, 12, 13], + [14, 15], + ] + assert len(self.chunks) == len(self.expected_chunk_data_list) + + self.child_proc_test_code = """ +import sys +from pyspark.sql.chunk import readChunk +chunk_ids = sys.argv[1].split(",") +for chunk_id in chunk_ids: + chunk_pd = readChunk(chunk_id).to_pandas() + chunk_pd.to_pickle(f"{chunk_id}.pkl") +""" + + def tearDown(self): + self.spark.stop() + sys.path = self._old_sys_path + + def test_readChunk_in_driver(self): + for i, chunk in enumerate(self.chunks): + chunk_data = list(readChunk(chunk.id).to_pandas().id) + self.assertEqual(chunk_data, self.expected_chunk_data_list[i]) + + def test_readChunk_in_executor(self): + def mapper(chunk_id): + return list(readChunk(chunk_id).to_pandas().id) + + chunk_data_list = self.sc.parallelize(self.chunk_ids, 4).map(mapper).collect() + self.assertEqual(chunk_data_list, self.expected_chunk_data_list) + + def _assert_saved_chunk_data_correct(self, dir_path): + for chunk_id, expected_chunk_data in zip(self.chunk_ids, self.expected_chunk_data_list): + with open(os.path.join(dir_path, f"{chunk_id}.pkl"), "rb") as f: + pdf = pickle.load(f) + chunk_data = list(pdf.id) + self.assertEqual(chunk_data, expected_chunk_data) + + def test_readChunk_in_driver_child_proc(self): + with tempfile.TemporaryDirectory() as tmp_dir: + with open(os.path.join(tmp_dir, "readChunk_and_save.py"), "w") as f: + f.write(self.child_proc_test_code) + + subprocess.check_call( + ["python", "./readChunk_and_save.py", ",".join(self.chunk_ids)], + cwd=tmp_dir, + ) + + self._assert_saved_chunk_data_correct(tmp_dir) + + def test_readChunk_in_udf_worker_child_proc(self): + with tempfile.TemporaryDirectory() as tmp_dir: + with open(os.path.join(tmp_dir, "readChunk_and_save.py"), "w") as f: + f.write(self.child_proc_test_code) + + def mapper(chunk_id): + subprocess.check_call( + ["python", "./readChunk_and_save.py", chunk_id], + cwd=tmp_dir, + ) + return True + + self.sc.parallelize(self.chunk_ids, 4).map(mapper).collect() + self._assert_saved_chunk_data_correct(tmp_dir) + + def test_unpersist_chunk(self): + df = self.spark.range(16) + chunks = persistDataFrameAsChunks(df, 16) + unpersistChunks([chunks[0].id]) + time.sleep(5) # ensure chunk removal completes + with self.assertRaisesRegex( + PySparkRuntimeError, + "cache does not exist or has been removed", + ): + readChunk(chunks[0].id) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/ChunkReadUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/ChunkReadUtils.scala new file mode 100644 index 0000000000000..093c68e291653 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/ChunkReadUtils.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.python + +import java.io.ByteArrayOutputStream + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters} +import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.{ArrowBatchBlockId, BlockId, StorageLevel} + + +case class ChunkMeta( + id: String, + rowCount: Long, + byteCount: Long +) + +/** + * A partition evaluator class to: + * 1. convert spark DataFrame partition rows into arrow batches + * 2. persist arrow batches to block manager using storage level "MEMORY_AND_DISK", + * each arrow batch is persisted as a "Array[Byte]" type block. + */ +class PersistDataFrameAsArrowBatchChunksPartitionEvaluator( + schema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + maxRecordsPerBatch: Long +) extends PartitionEvaluator[InternalRow, ChunkMeta] with Logging { + + def eval(partitionIndex: Int, inputs: Iterator[InternalRow]*): Iterator[ChunkMeta] = { + val blockManager = SparkEnv.get.blockManager + val chunkMetaList = new ArrayBuffer[ChunkMeta]() + + val context = TaskContext.get() + val arrowBatchIter = ArrowConverters.toBatchIterator( + inputs(0), schema, maxRecordsPerBatch, timeZoneId, + errorOnDuplicatedFieldNames, context + ) + + try { + while (arrowBatchIter.hasNext) { + val arrowBatch = arrowBatchIter.next() + val rowCount = arrowBatchIter.getRowCountInLastBatch + + val uuid = java.util.UUID.randomUUID() + val blockId = ArrowBatchBlockId(uuid) + + val out = new ByteArrayOutputStream(32 * 1024 * 1024) + + val batchWriter = + new ArrowBatchStreamWriter(schema, out, timeZoneId, errorOnDuplicatedFieldNames) + + batchWriter.writeBatches(Iterator.single(arrowBatch)) + batchWriter.end() + + val blockData = out.toByteArray + + blockManager.putSingle[Array[Byte]]( + blockId, blockData, StorageLevel.MEMORY_AND_DISK, tellMaster = true + ) + chunkMetaList.append( + ChunkMeta(blockId.toString, rowCount, blockData.length) + ) + } + } catch { + case e: Exception => + // Clean cached chunks + for (chunkMeta <- chunkMetaList) { + try { + blockManager.master.removeBlock(BlockId(chunkMeta.id)) + } catch { + case _: Exception => + logWarning(s"Remove arrow batch block of ID '${chunkMeta.id}' failed.") + } + } + throw e + } + + chunkMetaList.iterator + } +} + +/** + * A partition evaluator factory class to create + * instance of `PersistDataFrameAsArrowBatchChunksPartitionEvaluator`. + */ +class PersistDataFrameAsArrowBatchChunksPartitionEvaluatorFactory( + schema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + maxRecordsPerBatch: Long +) extends PartitionEvaluatorFactory[InternalRow, ChunkMeta] { + + def createEvaluator(): PartitionEvaluator[InternalRow, ChunkMeta] = { + new PersistDataFrameAsArrowBatchChunksPartitionEvaluator( + schema, timeZoneId, errorOnDuplicatedFieldNames, maxRecordsPerBatch + ) + } +} + +object ChunkReadUtils { + + def persistDataFrameAsArrowBatchChunks( + dataFrame: DataFrame, maxRecordsPerBatch: Int + ): Array[ChunkMeta] = { + val sparkSession = SparkSession.getActiveSession.get + + val maxRecordsPerBatchVal = if (maxRecordsPerBatch == -1) { + sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + } else { + maxRecordsPerBatch + } + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + val errorOnDuplicatedFieldNames = + sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" + + dataFrame.queryExecution.toRdd.mapPartitionsWithEvaluator( + new PersistDataFrameAsArrowBatchChunksPartitionEvaluatorFactory( + schema = dataFrame.schema, + timeZoneId = timeZoneId, + errorOnDuplicatedFieldNames = errorOnDuplicatedFieldNames, + maxRecordsPerBatch = maxRecordsPerBatchVal + ) + ).collect() + } + + def unpersistChunks(chunkIds: java.util.List[String]): Unit = { + val blockManagerMaster = SparkEnv.get.blockManager.master + + for (chunkId <- chunkIds.asScala) { + try { + blockManagerMaster.removeBlock(BlockId(chunkId)) + } catch { + case _: Exception => () + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 9e6a99ef9fb28..b5d2ab0b646e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -92,6 +92,9 @@ private[sql] object ArrowConverters extends Logging { private val root = VectorSchemaRoot.create(arrowSchema, allocator) protected val unloader = new VectorUnloader(root) protected val arrowWriter = ArrowWriter.create(root) + private var rowCountInLastBatch: Long = -1L + + def getRowCountInLastBatch: Long = rowCountInLastBatch Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => close() @@ -117,6 +120,7 @@ private[sql] object ArrowConverters extends Logging { val batch = unloader.getRecordBatch() MessageSerializer.serialize(writeChannel, batch) batch.close() + rowCountInLastBatch = rowCount } { arrowWriter.reset() }