From 2f0c66904f8f19640a5a88cde2c727da9014d9ab Mon Sep 17 00:00:00 2001 From: EJ Song Date: Wed, 4 Oct 2023 14:29:29 -0700 Subject: [PATCH 1/4] set active session for commands Signed-off-by: Eunjin Song Co-authored-by: Chungmin Lee --- .../io/delta/tables/DeltaMergeBuilder.scala | 38 +++--- .../scala/io/delta/tables/DeltaTable.scala | 4 +- .../io/delta/tables/DeltaTableBuilder.scala | 114 +++++++++--------- .../delta/tables/execution/DeltaConvert.scala | 15 ++- .../execution/DeltaTableOperations.scala | 36 ++++-- .../apache/spark/sql/delta/DeltaTable.scala | 10 ++ 6 files changed, 123 insertions(+), 94 deletions(-) diff --git a/core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala b/core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala index f1255e851a1..0607847f8c4 100644 --- a/core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala +++ b/core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala @@ -19,7 +19,7 @@ package io.delta.tables import scala.collection.JavaConverters._ import scala.collection.Map -import org.apache.spark.sql.delta.{DeltaErrors, PreprocessTableMerge} +import org.apache.spark.sql.delta.{DeltaErrors, DeltaTableUtils, PreprocessTableMerge} import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.delta.util.AnalysisHelper @@ -203,24 +203,26 @@ class DeltaMergeBuilder private( */ def execute(): Unit = improveUnsupportedOpError { val sparkSession = targetTable.toDF.sparkSession - // Note: We are explicitly resolving DeltaMergeInto plan rather than going to through the - // Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all - // references in the DeltaMergeInto using both source and target child plans, even before - // DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable, - // and handles that separately by skipping resolution (for Delta) and letting the - // DeltaAnalysis rule do the resolving correctly. This can be solved by generating - // MergeIntoTable instead, which blocked by the different issue with MergeIntoTable as explained - // in the function `mergePlan` and https://issues.apache.org/jira/browse/SPARK-34962. - val resolvedMergeInto = - DeltaMergeInto.resolveReferences(mergePlan, sparkSession.sessionState.conf)( - tryResolveReferences(sparkSession) _) - if (!resolvedMergeInto.resolved) { - throw DeltaErrors.analysisException("Failed to resolve\n", plan = Some(resolvedMergeInto)) + DeltaTableUtils.withActiveSession(sparkSession) { + // Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all + // references in the DeltaMergeInto using both source and target child plans, even before + // DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable, + // and handles that separately by skipping resolution (for Delta) and letting the + // DeltaAnalysis rule do the resolving correctly. This can be solved by generating + // MergeIntoTable instead, which blocked by the different issue with MergeIntoTable as + // explained in the function `mergePlan` and + // https://issues.apache.org/jira/browse/SPARK-34962. + val resolvedMergeInto = + DeltaMergeInto.resolveReferences(mergePlan, sparkSession.sessionState.conf)( + tryResolveReferences(sparkSession) _) + if (!resolvedMergeInto.resolved) { + throw DeltaErrors.analysisException("Failed to resolve\n", plan = Some(resolvedMergeInto)) + } + // Preprocess the actions and verify + val mergeIntoCommand = PreprocessTableMerge(sparkSession.sessionState.conf)(resolvedMergeInto) + sparkSession.sessionState.analyzer.checkAnalysis(mergeIntoCommand) + mergeIntoCommand.run(sparkSession) } - // Preprocess the actions and verify - val mergeIntoCommand = PreprocessTableMerge(sparkSession.sessionState.conf)(resolvedMergeInto) - sparkSession.sessionState.analyzer.checkAnalysis(mergeIntoCommand) - mergeIntoCommand.run(sparkSession) } /** diff --git a/core/src/main/scala/io/delta/tables/DeltaTable.scala b/core/src/main/scala/io/delta/tables/DeltaTable.scala index b353e28b0a9..0182f86be3d 100644 --- a/core/src/main/scala/io/delta/tables/DeltaTable.scala +++ b/core/src/main/scala/io/delta/tables/DeltaTable.scala @@ -807,7 +807,9 @@ object DeltaTable { */ @Evolving def createOrReplace(spark: SparkSession): DeltaTableBuilder = { - new DeltaTableBuilder(spark, ReplaceTableOptions(orCreate = true)) + DeltaTableUtils.withActiveSession(spark) { + new DeltaTableBuilder(spark, ReplaceTableOptions(orCreate = true)) + } } /** diff --git a/core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala b/core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala index 8397d76fddd..3fd0bae11df 100644 --- a/core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala +++ b/core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala @@ -296,70 +296,72 @@ class DeltaTableBuilder private[tables]( */ @Evolving def execute(): DeltaTable = { - if (identifier == null && location.isEmpty) { - throw DeltaErrors.analysisException("Table name or location has to be specified") - } + DeltaTableUtils.withActiveSession(spark) { + if (identifier == null && location.isEmpty) { + throw DeltaErrors.analysisException("Table name or location has to be specified") + } - if (this.identifier == null) { - identifier = s"delta.`${location.get}`" - } + if (this.identifier == null) { + identifier = s"delta.`${location.get}`" + } - // Return DeltaTable Object. - val tableId: TableIdentifier = spark.sessionState.sqlParser.parseTableIdentifier(identifier) + // Return DeltaTable Object. + val tableId: TableIdentifier = spark.sessionState.sqlParser.parseTableIdentifier(identifier) - if (DeltaTableUtils.isValidPath(tableId) && location.nonEmpty - && tableId.table != location.get) { - throw DeltaErrors.analysisException( - s"Creating path-based Delta table with a different location isn't supported. " - + s"Identifier: $identifier, Location: ${location.get}") - } + if (DeltaTableUtils.isValidPath(tableId) && location.nonEmpty + && tableId.table != location.get) { + throw DeltaErrors.analysisException( + s"Creating path-based Delta table with a different location isn't supported. " + + s"Identifier: $identifier, Location: ${location.get}") + } - val table = spark.sessionState.sqlParser.parseMultipartIdentifier(identifier) + val table = spark.sessionState.sqlParser.parseMultipartIdentifier(identifier) - val partitioning = partitioningColumns.map { colNames => - colNames.map(name => DeltaTableUtils.parseColToTransform(name)) - }.getOrElse(Seq.empty[Transform]) + val partitioning = partitioningColumns.map { colNames => + colNames.map(name => DeltaTableUtils.parseColToTransform(name)) + }.getOrElse(Seq.empty[Transform]) - val stmt = builderOption match { - case CreateTableOptions(ifNotExists) => - CreateTableStatement( - table, - StructType(columns), - partitioning, - None, - this.properties, - Some(FORMAT_NAME), - Map.empty, - location, - tblComment, - None, - false, - ifNotExists - ) - case ReplaceTableOptions(orCreate) => - ReplaceTableStatement( - table, - StructType(columns), - partitioning, - None, - this.properties, - Some(FORMAT_NAME), - Map.empty, - location, - tblComment, - None, - orCreate - ) - } - val qe = spark.sessionState.executePlan(stmt) - // call `QueryExecution.toRDD` to trigger the execution of commands. - SQLExecution.withNewExecutionId(qe, Some("create delta table"))(qe.toRdd) + val stmt = builderOption match { + case CreateTableOptions(ifNotExists) => + CreateTableStatement( + table, + StructType(columns), + partitioning, + None, + this.properties, + Some(FORMAT_NAME), + Map.empty, + location, + tblComment, + None, + false, + ifNotExists + ) + case ReplaceTableOptions(orCreate) => + ReplaceTableStatement( + table, + StructType(columns), + partitioning, + None, + this.properties, + Some(FORMAT_NAME), + Map.empty, + location, + tblComment, + None, + orCreate + ) + } + val qe = spark.sessionState.executePlan(stmt) + // call `QueryExecution.toRDD` to trigger the execution of commands. + SQLExecution.withNewExecutionId(qe, Some("create delta table"))(qe.toRdd) - // Return DeltaTable Object. - if (DeltaTableUtils.isValidPath(tableId)) { + // Return DeltaTable Object. + if (DeltaTableUtils.isValidPath(tableId)) { DeltaTable.forPath(location.get) - } else { - DeltaTable.forName(this.identifier) + } else { + DeltaTable.forName(this.identifier) + } } } } diff --git a/core/src/main/scala/io/delta/tables/execution/DeltaConvert.scala b/core/src/main/scala/io/delta/tables/execution/DeltaConvert.scala index 5a7e842cbc1..b1fa25de214 100644 --- a/core/src/main/scala/io/delta/tables/execution/DeltaConvert.scala +++ b/core/src/main/scala/io/delta/tables/execution/DeltaConvert.scala @@ -16,6 +16,7 @@ package io.delta.tables.execution +import org.apache.spark.sql.delta.DeltaTableUtils import org.apache.spark.sql.delta.commands.ConvertToDeltaCommand import io.delta.tables.DeltaTable @@ -29,12 +30,14 @@ trait DeltaConvertBase { tableIdentifier: TableIdentifier, partitionSchema: Option[StructType], deltaPath: Option[String]): DeltaTable = { - val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, deltaPath) - cvt.run(spark) - if (cvt.isCatalogTable(spark.sessionState.analyzer, tableIdentifier)) { - DeltaTable.forName(spark, tableIdentifier.toString) - } else { - DeltaTable.forPath(spark, tableIdentifier.table) + DeltaTableUtils.withActiveSession(spark) { + val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, deltaPath) + cvt.run(spark) + if (cvt.isCatalogTable(spark.sessionState.analyzer, tableIdentifier)) { + DeltaTable.forName(spark, tableIdentifier.toString) + } else { + DeltaTable.forPath(spark, tableIdentifier.table) + } } } } diff --git a/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala b/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala index 72a08e514b1..ef100119d6a 100644 --- a/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala +++ b/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala @@ -18,7 +18,7 @@ package io.delta.tables.execution import scala.collection.Map -import org.apache.spark.sql.delta.{DeltaErrors, DeltaHistoryManager, DeltaLog, PreprocessTableUpdate} +import org.apache.spark.sql.delta.{DeltaErrors, DeltaHistoryManager, DeltaLog, DeltaTableUtils, PreprocessTableUpdate} import org.apache.spark.sql.delta.commands.{DeleteCommand, DeltaGenerateCommand, VacuumCommand} import org.apache.spark.sql.delta.util.AnalysisHelper import io.delta.tables.DeltaTable @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.SparkSession /** * Interface to provide the actual implementations of DeltaTable operations. @@ -35,17 +36,21 @@ import org.apache.spark.sql.catalyst.plans.logical._ trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable => protected def executeDelete(condition: Option[Expression]): Unit = improveUnsupportedOpError { - val delete = DeleteFromTable(self.toDF.queryExecution.analyzed, condition) - toDataset(sparkSession, delete) + DeltaTableUtils.withActiveSession(sparkSession) { + val delete = DeleteFromTable(self.toDF.queryExecution.analyzed, condition) + toDataset(sparkSession, delete) + } } protected def executeHistory( deltaLog: DeltaLog, limit: Option[Int] = None, tableId: Option[TableIdentifier] = None): DataFrame = { - val history = deltaLog.history - val spark = self.toDF.sparkSession - spark.createDataFrame(history.getHistory(limit)) + DeltaTableUtils.withActiveSession(sparkSession) { + val history = deltaLog.history + val spark = self.toDF.sparkSession + spark.createDataFrame(history.getHistory(limit)) + } } protected def executeGenerate(tblIdentifier: String, mode: String): Unit = { @@ -60,19 +65,24 @@ trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable => protected def executeUpdate( set: Map[String, Column], condition: Option[Column]): Unit = improveUnsupportedOpError { - val assignments = set.map { case (targetColName, column) => - Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr) - }.toSeq - val update = UpdateTable(self.toDF.queryExecution.analyzed, assignments, condition.map(_.expr)) - toDataset(sparkSession, update) + DeltaTableUtils.withActiveSession(sparkSession) { + val assignments = set.map { case (targetColName, column) => + Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr) + }.toSeq + val update = + UpdateTable(self.toDF.queryExecution.analyzed, assignments, condition.map(_.expr)) + toDataset(sparkSession, update) + } } protected def executeVacuum( deltaLog: DeltaLog, retentionHours: Option[Double], tableId: Option[TableIdentifier] = None): DataFrame = { - VacuumCommand.gc(sparkSession, deltaLog, false, retentionHours) - sparkSession.emptyDataFrame + DeltaTableUtils.withActiveSession(sparkSession) { + VacuumCommand.gc(sparkSession, deltaLog, false, retentionHours) + sparkSession.emptyDataFrame + } } protected def toStrColumnMap(map: Map[String, String]): Map[String, Column] = { diff --git a/core/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala b/core/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala index e07d132c3ba..b4eed043fb7 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala @@ -356,4 +356,14 @@ object DeltaTableUtils extends PredicateHelper def parseColToTransform(col: String): IdentityTransform = { IdentityTransform(FieldReference(Seq(col))) } + + def withActiveSession[T](spark: SparkSession)(body: => T): T = { + val old = SparkSession.getActiveSession + SparkSession.setActiveSession(spark) + try { + body + } finally { + SparkSession.setActiveSession(old.getOrElse(null)) + } + } } From 901cdaac3812c52a1e32ef2a93bc7ddd72a5d261 Mon Sep 17 00:00:00 2001 From: EJ Song Date: Wed, 4 Oct 2023 15:38:41 -0700 Subject: [PATCH 2/4] test --- .../spark/sql/delta/MergeIntoSuiteBase.scala | 36 +++++++++++++++++++ python/delta/tests/test_deltatable.py | 30 ++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala b/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala index 52400d9d7bf..587ce1565ac 100644 --- a/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala +++ b/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala @@ -2919,4 +2919,40 @@ abstract class MergeIntoSuiteBase customConditionErrorRegex = Option("Aggregate functions are not supported in the .* condition of MERGE operation.*") ) + + Seq(true, false).foreach { differentActiveSession => + test("merge should use the same SparkSession consistently, differentActiveSession: " + + s"$differentActiveSession") { + withTempDir { dir => + withSQLConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key -> "false") { + val r = dir.getCanonicalPath + val sourcePath = s"$r/source" + val targetPath = s"$r/target" + val numSourceRecords = 20 + spark.range(numSourceRecords) + .withColumn("x", $"id") + .withColumn("y", $"id") + .write.mode("overwrite").format("delta").save(sourcePath) + spark.range(1) + .withColumn("x", $"id") + .write.mode("overwrite").format("delta").save(targetPath) + val spark2 = if (differentActiveSession) { + spark.newSession + } else { + spark + } + spark2.conf.set(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key, "true") + val target = io.delta.tables.DeltaTable.forPath(spark2, targetPath) + val source = spark.read.format("delta").load(sourcePath).alias("s") + val merge = target.alias("t") + .merge(source, "t.id = s.id") + .whenMatched.updateExpr(Map("t.x" -> "t.x + 1")) + .whenNotMatched.insertAll() + .execute() + // The target table should have the same number of rows as the source after the merge + assert(spark.read.format("delta").load(targetPath).count() == numSourceRecords) + } + } + } + } } diff --git a/python/delta/tests/test_deltatable.py b/python/delta/tests/test_deltatable.py index ed67b936a67..d8dd69d31ef 100644 --- a/python/delta/tests/test_deltatable.py +++ b/python/delta/tests/test_deltatable.py @@ -16,6 +16,7 @@ import unittest import os +from multiprocessing.pool import ThreadPool from pyspark.sql import Row from pyspark.sql.column import _to_seq @@ -279,6 +280,35 @@ def reset_table(): dt.merge(source, "key = k").whenNotMatchedInsert( values="k = 'a'", condition={"value": 1}) + def test_merge_with_inconsistent_sessions(self) -> None: + source_path = os.path.join(self.tempFile, "source") + target_path = os.path.join(self.tempFile, "target") + spark = self.spark + + def f(spark): + spark.range(20) \ + .withColumn("x", col("id")) \ + .withColumn("y", col("id")) \ + .write.mode("overwrite").format("delta").save(source_path) + spark.range(1) \ + .withColumn("x", col("id")) \ + .write.mode("overwrite").format("delta").save(target_path) + target = DeltaTable.forPath(spark, target_path) + source = spark.read.format("delta").load(source_path).alias("s") + target.alias("t") \ + .merge(source, "t.id = s.id") \ + .whenMatchedUpdate(set={"t.x": "t.x + 1"}) \ + .whenNotMatchedInsertAll() \ + .execute() + assert(spark.read.format("delta").load(target_path).count() == 20) + + pool = ThreadPool(3) + spark.conf.set("spark.databricks.delta.schema.autoMerge.enabled", "true") + try: + pool.starmap(f, [(spark,)]) + finally: + spark.conf.unset("spark.databricks.delta.schema.autoMerge.enabled") + def test_history(self): self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)]) self.__overwriteDeltaTable([('a', 3), ('b', 2), ('c', 1)]) From 7ac6adc49a81c134d08da0d03267ce36c98f3ba8 Mon Sep 17 00:00:00 2001 From: EJ Song Date: Thu, 5 Oct 2023 14:30:43 -0700 Subject: [PATCH 3/4] review commit --- .../io/delta/tables/DeltaMergeBuilder.scala | 5 +- .../scala/io/delta/tables/DeltaTable.scala | 7 +- .../io/delta/tables/DeltaTableBuilder.scala | 119 +++++++++--------- .../delta/tables/execution/DeltaConvert.scala | 18 ++- .../execution/DeltaTableOperations.scala | 25 ++-- .../apache/spark/sql/delta/DeltaTable.scala | 11 +- .../spark/sql/delta/MergeIntoSuiteBase.scala | 57 ++++----- 7 files changed, 110 insertions(+), 132 deletions(-) diff --git a/core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala b/core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala index 0607847f8c4..91a6dc2675c 100644 --- a/core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala +++ b/core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala @@ -19,7 +19,8 @@ package io.delta.tables import scala.collection.JavaConverters._ import scala.collection.Map -import org.apache.spark.sql.delta.{DeltaErrors, DeltaTableUtils, PreprocessTableMerge} +import org.apache.spark.sql.delta.{DeltaErrors, PreprocessTableMerge} +import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.delta.util.AnalysisHelper @@ -203,7 +204,7 @@ class DeltaMergeBuilder private( */ def execute(): Unit = improveUnsupportedOpError { val sparkSession = targetTable.toDF.sparkSession - DeltaTableUtils.withActiveSession(sparkSession) { + withActiveSession(sparkSession) { // Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all // references in the DeltaMergeInto using both source and target child plans, even before // DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable, diff --git a/core/src/main/scala/io/delta/tables/DeltaTable.scala b/core/src/main/scala/io/delta/tables/DeltaTable.scala index 0182f86be3d..5942f46b167 100644 --- a/core/src/main/scala/io/delta/tables/DeltaTable.scala +++ b/core/src/main/scala/io/delta/tables/DeltaTable.scala @@ -19,6 +19,7 @@ package io.delta.tables import scala.collection.JavaConverters._ import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession import org.apache.spark.sql.delta.actions.Protocol import org.apache.spark.sql.delta.catalog.DeltaTableV2 import org.apache.spark.sql.delta.sources.DeltaSQLConf @@ -806,10 +807,8 @@ object DeltaTable { * @since 1.0.0 */ @Evolving - def createOrReplace(spark: SparkSession): DeltaTableBuilder = { - DeltaTableUtils.withActiveSession(spark) { - new DeltaTableBuilder(spark, ReplaceTableOptions(orCreate = true)) - } + def createOrReplace(spark: SparkSession): DeltaTableBuilder = withActiveSession(spark) { + new DeltaTableBuilder(spark, ReplaceTableOptions(orCreate = true)) } /** diff --git a/core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala b/core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala index 3fd0bae11df..c424b788664 100644 --- a/core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala +++ b/core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala @@ -19,6 +19,7 @@ package io.delta.tables import scala.collection.mutable import org.apache.spark.sql.delta.{DeltaErrors, DeltaTableUtils} +import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession import io.delta.tables.execution._ import org.apache.spark.annotation._ @@ -295,73 +296,71 @@ class DeltaTableBuilder private[tables]( * @since 1.0.0 */ @Evolving - def execute(): DeltaTable = { - DeltaTableUtils.withActiveSession(spark) { - if (identifier == null && location.isEmpty) { - throw DeltaErrors.analysisException("Table name or location has to be specified") - } + def execute(): DeltaTable = withActiveSession(spark) { + if (identifier == null && location.isEmpty) { + throw DeltaErrors.analysisException("Table name or location has to be specified") + } - if (this.identifier == null) { - identifier = s"delta.`${location.get}`" - } + if (this.identifier == null) { + identifier = s"delta.`${location.get}`" + } - // Return DeltaTable Object. - val tableId: TableIdentifier = spark.sessionState.sqlParser.parseTableIdentifier(identifier) + // Return DeltaTable Object. + val tableId: TableIdentifier = spark.sessionState.sqlParser.parseTableIdentifier(identifier) - if (DeltaTableUtils.isValidPath(tableId) && location.nonEmpty - && tableId.table != location.get) { - throw DeltaErrors.analysisException( - s"Creating path-based Delta table with a different location isn't supported. " - + s"Identifier: $identifier, Location: ${location.get}") - } + if (DeltaTableUtils.isValidPath(tableId) && location.nonEmpty + && tableId.table != location.get) { + throw DeltaErrors.analysisException( + s"Creating path-based Delta table with a different location isn't supported. " + + s"Identifier: $identifier, Location: ${location.get}") + } - val table = spark.sessionState.sqlParser.parseMultipartIdentifier(identifier) + val table = spark.sessionState.sqlParser.parseMultipartIdentifier(identifier) - val partitioning = partitioningColumns.map { colNames => - colNames.map(name => DeltaTableUtils.parseColToTransform(name)) - }.getOrElse(Seq.empty[Transform]) + val partitioning = partitioningColumns.map { colNames => + colNames.map(name => DeltaTableUtils.parseColToTransform(name)) + }.getOrElse(Seq.empty[Transform]) - val stmt = builderOption match { - case CreateTableOptions(ifNotExists) => - CreateTableStatement( - table, - StructType(columns), - partitioning, - None, - this.properties, - Some(FORMAT_NAME), - Map.empty, - location, - tblComment, - None, - false, - ifNotExists - ) - case ReplaceTableOptions(orCreate) => - ReplaceTableStatement( - table, - StructType(columns), - partitioning, - None, - this.properties, - Some(FORMAT_NAME), - Map.empty, - location, - tblComment, - None, - orCreate - ) - } - val qe = spark.sessionState.executePlan(stmt) - // call `QueryExecution.toRDD` to trigger the execution of commands. - SQLExecution.withNewExecutionId(qe, Some("create delta table"))(qe.toRdd) + val stmt = builderOption match { + case CreateTableOptions(ifNotExists) => + CreateTableStatement( + table, + StructType(columns), + partitioning, + None, + this.properties, + Some(FORMAT_NAME), + Map.empty, + location, + tblComment, + None, + false, + ifNotExists + ) + case ReplaceTableOptions(orCreate) => + ReplaceTableStatement( + table, + StructType(columns), + partitioning, + None, + this.properties, + Some(FORMAT_NAME), + Map.empty, + location, + tblComment, + None, + orCreate + ) + } + val qe = spark.sessionState.executePlan(stmt) + // call `QueryExecution.toRDD` to trigger the execution of commands. + SQLExecution.withNewExecutionId(qe, Some("create delta table"))(qe.toRdd) - // Return DeltaTable Object. - if (DeltaTableUtils.isValidPath(tableId)) { - DeltaTable.forPath(location.get) - } else { - DeltaTable.forName(this.identifier) - } + // Return DeltaTable Object. + if (DeltaTableUtils.isValidPath(tableId)) { + DeltaTable.forPath(location.get) + } else { + DeltaTable.forName(this.identifier) } } } diff --git a/core/src/main/scala/io/delta/tables/execution/DeltaConvert.scala b/core/src/main/scala/io/delta/tables/execution/DeltaConvert.scala index b1fa25de214..9684914b2c8 100644 --- a/core/src/main/scala/io/delta/tables/execution/DeltaConvert.scala +++ b/core/src/main/scala/io/delta/tables/execution/DeltaConvert.scala @@ -16,7 +16,7 @@ package io.delta.tables.execution -import org.apache.spark.sql.delta.DeltaTableUtils +import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession import org.apache.spark.sql.delta.commands.ConvertToDeltaCommand import io.delta.tables.DeltaTable @@ -29,15 +29,13 @@ trait DeltaConvertBase { spark: SparkSession, tableIdentifier: TableIdentifier, partitionSchema: Option[StructType], - deltaPath: Option[String]): DeltaTable = { - DeltaTableUtils.withActiveSession(spark) { - val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, deltaPath) - cvt.run(spark) - if (cvt.isCatalogTable(spark.sessionState.analyzer, tableIdentifier)) { - DeltaTable.forName(spark, tableIdentifier.toString) - } else { - DeltaTable.forPath(spark, tableIdentifier.table) - } + deltaPath: Option[String]): DeltaTable = withActiveSession(spark) { + val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, deltaPath) + cvt.run(spark) + if (cvt.isCatalogTable(spark.sessionState.analyzer, tableIdentifier)) { + DeltaTable.forName(spark, tableIdentifier.toString) + } else { + DeltaTable.forPath(spark, tableIdentifier.table) } } } diff --git a/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala b/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala index ef100119d6a..26dd87239b6 100644 --- a/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala +++ b/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala @@ -18,7 +18,8 @@ package io.delta.tables.execution import scala.collection.Map -import org.apache.spark.sql.delta.{DeltaErrors, DeltaHistoryManager, DeltaLog, DeltaTableUtils, PreprocessTableUpdate} +import org.apache.spark.sql.delta.{DeltaErrors, DeltaHistoryManager, DeltaLog, PreprocessTableUpdate} +import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession import org.apache.spark.sql.delta.commands.{DeleteCommand, DeltaGenerateCommand, VacuumCommand} import org.apache.spark.sql.delta.util.AnalysisHelper import io.delta.tables.DeltaTable @@ -28,7 +29,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.SparkSession /** * Interface to provide the actual implementations of DeltaTable operations. @@ -36,7 +36,7 @@ import org.apache.spark.sql.SparkSession trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable => protected def executeDelete(condition: Option[Expression]): Unit = improveUnsupportedOpError { - DeltaTableUtils.withActiveSession(sparkSession) { + withActiveSession(sparkSession) { val delete = DeleteFromTable(self.toDF.queryExecution.analyzed, condition) toDataset(sparkSession, delete) } @@ -45,12 +45,9 @@ trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable => protected def executeHistory( deltaLog: DeltaLog, limit: Option[Int] = None, - tableId: Option[TableIdentifier] = None): DataFrame = { - DeltaTableUtils.withActiveSession(sparkSession) { - val history = deltaLog.history - val spark = self.toDF.sparkSession - spark.createDataFrame(history.getHistory(limit)) - } + tableId: Option[TableIdentifier] = None): DataFrame = withActiveSession(sparkSession) { + val history = deltaLog.history + sparkSession.createDataFrame(history.getHistory(limit)) } protected def executeGenerate(tblIdentifier: String, mode: String): Unit = { @@ -65,7 +62,7 @@ trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable => protected def executeUpdate( set: Map[String, Column], condition: Option[Column]): Unit = improveUnsupportedOpError { - DeltaTableUtils.withActiveSession(sparkSession) { + withActiveSession(sparkSession) { val assignments = set.map { case (targetColName, column) => Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr) }.toSeq @@ -78,11 +75,9 @@ trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable => protected def executeVacuum( deltaLog: DeltaLog, retentionHours: Option[Double], - tableId: Option[TableIdentifier] = None): DataFrame = { - DeltaTableUtils.withActiveSession(sparkSession) { - VacuumCommand.gc(sparkSession, deltaLog, false, retentionHours) - sparkSession.emptyDataFrame - } + tableId: Option[TableIdentifier] = None): DataFrame = withActiveSession(sparkSession) { + VacuumCommand.gc(sparkSession, deltaLog, false, retentionHours) + sparkSession.emptyDataFrame } protected def toStrColumnMap(map: Map[String, String]): Map[String, Column] = { diff --git a/core/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala b/core/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala index b4eed043fb7..33a7558561b 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/DeltaTable.scala @@ -357,13 +357,6 @@ object DeltaTableUtils extends PredicateHelper IdentityTransform(FieldReference(Seq(col))) } - def withActiveSession[T](spark: SparkSession)(body: => T): T = { - val old = SparkSession.getActiveSession - SparkSession.setActiveSession(spark) - try { - body - } finally { - SparkSession.setActiveSession(old.getOrElse(null)) - } - } + // Workaround for withActive not being visible in io/delta. + def withActiveSession[T](spark: SparkSession)(body: => T): T = spark.withActive(body) } diff --git a/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala b/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala index 587ce1565ac..a3ff28670e2 100644 --- a/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala +++ b/core/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala @@ -2920,38 +2920,31 @@ abstract class MergeIntoSuiteBase Option("Aggregate functions are not supported in the .* condition of MERGE operation.*") ) - Seq(true, false).foreach { differentActiveSession => - test("merge should use the same SparkSession consistently, differentActiveSession: " + - s"$differentActiveSession") { - withTempDir { dir => - withSQLConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key -> "false") { - val r = dir.getCanonicalPath - val sourcePath = s"$r/source" - val targetPath = s"$r/target" - val numSourceRecords = 20 - spark.range(numSourceRecords) - .withColumn("x", $"id") - .withColumn("y", $"id") - .write.mode("overwrite").format("delta").save(sourcePath) - spark.range(1) - .withColumn("x", $"id") - .write.mode("overwrite").format("delta").save(targetPath) - val spark2 = if (differentActiveSession) { - spark.newSession - } else { - spark - } - spark2.conf.set(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key, "true") - val target = io.delta.tables.DeltaTable.forPath(spark2, targetPath) - val source = spark.read.format("delta").load(sourcePath).alias("s") - val merge = target.alias("t") - .merge(source, "t.id = s.id") - .whenMatched.updateExpr(Map("t.x" -> "t.x + 1")) - .whenNotMatched.insertAll() - .execute() - // The target table should have the same number of rows as the source after the merge - assert(spark.read.format("delta").load(targetPath).count() == numSourceRecords) - } + test("Merge should use the same SparkSession consistently") { + withTempDir { dir => + withSQLConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key -> "false") { + val r = dir.getCanonicalPath + val sourcePath = s"$r/source" + val targetPath = s"$r/target" + val numSourceRecords = 20 + spark.range(numSourceRecords) + .withColumn("x", $"id") + .withColumn("y", $"id") + .write.mode("overwrite").format("delta").save(sourcePath) + spark.range(1) + .withColumn("x", $"id") + .write.mode("overwrite").format("delta").save(targetPath) + val spark2 = spark.newSession + spark2.conf.set(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key, "true") + val target = io.delta.tables.DeltaTable.forPath(spark2, targetPath) + val source = spark.read.format("delta").load(sourcePath).alias("s") + val merge = target.alias("t") + .merge(source, "t.id = s.id") + .whenMatched.updateExpr(Map("t.x" -> "t.x + 1")) + .whenNotMatched.insertAll() + .execute() + // The target table should have the same number of rows as the source after the merge + assert(spark.read.format("delta").load(targetPath).count() == numSourceRecords) } } } From ec4e638977b6be6c8c7d388caaaedf58a4c20c49 Mon Sep 17 00:00:00 2001 From: EJ Song Date: Thu, 12 Oct 2023 10:28:59 -0700 Subject: [PATCH 4/4] executeGenerate --- .../tables/execution/DeltaTableOperations.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala b/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala index 26dd87239b6..e2f000f7dc7 100644 --- a/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala +++ b/core/src/main/scala/io/delta/tables/execution/DeltaTableOperations.scala @@ -51,12 +51,14 @@ trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable => } protected def executeGenerate(tblIdentifier: String, mode: String): Unit = { - val tableId: TableIdentifier = sparkSession - .sessionState - .sqlParser - .parseTableIdentifier(tblIdentifier) - val generate = DeltaGenerateCommand(mode, tableId) - toDataset(sparkSession, generate) + withActiveSession(sparkSession) { + val tableId: TableIdentifier = sparkSession + .sessionState + .sqlParser + .parseTableIdentifier(tblIdentifier) + val generate = DeltaGenerateCommand(mode, tableId) + toDataset(sparkSession, generate) + } } protected def executeUpdate(