Skip to content

Commit

Permalink
Propagate threadlocals
Browse files Browse the repository at this point in the history
  • Loading branch information
fred-db committed Oct 9, 2023
1 parent 4f9c8b9 commit 60faa14
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.concurrent.duration.Duration

import org.apache.spark.sql.delta.DeltaErrors
import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.util.threads.SparkThreadLocalForwardingThreadPoolExecutor

import org.apache.spark.sql.SparkSession
import org.apache.spark.util.ThreadUtils
Expand Down Expand Up @@ -56,7 +57,8 @@ private[delta] class DeltaThreadPool(tpe: ThreadPoolExecutor) {
/** Convenience constructor that creates a [[ThreadPoolExecutor]] with sensible defaults. */
private[delta] object DeltaThreadPool {
def apply(prefix: String, numThreads: Int): DeltaThreadPool =
new DeltaThreadPool(ThreadUtils.newDaemonCachedThreadPool(prefix, numThreads))
new DeltaThreadPool(
SparkThreadLocalForwardingThreadPoolExecutor.newDaemonCachedThreadPool(prefix, numThreads))
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.util.threads

import java.util.Properties
import java.util.concurrent._

import scala.collection.JavaConverters._

import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.util.ThreadUtils.namedThreadFactory

/**
* Implementation of ThreadPoolExecutor that captures the Spark ThreadLocals present at submit time
* and inserts them into the thread before executing the provided runnable.
*/
class SparkThreadLocalForwardingThreadPoolExecutor(
corePoolSize: Int,
maximumPoolSize: Int,
keepAliveTime: Long,
unit: TimeUnit,
workQueue: BlockingQueue[Runnable],
threadFactory: ThreadFactory,
rejectedExecutionHandler: RejectedExecutionHandler = new ThreadPoolExecutor.AbortPolicy)
extends ThreadPoolExecutor(
corePoolSize, maximumPoolSize, keepAliveTime,
unit, workQueue, threadFactory, rejectedExecutionHandler) {

override def execute(command: Runnable): Unit =
super.execute(new SparkThreadLocalCapturingRunnable(command))
}

object SparkThreadLocalForwardingThreadPoolExecutor {
/**
* Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names
* are formatted as prefix-ID, where ID is a unique, sequentially assigned integer.
*/
def newDaemonCachedThreadPool(
prefix: String,
maxThreadNumber: Int,
keepAliveSeconds: Int = 60,
queueSize: Int = Integer.MAX_VALUE): ThreadPoolExecutor = {
val threadFactory = namedThreadFactory(prefix)
val threadPool = new SparkThreadLocalForwardingThreadPoolExecutor(
maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks
maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used
keepAliveSeconds,
TimeUnit.SECONDS,
new LinkedBlockingQueue[Runnable](queueSize),
threadFactory)
threadPool.allowCoreThreadTimeOut(true)
threadPool
}
}

trait SparkThreadLocalCapturingHelper extends Logging {
// At the time of creating this instance we capture the task context and command context.
val capturedTaskContext = TaskContext.get()
val sparkContext = SparkContext.getActive
// Capture an immutable threadsafe snapshot of the current local properties
val capturedProperties = sparkContext
.map(sc => CapturedSparkThreadLocals.toValuesArray(
org.apache.spark.util.Utils.cloneProperties(sc.getLocalProperties)))

def runWithCaptured[T](body: => T): T = {
// Save the previous contexts, overwrite them with the captured contexts, and then restore the
// previous when execution completes.
// This has the unfortunate side effect of writing nulls to these thread locals if they were
// empty beforehand.
val previousTaskContext = TaskContext.get()
val previousProperties = sparkContext.map(_.getLocalProperties)

TaskContext.setTaskContext(capturedTaskContext)
capturedProperties.foreach { p =>
sparkContext.foreach(_.setLocalProperties(CapturedSparkThreadLocals.toProperties(p)))
}

try {
body
} catch {
case t: Throwable =>
logError(s"Exception in thread ${Thread.currentThread().getName}", t)
throw t
} finally {
TaskContext.setTaskContext(previousTaskContext)
previousProperties.foreach(p => sparkContext.foreach(_.setLocalProperties(p)))
}
}
}

class CapturedSparkThreadLocals extends SparkThreadLocalCapturingHelper

object CapturedSparkThreadLocals {
def apply(): CapturedSparkThreadLocals = {
new CapturedSparkThreadLocals()
}

def toProperties(props: Array[(String, String)]): Properties = {
val resultProps = new Properties()
props.foreach { kvp =>
resultProps.put(kvp._1, kvp._2)
}
resultProps
}

def toValuesArray(props: Properties): Array[(String, String)] = {
props.asScala.toArray
}

}

class SparkThreadLocalCapturingRunnable(runnable: Runnable)
extends Runnable with SparkThreadLocalCapturingHelper {
override def run(): Unit = {
runWithCaptured(runnable.run())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.util

import java.util.Properties

import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future}
import scala.concurrent.duration._

import org.apache.spark._
import org.apache.spark.sql.delta.util.threads.{CapturedSparkThreadLocals, SparkThreadLocalForwardingThreadPoolExecutor}
import org.apache.spark.util.ThreadUtils

class SparkThreadLocalForwardingSuite extends SparkFunSuite {

test("SparkThreadLocalForwardingThreadPoolExecutor properly propagates" +
" TaskContext and Spark Local Properties") {
val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local"))
val executor =
SparkThreadLocalForwardingThreadPoolExecutor.newDaemonCachedThreadPool("test-threads", 1)
implicit val executionContext: ExecutionContextExecutor =
ExecutionContext.fromExecutor(executor)

val prevTaskContext = TaskContext.get()
try {
// assert that each instance of submitting a task to the execution context captures the
// current task context
val futures = (1 to 10) map { i =>
setTaskAndProperties(i, sc)

Future {
checkTaskAndProperties(i, sc)
}(executionContext)
}

assert(ThreadUtils.awaitResult(Future.sequence(futures), 10.seconds).forall(identity))
} finally {
TaskContext.setTaskContext(prevTaskContext)
ThreadUtils.shutdown(executor)
sc.stop()
}
}

def makeTaskContext(id: Int): TaskContext = {
new TaskContextImpl(id, 0, 0, 0, attemptNumber = 45613, 0, null, new Properties(), null)
}

def setTaskAndProperties(i: Int, sc: SparkContext = SparkContext.getActive.get): Unit = {
val tc = makeTaskContext(i)
TaskContext.setTaskContext(tc)
sc.setLocalProperty("test", i.toString)
}

def checkTaskAndProperties(i: Int, sc: SparkContext = SparkContext.getActive.get): Boolean = {
TaskContext.get() != null &&
TaskContext.get().stageId() == i &&
sc.getLocalProperty("test") == i.toString
}

test("That CapturedSparkThreadLocals properly restores the existing state") {
val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local"))
val prevTaskContext = TaskContext.get()
try {
setTaskAndProperties(10)
val capturedSparkThreadLocals = CapturedSparkThreadLocals()
setTaskAndProperties(11)
assert(!checkTaskAndProperties(10, sc))
assert(checkTaskAndProperties(11, sc))
capturedSparkThreadLocals.runWithCaptured {
assert(checkTaskAndProperties(10, sc))
}
assert(checkTaskAndProperties(11, sc))
} finally {
TaskContext.setTaskContext(prevTaskContext)
sc.stop()
}
}

test("That CapturedSparkThreadLocals properly restores the existing spark properties." +
" Changes to local properties inside a task do not affect the original properties") {
val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local"))
try {
sc.setLocalProperty("TestProp", "1")
val capturedSparkThreadLocals = CapturedSparkThreadLocals()
assert(sc.getLocalProperty("TestProp") == "1")
capturedSparkThreadLocals.runWithCaptured {
sc.setLocalProperty("TestProp", "2")
assert(sc.getLocalProperty("TestProp") == "2")
}
assert(sc.getLocalProperty("TestProp") == "1")
} finally {
sc.stop()
}
}

test("captured spark thread locals are immutable") {
val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local"))
try {
sc.setLocalProperty("test1", "good")
sc.setLocalProperty("test2", "good")
val threadLocals = CapturedSparkThreadLocals()
sc.setLocalProperty("test2", "bad")
assert(sc.getLocalProperty("test1") == "good")
assert(sc.getLocalProperty("test2") == "bad")
threadLocals.runWithCaptured {
assert(sc.getLocalProperty("test1") == "good")
assert(sc.getLocalProperty("test2") == "good")
sc.setLocalProperty("test1", "bad")
sc.setLocalProperty("test2", "maybe")
assert(sc.getLocalProperty("test1") == "bad")
assert(sc.getLocalProperty("test2") == "maybe")
}
assert(sc.getLocalProperty("test1") == "good")
assert(sc.getLocalProperty("test2") == "bad")
threadLocals.runWithCaptured {
assert(sc.getLocalProperty("test1") == "good")
assert(sc.getLocalProperty("test2") == "good")
}
} finally {
sc.stop()
}
}
}

0 comments on commit 60faa14

Please sign in to comment.