Skip to content

Commit

Permalink
add code
Browse files Browse the repository at this point in the history
  • Loading branch information
fred-db committed Oct 18, 2023
1 parent a7971a9 commit 682266f
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ object DMLWithDeletionVectorsHelper extends DeltaCommand {
conf = snapshot.deltaLog.newDeltaHadoopConf(),
dataPath = snapshot.deltaLog.dataPath,
addFiles = filesWithNoStats.toDS(spark),
numFilesOpt = Some(filesWithNoStats.size),
columnMappingMode = snapshot.metadata.columnMappingMode,
dataSchema = snapshot.dataSchema,
statsSchema = snapshot.statsSchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,23 @@ trait DeltaSQLConfBase {
.booleanConf
.createWithDefault(false)

val DELTA_USE_MULTI_THREADED_STATS_COLLECTION =
buildConf("collectStats.useMultiThreadedStatsCollection")
.internal()
.doc("Whether to use multi-threaded statistics collection. If false, statistics will be " +
"collected sequentially for each partition.")
.booleanConf
.createWithDefault(true)

val DELTA_STATS_COLLECTION_NUM_FILES_PARTITION =
buildConf("collectStats.numFilesPerPartition")
.internal()
.doc("Controls the number of files that should be within a partition " +
"during multi-threaded optimized statistics collection. A larger number will lead to " +
"less parallelism, but can reduce scheduling overhead.")
.intConf
.checkValue(v => v >= 1, "Must be at least 1.")
.createWithDefault(100)
}

object DeltaSQLConf extends DeltaSQLConfBase
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta.stats

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.duration.Duration
import scala.language.existentials
import scala.util.control.NonFatal

Expand All @@ -27,6 +28,7 @@ import org.apache.spark.sql.delta.actions.AddFile
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.stats.DeltaStatistics._
import org.apache.spark.sql.delta.util.{DeltaFileOperations, JsonUtils}
import org.apache.spark.sql.delta.util.threads.DeltaThreadPool
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -55,6 +57,7 @@ object StatsCollectionUtils
* @param conf The Hadoop configuration used to access file system.
* @param dataPath The data path of table, to which these AddFile(s) belong.
* @param addFiles The list of target AddFile(s) to be processed.
* @param numFilesOpt The number of AddFile(s) to process if known. Speeds up the query.
* @param columnMappingMode The column mapping mode of table.
* @param dataSchema The data schema of table.
* @param statsSchema The stats schema to be collected.
Expand All @@ -70,13 +73,20 @@ object StatsCollectionUtils
conf: Configuration,
dataPath: Path,
addFiles: Dataset[AddFile],
numFilesOpt: Option[Long],
columnMappingMode: DeltaColumnMappingMode,
dataSchema: StructType,
statsSchema: StructType,
ignoreMissingStats: Boolean = true,
setBoundsToWide: Boolean = false): Dataset[AddFile] = {

import org.apache.spark.sql.delta.implicits._
val useMultiThreadedStatsCollection = spark.sessionState.conf.getConf(
DeltaSQLConf.DELTA_USE_MULTI_THREADED_STATS_COLLECTION)
val preparedAddFiles = if (useMultiThreadedStatsCollection) {
prepareFilesForMultiThreadedStatsCollection(spark, addFiles, numFilesOpt)
} else {
addFiles
}

val parquetRebaseMode =
spark.sessionState.conf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_READ)
Expand All @@ -91,20 +101,62 @@ object StatsCollectionUtils
val broadcastConf = spark.sparkContext.broadcast(serializableConf)

val dataRootDir = dataPath.toString
addFiles.mapPartitions { addFileIter =>

import org.apache.spark.sql.delta.implicits._
preparedAddFiles.mapPartitions { addFileIter =>
val defaultFileSystem = new Path(dataRootDir).getFileSystem(broadcastConf.value.value)
addFileIter.map { addFile =>
computeStatsForFile(
addFile,
dataRootDir,
defaultFileSystem,
broadcastConf.value,
setBoundsToWide,
statsCollector)
if (useMultiThreadedStatsCollection) {
ParallelFetchPool.parallelMap(spark, addFileIter.toSeq) { addFile =>
computeStatsForFile(
addFile,
dataRootDir,
defaultFileSystem,
broadcastConf.value,
setBoundsToWide,
statsCollector)
}.toIterator
} else {
addFileIter.map { addFile =>
computeStatsForFile(
addFile,
dataRootDir,
defaultFileSystem,
broadcastConf.value,
setBoundsToWide,
statsCollector)
}
}
}
}

/**
* Prepares files for multi-threaded stats collection by splitting them up into more partitions
* if necessary. If the number of partitions is too small, not every executor might
* receive a partition, which reduces the achievable parallelism. By increasing the number of
* partitions we can achieve more parallelism.
*/
private def prepareFilesForMultiThreadedStatsCollection(
spark: SparkSession,
addFiles: Dataset[AddFile],
numFilesOpt: Option[Long]): Dataset[AddFile] = {

val numFiles = numFilesOpt.getOrElse(addFiles.count())
val currNumPartitions = addFiles.rdd.getNumPartitions
val numFilesPerPartition = spark.sessionState.conf.getConf(
DeltaSQLConf.DELTA_STATS_COLLECTION_NUM_FILES_PARTITION)

// We should not create more partitions than there are cores in the cluster.
val minNumPartitions = Math.min(
spark.sparkContext.defaultParallelism,
numFiles / numFilesPerPartition + 1).toInt
// Only repartition if it would increase the achievable parallelism
if (currNumPartitions < minNumPartitions) {
addFiles.repartition(minNumPartitions)
} else {
addFiles
}
}

private def computeStatsForFile(
addFile: AddFile,
dataRootDir: String,
Expand Down Expand Up @@ -137,6 +189,21 @@ object StatsCollectionUtils
}
}

object ParallelFetchPool {
val NUM_THREADS_PER_CORE = 10
val MAX_THREADS = 1024

val NUM_THREADS = Math.min(
Runtime.getRuntime.availableProcessors() * NUM_THREADS_PER_CORE, MAX_THREADS)

lazy val threadPool = DeltaThreadPool("stats-collection", NUM_THREADS)
def parallelMap[T, R](
spark: SparkSession,
items: Iterable[T],
timeout: Duration = Duration.Inf)(
f: T => R): Iterable[R] = threadPool.parallelMap(spark, items, timeout)(f)
}

/**
* A helper class to collect stats of parquet data files for Delta table and its equivalent (tables
* that can be converted into Delta table like Parquet/Iceberg table).
Expand Down

0 comments on commit 682266f

Please sign in to comment.