Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark] Define and use DeltaTableV2.startTransaction helper method #2053

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.{util => ju}
import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.sql.delta.{ColumnWithDefaultExprUtils, DeltaColumnMapping, DeltaErrors, DeltaLog, DeltaOptions, DeltaTableIdentifier, DeltaTableUtils, DeltaTimeTravelSpec, GeneratedColumn, Snapshot}
import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.commands.WriteIntoDelta
import org.apache.spark.sql.delta.commands.cdc.CDCReader
import org.apache.spark.sql.delta.metering.DeltaLogging
Expand All @@ -32,10 +32,13 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{ResolvedTable, UnresolvedTable}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, CatalogUtils}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableCatalog, V2TableWithV1Fallback}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.catalog.V1Table
import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder}
import org.apache.spark.sql.errors.QueryCompilationErrors
Expand Down Expand Up @@ -189,6 +192,23 @@ case class DeltaTableV2(
deltaLog, info.options, spark.sessionState.conf.useNullsForMissingDefaultColumnValues)
}

/**
* Starts a transaction for this table, using the snapshot captured during table resolution.
*
* WARNING: Caller is responsible to ensure that table resolution was recent (e.g. if working with
* [[DataFrame]] or [[DeltaTable]] API, where the table could have been resolved long ago).
*/
def startTransactionWithInitialSnapshot(): OptimisticTransaction =
startTransaction(Some(snapshot))

/**
* Starts a transaction for this table, using Some provided snapshot, or a fresh snapshot if None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/Some/some

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was intentional... Some/None as in Option.

* was provided.
*/
def startTransaction(snapshotOpt: Option[Snapshot] = None): OptimisticTransaction = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/snapshotOpt/snapshot

deltaLog.startTransaction(snapshotOpt)
}

/**
* Creates a V1 BaseRelation from this Table to allow read APIs to go through V1 DataSource code
* paths.
Expand Down Expand Up @@ -266,6 +286,31 @@ case class DeltaTableV2(
}
}

object DeltaTableV2 {
/** Resolves a path into a DeltaTableV2, leveraging standard v2 table resolution. */
def apply(spark: SparkSession, tablePath: Path, cmd: String): DeltaTableV2 =
resolve(spark, UnresolvedPathBasedDeltaTable(tablePath.toString, cmd), cmd)

/** Resolves a table identifier into a DeltaTableV2, leveraging standard v2 table resolution. */
def apply(spark: SparkSession, tableId: TableIdentifier, cmd: String): DeltaTableV2 =
resolve(spark, UnresolvedTable(tableId.nameParts, cmd, None), cmd)

/** Applies standard v2 table resolution to an unresolved Delta table plan node */
def resolve(spark: SparkSession, unresolved: LogicalPlan, cmd: String): DeltaTableV2 =
extractFrom(spark.sessionState.analyzer.ResolveRelations(unresolved), cmd)

/**
* Extracts the DeltaTableV2 from a resolved Delta table plan node, throwing "table not found" if
* the node does not actually represent a resolved Delta table.
*/
def extractFrom(plan: LogicalPlan, cmd: String): DeltaTableV2 = plan match {
case ResolvedTable(_, _, d: DeltaTableV2, _) => d
case ResolvedTable(_, _, t: V1Table, _) if DeltaTableUtils.isDeltaTable(t.catalogTable) =>
DeltaTableV2(SparkSession.active, new Path(t.v1Table.location), Some(t.v1Table))
case _ => throw DeltaErrors.notADeltaTableException(cmd)
}
}

private class WriteIntoDeltaBuilder(
log: DeltaLog,
writeOptions: CaseInsensitiveStringMap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,8 @@ trait DeltaCommand extends DeltaLogging {
* other cases this method will throw a "Table not found" exception.
*/
def getDeltaTable(target: LogicalPlan, cmd: String): DeltaTableV2 = {
target match {
case ResolvedTable(_, _, d: DeltaTableV2, _) => d
case ResolvedTable(_, _, t: V1Table, _) if DeltaTableUtils.isDeltaTable(t.catalogTable) =>
DeltaTableV2(SparkSession.active, new Path(t.v1Table.location), Some(t.v1Table))
case _ => throw DeltaErrors.notADeltaTableException(cmd)
}
// TODO: Remove this wrapper and let former callers invoke DeltaTableV2.extractFrom directly.
DeltaTableV2.extractFrom(target, cmd)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,16 @@ trait AlterDeltaTableCommand extends DeltaCommand {

def table: DeltaTableV2

protected def startTransaction(spark: SparkSession): OptimisticTransaction = {
val txn = table.deltaLog.startTransaction()
protected def startTransaction(): OptimisticTransaction = {
// WARNING: It's not safe to use startTransactionWithInitialSnapshot here. Some commands call
// this method more than once, and some commands can be created with a stale table.
val txn = table.startTransaction()
if (txn.readVersion == -1) {
throw DeltaErrors.notADeltaTableException(table.name())
}
txn
}


/**
* Check if the column to change has any dependent expressions:
* - generated column expressions
Expand Down Expand Up @@ -106,7 +107,7 @@ case class AlterTableSetPropertiesDeltaCommand(
override def run(sparkSession: SparkSession): Seq[Row] = {
val deltaLog = table.deltaLog
recordDeltaOperation(deltaLog, "delta.ddl.alter.setProperties") {
val txn = startTransaction(sparkSession)
val txn = startTransaction()

val metadata = txn.metadata
val filteredConfs = configuration.filterKeys {
Expand Down Expand Up @@ -154,7 +155,7 @@ case class AlterTableUnsetPropertiesDeltaCommand(
override def run(sparkSession: SparkSession): Seq[Row] = {
val deltaLog = table.deltaLog
recordDeltaOperation(deltaLog, "delta.ddl.alter.unsetProperties") {
val txn = startTransaction(sparkSession)
val txn = startTransaction()
val metadata = txn.metadata

val normalizedKeys = DeltaConfigs.normalizeConfigKeys(propKeys)
Expand Down Expand Up @@ -292,7 +293,7 @@ case class AlterTableDropFeatureDeltaCommand(
featureName, table.snapshot.metadata)
}

val txn = startTransaction(sparkSession)
val txn = table.startTransaction()
val snapshot = txn.snapshot

// Verify whether all requirements hold before performing the protocol downgrade.
Expand Down Expand Up @@ -348,7 +349,7 @@ case class AlterTableAddColumnsDeltaCommand(
override def run(sparkSession: SparkSession): Seq[Row] = {
val deltaLog = table.deltaLog
recordDeltaOperation(deltaLog, "delta.ddl.alter.addColumns") {
val txn = startTransaction(sparkSession)
val txn = startTransaction()

if (SchemaUtils.filterRecursively(
StructType(colsToAddWithPosition.map {
Expand Down Expand Up @@ -444,7 +445,7 @@ case class AlterTableDropColumnsDeltaCommand(
}
val deltaLog = table.deltaLog
recordDeltaOperation(deltaLog, "delta.ddl.alter.dropColumns") {
val txn = startTransaction(sparkSession)
val txn = startTransaction()
val metadata = txn.metadata
if (txn.metadata.columnMappingMode == NoMapping) {
throw DeltaErrors.dropColumnNotSupported(suggestUpgrade = true)
Expand Down Expand Up @@ -504,7 +505,7 @@ case class AlterTableChangeColumnDeltaCommand(
override def run(sparkSession: SparkSession): Seq[Row] = {
val deltaLog = table.deltaLog
recordDeltaOperation(deltaLog, "delta.ddl.alter.changeColumns") {
val txn = startTransaction(sparkSession)
val txn = startTransaction()
val metadata = txn.metadata
val oldSchema = metadata.schema
val resolver = sparkSession.sessionState.conf.resolver
Expand Down Expand Up @@ -708,7 +709,7 @@ case class AlterTableReplaceColumnsDeltaCommand(

override def run(sparkSession: SparkSession): Seq[Row] = {
recordDeltaOperation(table.deltaLog, "delta.ddl.alter.replaceColumns") {
val txn = startTransaction(sparkSession)
val txn = startTransaction()

val metadata = txn.metadata
val existingSchema = metadata.schema
Expand Down Expand Up @@ -837,7 +838,7 @@ case class AlterTableAddConstraintDeltaCommand(
throw DeltaErrors.invalidConstraintName(name)
}
recordDeltaOperation(deltaLog, "delta.ddl.alter.addConstraint") {
val txn = startTransaction(sparkSession)
val txn = startTransaction()

getConstraintWithName(table, name, txn.metadata, sparkSession).foreach { oldExpr =>
throw DeltaErrors.constraintAlreadyExists(name, oldExpr)
Expand Down Expand Up @@ -889,7 +890,7 @@ case class AlterTableDropConstraintDeltaCommand(
override def run(sparkSession: SparkSession): Seq[Row] = {
val deltaLog = table.deltaLog
recordDeltaOperation(deltaLog, "delta.ddl.alter.dropConstraint") {
val txn = startTransaction(sparkSession)
val txn = startTransaction()

val oldExprText = Constraints.getExprTextByName(name, txn.metadata, sparkSession)
if (oldExprText.isEmpty && !ifExists && !sparkSession.sessionState.conf.getConf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import java.util.UUID
import org.apache.spark.sql.delta.DeltaOperations.ManualUpdate
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.actions.TableFeatureProtocolUtils.{TABLE_FEATURES_MIN_READER_VERSION, TABLE_FEATURES_MIN_WRITER_VERSION}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest
import org.apache.spark.sql.delta.test.DeltaTestImplicits._
import org.apache.spark.sql.delta.util.{FileNames, JsonUtils}
import org.apache.hadoop.fs.Path

Expand Down Expand Up @@ -357,12 +359,14 @@ class ActionSerializerSuite extends QueryTest with SharedSparkSession with Delta
| tblproperties
| ('${TableFeatureProtocolUtils.propertyKey(DomainMetadataTableFeature)}' = 'enabled')
|""".stripMargin)
val deltaLog = DeltaLog.forTable(spark, TableIdentifier(table))
val deltaTable = DeltaTableV2(spark, TableIdentifier(table))
val deltaLog = deltaTable.deltaLog
val domainMetadatas = DomainMetadata(
domain = "testDomain",
configuration = JsonUtils.toJson(Map("key1" -> "value1")),
removed = false) :: Nil
val version = deltaLog.startTransaction().commit(domainMetadatas, ManualUpdate)
val version = deltaTable.startTransactionWithInitialSnapshot()
.commit(domainMetadatas, ManualUpdate)
val committedActions = deltaLog.store.read(
FileNames.deltaFile(deltaLog.logPath, version),
deltaLog.newDeltaHadoopConf())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable

import org.apache.spark.sql.delta.DeltaOperations.ManualUpdate
import org.apache.spark.sql.delta.actions.{Action, AddCDCFile, AddFile, Metadata => MetadataAction, Protocol, SetTransaction}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.schema.SchemaMergingUtils
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest
Expand Down Expand Up @@ -1587,8 +1588,8 @@ class DeltaColumnMappingSuite extends QueryTest
"t1",
props = Map(DeltaConfigs.CHANGE_DATA_FEED.key -> cdfEnabled.toString))

val log = DeltaLog.forTable(spark, TableIdentifier("t1"))
val currMetadata = log.snapshot.metadata
val table = DeltaTableV2(spark, TableIdentifier("t1"))
val currMetadata = table.snapshot.metadata
val upgradeMetadata = currMetadata.copy(
configuration = currMetadata.configuration ++ Map(
DeltaConfigs.MIN_READER_VERSION.key -> "2",
Expand All @@ -1597,7 +1598,7 @@ class DeltaColumnMappingSuite extends QueryTest
)
)

val txn = log.startTransaction()
val txn = table.startTransactionWithInitialSnapshot()
txn.updateMetadata(upgradeMetadata)

if (shouldBlock) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import scala.util.{Failure, Success, Try}

import org.apache.spark.sql.delta.DeltaOperations.{ManualUpdate, Truncate}
import org.apache.spark.sql.delta.actions.{DomainMetadata, TableFeatureProtocolUtils}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest
import org.apache.spark.sql.delta.test.DeltaTestImplicits._
import org.apache.spark.sql.delta.util.{FileNames, JsonUtils}
import org.junit.Assert._

Expand Down Expand Up @@ -64,12 +66,13 @@ class DomainMetadataSuite
|""".stripMargin)
(1 to 100).toDF("id").write.format("delta").mode("append").saveAsTable(table)

var deltaLog = DeltaLog.forTable(spark, TableIdentifier(table))
assert(deltaLog.unsafeVolatileSnapshot.domainMetadata.isEmpty)
var deltaTable = DeltaTableV2(spark, TableIdentifier(table))
def deltaLog = deltaTable.deltaLog
assert(deltaTable.snapshot.domainMetadata.isEmpty)

val domainMetadata = DomainMetadata("testDomain1", "", false) ::
DomainMetadata("testDomain2", "{\"key1\":\"value1\"", false) :: Nil
deltaLog.startTransaction().commit(domainMetadata, Truncate())
deltaTable.startTransactionWithInitialSnapshot().commit(domainMetadata, Truncate())
assertEquals(sortByDomain(domainMetadata), sortByDomain(deltaLog.update().domainMetadata))
assert(deltaLog.update().logSegment.checkpointProvider.version === -1)

Expand All @@ -78,12 +81,12 @@ class DomainMetadataSuite
// Clear the DeltaLog cache to force creating a new DeltaLog instance which will build
// the Snapshot from the checkpoint file.
DeltaLog.clearCache()
deltaLog = DeltaLog.forTable(spark, TableIdentifier(table))
assert(!deltaLog.unsafeVolatileSnapshot.logSegment.checkpointProvider.isEmpty)
deltaTable = DeltaTableV2(spark, TableIdentifier(table))
assert(!deltaTable.snapshot.logSegment.checkpointProvider.isEmpty)

assertEquals(
sortByDomain(domainMetadata),
sortByDomain(deltaLog.unsafeVolatileSnapshot.domainMetadata))
sortByDomain(deltaTable.snapshot.domainMetadata))
}

}
Expand All @@ -106,18 +109,19 @@ class DomainMetadataSuite
(1 to 100).toDF("id").write.format("delta").mode("append").saveAsTable(table)

DeltaLog.clearCache()
var deltaLog = DeltaLog.forTable(spark, TableIdentifier(table))
assert(deltaLog.unsafeVolatileSnapshot.domainMetadata.isEmpty)
val deltaTable = DeltaTableV2(spark, TableIdentifier(table))
val deltaLog = deltaTable.deltaLog
assert(deltaTable.snapshot.domainMetadata.isEmpty)

val domainMetadata = DomainMetadata("testDomain1", "", false) ::
DomainMetadata("testDomain2", "{\"key1\":\"value1\"}", false) :: Nil

deltaLog.startTransaction().commit(domainMetadata, Truncate())
deltaTable.startTransactionWithInitialSnapshot().commit(domainMetadata, Truncate())
assertEquals(sortByDomain(domainMetadata), sortByDomain(deltaLog.update().domainMetadata))
assert(deltaLog.update().logSegment.checkpointProvider.version === -1)

// Delete testDomain1.
deltaLog.startTransaction().commit(
deltaTable.startTransaction().commit(
DomainMetadata("testDomain1", "", true) :: Nil, Truncate())
val domainMetadatasAfterDeletion = DomainMetadata(
"testDomain2",
Expand All @@ -128,21 +132,19 @@ class DomainMetadataSuite

// Create a new commit and validate the incrementally built snapshot state respects the
// DomainMetadata deletion.
deltaLog.startTransaction().commit(Nil, ManualUpdate)
deltaLog.update()
assertEquals(
sortByDomain(domainMetadatasAfterDeletion),
deltaLog.unsafeVolatileSnapshot.domainMetadata)
deltaTable.startTransaction().commit(Nil, ManualUpdate)
var snapshot = deltaLog.update()
assertEquals(sortByDomain(domainMetadatasAfterDeletion), snapshot.domainMetadata)
if (doCheckpoint) {
deltaLog.checkpoint(deltaLog.unsafeVolatileSnapshot)
deltaLog.checkpoint(snapshot)
assertEquals(
sortByDomain(domainMetadatasAfterDeletion),
deltaLog.update().domainMetadata)
}

// force state reconstruction and validate it respects the DomainMetadata retention.
DeltaLog.clearCache()
val snapshot = DeltaLog.forTableWithSnapshot(spark, TableIdentifier(table))._2
snapshot = DeltaLog.forTableWithSnapshot(spark, TableIdentifier(table))._2
assertEquals(sortByDomain(domainMetadatasAfterDeletion), snapshot.domainMetadata)
}
}
Expand Down Expand Up @@ -191,12 +193,12 @@ class DomainMetadataSuite
| ('${TableFeatureProtocolUtils.propertyKey(DomainMetadataTableFeature)}' = 'enabled')
|""".stripMargin)
(1 to 100).toDF("id").write.format("delta").mode("append").saveAsTable(table)
val deltaLog = DeltaLog.forTable(spark, TableIdentifier(table))
val deltaTable = DeltaTableV2(spark, TableIdentifier(table))
val domainMetadata =
DomainMetadata("testDomain1", "", false) ::
DomainMetadata("testDomain1", "", false) :: Nil
val e = intercept[DeltaIllegalArgumentException] {
deltaLog.startTransaction().commit(domainMetadata, Truncate())
deltaTable.startTransactionWithInitialSnapshot().commit(domainMetadata, Truncate())
}
assertEquals(e.getMessage,
"Internal error: two DomainMetadata actions within the same transaction have " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._

import org.apache.spark.sql.delta.{CheckConstraintsTableFeature, DeltaLog, DeltaOperations}
import org.apache.spark.sql.delta.actions.{Metadata, TableFeatureProtocolUtils}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.constraints.{Constraint, Constraints, Invariants}
import org.apache.spark.sql.delta.constraints.Constraints.NotNull
import org.apache.spark.sql.delta.constraints.Invariants.PersistedExpression
Expand Down Expand Up @@ -400,8 +401,8 @@ class InvariantEnforcementSuite extends QueryTest
withTable("constraint") {
spark.range(10).selectExpr("id AS valueA", "id AS valueB", "id AS valueC")
.write.format("delta").saveAsTable("constraint")
val log = DeltaLog.forTable(spark, TableIdentifier("constraint", None))
val txn = log.startTransaction()
val table = DeltaTableV2(spark, TableIdentifier("constraint", None))
val txn = table.startTransactionWithInitialSnapshot()
val newMetadata = txn.metadata.copy(
configuration = txn.metadata.configuration +
("delta.constraints.mychk" -> "valueA < valueB"))
Expand All @@ -412,7 +413,7 @@ class InvariantEnforcementSuite extends QueryTest
} else {
CheckConstraintsTableFeature.minWriterVersion
}
assert(log.snapshot.protocol.minWriterVersion === upVersion)
assert(table.deltaLog.unsafeVolatileSnapshot.protocol.minWriterVersion === upVersion)
spark.sql("INSERT INTO constraint VALUES (50, 100, null)")
val e = intercept[InvariantViolationException] {
spark.sql("INSERT INTO constraint VALUES (100, 50, null)")
Expand All @@ -438,8 +439,8 @@ class InvariantEnforcementSuite extends QueryTest
withTable("constraint") {
spark.range(10).selectExpr("id AS valueA", "id AS valueB")
.write.format("delta").saveAsTable("constraint")
val log = DeltaLog.forTable(spark, TableIdentifier("constraint", None))
val txn = log.startTransaction()
val table = DeltaTableV2(spark, TableIdentifier("constraint", None))
val txn = table.startTransactionWithInitialSnapshot()
val newMetadata = txn.metadata.copy(
configuration = txn.metadata.configuration +
("delta.constraints.mychk" -> "valueA < valueB"))
Expand Down
Loading
Loading