Skip to content

Commit

Permalink
Address comments and move threadpool
Browse files Browse the repository at this point in the history
  • Loading branch information
fred-db committed Oct 13, 2023
1 parent 4f9c8b9 commit da5724d
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.storage.LogStore
import org.apache.spark.sql.delta.util.FileNames._
import org.apache.spark.sql.delta.util.NonFateSharingFuture
import org.apache.spark.sql.delta.util.threads.NonFateSharingFuture
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import scala.util.control.NonFatal
import com.databricks.spark.util.TagDefinitions.TAG_ASYNC
import org.apache.spark.sql.delta.actions.Metadata
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.util.DeltaThreadPool
import org.apache.spark.sql.delta.util.FileNames._
import org.apache.spark.sql.delta.util.JsonUtils
import org.apache.spark.sql.delta.util.threads.DeltaThreadPool
import com.fasterxml.jackson.annotation.JsonIgnore
import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
* limitations under the License.
*/

package org.apache.spark.sql.delta.util
package org.apache.spark.sql.delta.util.threads

import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
Expand All @@ -26,6 +26,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging

import org.apache.spark.sql.SparkSession
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.ThreadUtils.namedThreadFactory

/** A wrapper for [[ThreadPoolExecutor]] whose tasks run with the caller's [[SparkSession]]. */
private[delta] class DeltaThreadPool(tpe: ThreadPoolExecutor) {
Expand Down Expand Up @@ -56,7 +57,28 @@ private[delta] class DeltaThreadPool(tpe: ThreadPoolExecutor) {
/** Convenience constructor that creates a [[ThreadPoolExecutor]] with sensible defaults. */
private[delta] object DeltaThreadPool {
def apply(prefix: String, numThreads: Int): DeltaThreadPool =
new DeltaThreadPool(ThreadUtils.newDaemonCachedThreadPool(prefix, numThreads))
new DeltaThreadPool(newDaemonCachedThreadPool(prefix, numThreads))

/**
* Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names
* are formatted as prefix-ID, where ID is a unique, sequentially assigned integer.
*/
def newDaemonCachedThreadPool(
prefix: String,
maxThreadNumber: Int): ThreadPoolExecutor = {
val keepAliveSeconds = 60
val queueSize = Integer.MAX_VALUE
val threadFactory = namedThreadFactory(prefix)
val threadPool = new SparkThreadLocalForwardingThreadPoolExecutor(
maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks
maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used
keepAliveSeconds,
TimeUnit.SECONDS,
new LinkedBlockingQueue[Runnable](queueSize),
threadFactory)
threadPool.allowCoreThreadTimeOut(true)
threadPool
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.util.threads

import java.util.Properties
import java.util.concurrent._

import scala.collection.JavaConverters._

import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.util.{Utils => SparkUtils}

/**
* Implementation of ThreadPoolExecutor that captures the Spark ThreadLocals present at submit time
* and inserts them into the thread before executing the provided runnable.
*/
class SparkThreadLocalForwardingThreadPoolExecutor(
corePoolSize: Int,
maximumPoolSize: Int,
keepAliveTime: Long,
unit: TimeUnit,
workQueue: BlockingQueue[Runnable],
threadFactory: ThreadFactory,
rejectedExecutionHandler: RejectedExecutionHandler = new ThreadPoolExecutor.AbortPolicy)
extends ThreadPoolExecutor(
corePoolSize, maximumPoolSize, keepAliveTime,
unit, workQueue, threadFactory, rejectedExecutionHandler) {

override def execute(command: Runnable): Unit =
super.execute(new SparkThreadLocalCapturingRunnable(command))
}


trait SparkThreadLocalCapturingHelper extends Logging {
// At the time of creating this instance we capture the task context and command context.
val capturedTaskContext = TaskContext.get()
val sparkContext = SparkContext.getActive
// Capture an immutable threadsafe snapshot of the current local properties
val capturedProperties = sparkContext
.map(sc => CapturedSparkThreadLocals.toValuesArray(
SparkUtils.cloneProperties(sc.getLocalProperties)))

def runWithCaptured[T](body: => T): T = {
// Save the previous contexts, overwrite them with the captured contexts, and then restore the
// previous when execution completes.
// This has the unfortunate side effect of writing nulls to these thread locals if they were
// empty beforehand.
val previousTaskContext = TaskContext.get()
val previousProperties = sparkContext.map(_.getLocalProperties)

TaskContext.setTaskContext(capturedTaskContext)
for {
p <- capturedProperties
sc <- sparkContext
} {
sc.setLocalProperties(CapturedSparkThreadLocals.toProperties(p))
}

try {
body
} catch {
case t: Throwable =>
logError(s"Exception in thread ${Thread.currentThread().getName}", t)
throw t
} finally {
TaskContext.setTaskContext(previousTaskContext)
for {
p <- previousProperties
sc <- sparkContext
} {
sc.setLocalProperties(p)
}
}
}
}

class CapturedSparkThreadLocals extends SparkThreadLocalCapturingHelper

object CapturedSparkThreadLocals {
def apply(): CapturedSparkThreadLocals = {
new CapturedSparkThreadLocals()
}

def toProperties(props: Array[(String, String)]): Properties = {
val resultProps = new Properties()
for ((key, value) <- props) {
resultProps.put(key, value)
}
resultProps
}

def toValuesArray(props: Properties): Array[(String, String)] = {
props.asScala.toArray
}

}

class SparkThreadLocalCapturingRunnable(runnable: Runnable)
extends Runnable with SparkThreadLocalCapturingHelper {
override def run(): Unit = {
runWithCaptured(runnable.run())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.util.threads

import java.util.Properties
import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}

import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future}
import scala.concurrent.duration._

import org.apache.spark._
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.ThreadUtils.namedThreadFactory

class SparkThreadLocalForwardingSuite extends SparkFunSuite {

private def createThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = {
val threadFactory = namedThreadFactory(prefix)
val keepAliveTimeSeconds = 60
val threadPool = new SparkThreadLocalForwardingThreadPoolExecutor(
nThreads,
nThreads,
keepAliveTimeSeconds,
TimeUnit.MILLISECONDS,
new LinkedBlockingQueue[Runnable],
threadFactory)
threadPool.allowCoreThreadTimeOut(true)
threadPool
}

test("SparkThreadLocalForwardingThreadPoolExecutor properly propagates" +
" TaskContext and Spark Local Properties") {
val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local"))
val executor = createThreadPool(1, "test-threads")
implicit val executionContext: ExecutionContextExecutor =
ExecutionContext.fromExecutor(executor)

val prevTaskContext = TaskContext.get()
try {
// assert that each instance of submitting a task to the execution context captures the
// current task context
val futures = (1 to 10) map { i =>
setTaskAndProperties(i, sc)

Future {
checkTaskAndProperties(i, sc)
}(executionContext)
}

assert(ThreadUtils.awaitResult(Future.sequence(futures), 10.seconds).forall(identity))
} finally {
ThreadUtils.shutdown(executor)
TaskContext.setTaskContext(prevTaskContext)
sc.stop()
}
}

def makeTaskContext(id: Int): TaskContext = {
new TaskContextImpl(id, 0, 0, 0, attemptNumber = 45613, 0, null, new Properties(), null)
}

def setTaskAndProperties(i: Int, sc: SparkContext = SparkContext.getActive.get): Unit = {
val tc = makeTaskContext(i)
TaskContext.setTaskContext(tc)
sc.setLocalProperty("test", i.toString)
}

def checkTaskAndProperties(i: Int, sc: SparkContext = SparkContext.getActive.get): Boolean = {
TaskContext.get() != null &&
TaskContext.get().stageId() == i &&
sc.getLocalProperty("test") == i.toString
}

test("That CapturedSparkThreadLocals properly restores the existing state") {
val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local"))
val prevTaskContext = TaskContext.get()
try {
setTaskAndProperties(10)
val capturedSparkThreadLocals = CapturedSparkThreadLocals()
setTaskAndProperties(11)
assert(!checkTaskAndProperties(10, sc))
assert(checkTaskAndProperties(11, sc))
capturedSparkThreadLocals.runWithCaptured {
assert(checkTaskAndProperties(10, sc))
}
assert(checkTaskAndProperties(11, sc))
} finally {
TaskContext.setTaskContext(prevTaskContext)
sc.stop()
}
}

test("That CapturedSparkThreadLocals properly restores the existing spark properties." +
" Changes to local properties inside a task do not affect the original properties") {
val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local"))
try {
sc.setLocalProperty("TestProp", "1")
val capturedSparkThreadLocals = CapturedSparkThreadLocals()
assert(sc.getLocalProperty("TestProp") == "1")
capturedSparkThreadLocals.runWithCaptured {
sc.setLocalProperty("TestProp", "2")
assert(sc.getLocalProperty("TestProp") == "2")
}
assert(sc.getLocalProperty("TestProp") == "1")
} finally {
sc.stop()
}
}


test("captured spark thread locals are immutable") {
val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local"))
try {
sc.setLocalProperty("test1", "good")
sc.setLocalProperty("test2", "good")
val threadLocals = CapturedSparkThreadLocals()
sc.setLocalProperty("test2", "bad")
assert(sc.getLocalProperty("test1") == "good")
assert(sc.getLocalProperty("test2") == "bad")
threadLocals.runWithCaptured {
assert(sc.getLocalProperty("test1") == "good")
assert(sc.getLocalProperty("test2") == "good")
sc.setLocalProperty("test1", "bad")
sc.setLocalProperty("test2", "maybe")
assert(sc.getLocalProperty("test1") == "bad")
assert(sc.getLocalProperty("test2") == "maybe")
}
assert(sc.getLocalProperty("test1") == "good")
assert(sc.getLocalProperty("test2") == "bad")
threadLocals.runWithCaptured {
assert(sc.getLocalProperty("test1") == "good")
assert(sc.getLocalProperty("test2") == "good")
}
} finally {
sc.stop()
}
}
}

0 comments on commit da5724d

Please sign in to comment.