Skip to content

Commit

Permalink
chunk-read-impl
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Feb 7, 2024
1 parent b74b159 commit 00b2bf9
Show file tree
Hide file tree
Showing 14 changed files with 691 additions and 0 deletions.
20 changes: 20 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.python.CachedArrowBatchServer
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.errors.SparkCoreErrors
Expand Down Expand Up @@ -379,6 +380,14 @@ class SparkContext(config: SparkConf) extends Logging {
override protected def initialValue(): Properties = new Properties()
}

private[spark] def cachedArrowBatchServerPort: Option[Int] = {
_cachedArrowBatchServer.map(_.serverSocket.getLocalPort)
}

private[spark] def cachedArrowBatchServerSecret: Option[String] = {
_cachedArrowBatchServer.map(_.authHelper.secret)
}

/* ------------------------------------------------------------------------------------- *
| Initialization. This code initializes the context in a manner that is exception-safe. |
| All internal fields holding state are initialized here, and any error prompts the |
Expand All @@ -401,6 +410,8 @@ class SparkContext(config: SparkConf) extends Logging {
}
}

private var _cachedArrowBatchServer: Option[CachedArrowBatchServer] = None

try {
_conf = config.clone()
_conf.get(SPARK_LOG_LEVEL).foreach { level =>
Expand Down Expand Up @@ -486,6 +497,12 @@ class SparkContext(config: SparkConf) extends Logging {
_env = createSparkEnv(_conf, isLocal, listenerBus)
SparkEnv.set(_env)

if (SparkEnv.get.conf.get(Python.PYTHON_DATAFRAME_CHUNK_READ_ENABLED)) {
val server = new CachedArrowBatchServer(SparkEnv.get.conf, SparkEnv.get.blockManager)
server.start()
_cachedArrowBatchServer = Some(server)
}

// If running the REPL, register the repl's output dir with the file server.
_conf.getOption("spark.repl.class.outputDir").foreach { path =>
val replUri = _env.rpcEnv.fileServer.addDirectory("/classes", new File(path))
Expand Down Expand Up @@ -2333,6 +2350,9 @@ class SparkContext(config: SparkConf) extends Logging {
Utils.tryLogNonFatalError {
_progressBar.foreach(_.stop())
}
Utils.tryLogNonFatalError {
_cachedArrowBatchServer.foreach(_.stop())
}
_taskScheduler = null
// TODO: Cache.stop()?
if (_env != null) {
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ class SparkEnv (

private[spark] var executorBackend: Option[ExecutorBackend] = None

private[spark] var cachedArrowBatchServerPort: Option[Int] = None

private[spark] var cachedArrowBatchServerSecret: Option[String] = None

private[spark] def stop(): Unit = {

if (!isStopped) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.python

import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream}
import java.net.{InetAddress, ServerSocket, Socket, SocketException}
import java.nio.charset.StandardCharsets.UTF_8

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.storage.{ArrowBatchBlockId, BlockId, BlockManager}


/**
* A helper class to run cached arrow batch server.
* Cached arrow batch server is for serving chunk data, when user calls
* `pyspark.sql.chunk.readChunk` API, it creates connection to this server,
* then sends chunk_id to the server, then the server responses chunk binary data
* to client.
* The server queries chuck data using chunk_id from block manager,
* for chunk data storage logic, please refer to
* `PersistDataFrameAsArrowBatchChunksPartitionEvaluator` class.
*/
class CachedArrowBatchServer(
val sparkConf: SparkConf,
val blockManager: BlockManager
) extends Logging {

val authHelper = new SocketAuthHelper(sparkConf)

val serverSocket = new ServerSocket(
0, 1, InetAddress.getLoopbackAddress()
)

protected def readUtf8(s: Socket): String = {
val din = new DataInputStream(s.getInputStream())
val len = din.readInt()
val bytes = new Array[Byte](len)
din.readFully(bytes)
new String(bytes, UTF_8)
}

protected def writeUtf8(str: String, s: Socket): Unit = {
val bytes = str.getBytes(UTF_8)
val dout = new DataOutputStream(s.getOutputStream())
dout.writeInt(bytes.length)
dout.write(bytes, 0, bytes.length)
dout.flush()
}

private def handleConnection(sock: Socket): Unit = {
val blockId = BlockId(readUtf8(sock))
assert(blockId.isInstanceOf[ArrowBatchBlockId])

var errMessage = "ok"
var blockDataOpt: Option[Array[Byte]] = None

try {
val blockResult = blockManager.get[Array[Byte]](blockId)
if (blockResult.isDefined) {
blockDataOpt = Some(blockResult.get.data.next().asInstanceOf[Array[Byte]])
} else {
errMessage = s"The chunk $blockId data cache does not exist or has been removed"
}
} catch {
case e: Exception =>
errMessage = e.getMessage
}

writeUtf8(errMessage, sock)

if (blockDataOpt.isDefined) {
val out = new BufferedOutputStream(sock.getOutputStream())
out.write(blockDataOpt.get)
out.flush()
}
}

def createConnectionThread(sock: Socket, threadName: String): Unit = {
new Thread(threadName) {
setDaemon(true)

override def run(): Unit = {
try {
authHelper.authClient(sock)
handleConnection(sock)
} finally {
JavaUtils.closeQuietly(sock)
}
}
}.start()
}

def start(): Unit = {
logTrace("Creating listening socket")

new Thread("CachedArrowBatchServer-listener") {
setDaemon(true)

override def run(): Unit = {
var sock: Socket = null

var connectionCount = 0
try {
while (true) {
sock = serverSocket.accept()
connectionCount += 1
createConnectionThread(
sock, s"CachedArrowBatchServer-connection-$connectionCount"
)
}
} catch {
case e: SocketException =>
// if serverSocket is closed, it means the server is shut down.
// swallow the exception.
if (!serverSocket.isClosed) {
throw e
}
} finally {
logTrace("Closing server")
JavaUtils.closeQuietly(serverSocket)
}
}
}.start()
}

def stop(): Unit = {
JavaUtils.closeQuietly(serverSocket)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ private[spark] class PythonWorkerFactory(
}
}

private def appendCachedArrowBatchServerEnvVars(
workerEnv: java.util.Map[String, String]
): Unit = {
val env = SparkEnv.get
if (env.conf.get(PYTHON_DATAFRAME_CHUNK_READ_ENABLED)) {
workerEnv.put(
"PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_PORT",
env.cachedArrowBatchServerPort.get.toString
)
workerEnv.put(
"PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_SECRET",
env.cachedArrowBatchServerSecret.get
)
}
}

/**
* Connect to a worker launched through pyspark/daemon.py (by default), which forks python
* processes itself to avoid the high cost of forking from Java. This currently only works
Expand Down Expand Up @@ -170,6 +186,7 @@ private[spark] class PythonWorkerFactory(
}
val workerEnv = pb.environment()
workerEnv.putAll(envVars.asJava)
appendCachedArrowBatchServerEnvVars(workerEnv)
workerEnv.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
workerEnv.put("PYTHONUNBUFFERED", "YES")
Expand Down Expand Up @@ -247,6 +264,7 @@ private[spark] class PythonWorkerFactory(
}
val workerEnv = pb.environment()
workerEnv.putAll(envVars.asJava)
appendCachedArrowBatchServerEnvVars(workerEnv)
workerEnv.put("PYTHONPATH", pythonPath)
workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
if (Utils.preferIPv6) {
Expand Down
22 changes: 22 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.slf4j.MDC

import org.apache.spark._
import org.apache.spark.api.python.CachedArrowBatchServer
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
Expand Down Expand Up @@ -347,6 +348,21 @@ private[spark] class Executor(

metricsPoller.start()

val cachedArrowBatchServer: Option[CachedArrowBatchServer] = if (
SparkEnv.get.conf.get(Python.PYTHON_DATAFRAME_CHUNK_READ_ENABLED)
) {
val server = new CachedArrowBatchServer(env.conf, env.blockManager)

server.start()

env.cachedArrowBatchServerPort = Some(server.serverSocket.getLocalPort)
env.cachedArrowBatchServerSecret = Some(server.authHelper.secret)

Some(server)
} else {
None
}

private[executor] def numRunningTasks: Int = runningTasks.size()

/**
Expand Down Expand Up @@ -418,6 +434,12 @@ private[spark] class Executor(
if (!executorShutdown.getAndSet(true)) {
ShutdownHookManager.removeShutdownHook(stopHookReference)
env.metricsSystem.report()
try {
cachedArrowBatchServer.foreach(_.stop())
} catch {
case NonFatal(e) =>
logWarning("Unable to stop arrow batch server", e)
}
try {
if (metricsPoller != null) {
metricsPoller.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,12 @@ private[spark] object Python {
.version("3.2.0")
.booleanConf
.createWithDefault(false)

val PYTHON_DATAFRAME_CHUNK_READ_ENABLED =
ConfigBuilder("spark.python.dataFrameChunkRead.enabled")
.doc("When true, driver and executors launch local cached arrow batch servers for serving " +
"PySpark DataFrame 'pyspark.sql.chunk.read_chunk' API requests.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)
}
9 changes: 9 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/BlockId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
override def name: String = "test_" + id
}

private[spark] case class ArrowBatchBlockId(id: UUID) extends BlockId {
override def name: String = "arrow_batch_" + id
}


@DeveloperApi
class UnrecognizedBlockId(name: String)
extends SparkException(s"Failed to parse $name into a block ID")
Expand Down Expand Up @@ -215,6 +220,7 @@ object BlockId {
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r
val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r
val ARROW_BATCH = "arrow_batch_([-A-Fa-f0-9]+)".r
val TEST = "test_(.*)".r

def apply(name: String): BlockId = name match {
Expand Down Expand Up @@ -254,8 +260,11 @@ object BlockId {
TempLocalBlockId(UUID.fromString(uuid))
case TEMP_SHUFFLE(uuid) =>
TempShuffleBlockId(UUID.fromString(uuid))
case ARROW_BATCH(uuid) =>
ArrowBatchBlockId(UUID.fromString(uuid))
case TEST(value) =>
TestBlockId(value)

case _ => throw SparkCoreErrors.unrecognizedBlockIdError(name)
}
}
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def __hash__(self):
"pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations",
"pyspark.sql.tests.pandas.test_pandas_udf_window",
"pyspark.sql.tests.pandas.test_converter",
"pyspark.sql.tests.test_chunk_read_api",
"pyspark.sql.tests.test_pandas_sqlmetrics",
"pyspark.sql.tests.test_python_datasource",
"pyspark.sql.tests.test_readwriter",
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,18 @@ def _do_init(

# Create the Java SparkContext through Py4J
self._jsc = jsc or self._initialize_context(self._conf._jconf)

# Reset the SparkConf to the one actually used by the SparkContext in JVM.
self._conf = SparkConf(_jconf=self._jsc.sc().conf())

if self.getConf().get("spark.python.dataFrameChunkRead.enabled", "false").lower() == "true":
os.environ["PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_PORT"] = str(
self._jsc.sc().cachedArrowBatchServerPort().get()
)
os.environ["PYSPARK_EXECUTOR_CACHED_ARROW_BATCH_SERVER_SECRET"] = (
self._jsc.sc().cachedArrowBatchServerSecret().get()
)

# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
assert self._gateway is not None
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,11 @@ def write_with_length(obj, stream):
stream.write(obj)


def read_with_length(stream):
length = read_int(stream)
return stream.read(length)


class ChunkedStream:

"""
Expand Down
Loading

0 comments on commit 00b2bf9

Please sign in to comment.