diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/CheckpointProvider.scala b/spark/src/main/scala/org/apache/spark/sql/delta/CheckpointProvider.scala index 5c4ff6f852b..e63df4d5cda 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/CheckpointProvider.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/CheckpointProvider.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.storage.LogStore import org.apache.spark.sql.delta.util.FileNames._ -import org.apache.spark.sql.delta.util.NonFateSharingFuture +import org.apache.spark.sql.delta.util.threads.NonFateSharingFuture import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala b/spark/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala index 0945e89cf42..f114f3b85eb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/SnapshotManagement.scala @@ -28,9 +28,9 @@ import scala.util.control.NonFatal import com.databricks.spark.util.TagDefinitions.TAG_ASYNC import org.apache.spark.sql.delta.actions.Metadata import org.apache.spark.sql.delta.sources.DeltaSQLConf -import org.apache.spark.sql.delta.util.DeltaThreadPool import org.apache.spark.sql.delta.util.FileNames._ import org.apache.spark.sql.delta.util.JsonUtils +import org.apache.spark.sql.delta.util.threads.DeltaThreadPool import com.fasterxml.jackson.annotation.JsonIgnore import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/util/DeltaThreadPool.scala b/spark/src/main/scala/org/apache/spark/sql/delta/util/threads/DeltaThreadPool.scala similarity index 75% rename from spark/src/main/scala/org/apache/spark/sql/delta/util/DeltaThreadPool.scala rename to spark/src/main/scala/org/apache/spark/sql/delta/util/threads/DeltaThreadPool.scala index 390b3acfdab..ecf9e941c68 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/util/DeltaThreadPool.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/util/threads/DeltaThreadPool.scala @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.apache.spark.sql.delta.util +package org.apache.spark.sql.delta.util.threads -import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration @@ -26,6 +26,7 @@ import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.SparkSession import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.ThreadUtils.namedThreadFactory /** A wrapper for [[ThreadPoolExecutor]] whose tasks run with the caller's [[SparkSession]]. */ private[delta] class DeltaThreadPool(tpe: ThreadPoolExecutor) { @@ -56,7 +57,28 @@ private[delta] class DeltaThreadPool(tpe: ThreadPoolExecutor) { /** Convenience constructor that creates a [[ThreadPoolExecutor]] with sensible defaults. */ private[delta] object DeltaThreadPool { def apply(prefix: String, numThreads: Int): DeltaThreadPool = - new DeltaThreadPool(ThreadUtils.newDaemonCachedThreadPool(prefix, numThreads)) + new DeltaThreadPool(newDaemonCachedThreadPool(prefix, numThreads)) + + /** + * Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names + * are formatted as prefix-ID, where ID is a unique, sequentially assigned integer. + */ + def newDaemonCachedThreadPool( + prefix: String, + maxThreadNumber: Int): ThreadPoolExecutor = { + val keepAliveSeconds = 60 + val queueSize = Integer.MAX_VALUE + val threadFactory = namedThreadFactory(prefix) + val threadPool = new SparkThreadLocalForwardingThreadPoolExecutor( + maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks + maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used + keepAliveSeconds, + TimeUnit.SECONDS, + new LinkedBlockingQueue[Runnable](queueSize), + threadFactory) + threadPool.allowCoreThreadTimeOut(true) + threadPool + } } /** diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/util/threads/SparkThreadLocalForwardingThreadPoolExecutor.scala b/spark/src/main/scala/org/apache/spark/sql/delta/util/threads/SparkThreadLocalForwardingThreadPoolExecutor.scala new file mode 100644 index 00000000000..dbb126ccaea --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/delta/util/threads/SparkThreadLocalForwardingThreadPoolExecutor.scala @@ -0,0 +1,118 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.util.threads + +import java.util.Properties +import java.util.concurrent._ + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkContext, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Utils => SparkUtils} + +/** + * Implementation of ThreadPoolExecutor that captures the Spark ThreadLocals present at submit time + * and inserts them into the thread before executing the provided runnable. + */ +class SparkThreadLocalForwardingThreadPoolExecutor( + corePoolSize: Int, + maximumPoolSize: Int, + keepAliveTime: Long, + unit: TimeUnit, + workQueue: BlockingQueue[Runnable], + threadFactory: ThreadFactory, + rejectedExecutionHandler: RejectedExecutionHandler = new ThreadPoolExecutor.AbortPolicy) + extends ThreadPoolExecutor( + corePoolSize, maximumPoolSize, keepAliveTime, + unit, workQueue, threadFactory, rejectedExecutionHandler) { + + override def execute(command: Runnable): Unit = + super.execute(new SparkThreadLocalCapturingRunnable(command)) +} + + +trait SparkThreadLocalCapturingHelper extends Logging { + // At the time of creating this instance we capture the task context and command context. + val capturedTaskContext = TaskContext.get() + val sparkContext = SparkContext.getActive + // Capture an immutable threadsafe snapshot of the current local properties + val capturedProperties = sparkContext + .map(sc => CapturedSparkThreadLocals.toValuesArray( + SparkUtils.cloneProperties(sc.getLocalProperties))) + + def runWithCaptured[T](body: => T): T = { + // Save the previous contexts, overwrite them with the captured contexts, and then restore the + // previous when execution completes. + // This has the unfortunate side effect of writing nulls to these thread locals if they were + // empty beforehand. + val previousTaskContext = TaskContext.get() + val previousProperties = sparkContext.map(_.getLocalProperties) + + TaskContext.setTaskContext(capturedTaskContext) + for { + p <- capturedProperties + sc <- sparkContext + } { + sc.setLocalProperties(CapturedSparkThreadLocals.toProperties(p)) + } + + try { + body + } catch { + case t: Throwable => + logError(s"Exception in thread ${Thread.currentThread().getName}", t) + throw t + } finally { + TaskContext.setTaskContext(previousTaskContext) + for { + p <- previousProperties + sc <- sparkContext + } { + sc.setLocalProperties(p) + } + } + } +} + +class CapturedSparkThreadLocals extends SparkThreadLocalCapturingHelper + +object CapturedSparkThreadLocals { + def apply(): CapturedSparkThreadLocals = { + new CapturedSparkThreadLocals() + } + + def toProperties(props: Array[(String, String)]): Properties = { + val resultProps = new Properties() + for ((key, value) <- props) { + resultProps.put(key, value) + } + resultProps + } + + def toValuesArray(props: Properties): Array[(String, String)] = { + props.asScala.toArray + } + +} + +class SparkThreadLocalCapturingRunnable(runnable: Runnable) + extends Runnable with SparkThreadLocalCapturingHelper { + override def run(): Unit = { + runWithCaptured(runnable.run()) + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/util/threads/DeltaThreadPoolSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/util/threads/DeltaThreadPoolSuite.scala new file mode 100644 index 00000000000..51d5988ce5f --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/util/threads/DeltaThreadPoolSuite.scala @@ -0,0 +1,69 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.util.threads + +import java.util.Properties + +import org.apache.spark.{SparkFunSuite, TaskContext, TaskContextImpl} +import org.apache.spark.sql.test.SharedSparkSession + +class DeltaThreadPoolSuite extends SparkFunSuite with SharedSparkSession { + + val threadPool: DeltaThreadPool = DeltaThreadPool("test", 1) + + def makeTaskContext(id: Int): TaskContext = { + new TaskContextImpl(id, 0, 0, 0, attemptNumber = 45613, 0, null, new Properties(), null) + } + + def testForwarding(testName: String, id: Int)(f: => Unit): Unit = { + test(testName) { + val prevTaskContext = TaskContext.get() + TaskContext.setTaskContext(makeTaskContext(id)) + sparkContext.setLocalProperty("test", id.toString) + + try { + f + } finally { + TaskContext.setTaskContext(prevTaskContext) + } + } + } + + def assertTaskAndProperties(id: Int): Unit = { + assert(TaskContext.get() !== null) + assert(TaskContext.get().stageId() === id) + assert(sparkContext.getLocalProperty("test") === id.toString) + } + + testForwarding("parallelMap captures TaskContext", id = 0) { + threadPool.parallelMap(spark, 0 until 1) { _ => + assertTaskAndProperties(id = 0) + } + } + + testForwarding("submit captures TaskContext and local properties", id = 1) { + threadPool.submit(spark) { + assertTaskAndProperties(id = 1) + } + } + + testForwarding("submitNonFateSharing captures TaskContext and local properties", id = 2) { + threadPool.submitNonFateSharing { _ => + assertTaskAndProperties(id = 2) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/util/threads/SparkThreadLocalForwardingSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/util/threads/SparkThreadLocalForwardingSuite.scala new file mode 100644 index 00000000000..50b946a78b7 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/util/threads/SparkThreadLocalForwardingSuite.scala @@ -0,0 +1,152 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.util.threads + +import java.util.Properties +import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} + +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} +import scala.concurrent.duration._ + +import org.apache.spark._ +import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.ThreadUtils.namedThreadFactory + +class SparkThreadLocalForwardingSuite extends SparkFunSuite { + + private def createThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { + val threadFactory = namedThreadFactory(prefix) + val keepAliveTimeSeconds = 60 + val threadPool = new SparkThreadLocalForwardingThreadPoolExecutor( + nThreads, + nThreads, + keepAliveTimeSeconds, + TimeUnit.MILLISECONDS, + new LinkedBlockingQueue[Runnable], + threadFactory) + threadPool.allowCoreThreadTimeOut(true) + threadPool + } + + test("SparkThreadLocalForwardingThreadPoolExecutor properly propagates" + + " TaskContext and Spark Local Properties") { + val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local")) + val executor = createThreadPool(1, "test-threads") + implicit val executionContext: ExecutionContextExecutor = + ExecutionContext.fromExecutor(executor) + + val prevTaskContext = TaskContext.get() + try { + // assert that each instance of submitting a task to the execution context captures the + // current task context + val futures = (1 to 10) map { i => + setTaskAndProperties(i, sc) + + Future { + checkTaskAndProperties(i, sc) + }(executionContext) + } + + assert(ThreadUtils.awaitResult(Future.sequence(futures), 10.seconds).forall(identity)) + } finally { + ThreadUtils.shutdown(executor) + TaskContext.setTaskContext(prevTaskContext) + sc.stop() + } + } + + def makeTaskContext(id: Int): TaskContext = { + new TaskContextImpl(id, 0, 0, 0, attemptNumber = 45613, 0, null, new Properties(), null) + } + + def setTaskAndProperties(i: Int, sc: SparkContext = SparkContext.getActive.get): Unit = { + val tc = makeTaskContext(i) + TaskContext.setTaskContext(tc) + sc.setLocalProperty("test", i.toString) + } + + def checkTaskAndProperties(i: Int, sc: SparkContext = SparkContext.getActive.get): Boolean = { + TaskContext.get() != null && + TaskContext.get().stageId() == i && + sc.getLocalProperty("test") == i.toString + } + + test("That CapturedSparkThreadLocals properly restores the existing state") { + val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local")) + val prevTaskContext = TaskContext.get() + try { + setTaskAndProperties(10) + val capturedSparkThreadLocals = CapturedSparkThreadLocals() + setTaskAndProperties(11) + assert(!checkTaskAndProperties(10, sc)) + assert(checkTaskAndProperties(11, sc)) + capturedSparkThreadLocals.runWithCaptured { + assert(checkTaskAndProperties(10, sc)) + } + assert(checkTaskAndProperties(11, sc)) + } finally { + TaskContext.setTaskContext(prevTaskContext) + sc.stop() + } + } + + test("That CapturedSparkThreadLocals properly restores the existing spark properties." + + " Changes to local properties inside a task do not affect the original properties") { + val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local")) + try { + sc.setLocalProperty("TestProp", "1") + val capturedSparkThreadLocals = CapturedSparkThreadLocals() + assert(sc.getLocalProperty("TestProp") == "1") + capturedSparkThreadLocals.runWithCaptured { + sc.setLocalProperty("TestProp", "2") + assert(sc.getLocalProperty("TestProp") == "2") + } + assert(sc.getLocalProperty("TestProp") == "1") + } finally { + sc.stop() + } + } + + + test("captured spark thread locals are immutable") { + val sc = SparkContext.getOrCreate(new SparkConf().setAppName("test").setMaster("local")) + try { + sc.setLocalProperty("test1", "good") + sc.setLocalProperty("test2", "good") + val threadLocals = CapturedSparkThreadLocals() + sc.setLocalProperty("test2", "bad") + assert(sc.getLocalProperty("test1") == "good") + assert(sc.getLocalProperty("test2") == "bad") + threadLocals.runWithCaptured { + assert(sc.getLocalProperty("test1") == "good") + assert(sc.getLocalProperty("test2") == "good") + sc.setLocalProperty("test1", "bad") + sc.setLocalProperty("test2", "maybe") + assert(sc.getLocalProperty("test1") == "bad") + assert(sc.getLocalProperty("test2") == "maybe") + } + assert(sc.getLocalProperty("test1") == "good") + assert(sc.getLocalProperty("test2") == "bad") + threadLocals.runWithCaptured { + assert(sc.getLocalProperty("test1") == "good") + assert(sc.getLocalProperty("test2") == "good") + } + } finally { + sc.stop() + } + } +}