Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark][1.0] Fix a data loss bug in MergeIntoCommand #2128

Open
wants to merge 4 commits into
base: branch-1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions core/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.collection.JavaConverters._
import scala.collection.Map

import org.apache.spark.sql.delta.{DeltaErrors, PreprocessTableMerge}
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.util.AnalysisHelper

Expand Down Expand Up @@ -203,24 +204,26 @@ class DeltaMergeBuilder private(
*/
def execute(): Unit = improveUnsupportedOpError {
val sparkSession = targetTable.toDF.sparkSession
// Note: We are explicitly resolving DeltaMergeInto plan rather than going to through the
// Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all
// references in the DeltaMergeInto using both source and target child plans, even before
// DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable,
// and handles that separately by skipping resolution (for Delta) and letting the
// DeltaAnalysis rule do the resolving correctly. This can be solved by generating
// MergeIntoTable instead, which blocked by the different issue with MergeIntoTable as explained
// in the function `mergePlan` and https://issues.apache.org/jira/browse/SPARK-34962.
val resolvedMergeInto =
DeltaMergeInto.resolveReferences(mergePlan, sparkSession.sessionState.conf)(
tryResolveReferences(sparkSession) _)
if (!resolvedMergeInto.resolved) {
throw DeltaErrors.analysisException("Failed to resolve\n", plan = Some(resolvedMergeInto))
withActiveSession(sparkSession) {
// Analyzer using `Dataset.ofRows()` because the Analyzer incorrectly resolves all
// references in the DeltaMergeInto using both source and target child plans, even before
// DeltaAnalysis rule kicks in. This is because the Analyzer understands only MergeIntoTable,
// and handles that separately by skipping resolution (for Delta) and letting the
// DeltaAnalysis rule do the resolving correctly. This can be solved by generating
// MergeIntoTable instead, which blocked by the different issue with MergeIntoTable as
// explained in the function `mergePlan` and
// https://issues.apache.org/jira/browse/SPARK-34962.
val resolvedMergeInto =
DeltaMergeInto.resolveReferences(mergePlan, sparkSession.sessionState.conf)(
tryResolveReferences(sparkSession) _)
if (!resolvedMergeInto.resolved) {
throw DeltaErrors.analysisException("Failed to resolve\n", plan = Some(resolvedMergeInto))
}
// Preprocess the actions and verify
val mergeIntoCommand = PreprocessTableMerge(sparkSession.sessionState.conf)(resolvedMergeInto)
sparkSession.sessionState.analyzer.checkAnalysis(mergeIntoCommand)
mergeIntoCommand.run(sparkSession)
}
// Preprocess the actions and verify
val mergeIntoCommand = PreprocessTableMerge(sparkSession.sessionState.conf)(resolvedMergeInto)
sparkSession.sessionState.analyzer.checkAnalysis(mergeIntoCommand)
mergeIntoCommand.run(sparkSession)
}

/**
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/io/delta/tables/DeltaTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.delta.tables
import scala.collection.JavaConverters._

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.actions.Protocol
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.sources.DeltaSQLConf
Expand Down Expand Up @@ -806,7 +807,7 @@ object DeltaTable {
* @since 1.0.0
*/
@Evolving
def createOrReplace(spark: SparkSession): DeltaTableBuilder = {
def createOrReplace(spark: SparkSession): DeltaTableBuilder = withActiveSession(spark) {
new DeltaTableBuilder(spark, ReplaceTableOptions(orCreate = true))
}

Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/io/delta/tables/DeltaTableBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.delta.tables
import scala.collection.mutable

import org.apache.spark.sql.delta.{DeltaErrors, DeltaTableUtils}
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import io.delta.tables.execution._

import org.apache.spark.annotation._
Expand Down Expand Up @@ -295,7 +296,7 @@ class DeltaTableBuilder private[tables](
* @since 1.0.0
*/
@Evolving
def execute(): DeltaTable = {
def execute(): DeltaTable = withActiveSession(spark) {
if (identifier == null && location.isEmpty) {
throw DeltaErrors.analysisException("Table name or location has to be specified")
}
Expand Down Expand Up @@ -357,7 +358,7 @@ class DeltaTableBuilder private[tables](

// Return DeltaTable Object.
if (DeltaTableUtils.isValidPath(tableId)) {
DeltaTable.forPath(location.get)
DeltaTable.forPath(location.get)
} else {
DeltaTable.forName(this.identifier)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.delta.tables.execution

import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.commands.ConvertToDeltaCommand
import io.delta.tables.DeltaTable

Expand All @@ -28,7 +29,7 @@ trait DeltaConvertBase {
spark: SparkSession,
tableIdentifier: TableIdentifier,
partitionSchema: Option[StructType],
deltaPath: Option[String]): DeltaTable = {
deltaPath: Option[String]): DeltaTable = withActiveSession(spark) {
val cvt = ConvertToDeltaCommand(tableIdentifier, partitionSchema, deltaPath)
cvt.run(spark)
if (cvt.isCatalogTable(spark.sessionState.analyzer, tableIdentifier)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.delta.tables.execution
import scala.collection.Map

import org.apache.spark.sql.delta.{DeltaErrors, DeltaHistoryManager, DeltaLog, PreprocessTableUpdate}
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.commands.{DeleteCommand, DeltaGenerateCommand, VacuumCommand}
import org.apache.spark.sql.delta.util.AnalysisHelper
import io.delta.tables.DeltaTable
Expand All @@ -35,42 +36,48 @@ import org.apache.spark.sql.catalyst.plans.logical._
trait DeltaTableOperations extends AnalysisHelper { self: DeltaTable =>

protected def executeDelete(condition: Option[Expression]): Unit = improveUnsupportedOpError {
val delete = DeleteFromTable(self.toDF.queryExecution.analyzed, condition)
toDataset(sparkSession, delete)
withActiveSession(sparkSession) {
val delete = DeleteFromTable(self.toDF.queryExecution.analyzed, condition)
toDataset(sparkSession, delete)
}
}

protected def executeHistory(
deltaLog: DeltaLog,
limit: Option[Int] = None,
tableId: Option[TableIdentifier] = None): DataFrame = {
tableId: Option[TableIdentifier] = None): DataFrame = withActiveSession(sparkSession) {
val history = deltaLog.history
val spark = self.toDF.sparkSession
spark.createDataFrame(history.getHistory(limit))
sparkSession.createDataFrame(history.getHistory(limit))
}

protected def executeGenerate(tblIdentifier: String, mode: String): Unit = {
val tableId: TableIdentifier = sparkSession
.sessionState
.sqlParser
.parseTableIdentifier(tblIdentifier)
val generate = DeltaGenerateCommand(mode, tableId)
toDataset(sparkSession, generate)
withActiveSession(sparkSession) {
val tableId: TableIdentifier = sparkSession
.sessionState
.sqlParser
.parseTableIdentifier(tblIdentifier)
val generate = DeltaGenerateCommand(mode, tableId)
toDataset(sparkSession, generate)
}
}

protected def executeUpdate(
set: Map[String, Column],
condition: Option[Column]): Unit = improveUnsupportedOpError {
val assignments = set.map { case (targetColName, column) =>
Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr)
}.toSeq
val update = UpdateTable(self.toDF.queryExecution.analyzed, assignments, condition.map(_.expr))
toDataset(sparkSession, update)
withActiveSession(sparkSession) {
val assignments = set.map { case (targetColName, column) =>
Assignment(UnresolvedAttribute.quotedString(targetColName), column.expr)
}.toSeq
val update =
UpdateTable(self.toDF.queryExecution.analyzed, assignments, condition.map(_.expr))
toDataset(sparkSession, update)
}
}

protected def executeVacuum(
deltaLog: DeltaLog,
retentionHours: Option[Double],
tableId: Option[TableIdentifier] = None): DataFrame = {
tableId: Option[TableIdentifier] = None): DataFrame = withActiveSession(sparkSession) {
VacuumCommand.gc(sparkSession, deltaLog, false, retentionHours)
sparkSession.emptyDataFrame
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,4 +356,7 @@ object DeltaTableUtils extends PredicateHelper
def parseColToTransform(col: String): IdentityTransform = {
IdentityTransform(FieldReference(Seq(col)))
}

// Workaround for withActive not being visible in io/delta.
def withActiveSession[T](spark: SparkSession)(body: => T): T = spark.withActive(body)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2919,4 +2919,33 @@ abstract class MergeIntoSuiteBase
customConditionErrorRegex =
Option("Aggregate functions are not supported in the .* condition of MERGE operation.*")
)

test("Merge should use the same SparkSession consistently") {
withTempDir { dir =>
withSQLConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key -> "false") {
val r = dir.getCanonicalPath
val sourcePath = s"$r/source"
val targetPath = s"$r/target"
val numSourceRecords = 20
spark.range(numSourceRecords)
.withColumn("x", $"id")
.withColumn("y", $"id")
.write.mode("overwrite").format("delta").save(sourcePath)
spark.range(1)
.withColumn("x", $"id")
.write.mode("overwrite").format("delta").save(targetPath)
val spark2 = spark.newSession
spark2.conf.set(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE.key, "true")
val target = io.delta.tables.DeltaTable.forPath(spark2, targetPath)
val source = spark.read.format("delta").load(sourcePath).alias("s")
val merge = target.alias("t")
.merge(source, "t.id = s.id")
.whenMatched.updateExpr(Map("t.x" -> "t.x + 1"))
.whenNotMatched.insertAll()
.execute()
// The target table should have the same number of rows as the source after the merge
assert(spark.read.format("delta").load(targetPath).count() == numSourceRecords)
}
}
}
}
30 changes: 30 additions & 0 deletions python/delta/tests/test_deltatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unittest
import os
from multiprocessing.pool import ThreadPool

from pyspark.sql import Row
from pyspark.sql.column import _to_seq
Expand Down Expand Up @@ -279,6 +280,35 @@ def reset_table():
dt.merge(source, "key = k").whenNotMatchedInsert(
values="k = 'a'", condition={"value": 1})

def test_merge_with_inconsistent_sessions(self) -> None:
source_path = os.path.join(self.tempFile, "source")
target_path = os.path.join(self.tempFile, "target")
spark = self.spark

def f(spark):
spark.range(20) \
.withColumn("x", col("id")) \
.withColumn("y", col("id")) \
.write.mode("overwrite").format("delta").save(source_path)
spark.range(1) \
.withColumn("x", col("id")) \
.write.mode("overwrite").format("delta").save(target_path)
target = DeltaTable.forPath(spark, target_path)
source = spark.read.format("delta").load(source_path).alias("s")
target.alias("t") \
.merge(source, "t.id = s.id") \
.whenMatchedUpdate(set={"t.x": "t.x + 1"}) \
.whenNotMatchedInsertAll() \
.execute()
assert(spark.read.format("delta").load(target_path).count() == 20)

pool = ThreadPool(3)
spark.conf.set("spark.databricks.delta.schema.autoMerge.enabled", "true")
try:
pool.starmap(f, [(spark,)])
finally:
spark.conf.unset("spark.databricks.delta.schema.autoMerge.enabled")
sezruby marked this conversation as resolved.
Show resolved Hide resolved

def test_history(self):
self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)])
self.__overwriteDeltaTable([('a', 3), ('b', 2), ('c', 1)])
Expand Down
Loading