Skip to content

Commit

Permalink
[Spark] Propagate thread locals to Delta thread pools
Browse files Browse the repository at this point in the history
* The default thread pool executor in Apache Spark does not forward thread locals to threads spawned in a thread pool.
*  This can cause issues if the threads depend on the thread locals.
* To fix this, we introduce a wrapper class around the thread pool executor that forwards thread locals.

Closes delta-io#2154

GitOrigin-RevId: 9e9423e4b041232457ffaab18f5f96490bb45b88
  • Loading branch information
fred-db authored and xupefei committed Oct 31, 2023
1 parent 79057fd commit cbd2a50
Show file tree
Hide file tree
Showing 6 changed files with 366 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.storage.LogStore
import org.apache.spark.sql.delta.util.FileNames._
import org.apache.spark.sql.delta.util.NonFateSharingFuture
import org.apache.spark.sql.delta.util.threads.NonFateSharingFuture
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import scala.util.control.NonFatal
import com.databricks.spark.util.TagDefinitions.TAG_ASYNC
import org.apache.spark.sql.delta.actions.Metadata
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.util.DeltaThreadPool
import org.apache.spark.sql.delta.util.FileNames._
import org.apache.spark.sql.delta.util.JsonUtils
import org.apache.spark.sql.delta.util.threads.DeltaThreadPool
import com.fasterxml.jackson.annotation.JsonIgnore
import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
* limitations under the License.
*/

package org.apache.spark.sql.delta.util
package org.apache.spark.sql.delta.util.threads

import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
Expand All @@ -26,6 +26,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging

import org.apache.spark.sql.SparkSession
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.ThreadUtils.namedThreadFactory

/** A wrapper for [[ThreadPoolExecutor]] whose tasks run with the caller's [[SparkSession]]. */
private[delta] class DeltaThreadPool(tpe: ThreadPoolExecutor) {
Expand Down Expand Up @@ -56,7 +57,28 @@ private[delta] class DeltaThreadPool(tpe: ThreadPoolExecutor) {
/** Convenience constructor that creates a [[ThreadPoolExecutor]] with sensible defaults. */
private[delta] object DeltaThreadPool {
def apply(prefix: String, numThreads: Int): DeltaThreadPool =
new DeltaThreadPool(ThreadUtils.newDaemonCachedThreadPool(prefix, numThreads))
new DeltaThreadPool(newDaemonCachedThreadPool(prefix, numThreads))

/**
* Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names
* are formatted as prefix-ID, where ID is a unique, sequentially assigned integer.
*/
def newDaemonCachedThreadPool(
prefix: String,
maxThreadNumber: Int): ThreadPoolExecutor = {
val keepAliveSeconds = 60
val queueSize = Integer.MAX_VALUE
val threadFactory = namedThreadFactory(prefix)
val threadPool = new SparkThreadLocalForwardingThreadPoolExecutor(
maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks
maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used
keepAliveSeconds,
TimeUnit.SECONDS,
new LinkedBlockingQueue[Runnable](queueSize),
threadFactory)
threadPool.allowCoreThreadTimeOut(true)
threadPool
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.util.threads

import java.util.Properties
import java.util.concurrent._

import scala.collection.JavaConverters._

import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.util.{Utils => SparkUtils}

/**
* Implementation of ThreadPoolExecutor that captures the Spark ThreadLocals present at submit time
* and inserts them into the thread before executing the provided runnable.
*/
class SparkThreadLocalForwardingThreadPoolExecutor(
corePoolSize: Int,
maximumPoolSize: Int,
keepAliveTime: Long,
unit: TimeUnit,
workQueue: BlockingQueue[Runnable],
threadFactory: ThreadFactory,
rejectedExecutionHandler: RejectedExecutionHandler = new ThreadPoolExecutor.AbortPolicy)
extends ThreadPoolExecutor(
corePoolSize, maximumPoolSize, keepAliveTime,
unit, workQueue, threadFactory, rejectedExecutionHandler) {

override def execute(command: Runnable): Unit =
super.execute(new SparkThreadLocalCapturingRunnable(command))
}


trait SparkThreadLocalCapturingHelper extends Logging {
// At the time of creating this instance we capture the task context and command context.
val capturedTaskContext = TaskContext.get()
val sparkContext = SparkContext.getActive
// Capture an immutable threadsafe snapshot of the current local properties
val capturedProperties = sparkContext
.map(sc => CapturedSparkThreadLocals.toValuesArray(
SparkUtils.cloneProperties(sc.getLocalProperties)))

def runWithCaptured[T](body: => T): T = {
// Save the previous contexts, overwrite them with the captured contexts, and then restore the
// previous when execution completes.
// This has the unfortunate side effect of writing nulls to these thread locals if they were
// empty beforehand.
val previousTaskContext = TaskContext.get()
val previousProperties = sparkContext.map(_.getLocalProperties)

TaskContext.setTaskContext(capturedTaskContext)
for {
p <- capturedProperties
sc <- sparkContext
} {
sc.setLocalProperties(CapturedSparkThreadLocals.toProperties(p))
}

try {
body
} catch {
case t: Throwable =>
logError(s"Exception in thread ${Thread.currentThread().getName}", t)
throw t
} finally {
TaskContext.setTaskContext(previousTaskContext)
for {
p <- previousProperties
sc <- sparkContext
} {
sc.setLocalProperties(p)
}
}
}
}

class CapturedSparkThreadLocals extends SparkThreadLocalCapturingHelper

object CapturedSparkThreadLocals {
def apply(): CapturedSparkThreadLocals = {
new CapturedSparkThreadLocals()
}

def toProperties(props: Array[(String, String)]): Properties = {
val resultProps = new Properties()
for ((key, value) <- props) {
resultProps.put(key, value)
}
resultProps
}

def toValuesArray(props: Properties): Array[(String, String)] = {
props.asScala.toArray
}

}

class SparkThreadLocalCapturingRunnable(runnable: Runnable)
extends Runnable with SparkThreadLocalCapturingHelper {
override def run(): Unit = {
runWithCaptured(runnable.run())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.util.threads

import java.util.Properties

import org.apache.spark.{SparkFunSuite, TaskContext, TaskContextImpl}
import org.apache.spark.sql.test.SharedSparkSession

class DeltaThreadPoolSuite extends SparkFunSuite with SharedSparkSession {

val threadPool: DeltaThreadPool = DeltaThreadPool("test", 1)

def makeTaskContext(id: Int): TaskContext = {
new TaskContextImpl(id, 0, 0, 0, attemptNumber = 45613, 0, null, new Properties(), null)
}

def testForwarding(testName: String, id: Int)(f: => Unit): Unit = {
test(testName) {
val prevTaskContext = TaskContext.get()
TaskContext.setTaskContext(makeTaskContext(id))
sparkContext.setLocalProperty("test", id.toString)

try {
f
} finally {
TaskContext.setTaskContext(prevTaskContext)
}
}
}

def assertTaskAndProperties(id: Int): Unit = {
assert(TaskContext.get() !== null)
assert(TaskContext.get().stageId() === id)
assert(sparkContext.getLocalProperty("test") === id.toString)
}

testForwarding("parallelMap captures TaskContext", id = 0) {
threadPool.parallelMap(spark, 0 until 1) { _ =>
assertTaskAndProperties(id = 0)
}
}

testForwarding("submit captures TaskContext and local properties", id = 1) {
threadPool.submit(spark) {
assertTaskAndProperties(id = 1)
}
}

testForwarding("submitNonFateSharing captures TaskContext and local properties", id = 2) {
threadPool.submitNonFateSharing { _ =>
assertTaskAndProperties(id = 2)
}
}
}
Loading

0 comments on commit cbd2a50

Please sign in to comment.