Skip to content

Commit

Permalink
update test and doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
VenkataKarthikP committed Jan 5, 2024
1 parent 08a71ab commit 6c1f97e
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ case class DataSynchronizationAnalyzer(dfToCompare: DataFrame,
override def computeMetricFrom(state: Option[DataSynchronizationState]): DoubleMetric = {

val metric = state match {
case Some(s) => Try(s.synchronizedDataCount.toDouble / s.dataCount.toDouble)
case Some(s) => Try(s.synchronizedDataCount.toDouble / s.totalDataCount.toDouble)
case _ => Failure(new IllegalStateException("No state available for DataSynchronizationAnalyzer"))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,31 @@
package com.amazon.deequ.analyzers

/**
* To store state of DataSynchronization
* Represents the state of data synchronization between two DataFrames in Deequ.
* This state keeps track of the count of synchronized record count and the total record count.
* It is used to calculate a ratio of synchronization, which is a measure of how well the data
* in the two DataFrames are synchronized.
*
* @param synchronizedDataCount The count of records that are considered synchronized between the two DataFrames.
* @param totalDataCount The total count of records for check.
*
* The `sum` method allows for aggregation of this state with another, combining the counts from both states.
* This is useful in distributed computations where states from different partitions need to be aggregated.
*
* The `metricValue` method computes the synchronization ratio. It is the ratio of `synchronizedDataCount`
* to `dataCount`.
* If `dataCount` is zero, which means no data points were examined, the method returns `Double.NaN`
* to indicate the undefined state.
*
* @param synchronizedDataCount - Count Of rows that are in sync
* @param dataCount - total count of records to caluculate ratio.
*/
case class DataSynchronizationState(synchronizedDataCount: Long, dataCount: Long)
case class DataSynchronizationState(synchronizedDataCount: Long, totalDataCount: Long)
extends DoubleValuedState[DataSynchronizationState] {
override def sum(other: DataSynchronizationState): DataSynchronizationState = {
DataSynchronizationState(synchronizedDataCount + other.synchronizedDataCount, dataCount + other.dataCount)
DataSynchronizationState(synchronizedDataCount + other.synchronizedDataCount, totalDataCount + other.totalDataCount)
}

override def metricValue(): Double = {
if (dataCount == 0L) Double.NaN else synchronizedDataCount.toDouble / dataCount.toDouble
if (totalDataCount == 0L) Double.NaN else synchronizedDataCount.toDouble / totalDataCount.toDouble
}
}

Expand Down
22 changes: 22 additions & 0 deletions src/main/scala/com/amazon/deequ/checks/Check.scala
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,27 @@ case class Check(
* Utilizes [[com.amazon.deequ.analyzers.DataSynchronizationAnalyzer]] for comparing the data
* and Constraint [[com.amazon.deequ.constraints.DataSynchronizationConstraint]].
*
* Usage:
* To use this method, create a VerificationSuite and invoke this method as part of adding checks:
* {{{
* val baseDataFrame: DataFrame = ...
* val otherDataFrame: DataFrame = ...
* val columnMappings: Map[String, String] = Map("baseCol1" -> "otherCol1", "baseCol2" -> "otherCol2")
* val assertionFunction: Double => Boolean = _ > 0.7
*
* val check = new Check(CheckLevel.Error, "Data Synchronization Check")
* .isDataSynchronized(otherDataFrame, columnMappings, assertionFunction)
*
* val verificationResult = VerificationSuite()
* .onData(baseDataFrame)
* .addCheck(check)
* .run()
* }}}
*
* This will add a data synchronization check to the VerificationSuite, comparing the specified columns of
* baseDataFrame and otherDataFrame based on the provided assertion function.
*
*
* @param otherDf The DataFrame to be compared with the current one. Analyzed in conjunction with the
* current DataFrame to assess data synchronization.
* @param columnMappings A map defining the column correlations between the current DataFrame and otherDf.
Expand All @@ -372,6 +393,7 @@ case class Check(
* @return A [[com.amazon.deequ.checks.Check]] object representing the outcome
* of the synchronization check. This object can be used in Deequ's verification suite to
* assert data quality constraints.
*
*/
def isDataSynchronized(otherDf: DataFrame, columnMappings: Map[String, String], assertion: Double => Boolean,
hint: Option[String] = None): Check = {
Expand Down
178 changes: 174 additions & 4 deletions src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@ package com.amazon.deequ
import com.amazon.deequ.analyzers._
import com.amazon.deequ.analyzers.runners.AnalyzerContext
import com.amazon.deequ.anomalydetection.AbsoluteChangeStrategy
import com.amazon.deequ.checks.{Check, CheckLevel, CheckStatus}
import com.amazon.deequ.constraints.{Constraint, ConstraintResult}
import com.amazon.deequ.checks.Check
import com.amazon.deequ.checks.CheckLevel
import com.amazon.deequ.checks.CheckStatus
import com.amazon.deequ.constraints.Constraint
import com.amazon.deequ.io.DfsUtils
import com.amazon.deequ.metrics.{DoubleMetric, Entity, Metric}
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import com.amazon.deequ.repository.MetricsRepository
import com.amazon.deequ.repository.ResultKey
import com.amazon.deequ.repository.memory.InMemoryMetricsRepository
import com.amazon.deequ.utils.CollectionUtils.SeqExtensions
import com.amazon.deequ.utils.FixtureSupport
import com.amazon.deequ.utils.TempFileUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.when
import org.scalamock.scalatest.MockFactory
import org.scalatest.Matchers
import org.scalatest.WordSpec
Expand Down Expand Up @@ -805,7 +811,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
.hasCompleteness("fake", x => x > 0)

val checkHasDataInSyncTest = Check(CheckLevel.Error, "shouldSucceedForAge")
.isDataSynchronized(df, Map("age" -> "age"), _ > 0.99, Some("shouldpass"))
.isDataSynchronized(df, Map("age" -> "age"), _ > 0.99, Some("shouldPass"))

val verificationResult = VerificationSuite()
.onData(df)
Expand Down Expand Up @@ -978,6 +984,170 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
List(Some("Value: 0.125 does not meet the constraint requirement!"))
assert(subsetNameFailResult.status == CheckStatus.Error)
}

"Should work Data Synchronization checks for single column" in withSparkSession {
sparkSession =>
val df = getDateDf(sparkSession).select("id", "product", "product_id", "units")
val dfModified = df.withColumn("id", when(col("id") === 100, 99)
.otherwise(col("id")))
val dfColRenamed = df.withColumnRenamed("id", "id_renamed")

val dataSyncCheckPass = Check(CheckLevel.Error, "data synchronization check pass")
.isDataSynchronized(dfModified, Map("id" -> "id"), _ > 0.7, Some("shouldPass"))

val dataSyncCheckFail = Check(CheckLevel.Error, "data synchronization check fail")
.isDataSynchronized(dfModified, Map("id" -> "id"), _ > 0.9, Some("shouldFail"))

val emptyDf = sparkSession.createDataFrame(sparkSession.sparkContext.emptyRDD[Row], df.schema)
val dataSyncCheckEmpty = Check(CheckLevel.Error, "data synchronization check on empty DataFrame")
.isDataSynchronized(emptyDf, Map("id" -> "id"), _ < 0.5)

val dataSyncCheckColMismatchDestination =
Check(CheckLevel.Error, "data synchronization check col mismatch in destination")
.isDataSynchronized(dfModified, Map("id" -> "id2"), _ < 0.5)

val dataSyncCheckColMismatchSource =
Check(CheckLevel.Error, "data synchronization check col mismatch in source")
.isDataSynchronized(dfModified, Map("id2" -> "id"), _ < 0.5)

val dataSyncCheckColRenamed =
Check(CheckLevel.Error, "data synchronization check col names renamed")
.isDataSynchronized(dfColRenamed, Map("id" -> "id_renamed"), _ == 1.0)

val dataSyncFullMatch =
Check(CheckLevel.Error, "data synchronization check full match")
.isDataSynchronized(df, Map("id" -> "id"), _ == 1.0)


val verificationResult = VerificationSuite()
.onData(df)
.addCheck(dataSyncCheckPass)
.addCheck(dataSyncCheckFail)
.addCheck(dataSyncCheckEmpty)
.addCheck(dataSyncCheckColMismatchDestination)
.addCheck(dataSyncCheckColMismatchSource)
.addCheck(dataSyncCheckColRenamed)
.addCheck(dataSyncFullMatch)
.run()

val passResult = verificationResult.checkResults(dataSyncCheckPass)
passResult.constraintResults.map(_.message) shouldBe
List(None)
assert(passResult.status == CheckStatus.Success)

val failResult = verificationResult.checkResults(dataSyncCheckFail)
failResult.constraintResults.map(_.message) shouldBe
List(Some("Value: 0.8 does not meet the constraint requirement! shouldFail"))
assert(failResult.status == CheckStatus.Error)

val emptyResult = verificationResult.checkResults(dataSyncCheckEmpty)
emptyResult.constraintResults.map(_.message) shouldBe
List(Some("Value: NaN does not meet the constraint requirement!"))
assert(emptyResult.status == CheckStatus.Error)

val colMismatchDestResult = verificationResult.checkResults(dataSyncCheckColMismatchDestination)
colMismatchDestResult.constraintResults.map(_.message) shouldBe
List(Some("Value: NaN does not meet the constraint requirement!"))
assert(colMismatchDestResult.status == CheckStatus.Error)

val colMismatchSourceResult = verificationResult.checkResults(dataSyncCheckColMismatchSource)
colMismatchSourceResult.constraintResults.map(_.message) shouldBe
List(Some("Value: NaN does not meet the constraint requirement!"))
assert(colMismatchSourceResult.status == CheckStatus.Error)

val colRenamedResult = verificationResult.checkResults(dataSyncCheckColRenamed)
colRenamedResult.constraintResults.map(_.message) shouldBe List(None)
assert(colRenamedResult.status == CheckStatus.Success)

val fullMatchResult = verificationResult.checkResults(dataSyncFullMatch)
fullMatchResult.constraintResults.map(_.message) shouldBe List(None)
assert(fullMatchResult.status == CheckStatus.Success)

}

"Should work Data Synchronization checks for multiple column" in withSparkSession {
sparkSession =>
val df = getDateDf(sparkSession).select("id", "product", "product_id", "units")
val dfModified = df.withColumn("id", when(col("id") === 100, 99)
.otherwise(col("id")))
val dfColRenamed = df.withColumnRenamed("id", "id_renamed")
val colMap = Map("id" -> "id", "product" -> "product")

val dataSyncCheckPass = Check(CheckLevel.Error, "data synchronization check")
.isDataSynchronized(dfModified, colMap, _ > 0.7, Some("shouldPass"))

val dataSyncCheckFail = Check(CheckLevel.Error, "data synchronization check")
.isDataSynchronized(dfModified, colMap, _ > 0.9, Some("shouldFail"))

val emptyDf = sparkSession.createDataFrame(sparkSession.sparkContext.emptyRDD[Row], df.schema)
val dataSyncCheckEmpty = Check(CheckLevel.Error, "data synchronization check on empty DataFrame")
.isDataSynchronized(emptyDf, colMap, _ < 0.5)

val dataSyncCheckColMismatchDestination =
Check(CheckLevel.Error, "data synchronization check col mismatch in destination")
.isDataSynchronized(dfModified, colMap, _ > 0.9)

val dataSyncCheckColMismatchSource =
Check(CheckLevel.Error, "data synchronization check col mismatch in source")
.isDataSynchronized(dfModified, Map("id2" -> "id", "product" -> "product"), _ < 0.5)

val dataSyncCheckColRenamed =
Check(CheckLevel.Error, "data synchronization check col names renamed")
.isDataSynchronized(dfColRenamed, Map("id" -> "id_renamed", "product" -> "product"), _ == 1.0,
Some("shouldPass"))

val dataSyncFullMatch =
Check(CheckLevel.Error, "data synchronization check col full match")
.isDataSynchronized(df, colMap, _ == 1, Some("shouldPass"))


val verificationResult = VerificationSuite()
.onData(df)
.addCheck(dataSyncCheckPass)
.addCheck(dataSyncCheckFail)
.addCheck(dataSyncCheckEmpty)
.addCheck(dataSyncCheckColMismatchDestination)
.addCheck(dataSyncCheckColMismatchSource)
.addCheck(dataSyncCheckColRenamed)
.addCheck(dataSyncFullMatch)
.run()

val passResult = verificationResult.checkResults(dataSyncCheckPass)
passResult.constraintResults.map(_.message) shouldBe
List(None)
assert(passResult.status == CheckStatus.Success)

val failResult = verificationResult.checkResults(dataSyncCheckFail)
failResult.constraintResults.map(_.message) shouldBe
List(Some("Value: 0.8 does not meet the constraint requirement! shouldFail"))
assert(failResult.status == CheckStatus.Error)

val emptyResult = verificationResult.checkResults(dataSyncCheckEmpty)
emptyResult.constraintResults.map(_.message) shouldBe
List(Some("Value: NaN does not meet the constraint requirement!"))
assert(emptyResult.status == CheckStatus.Error)

val colMismatchDestResult = verificationResult.checkResults(dataSyncCheckColMismatchDestination)
colMismatchDestResult.constraintResults.map(_.message) shouldBe
List(Some("Value: 0.8 does not meet the constraint requirement!"))
assert(colMismatchDestResult.status == CheckStatus.Error)

val colMismatchSourceResult = verificationResult.checkResults(dataSyncCheckColMismatchSource)
colMismatchSourceResult.constraintResults.map(_.message) shouldBe
List(Some("Value: NaN does not meet the constraint requirement!"))
assert(colMismatchSourceResult.status == CheckStatus.Error)

val colRenamedResult = verificationResult.checkResults(dataSyncCheckColRenamed)
colRenamedResult.constraintResults.map(_.message) shouldBe
List(None)
assert(colRenamedResult.status == CheckStatus.Success)

val fullMatchResult = verificationResult.checkResults(dataSyncFullMatch)
fullMatchResult.constraintResults.map(_.message) shouldBe
List(None)
assert(fullMatchResult.status == CheckStatus.Success)

}
}

/** Run anomaly detection using a repository with some previous analysis results for testing */
Expand Down

0 comments on commit 6c1f97e

Please sign in to comment.