diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/DMLWithDeletionVectorsHelper.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/DMLWithDeletionVectorsHelper.scala index 230c9eb7034..e87725340dd 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/DMLWithDeletionVectorsHelper.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/DMLWithDeletionVectorsHelper.scala @@ -253,6 +253,7 @@ object DMLWithDeletionVectorsHelper extends DeltaCommand { conf = snapshot.deltaLog.newDeltaHadoopConf(), dataPath = snapshot.deltaLog.dataPath, addFiles = filesWithNoStats.toDS(spark), + numFilesOpt = Some(filesWithNoStats.size), columnMappingMode = snapshot.metadata.columnMappingMode, dataSchema = snapshot.dataSchema, statsSchema = snapshot.statsSchema, diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala index 879d7561b61..82ec85cdf3a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala @@ -1343,6 +1343,23 @@ trait DeltaSQLConfBase { .booleanConf .createWithDefault(false) + val DELTA_USE_MULTI_THREADED_STATS_COLLECTION = + buildConf("collectStats.useMultiThreadedStatsCollection") + .internal() + .doc("Whether to use multi-threaded statistics collection. If false, statistics will be " + + "collected sequentially for each partition.") + .booleanConf + .createWithDefault(true) + + val DELTA_STATS_COLLECTION_NUM_FILES_PARTITION = + buildConf("collectStats.numFilesPerPartition") + .internal() + .doc("Controls the number of files that should be within a partition " + + "during multi-threaded optimized statistics collection. A larger number will lead to " + + "less parallelism, but can reduce scheduling overhead.") + .intConf + .checkValue(v => v >= 1, "Must be at least 1.") + .createWithDefault(100) } object DeltaSQLConf extends DeltaSQLConfBase diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatsCollectionUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatsCollectionUtils.scala index 2d19b4a5448..653312883fc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatsCollectionUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatsCollectionUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta.stats import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.duration.Duration import scala.language.existentials import scala.util.control.NonFatal @@ -27,6 +28,7 @@ import org.apache.spark.sql.delta.actions.AddFile import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.stats.DeltaStatistics._ import org.apache.spark.sql.delta.util.{DeltaFileOperations, JsonUtils} +import org.apache.spark.sql.delta.util.threads.DeltaThreadPool import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path @@ -55,6 +57,7 @@ object StatsCollectionUtils * @param conf The Hadoop configuration used to access file system. * @param dataPath The data path of table, to which these AddFile(s) belong. * @param addFiles The list of target AddFile(s) to be processed. + * @param numFilesOpt The number of AddFile(s) to process if known. Speeds up the query. * @param columnMappingMode The column mapping mode of table. * @param dataSchema The data schema of table. * @param statsSchema The stats schema to be collected. @@ -70,13 +73,20 @@ object StatsCollectionUtils conf: Configuration, dataPath: Path, addFiles: Dataset[AddFile], + numFilesOpt: Option[Long], columnMappingMode: DeltaColumnMappingMode, dataSchema: StructType, statsSchema: StructType, ignoreMissingStats: Boolean = true, setBoundsToWide: Boolean = false): Dataset[AddFile] = { - import org.apache.spark.sql.delta.implicits._ + val useMultiThreadedStatsCollection = spark.sessionState.conf.getConf( + DeltaSQLConf.DELTA_USE_MULTI_THREADED_STATS_COLLECTION) + val preparedAddFiles = if (useMultiThreadedStatsCollection) { + prepareFilesForMultiThreadedStatsCollection(spark, addFiles, numFilesOpt) + } else { + addFiles + } val parquetRebaseMode = spark.sessionState.conf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_READ) @@ -91,20 +101,62 @@ object StatsCollectionUtils val broadcastConf = spark.sparkContext.broadcast(serializableConf) val dataRootDir = dataPath.toString - addFiles.mapPartitions { addFileIter => + + import org.apache.spark.sql.delta.implicits._ + preparedAddFiles.mapPartitions { addFileIter => val defaultFileSystem = new Path(dataRootDir).getFileSystem(broadcastConf.value.value) - addFileIter.map { addFile => - computeStatsForFile( - addFile, - dataRootDir, - defaultFileSystem, - broadcastConf.value, - setBoundsToWide, - statsCollector) + if (useMultiThreadedStatsCollection) { + ParallelFetchPool.parallelMap(spark, addFileIter.toSeq) { addFile => + computeStatsForFile( + addFile, + dataRootDir, + defaultFileSystem, + broadcastConf.value, + setBoundsToWide, + statsCollector) + }.toIterator + } else { + addFileIter.map { addFile => + computeStatsForFile( + addFile, + dataRootDir, + defaultFileSystem, + broadcastConf.value, + setBoundsToWide, + statsCollector) + } } } } + /** + * Prepares files for multi-threaded stats collection by splitting them up into more partitions + * if necessary. If the number of partitions is too small, not every executor might + * receive a partition, which reduces the achievable parallelism. By increasing the number of + * partitions we can achieve more parallelism. + */ + private def prepareFilesForMultiThreadedStatsCollection( + spark: SparkSession, + addFiles: Dataset[AddFile], + numFilesOpt: Option[Long]): Dataset[AddFile] = { + + val numFiles = numFilesOpt.getOrElse(addFiles.count()) + val currNumPartitions = addFiles.rdd.getNumPartitions + val numFilesPerPartition = spark.sessionState.conf.getConf( + DeltaSQLConf.DELTA_STATS_COLLECTION_NUM_FILES_PARTITION) + + // We should not create more partitions than there are cores in the cluster. + val minNumPartitions = Math.min( + spark.sparkContext.defaultParallelism, + numFiles / numFilesPerPartition + 1).toInt + // Only repartition if it would increase the achievable parallelism + if (currNumPartitions < minNumPartitions) { + addFiles.repartition(minNumPartitions) + } else { + addFiles + } + } + private def computeStatsForFile( addFile: AddFile, dataRootDir: String, @@ -137,6 +189,21 @@ object StatsCollectionUtils } } +object ParallelFetchPool { + val NUM_THREADS_PER_CORE = 10 + val MAX_THREADS = 1024 + + val NUM_THREADS = Math.min( + Runtime.getRuntime.availableProcessors() * NUM_THREADS_PER_CORE, MAX_THREADS) + + lazy val threadPool = DeltaThreadPool("stats-collection", NUM_THREADS) + def parallelMap[T, R]( + spark: SparkSession, + items: Iterable[T], + timeout: Duration = Duration.Inf)( + f: T => R): Iterable[R] = threadPool.parallelMap(spark, items, timeout)(f) +} + /** * A helper class to collect stats of parquet data files for Delta table and its equivalent (tables * that can be converted into Delta table like Parquet/Iceberg table).