From e5ada70efdccba6da6c3d34d9b0f812f3b2d0c4f Mon Sep 17 00:00:00 2001 From: rdsharma26 <65777064+rdsharma26@users.noreply.github.com> Date: Sun, 10 Mar 2024 17:08:37 -0400 Subject: [PATCH] [MinLength/MaxLength] Apply filtered row behavior at the row level evaluation (#547) * [MinLength/MaxLength] Apply filtered row behavior at the row level evaluation - For certain scenarios, the filtered row behavior for MinLength and MaxLength was not working correctly. - For example, when using both minLength and maxLength constraints in a single check, and with both using == as an assertion. This was resulting in the row level outcome of the filtered rows to be false. This was because we were replacing values for filtered rows for Min to MaxValue and for Max to MinValue. But a number could not equal both at the same time. - Updated the logic of the row level assertion to MinLength/MaxLength to match what was done for Min/Max. --- .../com/amazon/deequ/analyzers/Analyzer.scala | 11 +- .../amazon/deequ/analyzers/MaxLength.scala | 47 ++-- .../com/amazon/deequ/analyzers/Maximum.scala | 20 +- .../amazon/deequ/analyzers/MinLength.scala | 47 ++-- .../com/amazon/deequ/analyzers/Minimum.scala | 20 +- .../amazon/deequ/constraints/Constraint.scala | 61 +++-- .../amazon/deequ/VerificationSuiteTest.scala | 208 +++++++++++++++++- .../deequ/analyzers/MaxLengthTest.scala | 114 ++-------- .../amazon/deequ/analyzers/MaximumTest.scala | 2 - .../deequ/analyzers/MinLengthTest.scala | 117 ++-------- .../amazon/deequ/utils/FixtureSupport.scala | 10 +- 11 files changed, 367 insertions(+), 290 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index bc05adb5..9367f31e 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -519,13 +519,16 @@ private[deequ] object Analyzers { } def conditionalSelectionWithAugmentedOutcome(selection: Column, - condition: Option[String], - replaceWith: Double): Column = { + condition: Option[String]): Column = { val origSelection = array(lit(InScopeData.name).as("source"), selection.as("selection")) - val filteredSelection = array(lit(FilteredData.name).as("source"), lit(replaceWith).as("selection")) + + // The 2nd value in the array is set to null, but it can be set to anything. + // The value is not used to evaluate the row level outcome for filtered rows (to true/null). + // That decision is made using the 1st value which is set to "FilteredData" here. + val filteredSelection = array(lit(FilteredData.name).as("source"), lit(null).as("selection")) condition - .map { cond => when(not(expr(cond)), filteredSelection).otherwise(origSelection) } + .map { cond => when(expr(cond), origSelection).otherwise(filteredSelection) } .getOrElse(origSelection) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala index 3b55d4fa..141d92fb 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala @@ -23,10 +23,11 @@ import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.element_at import org.apache.spark.sql.functions.length +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.max -import org.apache.spark.sql.functions.not +import org.apache.spark.sql.functions.when import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -35,12 +36,15 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - max(criterion) :: Nil + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. + max(element_at(criterion, 2).cast(DoubleType)) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = { ifNoNullsIn(result, offset) { _ => - MaxState(result.getDouble(offset), Some(rowLevelResults)) + MaxState(result.getDouble(offset), Some(criterion)) } } @@ -51,35 +55,16 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where private[deequ] def criterion: Column = { - transformColForNullBehavior - } - - private[deequ] def rowLevelResults: Column = { - transformColForFilteredRow(criterion) - } - - private def transformColForFilteredRow(col: Column): Column = { - val whereNotCondition = where.map { expression => not(expr(expression)) } - getRowLevelFilterTreatment(analyzerOptions) match { - case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MinValue) - case _ => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null) - } - } - - private def transformColForNullBehavior: Column = { val isNullCheck = col(column).isNull - val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) - getNullBehavior match { - case NullBehavior.Fail => - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue) - case NullBehavior.EmptyString => - // Empty String is 0 length string - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType) - case _ => - colLengths + val colLength = length(col(column)).cast(DoubleType) + val updatedColumn = getNullBehavior match { + case NullBehavior.Fail => when(isNullCheck, Double.MaxValue).otherwise(colLength) + // Empty String is 0 length string + case NullBehavior.EmptyString => when(isNullCheck, lit(0.0)).otherwise(colLength) + case NullBehavior.Ignore => colLength } + + conditionalSelectionWithAugmentedOutcome(updatedColumn, where) } private def getNullBehavior: NullBehavior = { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala index abeee6d9..1e52a7ae 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala @@ -16,13 +16,18 @@ package com.amazon.deequ.analyzers -import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} -import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.{col, element_at, max} -import org.apache.spark.sql.types.{DoubleType, StructType} -import Analyzers._ +import com.amazon.deequ.analyzers.Analyzers._ +import com.amazon.deequ.analyzers.Preconditions.hasColumn +import com.amazon.deequ.analyzers.Preconditions.isNumeric import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.functions.max +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.Column +import org.apache.spark.sql.Row case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MaxState] with FullColumn { @@ -41,6 +46,9 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. max(element_at(criterion, 2).cast(DoubleType)) :: Nil } @@ -57,5 +65,5 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MinValue) + private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala index a6627d2d..ddc4497b 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala @@ -23,10 +23,11 @@ import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.element_at import org.apache.spark.sql.functions.length +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.min -import org.apache.spark.sql.functions.not +import org.apache.spark.sql.functions.when import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -35,12 +36,15 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - min(criterion) :: Nil + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. + min(element_at(criterion, 2).cast(DoubleType)) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = { ifNoNullsIn(result, offset) { _ => - MinState(result.getDouble(offset), Some(rowLevelResults)) + MinState(result.getDouble(offset), Some(criterion)) } } @@ -51,35 +55,16 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where private[deequ] def criterion: Column = { - transformColForNullBehavior - } - - private[deequ] def rowLevelResults: Column = { - transformColForFilteredRow(criterion) - } - - private def transformColForFilteredRow(col: Column): Column = { - val whereNotCondition = where.map { expression => not(expr(expression)) } - getRowLevelFilterTreatment(analyzerOptions) match { - case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MaxValue) - case _ => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null) - } - } - - private def transformColForNullBehavior: Column = { val isNullCheck = col(column).isNull - val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) - getNullBehavior match { - case NullBehavior.Fail => - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue) - case NullBehavior.EmptyString => - // Empty String is 0 length string - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType) - case _ => - colLengths + val colLength = length(col(column)).cast(DoubleType) + val updatedColumn = getNullBehavior match { + case NullBehavior.Fail => when(isNullCheck, Double.MinValue).otherwise(colLength) + // Empty String is 0 length string + case NullBehavior.EmptyString => when(isNullCheck, lit(0.0)).otherwise(colLength) + case NullBehavior.Ignore => colLength } + + conditionalSelectionWithAugmentedOutcome(updatedColumn, where) } private def getNullBehavior: NullBehavior = { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala index b17507fc..701ae0f0 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala @@ -16,13 +16,18 @@ package com.amazon.deequ.analyzers -import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} -import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.{col, element_at, min} -import org.apache.spark.sql.types.{DoubleType, StructType} -import Analyzers._ +import com.amazon.deequ.analyzers.Analyzers._ +import com.amazon.deequ.analyzers.Preconditions.hasColumn +import com.amazon.deequ.analyzers.Preconditions.isNumeric import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.functions.min +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.Column +import org.apache.spark.sql.Row case class MinState(minValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MinState] with FullColumn { @@ -41,6 +46,9 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. min(element_at(criterion, 2).cast(DoubleType)) :: Nil } @@ -57,5 +65,5 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MaxValue) + private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where) } diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index a28b6f2e..8df88165 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -558,7 +558,8 @@ object Constraint { val constraint = AnalysisBasedConstraint[MaxState, Double, Double](maxLength, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) + val updatedAssertion = getUpdatedRowLevelAssertionForLengthConstraint(assertion, maxLength.analyzerOptions) + val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion) new RowLevelAssertedConstraint( constraint, @@ -593,7 +594,8 @@ object Constraint { val constraint = AnalysisBasedConstraint[MinState, Double, Double](minLength, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) + val updatedAssertion = getUpdatedRowLevelAssertionForLengthConstraint(assertion, minLength.analyzerOptions) + val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion) new RowLevelAssertedConstraint( constraint, @@ -953,26 +955,57 @@ object Constraint { } } - def filteredRowOutcome: java.lang.Boolean = { - analyzerOptions match { - case Some(opts) => - opts.filteredRow match { - case FilteredRowOutcome.TRUE => true - case FilteredRowOutcome.NULL => null - } - // https://github.com/awslabs/deequ/issues/530 - // Filtered rows should be marked as true by default. - // They can be set to null using the FilteredRowOutcome option. - case None => true + scope match { + case FilteredData.name => filteredRowOutcome(analyzerOptions) + case InScopeData.name => inScopeRowOutcome(value) + } + } + } + + private[this] def getUpdatedRowLevelAssertionForLengthConstraint(assertion: Double => Boolean, + analyzerOptions: Option[AnalyzerOptions]) + : Seq[String] => java.lang.Boolean = { + (d: Seq[String]) => { + val (scope, value) = (d.head, Option(d.last).map(_.toDouble)) + + def inScopeRowOutcome(value: Option[Double]): java.lang.Boolean = { + if (value.isDefined) { + // If value is defined, run it through the assertion. + assertion(value.get) + } else { + // If value is not defined (value is null), apply NullBehavior. + analyzerOptions match { + case Some(opts) => + opts.nullBehavior match { + case NullBehavior.EmptyString => assertion(0.0) + case NullBehavior.Fail => false + case NullBehavior.Ignore => null + } + case None => null + } } } scope match { - case FilteredData.name => filteredRowOutcome + case FilteredData.name => filteredRowOutcome(analyzerOptions) case InScopeData.name => inScopeRowOutcome(value) } } } + + private def filteredRowOutcome(analyzerOptions: Option[AnalyzerOptions]): java.lang.Boolean = { + analyzerOptions match { + case Some(opts) => + opts.filteredRow match { + case FilteredRowOutcome.TRUE => true + case FilteredRowOutcome.NULL => null + } + // https://github.com/awslabs/deequ/issues/530 + // Filtered rows should be marked as true by default. + // They can be set to null using the FilteredRowOutcome option. + case None => true + } + } } /** diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index f7684e96..587f8bf5 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -1759,7 +1759,6 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .addChecks(Seq(mkEqualityCheck(analyzerOptions))) .run() - val passResult = verificationResult.checkResults assertCheckResults(verificationResult, CheckStatus.Error) val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) @@ -1778,7 +1777,6 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .addChecks(Seq(mkEqualityCheck(analyzerOptions))) .run() - val passResult = verificationResult.checkResults assertCheckResults(verificationResult, CheckStatus.Error) val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) @@ -1789,6 +1787,212 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } } + "Verification Suite with ==/!= based MinLength/MaxLength checks and filtered row behavior" should { + val col1 = "Company" + val col2 = "ZipCode" + val col3 = "City" + + val check1Description = "length-equality-check-1" + val check2Description = "length-equality-check-2" + val check3Description = "length-equality-check-3" + + val check1WhereClause = "ID > 2" + val check2WhereClause = "ID in (1, 2, 3)" + val check3WhereClause = "ID <= 2" + + def mkLengthCheck1(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check1Description) + .hasMinLength(col1, _ == 8, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause) + .hasMaxLength(col1, _ == 8, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause) + + def mkLengthCheck2(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check2Description) + .hasMinLength(col2, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause) + .hasMaxLength(col2, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause) + + def mkLengthCheck3(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check3Description) + .hasMinLength(col3, _ != 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause) + .hasMaxLength(col3, _ != 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause) + + def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] = + df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq + + def assertCheckResults(verificationResult: VerificationResult): Unit = { + val passResult = verificationResult.checkResults + + val equalityCheck1Result = passResult.values.find(_.check.description == check1Description) + val equalityCheck2Result = passResult.values.find(_.check.description == check2Description) + val equalityCheck3Result = passResult.values.find(_.check.description == check3Description) + + assert(equalityCheck1Result.isDefined && equalityCheck1Result.get.status == CheckStatus.Success) + assert(equalityCheck2Result.isDefined && equalityCheck2Result.get.status == CheckStatus.Error) + assert(equalityCheck3Result.isDefined && equalityCheck3Result.get.status == CheckStatus.Success) + } + + def assertRowLevelResults(rowLevelResults: DataFrame, + analyzerOptions: AnalyzerOptions): Unit = { + val equalityCheck1Results = getRowLevelResults(rowLevelResults.select(check1Description)) + val equalityCheck2Results = getRowLevelResults(rowLevelResults.select(check2Description)) + val equalityCheck3Results = getRowLevelResults(rowLevelResults.select(check3Description)) + + val filteredOutcome: java.lang.Boolean = analyzerOptions.filteredRow match { + case FilteredRowOutcome.TRUE => true + case FilteredRowOutcome.NULL => null + } + + assert(equalityCheck1Results == Seq(filteredOutcome, filteredOutcome, true, true)) + assert(equalityCheck2Results == Seq(false, false, false, filteredOutcome)) + assert(equalityCheck3Results == Seq(true, true, filteredOutcome, filteredOutcome)) + } + + def assertMetrics(metricsDF: DataFrame): Unit = { + val metricsMap = getMetricsAsMap(metricsDF) + assert(metricsMap(s"$col1|MinLength (where: $check1WhereClause)") == 8.0) + assert(metricsMap(s"$col1|MaxLength (where: $check1WhereClause)") == 8.0) + assert(metricsMap(s"$col2|MinLength (where: $check2WhereClause)") == 0.0) + assert(metricsMap(s"$col2|MaxLength (where: $check2WhereClause)") == 5.0) + assert(metricsMap(s"$col3|MinLength (where: $check3WhereClause)") == 11.0) + assert(metricsMap(s"$col3|MaxLength (where: $check3WhereClause)") == 11.0) + } + + "mark filtered rows as null" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions( + nullBehavior = NullBehavior.EmptyString, filteredRow = FilteredRowOutcome.NULL + ) + + val equalityCheck1 = mkLengthCheck1(analyzerOptions) + val equalityCheck2 = mkLengthCheck2(analyzerOptions) + val equalityCheck3 = mkLengthCheck3(analyzerOptions) + + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3)) + .run() + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + + assertCheckResults(verificationResult) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + assertMetrics(metricsDF) + } + + "mark filtered rows as true" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions( + nullBehavior = NullBehavior.EmptyString, filteredRow = FilteredRowOutcome.TRUE + ) + + val equalityCheck1 = mkLengthCheck1(analyzerOptions) + val equalityCheck2 = mkLengthCheck2(analyzerOptions) + val equalityCheck3 = mkLengthCheck3(analyzerOptions) + + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3)) + .run() + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + + assertCheckResults(verificationResult) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + assertMetrics(metricsDF) + } + } + + "Verification Suite with ==/!= based MinLength/MaxLength checks and null row behavior" should { + val col = "City" + val checkDescription = "length-check" + val assertion = (d: Double) => d >= 0.0 && d <= 8.0 + + def mkLengthCheck(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, checkDescription) + .hasMinLength(col, assertion, analyzerOptions = Some(analyzerOptions)) + .hasMaxLength(col, assertion, analyzerOptions = Some(analyzerOptions)) + + def assertCheckResults(verificationResult: VerificationResult, checkStatus: CheckStatus.Value): Unit = { + val passResult = verificationResult.checkResults + val equalityCheckResult = passResult.values.find(_.check.description == checkDescription) + assert(equalityCheckResult.isDefined && equalityCheckResult.get.status == checkStatus) + } + + def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] = + df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq + + def assertRowLevelResults(rowLevelResults: DataFrame, + analyzerOptions: AnalyzerOptions): Unit = { + val equalityCheckResults = getRowLevelResults(rowLevelResults.select(checkDescription)) + val nullOutcome: java.lang.Boolean = analyzerOptions.nullBehavior match { + case NullBehavior.Fail => false + case NullBehavior.Ignore => null + case NullBehavior.EmptyString => true + } + + assert(equalityCheckResults == Seq(false, false, nullOutcome, true)) + } + + def assertMetrics(metricsDF: DataFrame, minLength: Double, maxLength: Double): Unit = { + val metricsMap = getMetricsAsMap(metricsDF) + assert(metricsMap(s"$col|MinLength") == minLength) + assert(metricsMap(s"$col|MaxLength") == maxLength) + } + + "keep non-filtered null rows as null" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.Ignore) + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(mkLengthCheck(analyzerOptions)) + .run() + + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF, 8.0, 11.0) + } + + "mark non-filtered null rows as false" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.Fail) + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(mkLengthCheck(analyzerOptions)) + .run() + + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF, Double.MinValue, Double.MaxValue) + } + + "mark non-filtered null rows as empty string" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.EmptyString) + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(mkLengthCheck(analyzerOptions)) + .run() + + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF, 0.0, 11.0) + } + } + /** Run anomaly detection using a repository with some previous analysis results for testing */ private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = { diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala index 456bd4c6..fd302a4d 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala @@ -13,145 +13,73 @@ * permissions and limitations under the License. * */ + package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.types.DoubleType import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class MaxLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + private val tempColName = "new" + + private def getValuesDF(df: DataFrame, outcomeColumn: Column): Seq[Row] = { + df.withColumn(tempColName, element_at(outcomeColumn, 2).cast(DoubleType)).collect() + } "MaxLength" should { "return row-level results for non-null columns" in withSparkSession { session => - val data = getDfWithStringColumns(session) val countryLength = MaxLength("Country") // It's "India" in every row val state: Option[MaxState] = countryLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = countryLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) } "return row-level results for null columns" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MaxLength("att3") // It's null in two rows val state: Option[MaxState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) + val values = getValuesDF(data, metric.fullColumn.get) + .map(r => if (r == null) null else r.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, null, 1.0, null, 1.0) } "return row-level results for null columns with NullBehavior fail option" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) // It's null in two rows - val addressLength = MaxLength("att3") + val addressLength = MaxLength("att3", None, Option(AnalyzerOptions(NullBehavior.Fail))) val state: Option[MaxState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(r => if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 1.0, null, 1.0, null, 1.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, Double.MaxValue, 1.0, Double.MaxValue, 1.0) } "return row-level results for blank strings" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MaxLength("att1") // It's empty strings val state: Option[MaxState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) - } - - "return row-level results with NullBehavior fail and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Fail))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, Double.MaxValue, 1.0, Double.MinValue, Double.MinValue) - } - - "return row-level results with NullBehavior fail and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, Double.MaxValue, 1.0, null, null) - } - - "return row-level results with NullBehavior empty and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.EmptyString))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, 0.0, 1.0, Double.MinValue, Double.MinValue) - } - - "return row-level results with NullBehavior empty and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, null, null) - } - - "return row-level results with NullBehavior ignore and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Ignore))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, null, 1.0, Double.MinValue, Double.MinValue) - } - - "return row-level results with NullBehavior ignore and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Ignore, FilteredRowOutcome.NULL))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, null, 1.0, null, null) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) } } - } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala index 1d13a8df..983e6bca 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala @@ -14,7 +14,6 @@ * */ - package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec @@ -50,7 +49,6 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F } "return row-level results for columns with null" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Maximum = Maximum("attNull") diff --git a/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala index 0f88e377..23e99574 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala @@ -14,45 +14,52 @@ * */ - package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.types.DoubleType import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + private val tempColName = "new" + + private def getValuesDF(df: DataFrame, outcomeColumn: Column): Seq[Row] = { + df.withColumn(tempColName, element_at(outcomeColumn, 2).cast(DoubleType)).collect() + } "MinLength" should { "return row-level results for non-null columns" in withSparkSession { session => - val data = getDfWithStringColumns(session) val countryLength = MinLength("Country") // It's "India" in every row val state: Option[MinState] = countryLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = countryLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) } "return row-level results for null columns" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MinLength("att3") // It's null in two rows val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) + val values = getValuesDF(data, metric.fullColumn.get) + .map(r => if (r == null) null else r.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, null, 1.0, null, 1.0) } "return row-level results for null columns with NullBehavior fail option" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) // It's null in two rows @@ -60,13 +67,11 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new") - ) shouldBe Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MinValue, 1.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MinValue, 1.0) } "return row-level results for null columns with NullBehavior empty option" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) // It's null in two rows @@ -74,99 +79,19 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) } "return row-level results for blank strings" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MinLength("att1") // It's empty strings val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) - } - - "return row-level results with NullBehavior fail and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Fail))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MaxValue, Double.MaxValue) - } - - "return row-level results with NullBehavior fail and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, Double.MinValue, 1.0, null, null) - } - - "return row-level results with NullBehavior empty and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.EmptyString))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, 0.0, 1.0, Double.MaxValue, Double.MaxValue) - } - - "return row-level results with NullBehavior empty and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, null, null) - } - - "return row-level results NullBehavior ignore and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Ignore))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, null, 1.0, Double.MaxValue, Double.MaxValue) - } - - "return row-level results with NullBehavior ignore and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Ignore, FilteredRowOutcome.NULL))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, null, 1.0, null, null) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) } } } diff --git a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala index 5c56ed4b..3a0866d2 100644 --- a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala +++ b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala @@ -439,11 +439,11 @@ trait FixtureSupport { import sparkSession.implicits._ Seq( - ("Acme", "90210", "CA", "Los Angeles"), - ("Acme", "90211", "CA", "Los Angeles"), - ("Robocorp", null, "NJ", null), - ("Robocorp", null, "NY", "New York") - ).toDF("Company", "ZipCode", "State", "City") + (1, "Acme", "90210", "CA", "Los Angeles"), + (2, "Acme", "90211", "CA", "Los Angeles"), + (3, "Robocorp", null, "NJ", null), + (4, "Robocorp", null, "NY", "New York") + ).toDF("ID", "Company", "ZipCode", "State", "City") } def getDfCompleteAndInCompleteColumnsWithPeriod(sparkSession: SparkSession): DataFrame = {