Skip to content

Commit

Permalink
Fix for satisfies row level results bug (#553)
Browse files Browse the repository at this point in the history
- The satisfies constraint was incorrectly using the provided assertion to evaluate the row level outcomes. The assertion should only be used to evaluate the final outcome.
- As part of this change, we have updated the row level results to return a true/false. The cast to an integer happens as part of the aggregation result.
- Added a test to verify the row level results using checks made up of different assertions.
  • Loading branch information
rdsharma26 authored and svanvari committed May 2, 2024
1 parent 6ca0f15 commit 6da724e
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 34 deletions.
15 changes: 4 additions & 11 deletions src/main/scala/com/amazon/deequ/analyzers/Compliance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.spark.sql.functions._
import Analyzers._
import com.amazon.deequ.analyzers.Preconditions.hasColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.types.DoubleType

/**
* Compliance is a measure of the fraction of rows that complies with the given column constraint.
Expand All @@ -43,37 +42,31 @@ case class Compliance(instance: String,
where: Option[String] = None,
columns: List[String] = List.empty[String],
analyzerOptions: Option[AnalyzerOptions] = None)
extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance)
with FilterableAnalyzer {
extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance) with FilterableAnalyzer {

override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = {

ifNoNullsIn(result, offset, howMany = 2) { _ =>
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults))
}
}

override def aggregationFunctions(): Seq[Column] = {

val summation = sum(criterion)

val summation = sum(criterion.cast(IntegerType))
summation :: conditionalCount(where) :: Nil
}

override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = {
conditionalSelection(expr(predicate), where).cast(IntegerType)
}
private def criterion: Column = conditionalSelection(expr(predicate), where)

private def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true).cast(IntegerType)
conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true)
case _ =>
// The default behavior when using filtering for rows is to treat them as nulls. No special treatment needed.
criterion
Expand Down
6 changes: 2 additions & 4 deletions src/main/scala/com/amazon/deequ/constraints/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,10 @@ object Constraint {
val constraint = AnalysisBasedConstraint[NumMatchesAndCount, Double, Double](
compliance, assertion, hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
new RowLevelAssertedConstraint(
new RowLevelConstraint(
constraint,
s"ComplianceConstraint($compliance)",
s"ColumnsCompliance-${compliance.predicate}",
sparkAssertion)
s"ColumnsCompliance-${compliance.predicate}")
}

/**
Expand Down
70 changes: 70 additions & 0 deletions src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.amazon.deequ.constraints.Constraint
import com.amazon.deequ.io.DfsUtils
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.repository.MetricsRepository
import com.amazon.deequ.repository.ResultKey
import com.amazon.deequ.repository.memory.InMemoryMetricsRepository
Expand Down Expand Up @@ -1993,6 +1994,75 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
}
}

"Verification Suite's Row Level Results" should {
"yield correct results for satisfies check" in withSparkSession { sparkSession =>
import sparkSession.implicits._
val df = Seq(
(1, "blue"),
(2, "green"),
(3, "blue"),
(4, "red"),
(5, "purple")
).toDF("id", "color")

val columnCondition = "color in ('blue')"
val whereClause = "id <= 3"

case class CheckConfig(checkName: String,
assertion: Double => Boolean,
checkStatus: CheckStatus.Value,
whereClause: Option[String] = None)

val success = CheckStatus.Success
val error = CheckStatus.Error

val checkConfigs = Seq(
// Without where clause: Expected compliance metric for full dataset for given condition is 0.4
CheckConfig("check with >", (d: Double) => d > 0.5, error),
CheckConfig("check with >=", (d: Double) => d >= 0.35, success),
CheckConfig("check with <", (d: Double) => d < 0.3, error),
CheckConfig("check with <=", (d: Double) => d <= 0.4, success),
CheckConfig("check with =", (d: Double) => d == 0.4, success),
CheckConfig("check with > / <", (d: Double) => d > 0.0 && d < 0.5, success),
CheckConfig("check with >= / <=", (d: Double) => d >= 0.41 && d <= 1.1, error),

// With where Clause: Expected compliance metric for full dataset for given condition with where clause is 0.67
CheckConfig("check w/ where and with >", (d: Double) => d > 0.7, error, Some(whereClause)),
CheckConfig("check w/ where and with >=", (d: Double) => d >= 0.66, success, Some(whereClause)),
CheckConfig("check w/ where and with <", (d: Double) => d < 0.6, error, Some(whereClause)),
CheckConfig("check w/ where and with <=", (d: Double) => d <= 0.67, success, Some(whereClause)),
CheckConfig("check w/ where and with =", (d: Double) => d == 0.66, error, Some(whereClause)),
CheckConfig("check w/ where and with > / <", (d: Double) => d > 0.0 && d < 0.5, error, Some(whereClause)),
CheckConfig("check w/ where and with >= / <=", (d: Double) => d >= 0.41 && d <= 1.1, success, Some(whereClause))
)

val checks = checkConfigs.map { checkConfig =>
val constraintName = s"Constraint for check: ${checkConfig.checkName}"
val check = Check(CheckLevel.Error, checkConfig.checkName)
.satisfies(columnCondition, constraintName, checkConfig.assertion)
checkConfig.whereClause.map(check.where).getOrElse(check)
}

val verificationResult = VerificationSuite().onData(df).addChecks(checks).run()
val actualResults = verificationResult.checkResults.map { case (c, r) => c.description -> r.status }
val expectedResults = checkConfigs.map { c => c.checkName -> c.checkStatus}.toMap
assert(actualResults == expectedResults)

verificationResult.metrics.values.foreach { metric =>
val metricValue = metric.asInstanceOf[Metric[Double]].value.toOption.getOrElse(0.0)
if (metric.instance.contains("where")) assert(math.abs(metricValue - 0.66) < 0.1)
else assert(metricValue == 0.4)
}

val rowLevelResults = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df)
checkConfigs.foreach { checkConfig =>
val results = rowLevelResults.select(checkConfig.checkName).collect().map { r => r.getAs[Boolean](0)}.toSeq
if (checkConfig.whereClause.isDefined) assert(results == Seq(true, false, true, true, true))
else assert(results == Seq(true, false, true, false, false))
}
}
}

/** Run anomaly detection using a repository with some previous analysis results for testing */
private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = {

Expand Down
30 changes: 11 additions & 19 deletions src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
*
*/


package com.amazon.deequ.analyzers

import com.amazon.deequ.SparkContextSpec
Expand All @@ -25,34 +24,30 @@ import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport {

"Compliance" should {
"return row-level results for columns" in withSparkSession { session =>

val data = getDfWithNumericValues(session)

val att1Compliance = Compliance("rule1", "att1 > 3", columns = List("att1"))
val state = att1Compliance.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Int]("new")
) shouldBe Seq(0, 0, 0, 1, 1, 1)
data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")
) shouldBe Seq(false, false, false, true, true, true)
}

"return row-level results for null columns" in withSparkSession { session =>

val data = getDfWithNumericValues(session)

val att1Compliance = Compliance("rule1", "attNull > 3", columns = List("att1"))
val state = att1Compliance.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 1, 1, 1)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(null, null, null, true, true, true)
}

"return row-level results filtered with null" in withSparkSession { session =>

val data = getDfWithNumericValues(session)

val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"),
Expand All @@ -61,11 +56,10 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 0, 1, 1)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(null, null, null, false, true, true)
}

"return row-level results filtered with true" in withSparkSession { session =>

val data = getDfWithNumericValues(session)

val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"),
Expand All @@ -74,7 +68,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(1, 1, 1, 0, 1, 1)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(true, true, true, false, true, true)
}

"return row-level results for compliance in bounds" in withSparkSession { session =>
Expand All @@ -93,7 +87,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 0)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, true, true, false)
}

"return row-level results for compliance in bounds filtered as null" in withSparkSession { session =>
Expand All @@ -114,7 +108,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, null, null, null)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, null, null, null)
}

"return row-level results for compliance in bounds filtered as true" in withSparkSession { session =>
Expand All @@ -135,7 +129,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 1)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, true, true, true)
}

"return row-level results for compliance in array" in withSparkSession { session =>
Expand All @@ -157,7 +151,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 0)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, true, false)
}

"return row-level results for compliance in array filtered as null" in withSparkSession { session =>
Expand All @@ -180,7 +174,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, null, null)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, null, null)
}

"return row-level results for compliance in array filtered as true" in withSparkSession { session =>
Expand All @@ -196,16 +190,14 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit

val data = getDfWithNumericValues(session)


val att1Compliance = Compliance(s"$column contained in ${allowedValues.mkString(",")}", predicate,
where = Option("att1 < 5"), columns = List("att3"),
analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE)))
val state = att1Compliance.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 1)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, true, true)
}
}

}

0 comments on commit 6da724e

Please sign in to comment.