Skip to content

Commit

Permalink
[MinLength/MaxLength] Apply filtered row behavior at the row level ev…
Browse files Browse the repository at this point in the history
…aluation (#547)

* [MinLength/MaxLength] Apply filtered row behavior at the row level evaluation

- For certain scenarios, the filtered row behavior for MinLength and MaxLength was not working correctly.
- For example, when using both minLength and maxLength constraints in a single check, and with both using == <value> as an assertion. This was resulting in the row level outcome of the filtered rows to be false. This was because we were replacing values for filtered rows for Min to MaxValue and for Max to MinValue. But a number could not equal both at the same time.
- Updated the logic of the row level assertion to MinLength/MaxLength to match what was done for Min/Max.
  • Loading branch information
rdsharma26 committed Apr 16, 2024
1 parent 6dc6b70 commit a5dc5bf
Show file tree
Hide file tree
Showing 11 changed files with 367 additions and 290 deletions.
11 changes: 7 additions & 4 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,16 @@ private[deequ] object Analyzers {
}

def conditionalSelectionWithAugmentedOutcome(selection: Column,
condition: Option[String],
replaceWith: Double): Column = {
condition: Option[String]): Column = {
val origSelection = array(lit(InScopeData.name).as("source"), selection.as("selection"))
val filteredSelection = array(lit(FilteredData.name).as("source"), lit(replaceWith).as("selection"))

// The 2nd value in the array is set to null, but it can be set to anything.
// The value is not used to evaluate the row level outcome for filtered rows (to true/null).
// That decision is made using the 1st value which is set to "FilteredData" here.
val filteredSelection = array(lit(FilteredData.name).as("source"), lit(null).as("selection"))

condition
.map { cond => when(not(expr(cond)), filteredSelection).otherwise(origSelection) }
.map { cond => when(expr(cond), origSelection).otherwise(filteredSelection) }
.getOrElse(origSelection)
}

Expand Down
47 changes: 16 additions & 31 deletions src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ import com.amazon.deequ.analyzers.Preconditions.isString
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.element_at
import org.apache.spark.sql.functions.length
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType

Expand All @@ -35,12 +36,15 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
max(criterion) :: Nil
// The criterion returns a column where each row contains an array of 2 elements.
// The first element of the array is a string that indicates if the row is "in scope" or "filtered" out.
// The second element is the value used for calculating the metric. We use "element_at" to extract it.
max(element_at(criterion, 2).cast(DoubleType)) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = {
ifNoNullsIn(result, offset) { _ =>
MaxState(result.getDouble(offset), Some(rowLevelResults))
MaxState(result.getDouble(offset), Some(criterion))
}
}

Expand All @@ -51,35 +55,16 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio
override def filterCondition: Option[String] = where

private[deequ] def criterion: Column = {
transformColForNullBehavior
}

private[deequ] def rowLevelResults: Column = {
transformColForFilteredRow(criterion)
}

private def transformColForFilteredRow(col: Column): Column = {
val whereNotCondition = where.map { expression => not(expr(expression)) }
getRowLevelFilterTreatment(analyzerOptions) match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MinValue)
case _ =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null)
}
}

private def transformColForNullBehavior: Column = {
val isNullCheck = col(column).isNull
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
getNullBehavior match {
case NullBehavior.Fail =>
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue)
case NullBehavior.EmptyString =>
// Empty String is 0 length string
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType)
case _ =>
colLengths
val colLength = length(col(column)).cast(DoubleType)
val updatedColumn = getNullBehavior match {
case NullBehavior.Fail => when(isNullCheck, Double.MaxValue).otherwise(colLength)
// Empty String is 0 length string
case NullBehavior.EmptyString => when(isNullCheck, lit(0.0)).otherwise(colLength)
case NullBehavior.Ignore => colLength
}

conditionalSelectionWithAugmentedOutcome(updatedColumn, where)
}

private def getNullBehavior: NullBehavior = {
Expand Down
20 changes: 14 additions & 6 deletions src/main/scala/com/amazon/deequ/analyzers/Maximum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric}
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.functions.{col, element_at, max}
import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.Preconditions.hasColumn
import com.amazon.deequ.analyzers.Preconditions.isNumeric
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.element_at
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row

case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MaxState] with FullColumn {
Expand All @@ -41,6 +46,9 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
// The criterion returns a column where each row contains an array of 2 elements.
// The first element of the array is a string that indicates if the row is "in scope" or "filtered" out.
// The second element is the value used for calculating the metric. We use "element_at" to extract it.
max(element_at(criterion, 2).cast(DoubleType)) :: Nil
}

Expand All @@ -57,5 +65,5 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions
override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MinValue)
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where)
}
47 changes: 16 additions & 31 deletions src/main/scala/com/amazon/deequ/analyzers/MinLength.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ import com.amazon.deequ.analyzers.Preconditions.isString
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.element_at
import org.apache.spark.sql.functions.length
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.min
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType

Expand All @@ -35,12 +36,15 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
min(criterion) :: Nil
// The criterion returns a column where each row contains an array of 2 elements.
// The first element of the array is a string that indicates if the row is "in scope" or "filtered" out.
// The second element is the value used for calculating the metric. We use "element_at" to extract it.
min(element_at(criterion, 2).cast(DoubleType)) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = {
ifNoNullsIn(result, offset) { _ =>
MinState(result.getDouble(offset), Some(rowLevelResults))
MinState(result.getDouble(offset), Some(criterion))
}
}

Expand All @@ -51,35 +55,16 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio
override def filterCondition: Option[String] = where

private[deequ] def criterion: Column = {
transformColForNullBehavior
}

private[deequ] def rowLevelResults: Column = {
transformColForFilteredRow(criterion)
}

private def transformColForFilteredRow(col: Column): Column = {
val whereNotCondition = where.map { expression => not(expr(expression)) }
getRowLevelFilterTreatment(analyzerOptions) match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MaxValue)
case _ =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null)
}
}

private def transformColForNullBehavior: Column = {
val isNullCheck = col(column).isNull
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
getNullBehavior match {
case NullBehavior.Fail =>
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue)
case NullBehavior.EmptyString =>
// Empty String is 0 length string
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType)
case _ =>
colLengths
val colLength = length(col(column)).cast(DoubleType)
val updatedColumn = getNullBehavior match {
case NullBehavior.Fail => when(isNullCheck, Double.MinValue).otherwise(colLength)
// Empty String is 0 length string
case NullBehavior.EmptyString => when(isNullCheck, lit(0.0)).otherwise(colLength)
case NullBehavior.Ignore => colLength
}

conditionalSelectionWithAugmentedOutcome(updatedColumn, where)
}

private def getNullBehavior: NullBehavior = {
Expand Down
20 changes: 14 additions & 6 deletions src/main/scala/com/amazon/deequ/analyzers/Minimum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric}
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.functions.{col, element_at, min}
import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.Preconditions.hasColumn
import com.amazon.deequ.analyzers.Preconditions.isNumeric
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.element_at
import org.apache.spark.sql.functions.min
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row

case class MinState(minValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MinState] with FullColumn {
Expand All @@ -41,6 +46,9 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
// The criterion returns a column where each row contains an array of 2 elements.
// The first element of the array is a string that indicates if the row is "in scope" or "filtered" out.
// The second element is the value used for calculating the metric. We use "element_at" to extract it.
min(element_at(criterion, 2).cast(DoubleType)) :: Nil
}

Expand All @@ -57,5 +65,5 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions
override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MaxValue)
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where)
}
61 changes: 47 additions & 14 deletions src/main/scala/com/amazon/deequ/constraints/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,8 @@ object Constraint {
val constraint = AnalysisBasedConstraint[MaxState, Double, Double](maxLength, assertion,
hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
val updatedAssertion = getUpdatedRowLevelAssertionForLengthConstraint(assertion, maxLength.analyzerOptions)
val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion)

new RowLevelAssertedConstraint(
constraint,
Expand Down Expand Up @@ -593,7 +594,8 @@ object Constraint {
val constraint = AnalysisBasedConstraint[MinState, Double, Double](minLength, assertion,
hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
val updatedAssertion = getUpdatedRowLevelAssertionForLengthConstraint(assertion, minLength.analyzerOptions)
val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion)

new RowLevelAssertedConstraint(
constraint,
Expand Down Expand Up @@ -953,26 +955,57 @@ object Constraint {
}
}

def filteredRowOutcome: java.lang.Boolean = {
analyzerOptions match {
case Some(opts) =>
opts.filteredRow match {
case FilteredRowOutcome.TRUE => true
case FilteredRowOutcome.NULL => null
}
// https://github.com/awslabs/deequ/issues/530
// Filtered rows should be marked as true by default.
// They can be set to null using the FilteredRowOutcome option.
case None => true
scope match {
case FilteredData.name => filteredRowOutcome(analyzerOptions)
case InScopeData.name => inScopeRowOutcome(value)
}
}
}

private[this] def getUpdatedRowLevelAssertionForLengthConstraint(assertion: Double => Boolean,
analyzerOptions: Option[AnalyzerOptions])
: Seq[String] => java.lang.Boolean = {
(d: Seq[String]) => {
val (scope, value) = (d.head, Option(d.last).map(_.toDouble))

def inScopeRowOutcome(value: Option[Double]): java.lang.Boolean = {
if (value.isDefined) {
// If value is defined, run it through the assertion.
assertion(value.get)
} else {
// If value is not defined (value is null), apply NullBehavior.
analyzerOptions match {
case Some(opts) =>
opts.nullBehavior match {
case NullBehavior.EmptyString => assertion(0.0)
case NullBehavior.Fail => false
case NullBehavior.Ignore => null
}
case None => null
}
}
}

scope match {
case FilteredData.name => filteredRowOutcome
case FilteredData.name => filteredRowOutcome(analyzerOptions)
case InScopeData.name => inScopeRowOutcome(value)
}
}
}

private def filteredRowOutcome(analyzerOptions: Option[AnalyzerOptions]): java.lang.Boolean = {
analyzerOptions match {
case Some(opts) =>
opts.filteredRow match {
case FilteredRowOutcome.TRUE => true
case FilteredRowOutcome.NULL => null
}
// https://github.com/awslabs/deequ/issues/530
// Filtered rows should be marked as true by default.
// They can be set to null using the FilteredRowOutcome option.
case None => true
}
}
}

/**
Expand Down
Loading

0 comments on commit a5dc5bf

Please sign in to comment.