diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index bc05adb5..9367f31e 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -519,13 +519,16 @@ private[deequ] object Analyzers { } def conditionalSelectionWithAugmentedOutcome(selection: Column, - condition: Option[String], - replaceWith: Double): Column = { + condition: Option[String]): Column = { val origSelection = array(lit(InScopeData.name).as("source"), selection.as("selection")) - val filteredSelection = array(lit(FilteredData.name).as("source"), lit(replaceWith).as("selection")) + + // The 2nd value in the array is set to null, but it can be set to anything. + // The value is not used to evaluate the row level outcome for filtered rows (to true/null). + // That decision is made using the 1st value which is set to "FilteredData" here. + val filteredSelection = array(lit(FilteredData.name).as("source"), lit(null).as("selection")) condition - .map { cond => when(not(expr(cond)), filteredSelection).otherwise(origSelection) } + .map { cond => when(expr(cond), origSelection).otherwise(filteredSelection) } .getOrElse(origSelection) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala index 3b55d4fa..141d92fb 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala @@ -23,10 +23,11 @@ import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.element_at import org.apache.spark.sql.functions.length +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.max -import org.apache.spark.sql.functions.not +import org.apache.spark.sql.functions.when import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -35,12 +36,15 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - max(criterion) :: Nil + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. + max(element_at(criterion, 2).cast(DoubleType)) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = { ifNoNullsIn(result, offset) { _ => - MaxState(result.getDouble(offset), Some(rowLevelResults)) + MaxState(result.getDouble(offset), Some(criterion)) } } @@ -51,35 +55,16 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where private[deequ] def criterion: Column = { - transformColForNullBehavior - } - - private[deequ] def rowLevelResults: Column = { - transformColForFilteredRow(criterion) - } - - private def transformColForFilteredRow(col: Column): Column = { - val whereNotCondition = where.map { expression => not(expr(expression)) } - getRowLevelFilterTreatment(analyzerOptions) match { - case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MinValue) - case _ => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null) - } - } - - private def transformColForNullBehavior: Column = { val isNullCheck = col(column).isNull - val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) - getNullBehavior match { - case NullBehavior.Fail => - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue) - case NullBehavior.EmptyString => - // Empty String is 0 length string - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType) - case _ => - colLengths + val colLength = length(col(column)).cast(DoubleType) + val updatedColumn = getNullBehavior match { + case NullBehavior.Fail => when(isNullCheck, Double.MaxValue).otherwise(colLength) + // Empty String is 0 length string + case NullBehavior.EmptyString => when(isNullCheck, lit(0.0)).otherwise(colLength) + case NullBehavior.Ignore => colLength } + + conditionalSelectionWithAugmentedOutcome(updatedColumn, where) } private def getNullBehavior: NullBehavior = { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala index abeee6d9..1e52a7ae 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala @@ -16,13 +16,18 @@ package com.amazon.deequ.analyzers -import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} -import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.{col, element_at, max} -import org.apache.spark.sql.types.{DoubleType, StructType} -import Analyzers._ +import com.amazon.deequ.analyzers.Analyzers._ +import com.amazon.deequ.analyzers.Preconditions.hasColumn +import com.amazon.deequ.analyzers.Preconditions.isNumeric import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.functions.max +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.Column +import org.apache.spark.sql.Row case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MaxState] with FullColumn { @@ -41,6 +46,9 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. max(element_at(criterion, 2).cast(DoubleType)) :: Nil } @@ -57,5 +65,5 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MinValue) + private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala index a6627d2d..ddc4497b 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala @@ -23,10 +23,11 @@ import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.element_at import org.apache.spark.sql.functions.length +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.min -import org.apache.spark.sql.functions.not +import org.apache.spark.sql.functions.when import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -35,12 +36,15 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - min(criterion) :: Nil + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. + min(element_at(criterion, 2).cast(DoubleType)) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = { ifNoNullsIn(result, offset) { _ => - MinState(result.getDouble(offset), Some(rowLevelResults)) + MinState(result.getDouble(offset), Some(criterion)) } } @@ -51,35 +55,16 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where private[deequ] def criterion: Column = { - transformColForNullBehavior - } - - private[deequ] def rowLevelResults: Column = { - transformColForFilteredRow(criterion) - } - - private def transformColForFilteredRow(col: Column): Column = { - val whereNotCondition = where.map { expression => not(expr(expression)) } - getRowLevelFilterTreatment(analyzerOptions) match { - case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MaxValue) - case _ => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null) - } - } - - private def transformColForNullBehavior: Column = { val isNullCheck = col(column).isNull - val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) - getNullBehavior match { - case NullBehavior.Fail => - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue) - case NullBehavior.EmptyString => - // Empty String is 0 length string - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType) - case _ => - colLengths + val colLength = length(col(column)).cast(DoubleType) + val updatedColumn = getNullBehavior match { + case NullBehavior.Fail => when(isNullCheck, Double.MinValue).otherwise(colLength) + // Empty String is 0 length string + case NullBehavior.EmptyString => when(isNullCheck, lit(0.0)).otherwise(colLength) + case NullBehavior.Ignore => colLength } + + conditionalSelectionWithAugmentedOutcome(updatedColumn, where) } private def getNullBehavior: NullBehavior = { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala index b17507fc..701ae0f0 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala @@ -16,13 +16,18 @@ package com.amazon.deequ.analyzers -import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} -import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.{col, element_at, min} -import org.apache.spark.sql.types.{DoubleType, StructType} -import Analyzers._ +import com.amazon.deequ.analyzers.Analyzers._ +import com.amazon.deequ.analyzers.Preconditions.hasColumn +import com.amazon.deequ.analyzers.Preconditions.isNumeric import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.functions.min +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.Column +import org.apache.spark.sql.Row case class MinState(minValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MinState] with FullColumn { @@ -41,6 +46,9 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. min(element_at(criterion, 2).cast(DoubleType)) :: Nil } @@ -57,5 +65,5 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MaxValue) + private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where) } diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index a28b6f2e..8df88165 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -558,7 +558,8 @@ object Constraint { val constraint = AnalysisBasedConstraint[MaxState, Double, Double](maxLength, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) + val updatedAssertion = getUpdatedRowLevelAssertionForLengthConstraint(assertion, maxLength.analyzerOptions) + val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion) new RowLevelAssertedConstraint( constraint, @@ -593,7 +594,8 @@ object Constraint { val constraint = AnalysisBasedConstraint[MinState, Double, Double](minLength, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) + val updatedAssertion = getUpdatedRowLevelAssertionForLengthConstraint(assertion, minLength.analyzerOptions) + val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion) new RowLevelAssertedConstraint( constraint, @@ -953,26 +955,57 @@ object Constraint { } } - def filteredRowOutcome: java.lang.Boolean = { - analyzerOptions match { - case Some(opts) => - opts.filteredRow match { - case FilteredRowOutcome.TRUE => true - case FilteredRowOutcome.NULL => null - } - // https://github.com/awslabs/deequ/issues/530 - // Filtered rows should be marked as true by default. - // They can be set to null using the FilteredRowOutcome option. - case None => true + scope match { + case FilteredData.name => filteredRowOutcome(analyzerOptions) + case InScopeData.name => inScopeRowOutcome(value) + } + } + } + + private[this] def getUpdatedRowLevelAssertionForLengthConstraint(assertion: Double => Boolean, + analyzerOptions: Option[AnalyzerOptions]) + : Seq[String] => java.lang.Boolean = { + (d: Seq[String]) => { + val (scope, value) = (d.head, Option(d.last).map(_.toDouble)) + + def inScopeRowOutcome(value: Option[Double]): java.lang.Boolean = { + if (value.isDefined) { + // If value is defined, run it through the assertion. + assertion(value.get) + } else { + // If value is not defined (value is null), apply NullBehavior. + analyzerOptions match { + case Some(opts) => + opts.nullBehavior match { + case NullBehavior.EmptyString => assertion(0.0) + case NullBehavior.Fail => false + case NullBehavior.Ignore => null + } + case None => null + } } } scope match { - case FilteredData.name => filteredRowOutcome + case FilteredData.name => filteredRowOutcome(analyzerOptions) case InScopeData.name => inScopeRowOutcome(value) } } } + + private def filteredRowOutcome(analyzerOptions: Option[AnalyzerOptions]): java.lang.Boolean = { + analyzerOptions match { + case Some(opts) => + opts.filteredRow match { + case FilteredRowOutcome.TRUE => true + case FilteredRowOutcome.NULL => null + } + // https://github.com/awslabs/deequ/issues/530 + // Filtered rows should be marked as true by default. + // They can be set to null using the FilteredRowOutcome option. + case None => true + } + } } /** diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index f7684e96..587f8bf5 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -1759,7 +1759,6 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .addChecks(Seq(mkEqualityCheck(analyzerOptions))) .run() - val passResult = verificationResult.checkResults assertCheckResults(verificationResult, CheckStatus.Error) val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) @@ -1778,7 +1777,6 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .addChecks(Seq(mkEqualityCheck(analyzerOptions))) .run() - val passResult = verificationResult.checkResults assertCheckResults(verificationResult, CheckStatus.Error) val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) @@ -1789,6 +1787,212 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } } + "Verification Suite with ==/!= based MinLength/MaxLength checks and filtered row behavior" should { + val col1 = "Company" + val col2 = "ZipCode" + val col3 = "City" + + val check1Description = "length-equality-check-1" + val check2Description = "length-equality-check-2" + val check3Description = "length-equality-check-3" + + val check1WhereClause = "ID > 2" + val check2WhereClause = "ID in (1, 2, 3)" + val check3WhereClause = "ID <= 2" + + def mkLengthCheck1(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check1Description) + .hasMinLength(col1, _ == 8, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause) + .hasMaxLength(col1, _ == 8, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause) + + def mkLengthCheck2(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check2Description) + .hasMinLength(col2, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause) + .hasMaxLength(col2, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause) + + def mkLengthCheck3(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check3Description) + .hasMinLength(col3, _ != 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause) + .hasMaxLength(col3, _ != 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause) + + def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] = + df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq + + def assertCheckResults(verificationResult: VerificationResult): Unit = { + val passResult = verificationResult.checkResults + + val equalityCheck1Result = passResult.values.find(_.check.description == check1Description) + val equalityCheck2Result = passResult.values.find(_.check.description == check2Description) + val equalityCheck3Result = passResult.values.find(_.check.description == check3Description) + + assert(equalityCheck1Result.isDefined && equalityCheck1Result.get.status == CheckStatus.Success) + assert(equalityCheck2Result.isDefined && equalityCheck2Result.get.status == CheckStatus.Error) + assert(equalityCheck3Result.isDefined && equalityCheck3Result.get.status == CheckStatus.Success) + } + + def assertRowLevelResults(rowLevelResults: DataFrame, + analyzerOptions: AnalyzerOptions): Unit = { + val equalityCheck1Results = getRowLevelResults(rowLevelResults.select(check1Description)) + val equalityCheck2Results = getRowLevelResults(rowLevelResults.select(check2Description)) + val equalityCheck3Results = getRowLevelResults(rowLevelResults.select(check3Description)) + + val filteredOutcome: java.lang.Boolean = analyzerOptions.filteredRow match { + case FilteredRowOutcome.TRUE => true + case FilteredRowOutcome.NULL => null + } + + assert(equalityCheck1Results == Seq(filteredOutcome, filteredOutcome, true, true)) + assert(equalityCheck2Results == Seq(false, false, false, filteredOutcome)) + assert(equalityCheck3Results == Seq(true, true, filteredOutcome, filteredOutcome)) + } + + def assertMetrics(metricsDF: DataFrame): Unit = { + val metricsMap = getMetricsAsMap(metricsDF) + assert(metricsMap(s"$col1|MinLength (where: $check1WhereClause)") == 8.0) + assert(metricsMap(s"$col1|MaxLength (where: $check1WhereClause)") == 8.0) + assert(metricsMap(s"$col2|MinLength (where: $check2WhereClause)") == 0.0) + assert(metricsMap(s"$col2|MaxLength (where: $check2WhereClause)") == 5.0) + assert(metricsMap(s"$col3|MinLength (where: $check3WhereClause)") == 11.0) + assert(metricsMap(s"$col3|MaxLength (where: $check3WhereClause)") == 11.0) + } + + "mark filtered rows as null" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions( + nullBehavior = NullBehavior.EmptyString, filteredRow = FilteredRowOutcome.NULL + ) + + val equalityCheck1 = mkLengthCheck1(analyzerOptions) + val equalityCheck2 = mkLengthCheck2(analyzerOptions) + val equalityCheck3 = mkLengthCheck3(analyzerOptions) + + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3)) + .run() + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + + assertCheckResults(verificationResult) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + assertMetrics(metricsDF) + } + + "mark filtered rows as true" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions( + nullBehavior = NullBehavior.EmptyString, filteredRow = FilteredRowOutcome.TRUE + ) + + val equalityCheck1 = mkLengthCheck1(analyzerOptions) + val equalityCheck2 = mkLengthCheck2(analyzerOptions) + val equalityCheck3 = mkLengthCheck3(analyzerOptions) + + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3)) + .run() + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + + assertCheckResults(verificationResult) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + assertMetrics(metricsDF) + } + } + + "Verification Suite with ==/!= based MinLength/MaxLength checks and null row behavior" should { + val col = "City" + val checkDescription = "length-check" + val assertion = (d: Double) => d >= 0.0 && d <= 8.0 + + def mkLengthCheck(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, checkDescription) + .hasMinLength(col, assertion, analyzerOptions = Some(analyzerOptions)) + .hasMaxLength(col, assertion, analyzerOptions = Some(analyzerOptions)) + + def assertCheckResults(verificationResult: VerificationResult, checkStatus: CheckStatus.Value): Unit = { + val passResult = verificationResult.checkResults + val equalityCheckResult = passResult.values.find(_.check.description == checkDescription) + assert(equalityCheckResult.isDefined && equalityCheckResult.get.status == checkStatus) + } + + def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] = + df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq + + def assertRowLevelResults(rowLevelResults: DataFrame, + analyzerOptions: AnalyzerOptions): Unit = { + val equalityCheckResults = getRowLevelResults(rowLevelResults.select(checkDescription)) + val nullOutcome: java.lang.Boolean = analyzerOptions.nullBehavior match { + case NullBehavior.Fail => false + case NullBehavior.Ignore => null + case NullBehavior.EmptyString => true + } + + assert(equalityCheckResults == Seq(false, false, nullOutcome, true)) + } + + def assertMetrics(metricsDF: DataFrame, minLength: Double, maxLength: Double): Unit = { + val metricsMap = getMetricsAsMap(metricsDF) + assert(metricsMap(s"$col|MinLength") == minLength) + assert(metricsMap(s"$col|MaxLength") == maxLength) + } + + "keep non-filtered null rows as null" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.Ignore) + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(mkLengthCheck(analyzerOptions)) + .run() + + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF, 8.0, 11.0) + } + + "mark non-filtered null rows as false" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.Fail) + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(mkLengthCheck(analyzerOptions)) + .run() + + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF, Double.MinValue, Double.MaxValue) + } + + "mark non-filtered null rows as empty string" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.EmptyString) + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(mkLengthCheck(analyzerOptions)) + .run() + + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF, 0.0, 11.0) + } + } + /** Run anomaly detection using a repository with some previous analysis results for testing */ private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = { diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala index 456bd4c6..fd302a4d 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala @@ -13,145 +13,73 @@ * permissions and limitations under the License. * */ + package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.types.DoubleType import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class MaxLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + private val tempColName = "new" + + private def getValuesDF(df: DataFrame, outcomeColumn: Column): Seq[Row] = { + df.withColumn(tempColName, element_at(outcomeColumn, 2).cast(DoubleType)).collect() + } "MaxLength" should { "return row-level results for non-null columns" in withSparkSession { session => - val data = getDfWithStringColumns(session) val countryLength = MaxLength("Country") // It's "India" in every row val state: Option[MaxState] = countryLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = countryLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) } "return row-level results for null columns" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MaxLength("att3") // It's null in two rows val state: Option[MaxState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) + val values = getValuesDF(data, metric.fullColumn.get) + .map(r => if (r == null) null else r.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, null, 1.0, null, 1.0) } "return row-level results for null columns with NullBehavior fail option" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) // It's null in two rows - val addressLength = MaxLength("att3") + val addressLength = MaxLength("att3", None, Option(AnalyzerOptions(NullBehavior.Fail))) val state: Option[MaxState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(r => if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 1.0, null, 1.0, null, 1.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, Double.MaxValue, 1.0, Double.MaxValue, 1.0) } "return row-level results for blank strings" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MaxLength("att1") // It's empty strings val state: Option[MaxState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) - } - - "return row-level results with NullBehavior fail and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Fail))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, Double.MaxValue, 1.0, Double.MinValue, Double.MinValue) - } - - "return row-level results with NullBehavior fail and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, Double.MaxValue, 1.0, null, null) - } - - "return row-level results with NullBehavior empty and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.EmptyString))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, 0.0, 1.0, Double.MinValue, Double.MinValue) - } - - "return row-level results with NullBehavior empty and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, null, null) - } - - "return row-level results with NullBehavior ignore and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Ignore))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, null, 1.0, Double.MinValue, Double.MinValue) - } - - "return row-level results with NullBehavior ignore and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Ignore, FilteredRowOutcome.NULL))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, null, 1.0, null, null) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) } } - } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala index 1d13a8df..983e6bca 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala @@ -14,7 +14,6 @@ * */ - package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec @@ -50,7 +49,6 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F } "return row-level results for columns with null" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Maximum = Maximum("attNull") diff --git a/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala index 0f88e377..23e99574 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala @@ -14,45 +14,52 @@ * */ - package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.types.DoubleType import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + private val tempColName = "new" + + private def getValuesDF(df: DataFrame, outcomeColumn: Column): Seq[Row] = { + df.withColumn(tempColName, element_at(outcomeColumn, 2).cast(DoubleType)).collect() + } "MinLength" should { "return row-level results for non-null columns" in withSparkSession { session => - val data = getDfWithStringColumns(session) val countryLength = MinLength("Country") // It's "India" in every row val state: Option[MinState] = countryLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = countryLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) } "return row-level results for null columns" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MinLength("att3") // It's null in two rows val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) + val values = getValuesDF(data, metric.fullColumn.get) + .map(r => if (r == null) null else r.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, null, 1.0, null, 1.0) } "return row-level results for null columns with NullBehavior fail option" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) // It's null in two rows @@ -60,13 +67,11 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new") - ) shouldBe Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MinValue, 1.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MinValue, 1.0) } "return row-level results for null columns with NullBehavior empty option" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) // It's null in two rows @@ -74,99 +79,19 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) } "return row-level results for blank strings" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MinLength("att1") // It's empty strings val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) - } - - "return row-level results with NullBehavior fail and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Fail))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MaxValue, Double.MaxValue) - } - - "return row-level results with NullBehavior fail and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, Double.MinValue, 1.0, null, null) - } - - "return row-level results with NullBehavior empty and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.EmptyString))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, 0.0, 1.0, Double.MaxValue, Double.MaxValue) - } - - "return row-level results with NullBehavior empty and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, null, null) - } - - "return row-level results NullBehavior ignore and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Ignore))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, null, 1.0, Double.MaxValue, Double.MaxValue) - } - - "return row-level results with NullBehavior ignore and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Ignore, FilteredRowOutcome.NULL))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, null, 1.0, null, null) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) } } } diff --git a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala index 5c56ed4b..3a0866d2 100644 --- a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala +++ b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala @@ -439,11 +439,11 @@ trait FixtureSupport { import sparkSession.implicits._ Seq( - ("Acme", "90210", "CA", "Los Angeles"), - ("Acme", "90211", "CA", "Los Angeles"), - ("Robocorp", null, "NJ", null), - ("Robocorp", null, "NY", "New York") - ).toDF("Company", "ZipCode", "State", "City") + (1, "Acme", "90210", "CA", "Los Angeles"), + (2, "Acme", "90211", "CA", "Los Angeles"), + (3, "Robocorp", null, "NJ", null), + (4, "Robocorp", null, "NY", "New York") + ).toDF("ID", "Company", "ZipCode", "State", "City") } def getDfCompleteAndInCompleteColumnsWithPeriod(sparkSession: SparkSession): DataFrame = {