From 5338fe46ea8f0da9762204deb6f63a04a0e1c68f Mon Sep 17 00:00:00 2001 From: Hubert Date: Fri, 20 Sep 2024 20:20:35 -0400 Subject: [PATCH 01/24] added more tests to the anomaly detection with extended results changes --- .../amazon/deequ/VerificationRunBuilder.scala | 46 ++- .../applicability/Applicability.scala | 13 +- .../AnomalyDetectionStrategy.scala | 14 + .../anomalydetection/AnomalyDetector.scala | 91 +++++- .../anomalydetection/BaseChangeStrategy.scala | 56 +++- .../BatchNormalStrategy.scala | 42 ++- .../ExtendedDetectionResult.scala | 145 +++++++++ .../OnlineNormalStrategy.scala | 35 +- .../SimpleThresholdStrategy.scala | 43 ++- .../seasonal/HoltWinters.scala | 73 ++++- .../scala/com/amazon/deequ/checks/Check.scala | 157 ++++++++- .../AnomalyExtendedResultsConstraint.scala | 132 ++++++++ .../amazon/deequ/constraints/Constraint.scala | 37 ++- ...yDetectionWithExtendedResultsExample.scala | 98 ++++++ .../amazon/deequ/VerificationSuiteTest.scala | 134 +++++++- .../AbsoluteChangeStrategyTest.scala | 168 +++++++++- .../AnomalyDetectorTest.scala | 112 ++++++- .../BatchNormalStrategyTest.scala | 141 +++++++- .../OnlineNormalStrategyTest.scala | 168 +++++++++- .../RateOfChangeStrategyTest.scala | 48 ++- .../RelativeRateOfChangeStrategyTest.scala | 164 +++++++++- .../SimpleThresholdStrategyTest.scala | 75 ++++- .../seasonal/HoltWintersTest.scala | 221 ++++++++++++- .../deequ/checks/ApplicabilityTest.scala | 28 +- .../com/amazon/deequ/checks/CheckTest.scala | 231 +++++++++++++- ...AnomalyExtendedResultsConstraintTest.scala | 300 ++++++++++++++++++ .../deequ/constraints/ConstraintUtils.scala | 15 +- .../deequ/constraints/ConstraintsTest.scala | 32 ++ ...itoryAnomalyDetectionIntegrationTest.scala | 83 ++++- 29 files changed, 2763 insertions(+), 139 deletions(-) create mode 100644 src/main/scala/com/amazon/deequ/anomalydetection/ExtendedDetectionResult.scala create mode 100644 src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala create mode 100644 src/main/scala/com/amazon/deequ/examples/AnomalyDetectionWithExtendedResultsExample.scala create mode 100644 src/test/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraintTest.scala diff --git a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala index a4ee45f6b..cd4c89a49 100644 --- a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala +++ b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala @@ -16,7 +16,7 @@ package com.amazon.deequ -import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategy +import com.amazon.deequ.anomalydetection.{AnomalyDetectionStrategy, AnomalyDetectionStrategyWithExtendedResults} import com.amazon.deequ.analyzers.Analyzer import com.amazon.deequ.analyzers.{State, _} import com.amazon.deequ.checks.{Check, CheckLevel} @@ -241,6 +241,24 @@ class VerificationRunBuilderWithRepository( anomalyDetectionStrategy, analyzer, anomalyCheckConfigOrDefault) this } + + def addAnomalyCheckWithExtendedResults[S <: State[S]]( + anomalyDetectionStrategy: AnomalyDetectionStrategyWithExtendedResults, + analyzer: Analyzer[S, Metric[Double]], + anomalyCheckConfig: Option[AnomalyCheckConfig] = None) + : this.type = { + + val anomalyCheckConfigOrDefault = anomalyCheckConfig.getOrElse { + + val checkDescription = s"Anomaly check for ${analyzer.toString}" + + AnomalyCheckConfig(CheckLevel.Warning, checkDescription) + } + + checks :+= VerificationRunBuilderHelper.getAnomalyCheckWithExtendedResults( + metricsRepository.get, anomalyDetectionStrategy, analyzer, anomalyCheckConfigOrDefault) + this + } } class VerificationRunBuilderWithSparkSession( @@ -316,6 +334,32 @@ private[this] object VerificationRunBuilderHelper { anomalyCheckConfig.beforeDate ) } + + /** + * Build a check using Anomaly Detection with extended results methods + * + * @param metricsRepository A metrics repository to get the previous results + * @param anomalyDetectionStrategyWithExtendedResults The anomaly detection strategy with extended results + * @param analyzer The analyzer for the metric to run anomaly detection on + * @param anomalyCheckConfig Some configuration settings for the Check + */ + def getAnomalyCheckWithExtendedResults[S <: State[S]]( + metricsRepository: MetricsRepository, + anomalyDetectionStrategyWithExtendedResults: AnomalyDetectionStrategyWithExtendedResults, + analyzer: Analyzer[S, Metric[Double]], + anomalyCheckConfig: AnomalyCheckConfig) + : Check = { + + Check(anomalyCheckConfig.level, anomalyCheckConfig.description) + .isNewestPointNonAnomalousWithExtendedResults( + metricsRepository, + anomalyDetectionStrategyWithExtendedResults, + analyzer, + anomalyCheckConfig.withTagValues, + anomalyCheckConfig.afterDate, + anomalyCheckConfig.beforeDate + ) + } } /** diff --git a/src/main/scala/com/amazon/deequ/analyzers/applicability/Applicability.scala b/src/main/scala/com/amazon/deequ/analyzers/applicability/Applicability.scala index e2c282c14..dc55c84cf 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/applicability/Applicability.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/applicability/Applicability.scala @@ -21,7 +21,8 @@ import java.util.Calendar import com.amazon.deequ.analyzers.{Analyzer, State} import com.amazon.deequ.checks.Check -import com.amazon.deequ.constraints.{AnalysisBasedConstraint, Constraint, ConstraintDecorator} +import com.amazon.deequ.constraints.{AnalysisBasedConstraint, AnomalyExtendedResultsConstraint, + Constraint, ConstraintDecorator} import com.amazon.deequ.metrics.Metric import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row, SparkSession} @@ -187,9 +188,13 @@ private[deequ] class Applicability(session: SparkSession) { case (name, nc: ConstraintDecorator) => name -> nc.inner case (name, c: Constraint) => name -> c } - .collect { case (name, constraint: AnalysisBasedConstraint[_, _, _]) => - val metric = constraint.analyzer.calculate(data).value - name -> metric + .collect { + case (name, constraint: AnalysisBasedConstraint[_, _, _]) => + val metric = constraint.analyzer.calculate(data).value + name -> metric + case (name, constraint: AnomalyExtendedResultsConstraint[_, _, _]) => + val metric = constraint.analyzer.calculate(data).value + name -> metric } val constraintApplicabilities = check.constraints.zip(namedMetrics).map { diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/AnomalyDetectionStrategy.scala b/src/main/scala/com/amazon/deequ/anomalydetection/AnomalyDetectionStrategy.scala index 0c3f6805e..5e48e96bc 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/AnomalyDetectionStrategy.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/AnomalyDetectionStrategy.scala @@ -30,3 +30,17 @@ trait AnomalyDetectionStrategy { dataSeries: Vector[Double], searchInterval: (Int, Int) = (0, Int.MaxValue)): Seq[(Int, Anomaly)] } +trait AnomalyDetectionStrategyWithExtendedResults { + + /** + * Search for anomalies in a series of data points, returns extended results. + * + * @param dataSeries The data contained in a Vector of Doubles + * @param searchInterval The indices between which anomalies should be detected. [a, b). + * @return The indices of all data points with their corresponding anomaly extended results wrapper + * object. + */ + def detectWithExtendedResults( + dataSeries: Vector[Double], + searchInterval: (Int, Int) = (0, Int.MaxValue)): Seq[(Int, AnomalyDetectionDataPoint)] +} diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/AnomalyDetector.scala b/src/main/scala/com/amazon/deequ/anomalydetection/AnomalyDetector.scala index e7146c0e9..96f3925af 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/AnomalyDetector.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/AnomalyDetector.scala @@ -56,12 +56,8 @@ case class AnomalyDetector(strategy: AnomalyDetectionStrategy) { val allDataPoints = sortedDataPoints :+ newPoint - // Run anomaly - val anomalies = detectAnomaliesInHistory(allDataPoints, (newPoint.time, Long.MaxValue)) - .anomalies - - // Create a Detection result with all anomalies - DetectionResult(anomalies) + // Run anomaly and create a Detection result with all anomalies + detectAnomaliesInHistory(allDataPoints, (newPoint.time, Long.MaxValue)) } /** @@ -100,3 +96,86 @@ case class AnomalyDetector(strategy: AnomalyDetectionStrategy) { DetectionResult(anomalies.map { case (index, anomaly) => (sortedTimestamps(index), anomaly) }) } } + +case class AnomalyDetectorWithExtendedResults(strategy: AnomalyDetectionStrategyWithExtendedResults) { + + + /** + * Given a sequence of metrics and a current value, detects if there is an anomaly by using the + * given algorithm and returns extended results. + * + * @param historicalDataPoints Sequence of tuples (Points in time with corresponding Metric). + * @param newPoint A new data point to check if there are anomalies + * @return + */ + def isNewPointAnomalousWithExtendedResults( + historicalDataPoints: Seq[DataPoint[Double]], + newPoint: DataPoint[Double]) + : ExtendedDetectionResult = { + + require(historicalDataPoints.nonEmpty, "historicalDataPoints must not be empty!") + + val sortedDataPoints = historicalDataPoints.sortBy(_.time) + + val firstDataPointTime = sortedDataPoints.head.time + val lastDataPointTime = sortedDataPoints.last.time + + val newPointTime = newPoint.time + + require(lastDataPointTime < newPointTime, + s"Can't decide which range to use for anomaly detection. New data point with time " + + s"$newPointTime is in history range ($firstDataPointTime - $lastDataPointTime)!") + + val allDataPoints = sortedDataPoints :+ newPoint + + // Run anomaly and create an Extended Detection result with all data points and anomaly details + detectAnomaliesInHistoryWithExtendedResults(allDataPoints, (newPoint.time, Long.MaxValue)) + } + + + /** + * Given a strategy, detects anomalies in a time series after some preprocessing + * and returns extended results. + * + * @param dataSeries Sequence of tuples (Points in time with corresponding value). + * @param searchInterval The interval in which anomalies should be detected. [a, b). + * @return A wrapper object, containing all data points with anomaly extended results. + */ + def detectAnomaliesInHistoryWithExtendedResults( + dataSeries: Seq[DataPoint[Double]], + searchInterval: (Long, Long) = (Long.MinValue, Long.MaxValue)) + : ExtendedDetectionResult = { + + def findIndexForBound(sortedTimestamps: Seq[Long], boundValue: Long): Int = { + sortedTimestamps.search(boundValue).insertionPoint + } + + val (searchStart, searchEnd) = searchInterval + + require(searchStart <= searchEnd, + "The first interval element has to be smaller or equal to the last.") + + // Remove missing values and sort series by time + val removedMissingValues = dataSeries.filter { + _.metricValue.isDefined + } + val sortedSeries = removedMissingValues.sortBy { + _.time + } + val sortedTimestamps = sortedSeries.map { + _.time + } + + // Find indices of lower and upper bound + val lowerBoundIndex = findIndexForBound(sortedTimestamps, searchStart) + val upperBoundIndex = findIndexForBound(sortedTimestamps, searchEnd) + + val anomalies = strategy.detectWithExtendedResults( + sortedSeries.flatMap { + _.metricValue + }.toVector, (lowerBoundIndex, upperBoundIndex)) + + ExtendedDetectionResult(anomalies.map { case (index, anomaly) => (sortedTimestamps(index), anomaly) }) + } + +} diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/BaseChangeStrategy.scala b/src/main/scala/com/amazon/deequ/anomalydetection/BaseChangeStrategy.scala index e00c86772..0ac353223 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/BaseChangeStrategy.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/BaseChangeStrategy.scala @@ -27,7 +27,7 @@ import breeze.linalg.DenseVector * Set to 1 it calculates the difference between two consecutive values. */ trait BaseChangeStrategy - extends AnomalyDetectionStrategy { + extends AnomalyDetectionStrategy with AnomalyDetectionStrategyWithExtendedResults { def maxRateDecrease: Option[Double] def maxRateIncrease: Option[Double] @@ -67,7 +67,8 @@ trait BaseChangeStrategy } /** - * Search for anomalies in a series of data points. + * Search for anomalies in a series of data points. This function uses the + * detectWithExtendedResults function and then filters and maps to return only anomaly data point objects. * * If there aren't enough data points preceding the searchInterval, * it may happen that the interval's first elements (depending on the specified order) @@ -81,6 +82,30 @@ trait BaseChangeStrategy dataSeries: Vector[Double], searchInterval: (Int, Int)) : Seq[(Int, Anomaly)] = { + + detectWithExtendedResults(dataSeries, searchInterval) + .filter { case (_, anomDataPoint) => anomDataPoint.isAnomaly } + .map { case (i, anomDataPoint) => + (i, Anomaly(Some(anomDataPoint.dataMetricValue), anomDataPoint.confidence, anomDataPoint.detail)) + } + } + + /** + * Search for anomalies in a series of data points, returns extended results. + * + * If there aren't enough data points preceding the searchInterval, + * it may happen that the interval's first elements (depending on the specified order) + * can't be flagged as anomalies. + * + * @param dataSeries The data contained in a Vector of Doubles + * @param searchInterval The indices between which anomalies should be detected. [a, b). + * @return The indices of all anomalies in the interval and their corresponding wrapper object + * with extended results. + */ + override def detectWithExtendedResults( + dataSeries: Vector[Double], + searchInterval: (Int, Int)) + : Seq[(Int, AnomalyDetectionDataPoint)] = { val (start, end) = searchInterval require(start <= end, @@ -89,15 +114,24 @@ trait BaseChangeStrategy val startPoint = Seq(start - order, 0).max val data = diff(DenseVector(dataSeries.slice(startPoint, end): _*), order).data - data.zipWithIndex.filter { case (value, _) => - (value < maxRateDecrease.getOrElse(Double.MinValue) - || value > maxRateIncrease.getOrElse(Double.MaxValue)) - } - .map { case (change, index) => - (index + startPoint + order, Anomaly(Option(dataSeries(index + startPoint + order)), 1.0, - Some(s"[AbsoluteChangeStrategy]: Change of $change is not in bounds [" + - s"${maxRateDecrease.getOrElse(Double.MinValue)}, " + - s"${maxRateIncrease.getOrElse(Double.MaxValue)}]. Order=$order"))) + val lowerBound = maxRateDecrease.getOrElse(Double.MinValue) + val upperBound = maxRateIncrease.getOrElse(Double.MaxValue) + + + data.zipWithIndex.map { + case (change, index) => + val outputSequenceIndex = index + startPoint + order + val value = dataSeries(outputSequenceIndex) + val (detail, isAnomaly) = if (change < lowerBound || change > upperBound) { + (Some(s"[AbsoluteChangeStrategy]: Change of $change is not in bounds [" + + s"$lowerBound, " + + s"$upperBound]. Order=$order"), true) + } + else { + (None, false) + } + (outputSequenceIndex, AnomalyDetectionDataPoint(value, change, + Threshold(lowerBound = Bound(lowerBound), upperBound = Bound(upperBound)), isAnomaly, 1.0, detail)) } } } diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategy.scala b/src/main/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategy.scala index baff49c03..7d4bb6304 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategy.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategy.scala @@ -33,7 +33,9 @@ import breeze.stats.meanAndVariance case class BatchNormalStrategy( lowerDeviationFactor: Option[Double] = Some(3.0), upperDeviationFactor: Option[Double] = Some(3.0), - includeInterval: Boolean = false) extends AnomalyDetectionStrategy { + includeInterval: Boolean = false) + extends AnomalyDetectionStrategy with AnomalyDetectionStrategyWithExtendedResults + { require(lowerDeviationFactor.isDefined || upperDeviationFactor.isDefined, "At least one factor has to be specified.") @@ -43,7 +45,8 @@ case class BatchNormalStrategy( /** - * Search for anomalies in a series of data points. + * Search for anomalies in a series of data points. This function uses the + * detectWithExtendedResults function and then filters and maps to return only anomaly objects. * * @param dataSeries The data contained in a Vector of Doubles * @param searchInterval The indices between which anomalies should be detected. [a, b). @@ -53,6 +56,25 @@ case class BatchNormalStrategy( dataSeries: Vector[Double], searchInterval: (Int, Int)): Seq[(Int, Anomaly)] = { + detectWithExtendedResults(dataSeries, searchInterval) + .filter { case (_, anomDataPoint) => anomDataPoint.isAnomaly } + .map { case (i, anomDataPoint) => + (i, Anomaly(Some(anomDataPoint.dataMetricValue), anomDataPoint.confidence, anomDataPoint.detail)) + } + } + + /** + * Search for anomalies in a series of data points, returns extended results. + * + * @param dataSeries The data contained in a Vector of Doubles + * @param searchInterval The indices between which anomalies should be detected. [a, b). + * @return The indices of all anomalies in the interval and their corresponding wrapper object + * with extended results. + */ + override def detectWithExtendedResults( + dataSeries: Vector[Double], + searchInterval: (Int, Int)): Seq[(Int, AnomalyDetectionDataPoint)] = { + val (searchStart, searchEnd) = searchInterval require(searchStart <= searchEnd, "The start of the interval can't be larger than the end.") @@ -83,13 +105,17 @@ case class BatchNormalStrategy( dataSeries.zipWithIndex .slice(searchStart, searchEnd) - .filter { case (value, _) => value > upperBound || value < lowerBound } .map { case (value, index) => - - val detail = Some(s"[BatchNormalStrategy]: Value $value is not in " + - s"bounds [$lowerBound, $upperBound].") - - (index, Anomaly(Option(value), 1.0, detail)) + val (detail, isAnomaly) = if (value > upperBound || value < lowerBound) { + (Some(s"[BatchNormalStrategy]: Value $value is not in " + + s"bounds [$lowerBound, $upperBound]."), true) + } else { + (None, false) + } + (index, AnomalyDetectionDataPoint(value, value, + Threshold(lowerBound = Bound(lowerBound), upperBound = Bound(upperBound)), isAnomaly, 1.0, detail)) } } + + } diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/ExtendedDetectionResult.scala b/src/main/scala/com/amazon/deequ/anomalydetection/ExtendedDetectionResult.scala new file mode 100644 index 000000000..7e024b2cf --- /dev/null +++ b/src/main/scala/com/amazon/deequ/anomalydetection/ExtendedDetectionResult.scala @@ -0,0 +1,145 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.anomalydetection + +/** + * The classes here provide the same anomaly detection functionality as in DetectionResult + * but also provide extended results through details contained in the AnomalyDetectionDataPoint class. + * See below. + */ + +/** + * Anomaly Detection Data Point class + * This class is different from the Anomaly Class in that this class + * wraps around all data points, not just anomalies, and provides extended results including + * if the data point is an anomaly, and the thresholds used in the anomaly calculation. + * + * @param dataMetricValue The metric value that is the data point. + * @param anomalyMetricValue The metric value that is being used in the anomaly calculation. + * This usually aligns with dataMetricValue but not always, + * like in a rate of change strategy where the rate of change is the anomaly metric + * which may not equal the actual data point value. + * @param anomalyThreshold The thresholds used in the anomaly check, the anomalyMetricValue is + * compared to this threshold. + * @param isAnomaly If the data point is an anomaly. + * @param confidence Confidence of anomaly detection. + * @param detail Detailed error message. + */ +class AnomalyDetectionDataPoint( + val dataMetricValue: Double, + val anomalyMetricValue: Double, + val anomalyThreshold: Threshold, + val isAnomaly: Boolean, + val confidence: Double, + val detail: Option[String]) + { + + def canEqual(that: Any): Boolean = { + that.isInstanceOf[AnomalyDetectionDataPoint] + } + + /** + * Tests anomalyDetectionDataPoints for equality. Ignores detailed error message. + * + * @param obj The object/ anomaly to compare against. + * @return true, if and only if the dataMetricValue, anomalyMetricValue, anomalyThreshold, isAnomaly + * and confidence are the same. + */ + override def equals(obj: Any): Boolean = { + obj match { + case anomaly: AnomalyDetectionDataPoint => + anomaly.dataMetricValue == dataMetricValue && + anomaly.anomalyMetricValue == anomalyMetricValue && + anomaly.anomalyThreshold == anomalyThreshold && + anomaly.isAnomaly == isAnomaly && + anomaly.confidence == confidence + case _ => false + } + } + + override def hashCode: Int = { + val prime = 31 + var result = 1 + result = prime * result + dataMetricValue.hashCode() + result = prime * result + anomalyMetricValue.hashCode() + result = prime * result + anomalyThreshold.hashCode() + result = prime * result + isAnomaly.hashCode() + result = prime * result + confidence.hashCode() + result + } + +} + +object AnomalyDetectionDataPoint { + def apply(dataMetricValue: Double, anomalyMetricValue: Double, + anomalyThreshold: Threshold = Threshold(), isAnomaly: Boolean = false, + confidence: Double, detail: Option[String] = None + ): AnomalyDetectionDataPoint = { + new AnomalyDetectionDataPoint(dataMetricValue, anomalyMetricValue, anomalyThreshold, isAnomaly, confidence, detail) + } +} + + +/** + * Threshold class + * Defines threshold for the anomaly detection, defaults to inclusive bounds of Double.Min and Double.Max. + * @param upperBound The upper bound or threshold. + * @param lowerBound The lower bound or threshold. + */ +case class Threshold(lowerBound: Bound = Bound(Double.MinValue), upperBound: Bound = Bound(Double.MaxValue)) + +/** + * Bound Class + * Class representing a threshold/bound, with value and inclusive/exclusive boolean/ + * @param value The value of the bound as a Double. + * @param inclusive Boolean indicating if the Bound is inclusive or not. + */ +case class Bound(value: Double, inclusive: Boolean = true) + + + +/** + * ExtendedDetectionResult Class + * This class is returned from the detectAnomaliesInHistoryWithExtendedResults function. + * @param anomalyDetectionDataPointSequence The sequence of (timestamp, AnomalyDetectionDataPoint) pairs. + */ +case class ExtendedDetectionResult(anomalyDetectionDataPointSequence: + Seq[(Long, AnomalyDetectionDataPoint)] = Seq.empty) + + +/** + * AnomalyDetectionExtendedResult Class + * This class contains anomaly detection extended results through a Sequence of AnomalyDetectionDataPoints. + * This is currently an optional field in the ConstraintResult class that is exposed to users. + * + * Currently, anomaly detection only runs on "newest" data point (referring to the dataframe being + * run on by the verification suite) and not multiple data points, so the returned sequence will contain + * one AnomalyDetectionDataPoint. + * In the future, if we allow the anomaly check to detect multiple points, the returned sequence + * may be more than one AnomalyDetectionDataPoints. + * @param anomalyDetectionDataPoints Sequence of AnomalyDetectionDataPoints. + */ +case class AnomalyDetectionExtendedResult(anomalyDetectionDataPoints: Seq[AnomalyDetectionDataPoint]) + +/** + * AnomalyDetectionAssertionResult Class + * This class is returned by the assertion function Check.isNewestPointNonAnomalousWithExtendedResults. + * @param hasNoAnomaly Boolean indicating if there was no anomaly detected. + * @param anomalyDetectionExtendedResult AnomalyDetectionExtendedResults class. + */ +case class AnomalyDetectionAssertionResult(hasNoAnomaly: Boolean, + anomalyDetectionExtendedResult: AnomalyDetectionExtendedResult) diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategy.scala b/src/main/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategy.scala index 8bf8b634c..3955eae16 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategy.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategy.scala @@ -40,7 +40,8 @@ case class OnlineNormalStrategy( lowerDeviationFactor: Option[Double] = Some(3.0), upperDeviationFactor: Option[Double] = Some(3.0), ignoreStartPercentage: Double = 0.1, - ignoreAnomalies: Boolean = true) extends AnomalyDetectionStrategy { + ignoreAnomalies: Boolean = true) + extends AnomalyDetectionStrategy with AnomalyDetectionStrategyWithExtendedResults { require(lowerDeviationFactor.isDefined || upperDeviationFactor.isDefined, "At least one factor has to be specified.") @@ -121,9 +122,10 @@ case class OnlineNormalStrategy( /** - * Search for anomalies in a series of data points. + * Search for anomalies in a series of data points. This function uses the + * detectWithExtendedResults function and then filters and maps to return only anomaly objects. * - * @param dataSeries The data contained in a Vector of Doubles + * @param dataSeries The data contained in a Vector of Doubles. * @param searchInterval The indices between which anomalies should be detected. [a, b). * @return The indices of all anomalies in the interval and their corresponding wrapper object. */ @@ -132,6 +134,26 @@ case class OnlineNormalStrategy( searchInterval: (Int, Int)) : Seq[(Int, Anomaly)] = { + detectWithExtendedResults(dataSeries, searchInterval) + .filter { case (_, anomDataPoint) => anomDataPoint.isAnomaly } + .map { case (i, anomDataPoint) => + (i, Anomaly(Some(anomDataPoint.dataMetricValue), anomDataPoint.confidence, anomDataPoint.detail)) + } + } + + /** + * Search for anomalies in a series of data points, returns extended results. + * + * @param dataSeries The data contained in a Vector of Doubles. + * @param searchInterval The indices between which anomalies should be detected. [a, b). + * @return The indices of all anomalies in the interval and their corresponding wrapper object + * with extended results. + */ + override def detectWithExtendedResults( + dataSeries: Vector[Double], + searchInterval: (Int, Int)) + : Seq[(Int, AnomalyDetectionDataPoint)] = { + val (searchStart, searchEnd) = searchInterval require(searchStart <= searchEnd, "The start of the interval can't be larger than the end.") @@ -139,7 +161,6 @@ case class OnlineNormalStrategy( computeStatsAndAnomalies(dataSeries, searchInterval) .zipWithIndex .slice(searchStart, searchEnd) - .filter { case (result, _) => result.isAnomaly } .map { case (calcRes, index) => val lowerBound = calcRes.mean - lowerDeviationFactor.getOrElse(Double.MaxValue) * calcRes.stdDev @@ -149,7 +170,11 @@ case class OnlineNormalStrategy( val detail = Some(s"[OnlineNormalStrategy]: Value ${dataSeries(index)} is not in " + s"bounds [$lowerBound, $upperBound].") - (index, Anomaly(Option(dataSeries(index)), 1.0, detail)) + val value = dataSeries(index) + + (index, AnomalyDetectionDataPoint(value, value, + Threshold(lowerBound = Bound(lowerBound), upperBound = Bound(upperBound)), + calcRes.isAnomaly, 1.0, detail)) } } } diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategy.scala b/src/main/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategy.scala index ec7f5df74..03d30c7c7 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategy.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategy.scala @@ -25,34 +25,61 @@ package com.amazon.deequ.anomalydetection case class SimpleThresholdStrategy( lowerBound: Double = Double.MinValue, upperBound: Double) - extends AnomalyDetectionStrategy { + extends AnomalyDetectionStrategy with AnomalyDetectionStrategyWithExtendedResults { require(lowerBound <= upperBound, "The lower bound must be smaller or equal to the upper bound.") /** - * Search for anomalies in a series of data points. + * Search for anomalies in a series of data points. This function uses the + * detectWithExtendedResults function and then filters and maps to return only anomaly objects. * - * @param dataSeries The data contained in a Vector of Doubles + * @param dataSeries The data contained in a Vector of Doubles. * @param searchInterval The indices between which anomalies should be detected. [a, b). * @return The indices of all anomalies in the interval and their corresponding wrapper object. */ override def detect( + dataSeries: Vector[Double], + searchInterval: (Int, Int)) + : Seq[(Int, Anomaly)] = { + + detectWithExtendedResults(dataSeries, searchInterval) + .filter { case (_, anomDataPoint) => anomDataPoint.isAnomaly } + .map { case (i, anomDataPoint) => + (i, Anomaly(Some(anomDataPoint.dataMetricValue), anomDataPoint.confidence, anomDataPoint.detail)) + } + } + + /** + * Search for anomalies in a series of data points, returns extended results. + * + * @param dataSeries The data contained in a Vector of Doubles. + * @param searchInterval The indices between which anomalies should be detected. [a, b). + * @return The indices of all anomalies in the interval and their corresponding wrapper object + * with extended results. + */ + override def detectWithExtendedResults( dataSeries: Vector[Double], - searchInterval: (Int, Int)): Seq[(Int, Anomaly)] = { + searchInterval: (Int, Int)): Seq[(Int, AnomalyDetectionDataPoint)] = { val (searchStart, searchEnd) = searchInterval - require (searchStart <= searchEnd, "The start of the interval can't be larger than the end.") + require(searchStart <= searchEnd, "The start of the interval can't be larger than the end.") dataSeries.zipWithIndex .slice(searchStart, searchEnd) .filter { case (value, _) => value < lowerBound || value > upperBound } .map { case (value, index) => - val detail = Some(s"[SimpleThresholdStrategy]: Value $value is not in " + - s"bounds [$lowerBound, $upperBound]") + val (detail, isAnomaly) = if (value < lowerBound || value > upperBound) { + (Some(s"[SimpleThresholdStrategy]: Value $value is not in " + + s"bounds [$lowerBound, $upperBound]"), true) + } else { + (None, false) + } - (index, Anomaly(Option(value), 1.0, detail)) + (index, AnomalyDetectionDataPoint(value, value, + Threshold(lowerBound = Bound(lowerBound), upperBound = Bound(upperBound)), + isAnomaly, 1.0, detail)) } } } diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala b/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala index 0ee0ac25f..ec7db67e4 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala @@ -18,7 +18,7 @@ package com.amazon.deequ.anomalydetection.seasonal import breeze.linalg.DenseVector import breeze.optimize.{ApproximateGradientFunction, DiffFunction, LBFGSB} -import com.amazon.deequ.anomalydetection.{Anomaly, AnomalyDetectionStrategy} +import com.amazon.deequ.anomalydetection.{Anomaly, AnomalyDetectionDataPoint, AnomalyDetectionStrategy, AnomalyDetectionStrategyWithExtendedResults, Threshold, Bound} import collection.mutable.ListBuffer @@ -63,7 +63,7 @@ object HoltWinters { class HoltWinters( metricsInterval: HoltWinters.MetricInterval.Value, seasonality: HoltWinters.SeriesSeasonality.Value) - extends AnomalyDetectionStrategy { + extends AnomalyDetectionStrategy with AnomalyDetectionStrategyWithExtendedResults { import HoltWinters._ @@ -173,37 +173,76 @@ class HoltWinters( ) } - private def findAnomalies( - testSeries: Vector[Double], - forecasts: Seq[Double], - startIndex: Int, - residualSD: Double) - : Seq[(Int, Anomaly)] = { - testSeries.zip(forecasts).zipWithIndex - .collect { case ((inputValue, forecastedValue), detectionIndex) - if math.abs(inputValue - forecastedValue) > 1.96 * residualSD => + /** + * This function is renamed to add 'withExtendedResults' to the name. + * The functionality no longer filters out non anomalies, but instead leaves a flag + * of whether it's anomaly or not. The previous anomaly detection strategy uses this refactored function + * and then does the filtering to remove non anomalies and maps to previous anomaly objects. + * The new anomaly detection strategy with extended results uses this function and does not filter on it. + */ + private def findAnomaliesWithExtendedResults( + testSeries: Vector[Double], + forecasts: Seq[Double], + startIndex: Int, + residualSD: Double) + : Seq[(Int, AnomalyDetectionDataPoint)] = { - detectionIndex + startIndex -> Anomaly( - value = Some(inputValue), + testSeries.zip(forecasts).zipWithIndex + .collect { case ((inputValue, forecastedValue), detectionIndex) => + val anomalyMetricValue = math.abs(inputValue - forecastedValue) + val upperBound = 1.96 * residualSD + + val (detail, isAnomaly) = if (anomalyMetricValue > upperBound) { + (Some(s"Forecasted $forecastedValue for observed value $inputValue"), true) + } else { + (None, false) + } + detectionIndex + startIndex -> AnomalyDetectionDataPoint( + dataMetricValue = inputValue, + anomalyMetricValue = anomalyMetricValue, + anomalyThreshold = Threshold(upperBound = Bound(upperBound)), + isAnomaly = isAnomaly, confidence = 1.0, - detail = Some(s"Forecasted $forecastedValue for observed value $inputValue") + detail = detail ) } } /** - * Search for anomalies in a series of data points. + * Search for anomalies in a series of data points. This function uses the + * detectWithExtendedResults function and then filters and maps to return only anomaly objects. * - * @param dataSeries The data contained in a Vector of Doubles + * @param dataSeries The data contained in a Vector of Doubles. * @param searchInterval The indices between which anomalies should be detected. [a, b). * @return The indices of all anomalies in the interval and their corresponding wrapper object. + * */ override def detect( dataSeries: Vector[Double], searchInterval: (Int, Int) = (0, Int.MaxValue)) : Seq[(Int, Anomaly)] = { + detectWithExtendedResults(dataSeries, searchInterval) + .filter { case (_, anomDataPoint) => anomDataPoint.isAnomaly } + .map { case (i, anomDataPoint) => + (i, Anomaly(Some(anomDataPoint.dataMetricValue), anomDataPoint.confidence, anomDataPoint.detail)) + } + } + + /** + * Search for anomalies in a series of data points, returns extended results. + * + * @param dataSeries The data contained in a Vector of Doubles. + * @param searchInterval The indices between which anomalies should be detected. [a, b). + * @return The indices of all anomalies in the interval and their corresponding wrapper object + * with extended results. + */ + override def detectWithExtendedResults( + dataSeries: Vector[Double], + searchInterval: (Int, Int) = (0, Int.MaxValue)) + : Seq[(Int, AnomalyDetectionDataPoint)] = { + require(dataSeries.nonEmpty, "Provided data series is empty") val (start, end) = searchInterval @@ -244,6 +283,6 @@ class HoltWinters( require(modelResults.forecasts.size == numberOfObservationsToForecast) val testSeries = dataSeries.drop(start) - findAnomalies(testSeries, modelResults.forecasts, start, residualsStandardDeviation) + findAnomaliesWithExtendedResults(testSeries, modelResults.forecasts, start, residualsStandardDeviation) } } diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index 9f6f6ea03..954dc763d 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -25,10 +25,7 @@ import com.amazon.deequ.analyzers.Histogram import com.amazon.deequ.analyzers.KLLParameters import com.amazon.deequ.analyzers.Patterns import com.amazon.deequ.analyzers.State -import com.amazon.deequ.anomalydetection.HistoryUtils -import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategy -import com.amazon.deequ.anomalydetection.AnomalyDetector -import com.amazon.deequ.anomalydetection.DataPoint +import com.amazon.deequ.anomalydetection.{AnomalyDetectionAssertionResult, AnomalyDetectionExtendedResult, ExtendedDetectionResult, AnomalyDetectionStrategy, AnomalyDetectionStrategyWithExtendedResults, AnomalyDetector, AnomalyDetectorWithExtendedResults, DataPoint, HistoryUtils} import com.amazon.deequ.checks.ColumnCondition.isAnyNotNull import com.amazon.deequ.checks.ColumnCondition.isEachNotNull import com.amazon.deequ.constraints.Constraint._ @@ -513,6 +510,44 @@ case class Check( addConstraint(anomalyConstraint(analyzer, anomalyAssertionFunction, hint)) } + /** + * Creates a constraint that runs AnomalyDetection with extended results on the new value. + * + * @param metricsRepository A metrics repository to get the previous results. + * @param anomalyDetectionStrategyWithExtendedResults The anomaly detection strategy with extended results. + * @param analyzer The analyzer for the metric to run anomaly detection on. + * @param withTagValues Can contain a Map with tag names and the corresponding values + * to filter for. + * @param beforeDate The maximum dateTime of previous AnalysisResults to use for + * the Anomaly Detection. + * @param afterDate The minimum dateTime of previous AnalysisResults to use for + * the Anomaly Detection. + * @param hint A hint to provide additional context why a constraint + * could have failed. + * @return + */ + private[deequ] def isNewestPointNonAnomalousWithExtendedResults[S <: State[S]]( + metricsRepository: MetricsRepository, + anomalyDetectionStrategyWithExtendedResults: AnomalyDetectionStrategyWithExtendedResults, + analyzer: Analyzer[S, Metric[Double]], + withTagValues: Map[String, String], + afterDate: Option[Long], + beforeDate: Option[Long], + hint: Option[String] = None) + : Check = { + + val anomalyAssertionFunction = Check.isNewestPointNonAnomalousWithExtendedResults( + metricsRepository, + anomalyDetectionStrategyWithExtendedResults, + analyzer, + withTagValues, + afterDate, + beforeDate + )(_) + + addConstraint(anomalyConstraintWithExtendedResults(analyzer, anomalyAssertionFunction, hint)) + } + /** * Creates a constraint that asserts on a column entropy. @@ -1159,6 +1194,7 @@ case class Check( } .collect { case constraint: AnalysisBasedConstraint[_, _, _] => constraint.analyzer + case constraint: AnomalyExtendedResultsConstraint[_, _, _] => constraint.analyzer } .map { _.asInstanceOf[Analyzer[_, Metric[_]]] } .toSet @@ -1251,4 +1287,117 @@ object Check { detectedAnomalies.anomalies.isEmpty } + + + /** + * Common assertion function checking if the value can be considered as normal (that no + * anomalies were detected), given the anomaly detection strategy with extended results + * and details on how to retrieve the history. + * This assertion function returns an AnomalyDetectionAssertionResult which contains + * anomaly detection extended results. + * + * @param metricsRepository A metrics repository to get the previous results. + * @param anomalyDetectionStrategyWithExtendedResults The anomaly detection strategy with extended results. + * @param analyzer The analyzer for the metric to run anomaly detection on. + * @param withTagValues Can contain a Map with tag names and the corresponding values + * to filter for. + * @param beforeDate The maximum dateTime of previous AnalysisResults to use for + * the Anomaly Detection. + * @param afterDate The minimum dateTime of previous AnalysisResults to use for + * the Anomaly Detection. + * @param currentMetricValue current metric value. + * @return The AnomalyDetectionAssertionResult with the boolean if the newest data point is anomalous + * along with the AnomalyDetectionExtendedResult object which contains the + * anomaly detection extended result details. + */ + private[deequ] def isNewestPointNonAnomalousWithExtendedResults[S <: State[S]]( + metricsRepository: MetricsRepository, + anomalyDetectionStrategyWithExtendedResults: AnomalyDetectionStrategyWithExtendedResults, + analyzer: Analyzer[S, Metric[Double]], + withTagValues: Map[String, String], + afterDate: Option[Long], + beforeDate: Option[Long])( + currentMetricValue: Double) + : AnomalyDetectionAssertionResult = { + + // Get history keys + var repositoryLoader = metricsRepository.load() + + repositoryLoader = repositoryLoader.withTagValues(withTagValues) + + beforeDate.foreach { beforeDate => + repositoryLoader = repositoryLoader.before(beforeDate) + } + + afterDate.foreach { afterDate => + repositoryLoader = repositoryLoader.after(afterDate) + } + + repositoryLoader = repositoryLoader.forAnalyzers(Seq(analyzer)) + + val analysisResults = repositoryLoader.get() + + require(analysisResults.nonEmpty, "There have to be previous results in the MetricsRepository!") + + val historicalMetrics = analysisResults + // If we have multiple DataPoints with the same dateTime, which should not happen in most + // cases, we still want consistent behaviour, so we sort them by Tags first + // (sorting is stable in Scala) + .sortBy(_.resultKey.tags.values) + .map { analysisResult => + val analyzerContextMetricMap = analysisResult.analyzerContext.metricMap + + val onlyAnalyzerMetricEntryInLoadedAnalyzerContext = analyzerContextMetricMap.headOption + + val doubleMetricOption = onlyAnalyzerMetricEntryInLoadedAnalyzerContext + .collect { case (_, metric) => metric.asInstanceOf[Metric[Double]] } + + val dataSetDate = analysisResult.resultKey.dataSetDate + + (dataSetDate, doubleMetricOption) + } + + // Ensure this is the last dataPoint + val testDateTime = analysisResults.map(_.resultKey.dataSetDate).max + 1 + require(testDateTime != Long.MaxValue, "Test DateTime cannot be Long.MaxValue, otherwise the" + + "Anomaly Detection, which works with an open upper interval bound, won't test anything") + + // Run given anomaly detection strategy and return false if the newest value is an Anomaly + val anomalyDetector = AnomalyDetectorWithExtendedResults(anomalyDetectionStrategyWithExtendedResults) + val anomalyDetectionResult: ExtendedDetectionResult = anomalyDetector.isNewPointAnomalousWithExtendedResults( + HistoryUtils.extractMetricValues[Double](historicalMetrics), + DataPoint(testDateTime, Some(currentMetricValue))) + + + // this function checks if the newest point is anomalous and returns a boolean for assertion, + // along with that newest point with anomaly check details + getNewestPointAnomalyResults(anomalyDetectionResult) + } + + /** + * Takes in ExtendedDetectionResult and returns AnomalyDetectionAssertionResult + * @param extendedDetectionResult Contains sequence of AnomalyDetectionDataPoints + * @return The AnomalyDetectionAssertionResult with the boolean if the newest data point is anomalous + * and the AnomalyDetectionExtendedResult containing the newest data point + * wrapped in the AnomalyDetectionDataPoint class + */ + private[deequ] def getNewestPointAnomalyResults(extendedDetectionResult: ExtendedDetectionResult): + AnomalyDetectionAssertionResult = { + val (hasNoAnomaly, anomalyDetectionExtendedResults): (Boolean, AnomalyDetectionExtendedResult) = { + + // Based on upstream code, this anomaly detection data point sequence should never be empty + require(extendedDetectionResult.anomalyDetectionDataPointSequence != Nil, + "anomalyDetectionDataPoints from AnomalyDetectionExtendedResult cannot be empty") + + // get the last anomaly detection data point of sequence (there should only be one element for now) + // and check the isAnomaly boolean, also return the last anomaly detection data point + // wrapped in the anomaly detection extended result class + extendedDetectionResult.anomalyDetectionDataPointSequence match { + case _ :+ lastAnomalyDataPointPair => + (!lastAnomalyDataPointPair._2.isAnomaly, AnomalyDetectionExtendedResult(Seq(lastAnomalyDataPointPair._2))) + } + } + AnomalyDetectionAssertionResult( + hasNoAnomaly = hasNoAnomaly, anomalyDetectionExtendedResult = anomalyDetectionExtendedResults) + } } diff --git a/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala b/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala new file mode 100644 index 000000000..c55736ddd --- /dev/null +++ b/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala @@ -0,0 +1,132 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.constraints + +import com.amazon.deequ.analyzers.{Analyzer, State} +import com.amazon.deequ.anomalydetection.AnomalyDetectionAssertionResult +import com.amazon.deequ.metrics.Metric +import org.apache.spark.sql.DataFrame + +import scala.util.{Failure, Success} + +/** + * Case class for anomaly with extended results constraints that provides unified way to access + * AnalyzerContext and metrics stored in it. + * + * Runs the analysis and get the value of the metric returned by the analysis, + * picks the numeric value that will be used in the assertion function with metric picker + * runs the assertion. + * + * @param analyzer Analyzer to be run on the data frame. + * @param assertion Assertion function that returns an AnomalyDetectionAssertionResult with + * anomaly detection extended results as well as the assertion boolean. + * @param valuePicker Optional function to pick the interested part of the metric value that the + * assertion will be running on. Absence of such function means the metric + * value would be used in the assertion as it is. + * @param hint A hint to provide additional context why a constraint could have failed. + * @tparam M : Type of the metric value generated by the Analyzer. + * @tparam V : Type of the value being used in assertion function. + * + */ +private[deequ] case class AnomalyExtendedResultsConstraint[S <: State[S], M, V]( + analyzer: Analyzer[S, Metric[M]], + private[deequ] val assertion: V => AnomalyDetectionAssertionResult, + private[deequ] val valuePicker: Option[M => V] = None, + private[deequ] val hint: Option[String] = None) + extends Constraint { + + private[deequ] def calculateAndEvaluate(data: DataFrame) = { + val metric = analyzer.calculate(data) + evaluate(Map(analyzer -> metric)) + } + + override def evaluate( + analysisResults: Map[Analyzer[_, Metric[_]], Metric[_]]) + : ConstraintResult = { + + val metric = analysisResults.get(analyzer).map(_.asInstanceOf[Metric[M]]) + + metric.map(pickValueAndAssert).getOrElse( + // Analysis is missing + ConstraintResult(this, ConstraintStatus.Failure, + message = Some(AnomalyExtendedResultsConstraint.MissingAnalysis), metric = metric) + ) + } + + private[this] def pickValueAndAssert(metric: Metric[M]): ConstraintResult = { + + metric.value match { + // Analysis done successfully and result metric is there + case Success(metricValue) => + try { + val assertOn = runPickerOnMetric(metricValue) + val anomalyAssertionResult = runAssertion(assertOn) + + if (anomalyAssertionResult.hasNoAnomaly) { + ConstraintResult(this, ConstraintStatus.Success, metric = Some(metric), + anomalyDetectionExtendedResultOption = Some(anomalyAssertionResult.anomalyDetectionExtendedResult)) + } else { + var errorMessage = s"Value: $assertOn does not meet the constraint requirement," + + s" check the anomaly detection metadata!" + hint.foreach(hint => errorMessage += s" $hint") + + ConstraintResult(this, ConstraintStatus.Failure, Some(errorMessage), Some(metric), + anomalyDetectionExtendedResultOption = Some(anomalyAssertionResult.anomalyDetectionExtendedResult)) + } + + } catch { + case AnomalyExtendedResultsConstraint.ConstraintAssertionException(msg) => + ConstraintResult(this, ConstraintStatus.Failure, + message = Some(s"${AnomalyExtendedResultsConstraint.AssertionException}: $msg!"), metric = Some(metric)) + case AnomalyExtendedResultsConstraint.ValuePickerException(msg) => + ConstraintResult(this, ConstraintStatus.Failure, + message = Some(s"${AnomalyExtendedResultsConstraint.ProblematicMetricPicker}: $msg!"), + metric = Some(metric)) + } + // An exception occurred during analysis + case Failure(e) => ConstraintResult(this, + ConstraintStatus.Failure, message = Some(e.getMessage), metric = Some(metric)) + } + } + + private def runPickerOnMetric(metricValue: M): V = + try { + valuePicker.map(function => function(metricValue)).getOrElse(metricValue.asInstanceOf[V]) + } catch { + case e: Exception => throw AnomalyExtendedResultsConstraint.ValuePickerException(e.getMessage) + } + + private def runAssertion(assertOn: V): AnomalyDetectionAssertionResult = + try { + assertion(assertOn) + } catch { + case e: Exception => throw AnomalyExtendedResultsConstraint.ConstraintAssertionException(e.getMessage) + } + + // 'assertion' and 'valuePicker' are lambdas we have to represent them like '' + override def toString: String = + s"AnomalyBasedConstraint($analyzer,,${valuePicker.map(_ => "")},$hint)" +} + +private[deequ] object AnomalyExtendedResultsConstraint { + val MissingAnalysis = "Missing Analysis, can't run the constraint!" + val ProblematicMetricPicker = "Can't retrieve the value to assert on" + val AssertionException = "Can't execute the assertion" + + private case class ValuePickerException(message: String) extends RuntimeException(message) + private case class ConstraintAssertionException(message: String) extends RuntimeException(message) +} diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index 5bb8d477e..df020eb2f 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -17,6 +17,7 @@ package com.amazon.deequ.constraints import com.amazon.deequ.analyzers._ +import com.amazon.deequ.anomalydetection.{AnomalyDetectionAssertionResult, AnomalyDetectionExtendedResult} import com.amazon.deequ.metrics.BucketDistribution import com.amazon.deequ.metrics.Distribution import com.amazon.deequ.metrics.Metric @@ -30,11 +31,23 @@ object ConstraintStatus extends Enumeration { val Success, Failure = Value } +/** + * ConstraintResult Class + * + * @param constraint Constraint associated with result. + * @param status Status of constraint (Success, Failure). + * @param message Optional message for errors. + * @param metric Optional Metric from calculation. + * @param anomalyDetectionExtendedResultOption optional anomaly detection extended results + * if using anomaly detection with extended results. + + */ case class ConstraintResult( constraint: Constraint, status: ConstraintStatus.Value, message: Option[String] = None, - metric: Option[Metric[_]] = None) + metric: Option[Metric[_]] = None, + anomalyDetectionExtendedResultOption: Option[AnomalyDetectionExtendedResult] = None) /** Common trait for all data quality constraints */ trait Constraint extends Serializable { @@ -234,6 +247,28 @@ object Constraint { new NamedConstraint(constraint, s"AnomalyConstraint($analyzer)") } + /** + * Runs Completeness analysis on the given column and executes the anomaly assertion + * and also returns extended results. + * + * @param analyzer Analyzer for the metric to do Anomaly Detection on. + * @param anomalyAssertion Function that receives a double input parameter + * (since the metric is double metric) and returns an AnomalyDetectionAssertionResult + * which contains a boolean and anomaly extended results. + * @param hint A hint to provide additional context why a constraint could have failed. + */ + def anomalyConstraintWithExtendedResults[S <: State[S]]( + analyzer: Analyzer[S, Metric[Double]], + anomalyAssertion: Double => AnomalyDetectionAssertionResult, + hint: Option[String] = None) + : Constraint = { + + val constraint = AnomalyExtendedResultsConstraint[S, Double, Double](analyzer, anomalyAssertion, + hint = hint) + + new NamedConstraint(constraint, s"AnomalyConstraintWithExtendedResults($analyzer)") + } + /** * Runs Uniqueness analysis on the given columns and executes the assertion * diff --git a/src/main/scala/com/amazon/deequ/examples/AnomalyDetectionWithExtendedResultsExample.scala b/src/main/scala/com/amazon/deequ/examples/AnomalyDetectionWithExtendedResultsExample.scala new file mode 100644 index 000000000..dd73b006b --- /dev/null +++ b/src/main/scala/com/amazon/deequ/examples/AnomalyDetectionWithExtendedResultsExample.scala @@ -0,0 +1,98 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.examples + +import com.amazon.deequ.VerificationSuite +import com.amazon.deequ.analyzers.Size +import com.amazon.deequ.anomalydetection.RelativeRateOfChangeStrategy +import com.amazon.deequ.checks.CheckStatus._ +import com.amazon.deequ.examples.ExampleUtils.{itemsAsDataframe, withSpark} +import com.amazon.deequ.repository.ResultKey +import com.amazon.deequ.repository.memory.InMemoryMetricsRepository + +private[examples] object AnomalyDetectionWithExtendedResultsExample extends App { + + withSpark { session => + + /* In this simple example, we assume that we compute metrics on a dataset every day and we want + to ensure that they don't change drastically. For sake of simplicity, we just look at the + size of the data */ + + /* Anomaly detection operates on metrics stored in a metric repository, so lets create one */ + val metricsRepository = new InMemoryMetricsRepository() + + /* This is the key which we use to store the metrics for the dataset from yesterday */ + val yesterdaysKey = ResultKey(System.currentTimeMillis() - 24 * 60 * 1000) + + /* Yesterday, the data had only two rows */ + val yesterdaysDataset = itemsAsDataframe(session, + Item(1, "Thingy A", "awesome thing.", "high", 0), + Item(2, "Thingy B", "available at http://thingb.com", null, 0)) + + /* We test for anomalies in the size of the data, it should not increase by more than 2x. Note + that we store the resulting metrics in our repository */ + VerificationSuite() + .onData(yesterdaysDataset) + .useRepository(metricsRepository) + .saveOrAppendResult(yesterdaysKey) + .addAnomalyCheckWithExtendedResults( + RelativeRateOfChangeStrategy(maxRateIncrease = Some(2.0)), + Size() + ) + .run() + + /* Todays data has five rows, so the data size more than doubled and our anomaly check should + catch this */ + val todaysDataset = itemsAsDataframe(session, + Item(1, "Thingy A", "awesome thing.", "high", 0), + Item(2, "Thingy B", "available at http://thingb.com", null, 0), + Item(3, null, null, "low", 5), + Item(4, "Thingy D", "checkout https://thingd.ca", "low", 10), + Item(5, "Thingy E", null, "high", 12)) + + /* The key for today's result */ + val todaysKey = ResultKey(System.currentTimeMillis()) + + /* Repeat the anomaly check for today's data */ + val verificationResult = VerificationSuite() + .onData(todaysDataset) + .useRepository(metricsRepository) + .saveOrAppendResult(todaysKey) + .addAnomalyCheckWithExtendedResults( + RelativeRateOfChangeStrategy(maxRateIncrease = Some(2.0)), + Size() + ) + .run() + + /* Did we find an anomaly? */ + if (verificationResult.status != Success) { + println("Anomaly detected in the Size() metric!") + val anomalyDetectionDataPoint = verificationResult.checkResults.head._2.constraintResults. + head.anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + println(s"Rate of change of ${anomalyDetectionDataPoint.anomalyMetricValue} was not in " + + s"${anomalyDetectionDataPoint.anomalyThreshold}") + + /* Lets have a look at the actual metrics. */ + metricsRepository + .load() + .forAnalyzers(Seq(Size())) + .getSuccessMetricsAsDataFrame(session) + .show() + } + } + +} diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index e260d2f18..25a623c2a 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -18,7 +18,7 @@ package com.amazon.deequ import com.amazon.deequ.analyzers._ import com.amazon.deequ.analyzers.runners.AnalyzerContext -import com.amazon.deequ.anomalydetection.AbsoluteChangeStrategy +import com.amazon.deequ.anomalydetection.{AbsoluteChangeStrategy, AnomalyDetectionDataPoint, AnomalyDetectionExtendedResult, Bound, Threshold} import com.amazon.deequ.checks.Check import com.amazon.deequ.checks.CheckLevel import com.amazon.deequ.checks.CheckStatus @@ -681,6 +681,138 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } } + "addAnomalyCheckWithExtendedResults should work and output extended results" in withSparkSession { sparkSession => + evaluateWithRepositoryWithHistory { repository => + + val df = getDfWithNRows(sparkSession, 11) + val saveResultsWithKey = ResultKey(5, Map.empty) + + val analyzers = Completeness("item") :: Nil + + val verificationResultOne = VerificationSuite() + .onData(df) + .useRepository(repository) + .addRequiredAnalyzers(analyzers) + .saveOrAppendResult(saveResultsWithKey) + .addAnomalyCheckWithExtendedResults( + AbsoluteChangeStrategy(Some(-2.0), Some(2.0)), + Size(), + Some(AnomalyCheckConfig(CheckLevel.Warning, "Anomaly check to fail")) + ) + .run() + + val verificationResultTwo = VerificationSuite() + .onData(df) + .useRepository(repository) + .addRequiredAnalyzers(analyzers) + .saveOrAppendResult(saveResultsWithKey) + .addAnomalyCheckWithExtendedResults( + AbsoluteChangeStrategy(Some(-7.0), Some(7.0)), + Size(), + Some(AnomalyCheckConfig(CheckLevel.Error, "Anomaly check to succeed", + Map.empty, Some(0), Some(11))) + ) + .run() + + val verificationResultThree = VerificationSuite() + .onData(df) + .useRepository(repository) + .addRequiredAnalyzers(analyzers) + .saveOrAppendResult(saveResultsWithKey) + .addAnomalyCheckWithExtendedResults( + AbsoluteChangeStrategy(Some(-7.0), Some(7.0)), + Size() + ) + .run() + + val checkResultsOne = verificationResultOne.checkResults.head._2.status + val actualResultsOneAnomalyDetectionDataPoint = + verificationResultOne.checkResults.head._2.constraintResults.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + val expectedResultsOneAnomalyDetectionDataPoint = + AnomalyDetectionDataPoint(11.0, 7.0, Threshold(Bound(-2.0), Bound(2.0)), isAnomaly = true, 1.0) + + val checkResultsTwo = verificationResultTwo.checkResults.head._2.status + val actualResultsTwoAnomalyDetectionDataPoint = + verificationResultTwo.checkResults.head._2.constraintResults.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + val expectedResultsTwoAnomalyDetectionDataPoint = + AnomalyDetectionDataPoint(11.0, 0.0, Threshold(Bound(-7.0), Bound(7.0)), isAnomaly = false, 1.0) + + val checkResultsThree = verificationResultThree.checkResults.head._2.status + val actualResultsThreeAnomalyDetectionDataPoint = + verificationResultThree.checkResults.head._2.constraintResults.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + val expectedResultsThreeAnomalyDetectionDataPoint = + AnomalyDetectionDataPoint(11.0, 0.0, Threshold(Bound(-7.0), Bound(7.0)), isAnomaly = false, 1.0) + + assert(checkResultsOne == CheckStatus.Warning) + assert(checkResultsTwo == CheckStatus.Success) + assert(checkResultsThree == CheckStatus.Success) + + assert(actualResultsOneAnomalyDetectionDataPoint == expectedResultsOneAnomalyDetectionDataPoint) + assert(actualResultsTwoAnomalyDetectionDataPoint == expectedResultsTwoAnomalyDetectionDataPoint) + assert(actualResultsThreeAnomalyDetectionDataPoint == expectedResultsThreeAnomalyDetectionDataPoint) + } + } + + "addAnomalyCheckWithExtendedResults with duplicate check analyzer should work and output extended results" in + withSparkSession { sparkSession => + evaluateWithRepositoryWithHistory { repository => + + val df = getDfWithNRows(sparkSession, 11) + val saveResultsWithKey = ResultKey(5, Map.empty) + + val analyzers = Completeness("item") :: Nil + + val verificationResultOne = VerificationSuite() + .onData(df) + .addCheck(Check(CheckLevel.Error, "group-1").hasSize(_ == 11)) + .useRepository(repository) + .addRequiredAnalyzers(analyzers) + .saveOrAppendResult(saveResultsWithKey) + .addAnomalyCheckWithExtendedResults( + AbsoluteChangeStrategy(Some(-2.0), Some(2.0)), + Size(), + Some(AnomalyCheckConfig(CheckLevel.Warning, "Anomaly check to fail")) + ) + .run() + + val verificationResultTwo = VerificationSuite() + .onData(df) + .useRepository(repository) + .addRequiredAnalyzers(analyzers) + .saveOrAppendResult(saveResultsWithKey) + .addAnomalyCheckWithExtendedResults( + AbsoluteChangeStrategy(Some(-7.0), Some(7.0)), + Size(), + Some(AnomalyCheckConfig(CheckLevel.Error, "Anomaly check to succeed", + Map.empty, Some(0), Some(11))) + ) + .run() + + val checkResultsOne = verificationResultOne.checkResults.values.toSeq(1).status + val actualResultsOneAnomalyDetectionDataPoint = + verificationResultOne.checkResults.values.toSeq(1).constraintResults.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + val expectedResultsOneAnomalyDetectionDataPoint = + AnomalyDetectionDataPoint(11.0, 7.0, Threshold(Bound(-2.0), Bound(2.0)), isAnomaly = true, 1.0) + + val checkResultsTwo = verificationResultTwo.checkResults.head._2.status + val actualResultsTwoAnomalyDetectionDataPoint = + verificationResultTwo.checkResults.head._2.constraintResults.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + val expectedResultsTwoAnomalyDetectionDataPoint = + AnomalyDetectionDataPoint(11.0, 0.0, Threshold(Bound(-7.0), Bound(7.0)), isAnomaly = false, 1.0) + + assert(checkResultsOne == CheckStatus.Warning) + assert(checkResultsTwo == CheckStatus.Success) + + assert(actualResultsOneAnomalyDetectionDataPoint == expectedResultsOneAnomalyDetectionDataPoint) + assert(actualResultsTwoAnomalyDetectionDataPoint == expectedResultsTwoAnomalyDetectionDataPoint) + } + } + "write output files to specified locations" in withSparkSession { sparkSession => val df = getDfWithNumericValues(sparkSession) diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/AbsoluteChangeStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/AbsoluteChangeStrategyTest.scala index 66d3c737a..f970f9812 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/AbsoluteChangeStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/AbsoluteChangeStrategyTest.scala @@ -23,14 +23,7 @@ class AbsoluteChangeStrategyTest extends WordSpec with Matchers { "Absolute Change Strategy" should { - val strategy = AbsoluteChangeStrategy(Some(-2.0), Some(2.0)) - val data = (for (i <- 0 to 50) yield { - if (i < 20 || i > 30) { - 1.0 - } else { - if (i % 2 == 0) i else -i - } - }).toVector + val (strategy, data) = setupDefaultStrategyAndData() "detect all anomalies if no interval specified" in { val anomalyResult = strategy.detect(data) @@ -156,7 +149,166 @@ class AbsoluteChangeStrategyTest extends WordSpec with Matchers { assert(value < lowerBound || value > upperBound) } } + } + + + "Absolute Change Strategy using Extended Results" should { + + val (strategy, data) = setupDefaultStrategyAndData() + + "detect all anomalies if no interval specified" in { + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + val expectedAnomalyThreshold = Threshold(Bound(-2.0), Bound(2.0)) + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (20, AnomalyDetectionDataPoint(20, 19, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (21, AnomalyDetectionDataPoint(-21, -41, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(22, 43, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(-23, -45, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(24, 47, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(-25, -49, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 51, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(-27, -53, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 55, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(-29, -57, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 59, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, -29, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "only detect anomalies in interval" in { + val anomalyResult = strategy.detectWithExtendedResults(data, (25, 50)).filter({case (_, anom) => anom.isAnomaly}) + val expectedAnomalyThreshold = Threshold(Bound(-2.0), Bound(2.0)) + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (25, AnomalyDetectionDataPoint(-25, -49, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 51, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(-27, -53, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 55, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(-29, -57, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 59, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, -29, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "ignore min rate if none is given" in { + val strategy = AbsoluteChangeStrategy(None, Some(1.0)) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + val expectedAnomalyThreshold = Threshold(upperBound = Bound(1.0)) + // Anomalies with positive values only + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (20, AnomalyDetectionDataPoint(20, 19, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(22, 43, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(24, 47, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 51, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 55, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 59, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + + assert(anomalyResult == expectedResult) + } + + "ignore max rate if none is given" in { + val strategy = AbsoluteChangeStrategy(Some(-1.0), None) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + val expectedAnomalyThreshold = Threshold(lowerBound = Bound(-1.0)) + + // Anomalies with negative values only + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (21, AnomalyDetectionDataPoint(-21, -41, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(-23, -45, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(-25, -49, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(-27, -53, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(-29, -57, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, -29, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "detect no anomalies if rates are set to min/ max value" in { + val strategy = AbsoluteChangeStrategy(Some(Double.MinValue), Some(Double.MaxValue)) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + val expectedResult: List[(Int, AnomalyDetectionDataPoint)] = List() + assert(anomalyResult == expectedResult) + } + + "attribute indices correctly for higher orders without search interval" in { + val data = Vector(0.0, 1.0, 3.0, 6.0, 18.0, 72.0) + val strategy = AbsoluteChangeStrategy(None, Some(8.0), order = 2) + val result = strategy.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) + + val expectedResult = Seq( + (4, AnomalyDetectionDataPoint(18.0, 9.0, Threshold(upperBound = Bound(8.0)), isAnomaly = true, 1.0)), + (5, AnomalyDetectionDataPoint(72.0, 42.0, Threshold(upperBound = Bound(8.0)), isAnomaly = true, 1.0)) + ) + assert(result == expectedResult) + } + + "attribute indices correctly for higher orders with search interval" in { + val data = Vector(0.0, 1.0, 3.0, 6.0, 18.0, 72.0) + val strategy = AbsoluteChangeStrategy(None, Some(8.0), order = 2) + val result = strategy.detectWithExtendedResults(data, (5, 6)).filter({case (_, anom) => anom.isAnomaly}) + + val expectedResult = Seq( + (5, AnomalyDetectionDataPoint(72.0, 42.0, Threshold(upperBound = Bound(8.0)), isAnomaly = true, 1.0)) + ) + assert(result == expectedResult) + } + + "behave like the threshold strategy when order is 0" in { + val data = Vector(1.0, -1.0, 4.0, -7.0) + val result = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + val expectedResult = Seq( + (2, AnomalyDetectionDataPoint(4.0, 5.0, Threshold(Bound(-2.0), Bound(2.0)), isAnomaly = true, 1.0)), + (3, AnomalyDetectionDataPoint(-7.0, -11.0, Threshold(Bound(-2.0), Bound(2.0)), isAnomaly = true, 1.0)) + ) + assert(result == expectedResult) + } + + + "work fine with empty input" in { + val emptySeries = Vector[Double]() + val anomalyResult = strategy.detectWithExtendedResults(emptySeries).filter({case (_, anom) => anom.isAnomaly}) + + assert(anomalyResult == Seq[(Int, AnomalyDetectionDataPoint)]()) + } + + "produce error message with correct value and bounds" in { + val result = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + result.foreach { case (_, anom) => + val (value, lowerBound, upperBound) = + AnomalyDetectionTestUtils.firstThreeDoublesFromString(anom.detail.get) + + assert(value === anom.anomalyMetricValue) + assert(value < lowerBound || value > upperBound) + } + } + "assert anomalies are outside of anomaly bounds" in { + val result = strategy.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) + + result.foreach { case (_, anom) => + val value = anom.anomalyMetricValue + val upperBound = anom.anomalyThreshold.upperBound.value + val lowerBound = anom.anomalyThreshold.lowerBound.value + + assert(value < lowerBound || value > upperBound) + } + } + } + private def setupDefaultStrategyAndData(): (AbsoluteChangeStrategy, Vector[Double]) = { + val strategy = AbsoluteChangeStrategy(Some(-2.0), Some(2.0)) + val data = (for (i <- 0 to 50) yield { + if (i < 20 || i > 30) { + 1.0 + } else { + if (i % 2 == 0) i else -i + } + }).toVector + (strategy, data) } } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/AnomalyDetectorTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/AnomalyDetectorTest.scala index 08f411bd1..6068d111b 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/AnomalyDetectorTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/AnomalyDetectorTest.scala @@ -21,14 +21,15 @@ import org.scalatest.{Matchers, PrivateMethodTester, WordSpec} class AnomalyDetectorTest extends WordSpec with Matchers with MockFactory with PrivateMethodTester { - private val fakeAnomalyDetector = stub[AnomalyDetectionStrategy] - val aD = AnomalyDetector(fakeAnomalyDetector) - val data = Seq((0L, -1.0), (1L, 2.0), (2L, 3.0), (3L, 0.5)).map { case (t, v) => - DataPoint[Double](t, Option(v)) - } "Anomaly Detector" should { + val fakeAnomalyDetector = stub[AnomalyDetectionStrategy] + + val aD = AnomalyDetector(fakeAnomalyDetector) + val data = Seq((0L, -1.0), (1L, 2.0), (2L, 3.0), (3L, 0.5)).map { case (t, v) => + DataPoint[Double](t, Option(v)) + } "ignore missing values" in { val data = Seq(DataPoint[Double](0L, Option(1.0)), DataPoint[Double](1L, Option(2.0)), @@ -105,4 +106,105 @@ class AnomalyDetectorTest extends WordSpec with Matchers with MockFactory with P } } + + "Anomaly Detector with ExtendedResults" should { + + val fakeAnomalyDetector = stub[AnomalyDetectionStrategyWithExtendedResults] + + val aD = AnomalyDetectorWithExtendedResults(fakeAnomalyDetector) + val data = Seq((0L, -1.0), (1L, 2.0), (2L, 3.0), (3L, 0.5)).map { case (t, v) => + DataPoint[Double](t, Option(v)) + } + + "ignore missing values" in { + val data = Seq(DataPoint[Double](0L, Option(1.0)), DataPoint[Double](1L, Option(2.0)), + DataPoint[Double](2L, None), DataPoint[Double](3L, Option(1.0))) + + (fakeAnomalyDetector.detectWithExtendedResults _ when(Vector(1.0, 2.0, 1.0), (0, 3))) + .returns(Seq((1, AnomalyDetectionDataPoint(2.0, 2.0, Threshold(), confidence = 1.0)))) + + val anomalyResult = aD.detectAnomaliesInHistoryWithExtendedResults(data, (0L, 4L)) + + assert(anomalyResult == ExtendedDetectionResult(Seq((1L, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0))))) + } + + "only detect values in range" in { + (fakeAnomalyDetector.detectWithExtendedResults _ when(Vector(-1.0, 2.0, 3.0, 0.5), (2, 4))) + .returns(Seq((2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)))) + + val anomalyResult = aD.detectAnomaliesInHistoryWithExtendedResults(data, (2L, 4L)) + + assert(anomalyResult == ExtendedDetectionResult(Seq((2L, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0))))) + } + + "throw an error when intervals are not ordered" in { + intercept[IllegalArgumentException] { + aD.detectAnomaliesInHistoryWithExtendedResults(data, (4, 2)) + } + } + + "treat ordered values with time gaps correctly" in { + val data = (for (i <- 1 to 10) yield { + (i.toLong * 200L) -> 5.0 + }).map { case (t, v) => + DataPoint[Double](t, Option(v)) + } + + (fakeAnomalyDetector.detectWithExtendedResults _ when(data.map(_.metricValue.get).toVector, (0, 2))) + .returns ( + Seq( + (0, AnomalyDetectionDataPoint(5.0, 5.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(5.0, 5.0, confidence = 1.0)) + ) + ) + + val anomalyResult = aD.detectAnomaliesInHistoryWithExtendedResults(data, (200L, 401L)) + + assert(anomalyResult == ExtendedDetectionResult(Seq( + (200L, AnomalyDetectionDataPoint(5.0, 5.0, confidence = 1.0)), + (400L, AnomalyDetectionDataPoint(5.0, 5.0, confidence = 1.0))))) + } + + "treat unordered values with time gaps correctly" in { + val data = Seq((10L, -1.0), (25L, 2.0), (11L, 3.0), (0L, 0.5)).map { case (t, v) => + DataPoint[Double](t, Option(v)) + } + val tS = AnomalyDetector(SimpleThresholdStrategy(lowerBound = -0.5, upperBound = 1.0)) + + (fakeAnomalyDetector.detectWithExtendedResults _ when(Vector(0.5, -1.0, 3.0, 2.0), (0, 4))) + .returns( + Seq( + (1, AnomalyDetectionDataPoint(-1.0, -1.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), + (3, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)) + ) + ) + + val anomalyResult = aD.detectAnomaliesInHistoryWithExtendedResults(data) + + assert(anomalyResult == ExtendedDetectionResult( + Seq((10L, AnomalyDetectionDataPoint(-1.0, -1.0, confidence = 1.0)), + (11L, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), + (25L, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0))))) + } + + "treat unordered values without time gaps correctly" in { + val data = Seq((1L, -1.0), (3L, 2.0), (2L, 3.0), (0L, 0.5)).map { case (t, v) => + DataPoint[Double](t, Option(v)) + } + + (fakeAnomalyDetector.detectWithExtendedResults _ when(Vector(0.5, -1.0, 3.0, 2.0), (0, 4))) + .returns(Seq((1, AnomalyDetectionDataPoint(-1.0, -1.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), + (3, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)))) + + val anomalyResult = aD.detectAnomaliesInHistoryWithExtendedResults(data) + + assert(anomalyResult == ExtendedDetectionResult(Seq( + (1L, AnomalyDetectionDataPoint(-1.0, -1.0, confidence = 1.0)), + (2L, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), + (3L, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0))))) + } + + } } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategyTest.scala index 05b9a6272..0575ad3f7 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategyTest.scala @@ -24,19 +24,7 @@ class BatchNormalStrategyTest extends WordSpec with Matchers { "Batch Normal Strategy" should { - val strategy = - BatchNormalStrategy(lowerDeviationFactor = Some(1.0), upperDeviationFactor = Some(1.0)) - - val r = new Random(1) - val dist = (for (_ <- 0 to 49) yield { - r.nextGaussian() - }).toArray - - for (i <- 20 to 30) { - dist(i) += i + (i % 2 * -2 * i) - } - - val data = dist.toVector + val (strategy, data) = setupDefaultStrategyAndData() "only detect anomalies in interval" in { val anomalyResult = strategy.detect(data, (25, 50)) @@ -120,4 +108,131 @@ class BatchNormalStrategyTest extends WordSpec with Matchers { } } } + + "Batch Normal Strategy using Extended Results " should { + + val (strategy, data) = setupDefaultStrategyAndData() + + "only detect anomalies in interval" in { + val anomalyResult = + strategy.detectWithExtendedResults(data, (25, 50)).filter({ case (_, anom) => anom.isAnomaly }) + + val expectedAnomalyThreshold = Threshold(Bound(-9.280850004177061), Bound(10.639954755150061)) + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (25, AnomalyDetectionDataPoint(data(25), data(25), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(data(26), data(26), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(data(27), data(27), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(data(28), data(28), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(data(29), data(29), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(data(30), data(30), expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "ignore lower factor if none is given" in { + val strategy = BatchNormalStrategy(None, Some(1.0)) + val anomalyResult = + strategy.detectWithExtendedResults(data, (20, 31)).filter({ case (_, anom) => anom.isAnomaly }) + + val expectedAnomalyThreshold = Threshold(Bound(Double.NegativeInfinity), Bound(0.7781496015857838)) + // Anomalies with positive values only + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (20, AnomalyDetectionDataPoint(data(20), data(20), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(data(22), data(22), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(data(24), data(24), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(data(26), data(26), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(data(28), data(28), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(data(30), data(30), expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "ignore upper factor if none is given" in { + val strategy = BatchNormalStrategy(Some(1.0), None) + val anomalyResult = + strategy.detectWithExtendedResults(data, (10, 30)).filter({ case (_, anom) => anom.isAnomaly }) + val expectedAnomalyThreshold = Threshold(Bound(-5.063730045618394), Bound(Double.PositiveInfinity)) + + // Anomalies with negative values only + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (21, AnomalyDetectionDataPoint(data(21), data(21), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(data(23), data(23), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(data(25), data(25), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(data(27), data(27), expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(data(29), data(29), expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "ignore values in interval for mean/ stdDev if specified" in { + val data = Vector(1.0, 1.0, 1.0, 1000.0, 500.0, 1.0) + val strategy = BatchNormalStrategy(Some(3.0), Some(3.0)) + val anomalyResult = + strategy.detectWithExtendedResults(data, (3, 5)).filter({ case (_, anom) => anom.isAnomaly }) + + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (3, AnomalyDetectionDataPoint(1000, 1000, Threshold(Bound(1.0), Bound(1.0)), isAnomaly = true, 1.0)), + (4, AnomalyDetectionDataPoint(500, 500, Threshold(Bound(1.0), Bound(1.0)), isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "throw an exception when trying to exclude all data points from calculation" in { + val strategy = BatchNormalStrategy() + intercept[IllegalArgumentException] { + strategy.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) + } + } + "detect no anomalies if factors are set to max value" in { + val strategy = BatchNormalStrategy(Some(Double.MaxValue), Some(Double.MaxValue)) + val anomalyResult = + strategy.detectWithExtendedResults(data, (30, 51)).filter({ case (_, anom) => anom.isAnomaly }) + + val expected: List[(Int, AnomalyDetectionDataPoint)] = List() + assert(anomalyResult == expected) + } + + "produce error message with correct value and bounds" in { + val result = strategy.detectWithExtendedResults(data, (25, 50)).filter({ case (_, anom) => anom.isAnomaly }) + + result.foreach { case (_, anom) => + val (value, lowerBound, upperBound) = + AnomalyDetectionTestUtils.firstThreeDoublesFromString(anom.detail.get) + + assert(value === anom.anomalyMetricValue) + assert(value < lowerBound || value > upperBound) + } + } + + "assert anomalies are outside of anomaly bounds" in { + val result = strategy.detectWithExtendedResults(data, (25, 50)).filter({ case (_, anom) => anom.isAnomaly }) + + result.foreach { case (_, anom) => + val value = anom.anomalyMetricValue + val upperBound = anom.anomalyThreshold.upperBound.value + val lowerBound = anom.anomalyThreshold.lowerBound.value + + assert(value < lowerBound || value > upperBound) + } + + + } + } + + private def setupDefaultStrategyAndData(): (BatchNormalStrategy, Vector[Double]) = { + val strategy = + BatchNormalStrategy(lowerDeviationFactor = Some(1.0), upperDeviationFactor = Some(1.0)) + + val r = new Random(1) + val dist = (for (_ <- 0 to 49) yield { + r.nextGaussian() + }).toArray + + for (i <- 20 to 30) { + dist(i) += i + (i % 2 * -2 * i) + } + + val data = dist.toVector + (strategy, data) + } } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategyTest.scala index 781ffb7ad..d9fdd4ebc 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategyTest.scala @@ -26,18 +26,8 @@ class OnlineNormalStrategyTest extends WordSpec with Matchers { "Online Normal Strategy" should { - val strategy = OnlineNormalStrategy(lowerDeviationFactor = Some(1.5), - upperDeviationFactor = Some(1.5), ignoreStartPercentage = 0.2) - val r = new Random(1) - - val dist = (for (_ <- 0 to 50) yield { - r.nextGaussian() - }).toArray - - for (i <- 20 to 30) - dist(i) += i + (i % 2 * -2 * i) + val (strategy, data, r) = setupDefaultStrategyAndData() - val data = dist.toVector "detect all anomalies if no interval specified" in { val strategy = OnlineNormalStrategy(lowerDeviationFactor = Some(3.5), @@ -168,4 +158,160 @@ class OnlineNormalStrategyTest extends WordSpec with Matchers { } } } + + "Online Normal Strategy with Extended Results" should { + + val (strategy, data, r) = setupDefaultStrategyAndData() + "detect all anomalies if no interval specified" in { + val strategy = OnlineNormalStrategy(lowerDeviationFactor = Some(3.5), + upperDeviationFactor = Some(3.5), ignoreStartPercentage = 0.2) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (20, AnomalyDetectionDataPoint(data(20), data(20), + Threshold(Bound(-14.868489924421404), Bound(14.255383455388895)), isAnomaly = true, 1.0)), + (21, AnomalyDetectionDataPoint(data(21), data(21), + Threshold(Bound(-13.6338479733374), Bound(13.02074150430489)), isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(data(22), data(22), + Threshold(Bound(-16.71733585267535), Bound(16.104229383642842)), isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(data(23), data(23), + Threshold(Bound(-17.346915620547467), Bound(16.733809151514958)), isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(data(24), data(24), + Threshold(Bound(-17.496117397890874), Bound(16.883010928858365)), isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(data(25), data(25), + Threshold(Bound(-17.90391150851199), Bound(17.29080503947948)), isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(data(26), data(26), + Threshold(Bound(-17.028892797350824), Bound(16.415786328318315)), isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(data(27), data(27), + Threshold(Bound(-17.720100310354653), Bound(17.106993841322144)), isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(data(28), data(28), + Threshold(Bound(-18.23663168508628), Bound(17.62352521605377)), isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(data(29), data(29), + Threshold(Bound(-19.32641622778204), Bound(18.71330975874953)), isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(data(30), data(30), + Threshold(Bound(-18.96540323993527), Bound(18.35229677090276)), isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "only detect anomalies in interval" in { + val anomalyResult = strategy.detectWithExtendedResults(data, (25, 31)).filter({case (_, anom) => anom.isAnomaly}) + + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (25, AnomalyDetectionDataPoint(data(25), data(25), + Threshold(Bound(-15.630116599125694), Bound(16.989221350098695)), isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(data(26), data(26), + Threshold(Bound(-14.963376676338362), Bound(16.322481427311363)), isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(data(27), data(27), + Threshold(Bound(-15.131834814393196), Bound(16.490939565366197)), isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(data(28), data(28), + Threshold(Bound(-14.76810451038132), Bound(16.12720926135432)), isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(data(29), data(29), + Threshold(Bound(-15.078145049879462), Bound(16.437249800852463)), isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(data(30), data(30), + Threshold(Bound(-14.540171084298914), Bound(15.899275835271913)), isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "ignore lower factor if none is given" in { + val strategy = OnlineNormalStrategy(lowerDeviationFactor = None, + upperDeviationFactor = Some(1.5)) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + // Anomalies with positive values only + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (20, AnomalyDetectionDataPoint(data(20), data(20), + Threshold(Bound(Double.NegativeInfinity), Bound(5.934276775443095)), isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(data(22), data(22), + Threshold(Bound(Double.NegativeInfinity), Bound(7.979098353666404)), isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(data(24), data(24), + Threshold(Bound(Double.NegativeInfinity), Bound(9.582136909647211)), isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(data(26), data(26), + Threshold(Bound(Double.NegativeInfinity), Bound(10.320400087389258)), isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(data(28), data(28), + Threshold(Bound(Double.NegativeInfinity), Bound(11.113502213504855)), isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(data(30), data(30), + Threshold(Bound(Double.NegativeInfinity), Bound(11.776810456746686)), isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "ignore upper factor if none is given" in { + val strategy = OnlineNormalStrategy(lowerDeviationFactor = Some(1.5), + upperDeviationFactor = None) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + // Anomalies with negative values only + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (21, AnomalyDetectionDataPoint(data(21), data(21), + Threshold(Bound(-7.855820681098751), Bound(Double.PositiveInfinity)), isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(data(23), data(23), + Threshold(Bound(-10.14631437278386), Bound(Double.PositiveInfinity)), isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(data(25), data(25), + Threshold(Bound(-11.038751996286909), Bound(Double.PositiveInfinity)), isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(data(27), data(27), + Threshold(Bound(-11.359107787232386), Bound(Double.PositiveInfinity)), isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(data(29), data(29), + Threshold(Bound(-12.097995027317015), Bound(Double.PositiveInfinity)), isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "work fine with empty input" in { + val emptySeries = Vector[Double]() + val anomalyResult = strategy.detectWithExtendedResults(emptySeries).filter({case (_, anom) => anom.isAnomaly}) + + assert(anomalyResult == Seq[(Int, AnomalyDetectionDataPoint)]()) + } + + "detect no anomalies if factors are set to max value" in { + val strategy = OnlineNormalStrategy(Some(Double.MaxValue), Some(Double.MaxValue)) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + val expected: List[(Int, AnomalyDetectionDataPoint)] = List() + assert(anomalyResult == expected) + } + + "produce error message with correct value and bounds" in { + val result = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + result.foreach { case (_, anom) => + val (value, lowerBound, upperBound) = + AnomalyDetectionTestUtils.firstThreeDoublesFromString(anom.detail.get) + + assert(value === anom.anomalyMetricValue) + assert(value < lowerBound || value > upperBound) + } + } + + "assert anomalies are outside of anomaly bounds" in { + val result = strategy.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) + + result.foreach { case (_, anom) => + val value = anom.anomalyMetricValue + val upperBound = anom.anomalyThreshold.upperBound.value + val lowerBound = anom.anomalyThreshold.lowerBound.value + + assert(value < lowerBound || value > upperBound) + } + } + } + + + private def setupDefaultStrategyAndData(): (OnlineNormalStrategy, Vector[Double], Random) = { + val strategy = OnlineNormalStrategy(lowerDeviationFactor = Some(1.5), + upperDeviationFactor = Some(1.5), ignoreStartPercentage = 0.2) + val r = new Random(1) + + val dist = (for (_ <- 0 to 50) yield { + r.nextGaussian() + }).toArray + + for (i <- 20 to 30) + dist(i) += i + (i % 2 * -2 * i) + + val data = dist.toVector + (strategy, data, r) + } } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/RateOfChangeStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/RateOfChangeStrategyTest.scala index 70f66f033..d0e6ccba9 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/RateOfChangeStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/RateOfChangeStrategyTest.scala @@ -26,14 +26,7 @@ class RateOfChangeStrategyTest extends WordSpec with Matchers { "RateOfChange Strategy" should { - val strategy = RateOfChangeStrategy(Some(-2.0), Some(2.0)) - val data = (for (i <- 0 to 50) yield { - if (i < 20 || i > 30) { - 1.0 - } else { - if (i % 2 == 0) i else -i - } - }).toVector + val (strategy, data) = setupDefaultStrategyAndData() "detect all anomalies if no interval specified" in { val anomalyResult = strategy.detect(data) @@ -43,4 +36,43 @@ class RateOfChangeStrategyTest extends WordSpec with Matchers { assert(anomalyResult == expected) } } + + "RateOfChange Strategy with Extended Results" should { + + val (strategy, data) = setupDefaultStrategyAndData() + + "detect all anomalies if no interval specified" in { + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + val expectedAnomalyThreshold = Threshold(Bound(-2.0), Bound(2.0)) + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (20, AnomalyDetectionDataPoint(20, 19, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (21, AnomalyDetectionDataPoint(-21, -41, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(22, 43, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(-23, -45, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(24, 47, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(-25, -49, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 51, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(-27, -53, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 55, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(-29, -57, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 59, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, -29, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + } + + private def setupDefaultStrategyAndData(): (RateOfChangeStrategy, Vector[Double]) = { + val strategy = RateOfChangeStrategy(Some(-2.0), Some(2.0)) + val data = (for (i <- 0 to 50) yield { + if (i < 20 || i > 30) { + 1.0 + } else { + if (i % 2 == 0) i else -i + } + }).toVector + (strategy, data) + } } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/RelativeRateOfChangeStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/RelativeRateOfChangeStrategyTest.scala index bfde6ba18..c6da5ae2b 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/RelativeRateOfChangeStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/RelativeRateOfChangeStrategyTest.scala @@ -23,14 +23,7 @@ class RelativeRateOfChangeStrategyTest extends WordSpec with Matchers { "Relative Rate of Change Strategy" should { - val strategy = RelativeRateOfChangeStrategy(Some(0.5), Some(2.0)) - val data = (for (i <- 0 to 50) yield { - if (i < 20 || i > 30) { - 1.0 - } else { - if (i % 2 == 0) i else 1 - } - }).toVector + val (strategy, data) = setupDefaultStrategyAndData() "detect all anomalies if no interval specified" in { val anomalyResult = strategy.detect(data) @@ -150,4 +143,159 @@ class RelativeRateOfChangeStrategyTest extends WordSpec with Matchers { } } } + + "Relative Rate of Change Strategy with Extended Results" should { + + val (strategy, data) = setupDefaultStrategyAndData() + + "detect all anomalies if no interval specified" in { + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + val expectedAnomalyThreshold = Threshold(Bound(0.5), Bound(2.0)) + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (20, AnomalyDetectionDataPoint(20, 20, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (21, AnomalyDetectionDataPoint(1, 0.05, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(22, 22, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(1, 0.045454545454545456, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(24, 24, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(1, 0.041666666666666664, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 26, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(1, 0.038461538461538464, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 28, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(1, 0.03571428571428571, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 30, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, 0.03333333333333333, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "only detect anomalies in interval" in { + val anomalyResult = strategy.detectWithExtendedResults(data, (25, 50)).filter({case (_, anom) => anom.isAnomaly}) + + val expectedAnomalyThreshold = Threshold(Bound(0.5), Bound(2.0)) + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (25, AnomalyDetectionDataPoint(1, 0.041666666666666664, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 26, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(1, 0.038461538461538464, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 28, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(1, 0.03571428571428571, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 30, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, 0.03333333333333333, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "ignore min rate if none is given" in { + val strategy = RelativeRateOfChangeStrategy(None, Some(1.0)) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + // Anomalies with positive values only + val expectedAnomalyThreshold = Threshold(Bound(-1.7976931348623157E308), Bound(1.0)) + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (20, AnomalyDetectionDataPoint(20, 20, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(22, 22, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(24, 24, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 26, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 28, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 30, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "ignore max rate if none is given" in { + val strategy = RelativeRateOfChangeStrategy(Some(0.5), None) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + // Anomalies with negative values only + val expectedAnomalyThreshold = Threshold(Bound(0.5), Bound(1.7976931348623157E308)) + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (21, AnomalyDetectionDataPoint(1, 0.05, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(1, 0.045454545454545456, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(1, 0.041666666666666664, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(1, 0.038461538461538464, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(1, 0.03571428571428571, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, 0.03333333333333333, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "detect no anomalies if rates are set to min/ max value" in { + val strategy = RelativeRateOfChangeStrategy(Some(Double.MinValue), Some(Double.MaxValue)) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + val expected: List[(Int, AnomalyDetectionDataPoint)] = List() + assert(anomalyResult == expected) + } + + "attribute indices correctly for higher orders without search interval" in { + val data = Vector(0.0, 1.0, 3.0, 6.0, 18.0, 72.0) + val strategy = RelativeRateOfChangeStrategy(None, Some(8.0), order = 2) + val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (2, AnomalyDetectionDataPoint(3, Double.PositiveInfinity, + Threshold(Bound(-1.7976931348623157E308), Bound(8.0)), isAnomaly = true, 1.0)), + (5, AnomalyDetectionDataPoint(72, 12, + Threshold(Bound(-1.7976931348623157E308), Bound(8.0)), isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "attribute indices correctly for higher orders with search interval" in { + val data = Vector(0.0, 1.0, 3.0, 6.0, 18.0, 72.0) + val strategy = RelativeRateOfChangeStrategy(None, Some(8.0), order = 2) + val anomalyResult = strategy.detectWithExtendedResults(data, (5, 6)).filter({case (_, anom) => anom.isAnomaly}) + + val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( + (5, AnomalyDetectionDataPoint(72, 12, + Threshold(Bound(-1.7976931348623157E308), Bound(8.0)), isAnomaly = true, 1.0)) + ) + assert(anomalyResult == expectedResult) + } + + "work fine with empty input" in { + val emptySeries = Vector[Double]() + val anomalyResult = strategy.detectWithExtendedResults(emptySeries).filter({case (_, anom) => anom.isAnomaly}) + + assert(anomalyResult == Seq[(Int, AnomalyDetectionDataPoint)]()) + } + + "produce error message with correct value and bounds" in { + val result = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) + + result.foreach { case (_, anom) => + val (value, lowerBound, upperBound) = + AnomalyDetectionTestUtils.firstThreeDoublesFromString(anom.detail.get) + + assert(value === anom.anomalyMetricValue) + assert(value < lowerBound || value > upperBound) + } + } + + "assert anomalies are outside of anomaly bounds" in { + val result = strategy.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) + + result.foreach { case (_, anom) => + val value = anom.anomalyMetricValue + val upperBound = anom.anomalyThreshold.upperBound.value + val lowerBound = anom.anomalyThreshold.lowerBound.value + + assert(value < lowerBound || value > upperBound) + } + } + + + } + + private def setupDefaultStrategyAndData(): (RelativeRateOfChangeStrategy, Vector[Double]) = { + val strategy = RelativeRateOfChangeStrategy(Some(0.5), Some(2.0)) + val data = (for (i <- 0 to 50) yield { + if (i < 20 || i > 30) { + 1.0 + } else { + if (i % 2 == 0) i else 1 + } + }).toVector + (strategy, data) + } } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategyTest.scala index 92ead9e48..f8396c677 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategyTest.scala @@ -22,8 +22,7 @@ class SimpleThresholdStrategyTest extends WordSpec with Matchers { "Simple Threshold Strategy" should { - val strategy = SimpleThresholdStrategy(upperBound = 1.0) - val data = Vector(-1.0, 2.0, 3.0, 0.5) + val (strategy, data) = setupDefaultStrategyAndData() val expected = Seq((1, Anomaly(Option(2.0), 1.0)), (2, Anomaly(Option(3.0), 1.0))) "detect values above threshold" in { @@ -70,5 +69,77 @@ class SimpleThresholdStrategyTest extends WordSpec with Matchers { assert(value < lowerBound || value > upperBound) } } + + "Simple Threshold Strategy with Extended Results" should { + + val (strategy, data) = setupDefaultStrategyAndData() + val expectedAnomalyThreshold = Threshold(upperBound = Bound(1.0)) + val expectedResult = Seq( + (1, AnomalyDetectionDataPoint(2.0, 2.0, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, expectedAnomalyThreshold, isAnomaly = true, 1.0))) + + "detect values above threshold" in { + val anomalyResult = + strategy.detectWithExtendedResults(data, (0, 4)).filter({ case (_, anom) => anom.isAnomaly }) + + assert(anomalyResult == expectedResult) + } + + "detect all values without range specified" in { + val anomalyResult = strategy.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) + + assert(anomalyResult == expectedResult) + } + + "work fine with empty input" in { + val emptySeries = Vector[Double]() + val anomalyResult = + strategy.detectWithExtendedResults(emptySeries).filter({ case (_, anom) => anom.isAnomaly }) + + assert(anomalyResult == Seq[(Int, AnomalyDetectionDataPoint)]()) + } + + "work with upper and lower threshold" in { + val tS = SimpleThresholdStrategy(lowerBound = -0.5, upperBound = 1.0) + val anomalyResult = tS.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) + val expectedAnomalyThreshold = Threshold(Bound(-0.5), Bound(1.0)) + val expectedResult = Seq( + (0, AnomalyDetectionDataPoint(-1.0, -1.0, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, expectedAnomalyThreshold, isAnomaly = true, 1.0)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, expectedAnomalyThreshold, isAnomaly = true, 1.0))) + + assert(anomalyResult == expectedResult) + } + + "produce error message with correct value and bounds" in { + val result = strategy.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) + + result.foreach { case (_, anom) => + val (value, lowerBound, upperBound) = + AnomalyDetectionTestUtils.firstThreeDoublesFromString(anom.detail.get) + + assert(value === anom.anomalyMetricValue) + assert(value < lowerBound || value > upperBound) + } + } + + "assert anomalies are outside of anomaly bounds" in { + val result = strategy.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) + + result.foreach { case (_, anom) => + val value = anom.anomalyMetricValue + val upperBound = anom.anomalyThreshold.upperBound.value + val lowerBound = anom.anomalyThreshold.lowerBound.value + + assert(value < lowerBound || value > upperBound) + } + } + } + } + + private def setupDefaultStrategyAndData(): (SimpleThresholdStrategy, Vector[Double]) = { + val strategy = SimpleThresholdStrategy(upperBound = 1.0) + val data = Vector(-1.0, 2.0, 3.0, 0.5) + (strategy, data) } } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWintersTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWintersTest.scala index decf5a91c..8d8140366 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWintersTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWintersTest.scala @@ -16,19 +16,19 @@ package com.amazon.deequ.anomalydetection.seasonal -import com.amazon.deequ.anomalydetection.Anomaly +import com.amazon.deequ.anomalydetection.{Anomaly, AnomalyDetectionDataPoint} import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec +import scala.util.Random + class HoltWintersTest extends AnyWordSpec with Matchers { import HoltWintersTest._ "Additive Holt-Winters" should { - val rng = new util.Random(seed = 42L) - val twoWeeksOfData = Vector.fill(2)( - Vector[Double](1, 1, 1.2, 1.3, 1.5, 2.1, 1.9) - ).flatten.map(_ + rng.nextGaussian()) + + val twoWeeksOfData = setupData() "fail if start after or equal to end" in { val caught = intercept[IllegalArgumentException]( @@ -207,6 +207,205 @@ class HoltWintersTest extends AnyWordSpec with Matchers { anomalies should have size 3 } } + + "Additive Holt-Winters with Extended Results" should { + + val twoWeeksOfData = setupData() + + "fail if start after or equal to end" in { + val caught = intercept[IllegalArgumentException]( + dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(twoWeeksOfData, 1 -> 1)) + + caught.getMessage shouldBe "requirement failed: Start must be before end" + } + + "fail if no at least two cycles are available" in { + val fullInterval = 0 -> Int.MaxValue + + val caught = intercept[IllegalArgumentException]( + dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(Vector.empty, fullInterval)) + + caught.getMessage shouldBe "requirement failed: Provided data series is empty" + } + + "fail for negative search interval" in { + val negativeInterval = -2 -> -1 + + val caught = intercept[IllegalArgumentException]( + dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(twoWeeksOfData, negativeInterval)) + + caught.getMessage shouldBe + "requirement failed: The search interval needs to be strictly positive" + } + + "fail for too few data" in { + val fullInterval = 0 -> Int.MaxValue + val shortSeries = Vector[Double](1, 2, 3) + + val caught = intercept[IllegalArgumentException]( + dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(shortSeries, fullInterval)) + + caught.getMessage shouldBe + "requirement failed: Need at least two full cycles of data to estimate model" + } + + "run anomaly detection on the last data point if search interval beyond series size" in { + val interval = 100 -> 110 + val anomalies = dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(twoWeeksOfData, interval) + anomalies shouldBe empty + } + + "predict no anomaly for normally distributed errors" in { + val seriesWithOutlier = twoWeeksOfData ++ Vector(twoWeeksOfData.head) + val anomalies = + dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(seriesWithOutlier, 14 -> 15) + .filter({case (_, anom) => anom.isAnomaly}) + anomalies shouldBe empty + } + + "predict an anomaly" in { + val seriesWithOutlier = twoWeeksOfData ++ Vector(0.0d) + val anomalies = dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults( + seriesWithOutlier, 14 -> Int.MaxValue) + + anomalies should have size 1 + val (anomalyIndex, _) = anomalies.head + anomalyIndex shouldBe 14 + } + + "predict no anomalies on longer series" in { + val seriesWithOutlier = twoWeeksOfData ++ twoWeeksOfData + val anomalies = dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults( + seriesWithOutlier, 26 -> Int.MaxValue).filter({case (_, anom) => anom.isAnomaly}) + anomalies shouldBe empty + } + + "detect no anomalies on constant series" in { + val series = (0 until 21).map(_ => 1.0).toVector + val anomalies = dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(series, 14 -> Int.MaxValue) + .filter({case (_, anom) => anom.isAnomaly}) + anomalies shouldBe empty + } + + "detect a single anomaly in constant series with a single error" in { + val series = ((0 until 20).map(_ => 1.0) ++ Seq(0.0)).toVector + val anomalies = dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(series, 14 -> Int.MaxValue) + .filter({case (_, anom) => anom.isAnomaly}) + + anomalies should have size 1 + val (detectionIndex, _) = anomalies.head + detectionIndex shouldBe 20 + } + + "detect no anomalies on exact linear trend series" in { + val series = (0 until 48).map(_.toDouble).toVector + val anomalies = dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(series, 36 -> Int.MaxValue) + .filter({case (_, anom) => anom.isAnomaly}) + anomalies shouldBe empty + } + + "detect no anomalies on exact linear and seasonal effects" in { + val periodicity = 7 + val series = (0 until 48).map(t => math.sin(2 * math.Pi / periodicity * t)) + .zipWithIndex.map { case (s, level) => s + level }.toVector + + val anomalies = dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(series, 36 -> Int.MaxValue) + .filter({case (_, anom) => anom.isAnomaly}) + anomalies shouldBe empty + } + + "detect anomalies if the training data is wrong" in { + val train = Vector.fill(2)(Vector[Double](0, 1, 1, 1, 1, 1, 1)).flatten + val test = Vector[Double](1, 1, 1, 1, 1, 1, 1) + val series = train ++ test + + val anomalies = dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults(series, 14 -> 21) + .filter({case (_, anom) => anom.isAnomaly}) + + anomalies should have size 1 + val (detectionIndex, _) = anomalies.head + detectionIndex shouldBe 14 + } + + "work on monthly data with yearly seasonality" in { + // https://datamarket.com/data/set/22ox/monthly-milk-production-pounds-per-cow-jan-62-dec-75 + val monthlyMilkProduction = Vector[Double]( + 589, 561, 640, 656, 727, 697, 640, 599, 568, 577, 553, 582, + 600, 566, 653, 673, 742, 716, 660, 617, 583, 587, 565, 598, + 628, 618, 688, 705, 770, 736, 678, 639, 604, 611, 594, 634, + 658, 622, 709, 722, 782, 756, 702, 653, 615, 621, 602, 635, + 677, 635, 736, 755, 811, 798, 735, 697, 661, 667, 645, 688, + 713, 667, 762, 784, 837, 817, 767, 722, 681, 687, 660, 698, + 717, 696, 775, 796, 858, 826, 783, 740, 701, 706, 677, 711, + 734, 690, 785, 805, 871, 845, 801, 764, 725, 723, 690, 734, + 750, 707, 807, 824, 886, 859, 819, 783, 740, 747, 711, 751, + 804, 756, 860, 878, 942, 913, 869, 834, 790, 800, 763, 800, + 826, 799, 890, 900, 961, 935, 894, 855, 809, 810, 766, 805, + 821, 773, 883, 898, 957, 924, 881, 837, 784, 791, 760, 802, + 828, 778, 889, 902, 969, 947, 908, 867, 815, 812, 773, 813, + 834, 782, 892, 903, 966, 937, 896, 858, 817, 827, 797, 843 + ) + + val strategy = new HoltWinters( + HoltWinters.MetricInterval.Monthly, + HoltWinters.SeriesSeasonality.Yearly) + + val nYearsTrain = 3 + val nYearsTest = 1 + val trainSize = nYearsTrain * 12 + val testSize = nYearsTest * 12 + val nTotal = trainSize + testSize + + val anomalies = strategy.detectWithExtendedResults( + monthlyMilkProduction.take(nTotal), + trainSize -> nTotal + ).filter({case (_, anom) => anom.isAnomaly}) + + anomalies should have size 7 + } + + "work on an additional series with yearly seasonality" in { + // https://datamarket.com/data/set/22n4/monthly-car-sales-in-quebec-1960-1968 + val monthlyCarSalesQuebec = Vector[Double]( + 6550, 8728, 12026, 14395, 14587, 13791, 9498, 8251, 7049, 9545, 9364, 8456, + 7237, 9374, 11837, 13784, 15926, 13821, 11143, 7975, 7610, 10015, 12759, 8816, + 10677, 10947, 15200, 17010, 20900, 16205, 12143, 8997, 5568, 11474, 12256, 10583, + 10862, 10965, 14405, 20379, 20128, 17816, 12268, 8642, 7962, 13932, 15936, 12628, + 12267, 12470, 18944, 21259, 22015, 18581, 15175, 10306, 10792, 14752, 13754, 11738, + 12181, 12965, 19990, 23125, 23541, 21247, 15189, 14767, 10895, 17130, 17697, 16611, + 12674, 12760, 20249, 22135, 20677, 19933, 15388, 15113, 13401, 16135, 17562, 14720, + 12225, 11608, 20985, 19692, 24081, 22114, 14220, 13434, 13598, 17187, 16119, 13713, + 13210, 14251, 20139, 21725, 26099, 21084, 18024, 16722, 14385, 21342, 17180, 14577 + ) + + val strategy = new HoltWinters( + HoltWinters.MetricInterval.Monthly, + HoltWinters.SeriesSeasonality.Yearly) + + val nYearsTrain = 3 + val nYearsTest = 1 + val trainSize = nYearsTrain * 12 + val testSize = nYearsTest * 12 + val nTotal = trainSize + testSize + + val anomalies = strategy.detectWithExtendedResults( + monthlyCarSalesQuebec.take(nTotal), + trainSize -> nTotal + ).filter({case (_, anom) => anom.isAnomaly}) + + anomalies should have size 3 + } + } + + private def setupData(): Vector[Double] = { + val rng = new util.Random(seed = 42L) + val twoWeeksOfData = Vector.fill(2)( + Vector[Double](1, 1, 1.2, 1.3, 1.5, 2.1, 1.9) + ).flatten.map(_ + rng.nextGaussian()) + twoWeeksOfData + } + + } object HoltWintersTest { @@ -223,4 +422,16 @@ object HoltWintersTest { strategy.detect(series, interval) } + def dailyMetricsWithWeeklySeasonalityAnomaliesWithExtendedResults( + series: Vector[Double], + interval: (Int, Int)): Seq[(Int, AnomalyDetectionDataPoint)] = { + + val strategy = new HoltWinters( + HoltWinters.MetricInterval.Daily, + HoltWinters.SeriesSeasonality.Weekly + ) + + strategy.detectWithExtendedResults(series, interval) + } + } diff --git a/src/test/scala/com/amazon/deequ/checks/ApplicabilityTest.scala b/src/test/scala/com/amazon/deequ/checks/ApplicabilityTest.scala index 542f40fcf..73e589886 100644 --- a/src/test/scala/com/amazon/deequ/checks/ApplicabilityTest.scala +++ b/src/test/scala/com/amazon/deequ/checks/ApplicabilityTest.scala @@ -18,11 +18,14 @@ package com.amazon.deequ package checks import com.amazon.deequ.analyzers.applicability.Applicability -import com.amazon.deequ.analyzers.{Completeness, Compliance, Maximum, Minimum} +import com.amazon.deequ.analyzers.{Completeness, Compliance, Maximum, Minimum, Size} +import com.amazon.deequ.anomalydetection.{AnomalyDetectionStrategy, AnomalyDetectionStrategyWithExtendedResults} +import com.amazon.deequ.repository.MetricsRepository import org.apache.spark.sql.types._ +import org.scalamock.scalatest.MockFactory import org.scalatest.wordspec.AnyWordSpec -class ApplicabilityTest extends AnyWordSpec with SparkContextSpec { +class ApplicabilityTest extends AnyWordSpec with SparkContextSpec with MockFactory { private[this] val schema = StructType(Array( StructField("stringCol", StringType, nullable = true), @@ -48,7 +51,7 @@ class ApplicabilityTest extends AnyWordSpec with SparkContextSpec { "Applicability tests for checks" should { - "recognize applicable checks as applicable" in withSparkSession { session => + "recognize applicable analysis based checks as applicable" in withSparkSession { session => val applicability = new Applicability(session) @@ -66,6 +69,25 @@ class ApplicabilityTest extends AnyWordSpec with SparkContextSpec { } } + "recognize applicable anomaly based checks with extended results as applicable" in withSparkSession { session => + + val applicability = new Applicability(session) + val fakeAnomalyDetector = mock[AnomalyDetectionStrategyWithExtendedResults] + val repository = mock[MetricsRepository] + val validCheck = Check(CheckLevel.Error, "anomaly test") + .isNewestPointNonAnomalousWithExtendedResults(repository, fakeAnomalyDetector, Size(), Map.empty, + None, None) + + val resultForValidCheck = applicability.isApplicable(validCheck, schema) + + assert(resultForValidCheck.isApplicable) + assert(resultForValidCheck.failures.isEmpty) + assert(resultForValidCheck.constraintApplicabilities.size == validCheck.constraints.size) + resultForValidCheck.constraintApplicabilities.foreach { case (_, applicable) => + assert(applicable) + } + } + "detect checks with non existing columns" in withSparkSession { session => val applicability = new Applicability(session) diff --git a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala index 505e6d137..bb565258c 100644 --- a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala +++ b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala @@ -19,7 +19,8 @@ package checks import com.amazon.deequ.analyzers._ import com.amazon.deequ.analyzers.runners.{AnalysisRunner, AnalyzerContext} -import com.amazon.deequ.anomalydetection.{Anomaly, AnomalyDetectionStrategy} +import com.amazon.deequ.anomalydetection.{Anomaly, AnomalyDetectionAssertionResult, AnomalyDetectionDataPoint, AnomalyDetectionStrategy, AnomalyDetectionStrategyWithExtendedResults, ExtendedDetectionResult} +import com.amazon.deequ.checks.Check.getNewestPointAnomalyResults import com.amazon.deequ.constraints.{ConstrainableDataTypes, ConstraintStatus} import com.amazon.deequ.metrics.{DoubleMetric, Entity} import com.amazon.deequ.repository.memory.InMemoryMetricsRepository @@ -1108,6 +1109,234 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix } } + "Check isNewestPointNonAnomalousWithExtendedResults" should { + + "return the correct check status for anomaly detection for different analyzers" in + withSparkSession { sparkSession => + evaluateWithRepository { repository => + // Fake Anomaly Detector + val fakeAnomalyDetector = mock[AnomalyDetectionStrategyWithExtendedResults] + inSequence { + // Size results + (fakeAnomalyDetector.detectWithExtendedResults _) + .expects(Vector(1.0, 2.0, 3.0, 4.0, 11.0), (4, 5)) + .returns(Seq( + (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), + (3, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), + (4, AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0)))) + .once() + (fakeAnomalyDetector.detectWithExtendedResults _).expects(Vector(1.0, 2.0, 3.0, 4.0, 4.0), (4, 5)) + .returns(Seq( + (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), + (3, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), + (4, AnomalyDetectionDataPoint(4.0, 4.0, isAnomaly = true, confidence = 1.0)))) + .once() + // Distinctness results + (fakeAnomalyDetector.detectWithExtendedResults _) + .expects(Vector(1.0, 2.0, 3.0, 4.0, 1), (4, 5)) + .returns(Seq( + (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), + (3, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), + (4, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + .once() + (fakeAnomalyDetector.detectWithExtendedResults _) + .expects(Vector(1.0, 2.0, 3.0, 4.0, 1), (4, 5)) + .returns(Seq( + (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), + (3, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), + (4, AnomalyDetectionDataPoint(1.0, 1.0, isAnomaly = true, confidence = 1.0)))) + .once() + } + + // Get test AnalyzerContexts + val analysis = Analysis().addAnalyzers(Seq(Size(), Distinctness(Seq("c0", "c1")))) + + val context11Rows = AnalysisRunner.run(getDfWithNRows(sparkSession, 11), analysis) + val context4Rows = AnalysisRunner.run(getDfWithNRows(sparkSession, 4), analysis) + val contextNoRows = AnalysisRunner.run(getDfEmpty(sparkSession), analysis) + + // Check isNewestPointNonAnomalousWithExtendedResults using Size + val sizeAnomalyCheck = Check(CheckLevel.Error, "anomaly test") + .isNewestPointNonAnomalousWithExtendedResults(repository, fakeAnomalyDetector, Size(), Map.empty, + None, None) + + assert(sizeAnomalyCheck.evaluate(context11Rows).status == CheckStatus.Success) + assert(sizeAnomalyCheck.evaluate(context4Rows).status == CheckStatus.Error) + assert(sizeAnomalyCheck.evaluate(contextNoRows).status == CheckStatus.Error) + + // Now with Distinctness + val distinctnessAnomalyCheck = Check(CheckLevel.Error, "anomaly test") + .isNewestPointNonAnomalousWithExtendedResults(repository, fakeAnomalyDetector, + Distinctness(Seq("c0", "c1")), Map.empty, None, None) + + assert(distinctnessAnomalyCheck.evaluate(context11Rows).status == CheckStatus.Success) + assert(distinctnessAnomalyCheck.evaluate(context4Rows).status == CheckStatus.Error) + assert(distinctnessAnomalyCheck.evaluate(contextNoRows).status == CheckStatus.Error) + } + } + + "only use historic results filtered by tagValues if specified" in + withSparkSession { sparkSession => + evaluateWithRepository { repository => + // Fake Anomaly Detector + val fakeAnomalyDetector = mock[AnomalyDetectionStrategyWithExtendedResults] + inSequence { + // Size results + (fakeAnomalyDetector.detectWithExtendedResults _) + .expects(Vector(1.0, 2.0, 11.0), (2, 3)) + .returns(Seq( + (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0)))) + .once() + (fakeAnomalyDetector.detectWithExtendedResults _).expects(Vector(1.0, 2.0, 4.0), (2, 3)) + .returns(Seq( + (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(4.0, 4.0, isAnomaly = true, confidence = 1.0)))) + .once() + } + + // Get test AnalyzerContexts + val analysis = Analysis().addAnalyzer(Size()) + + val context11Rows = AnalysisRunner.run(getDfWithNRows(sparkSession, 11), analysis) + val context4Rows = AnalysisRunner.run(getDfWithNRows(sparkSession, 4), analysis) + val contextNoRows = AnalysisRunner.run(getDfEmpty(sparkSession), analysis) + + // Check isNewestPointNonAnomalousWithExtendedResults using Size + val sizeAnomalyCheck = Check(CheckLevel.Error, "anomaly test") + .isNewestPointNonAnomalousWithExtendedResults(repository, fakeAnomalyDetector, Size(), + Map("Region" -> "EU"), None, None) + + assert(sizeAnomalyCheck.evaluate(context11Rows).status == CheckStatus.Success) + assert(sizeAnomalyCheck.evaluate(context4Rows).status == CheckStatus.Error) + assert(sizeAnomalyCheck.evaluate(contextNoRows).status == CheckStatus.Error) + } + } + + "only use historic results after some dateTime if specified" in + withSparkSession { sparkSession => + evaluateWithRepository { repository => + // Fake Anomaly Detector + val fakeAnomalyDetector = mock[AnomalyDetectionStrategyWithExtendedResults] + inSequence { + // Size results + (fakeAnomalyDetector.detectWithExtendedResults _) + .expects(Vector(3.0, 4.0, 11.0), (2, 3)) + .returns(Seq( + (0, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0)))) + .once() + (fakeAnomalyDetector.detectWithExtendedResults _).expects(Vector(3.0, 4.0, 4.0), (2, 3)) + .returns(Seq( + (0, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(4.0, 4.0, isAnomaly = true, confidence = 1.0)))) + .once() + } + + // Get test AnalyzerContexts + val analysis = Analysis().addAnalyzer(Size()) + + val context11Rows = AnalysisRunner.run(getDfWithNRows(sparkSession, 11), analysis) + val context4Rows = AnalysisRunner.run(getDfWithNRows(sparkSession, 4), analysis) + val contextNoRows = AnalysisRunner.run(getDfEmpty(sparkSession), analysis) + + // Check isNewestPointNonAnomalousWithExtendedResults using Size + val sizeAnomalyCheck = Check(CheckLevel.Error, "anomaly test") + .isNewestPointNonAnomalousWithExtendedResults(repository, fakeAnomalyDetector, Size(), + Map.empty, Some(3), None) + + assert(sizeAnomalyCheck.evaluate(context11Rows).status == CheckStatus.Success) + assert(sizeAnomalyCheck.evaluate(context4Rows).status == CheckStatus.Error) + assert(sizeAnomalyCheck.evaluate(contextNoRows).status == CheckStatus.Error) + } + } + + "only use historic results before some dateTime if specified" in + withSparkSession { sparkSession => + evaluateWithRepository { repository => + // Fake Anomaly Detector + val fakeAnomalyDetector = mock[AnomalyDetectionStrategyWithExtendedResults] + inSequence { + // Size results + (fakeAnomalyDetector.detectWithExtendedResults _) + .expects(Vector(1.0, 2.0, 11.0), (2, 3)) + .returns(Seq( + (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0)))) + .once() + (fakeAnomalyDetector.detectWithExtendedResults _).expects(Vector(1.0, 2.0, 4.0), (2, 3)) + .returns(Seq( + (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(4.0, 4.0, isAnomaly = true, confidence = 1.0)))) + .once() + } + + // Get test AnalyzerContexts + val analysis = Analysis().addAnalyzer(Size()) + + val context11Rows = AnalysisRunner.run(getDfWithNRows(sparkSession, 11), analysis) + val context4Rows = AnalysisRunner.run(getDfWithNRows(sparkSession, 4), analysis) + val contextNoRows = AnalysisRunner.run(getDfEmpty(sparkSession), analysis) + + // Check isNewestPointNonAnomalousWithExtendedResults using Size + val sizeAnomalyCheck = Check(CheckLevel.Error, "anomaly test") + .isNewestPointNonAnomalousWithExtendedResults(repository, fakeAnomalyDetector, Size(), + Map.empty, None, Some(2)) + + assert(sizeAnomalyCheck.evaluate(context11Rows).status == CheckStatus.Success) + assert(sizeAnomalyCheck.evaluate(context4Rows).status == CheckStatus.Error) + assert(sizeAnomalyCheck.evaluate(contextNoRows).status == CheckStatus.Error) + } + } + } + + "getNewestPointAnomalyResults returns correct assertion result from anomaly detection data point sequence " + + "with multiple data points" in { + val anomalySequence: Seq[(Long, AnomalyDetectionDataPoint)] = + Seq( + (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), + (2, AnomalyDetectionDataPoint(11.0, 11.0, isAnomaly = true, confidence = 1.0))) + val result: AnomalyDetectionAssertionResult = + getNewestPointAnomalyResults(ExtendedDetectionResult(anomalySequence)) + assert(!result.hasNoAnomaly) + assert(result.anomalyDetectionExtendedResult.anomalyDetectionDataPoints.head == + AnomalyDetectionDataPoint(11.0, 11.0, isAnomaly = true, confidence = 1.0)) + } + + "getNewestPointAnomalyResults returns correct assertion result from anomaly detection data point sequence " + + "with one data point" in { + val anomalySequence: Seq[(Long, AnomalyDetectionDataPoint)] = + Seq( + (0, AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0))) + val result: AnomalyDetectionAssertionResult = + getNewestPointAnomalyResults(ExtendedDetectionResult(anomalySequence)) + assert(result.hasNoAnomaly) + assert(result.anomalyDetectionExtendedResult.anomalyDetectionDataPoints.head == + AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0)) + } + + "assert getNewestPointAnomalyResults throws exception from empty anomaly detection sequence" in { + val anomalySequence: Seq[(Long, AnomalyDetectionDataPoint)] = Seq() + intercept[IllegalArgumentException] { + getNewestPointAnomalyResults(ExtendedDetectionResult(anomalySequence)) + } + } + /** * Test for DataSync in verification suite. */ diff --git a/src/test/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraintTest.scala b/src/test/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraintTest.scala new file mode 100644 index 000000000..213123a74 --- /dev/null +++ b/src/test/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraintTest.scala @@ -0,0 +1,300 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.constraints + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.analyzers._ +import com.amazon.deequ.analyzers.runners.MetricCalculationException +import com.amazon.deequ.anomalydetection.{AnomalyDetectionAssertionResult, AnomalyDetectionDataPoint, AnomalyDetectionExtendedResult} +import com.amazon.deequ.constraints.ConstraintUtils.calculate +import com.amazon.deequ.metrics.{DoubleMetric, Entity, Metric} +import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.DataFrame +import org.scalamock.scalatest.MockFactory +import org.scalatest.{Matchers, PrivateMethodTester, WordSpec} + +import scala.util.{Failure, Try} + +class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with SparkContextSpec + with FixtureSupport with MockFactory with PrivateMethodTester { + + /** + * Sample function to use as value picker + * + * @return Returns input multiplied by 2 + */ + def valueDoubler(value: Double): Double = { + value * 2 + } + + /** + * Sample analyzer that returns a 1.0 value if the given column exists and fails otherwise. + */ + case class SampleAnalyzer(column: String) extends Analyzer[NumMatches, DoubleMetric] { + override def toFailureMetric(exception: Exception): DoubleMetric = { + DoubleMetric(Entity.Column, "sample", column, Failure(MetricCalculationException + .wrapIfNecessary(exception))) + } + + + override def calculate( + data: DataFrame, + stateLoader: Option[StateLoader], + statePersister: Option[StatePersister]) + : DoubleMetric = { + val value: Try[Double] = Try { + require(data.columns.contains(column), s"Missing column $column") + 1.0 + } + DoubleMetric(Entity.Column, "sample", column, value) + } + + override def computeStateFrom(data: DataFrame): Option[NumMatches] = { + throw new NotImplementedError() + } + + + override def computeMetricFrom(state: Option[NumMatches]): DoubleMetric = { + throw new NotImplementedError() + } + } + + "Anomaly extended results constraint" should { + + "assert correctly on values if analysis is successful" in + withSparkSession { sparkSession => + val df = getDfMissing(sparkSession) + + // Analysis result should equal to 1.0 for an existing column + + val anomalyAssertionFunctionA = (metric: Double) => { + AnomalyDetectionAssertionResult(metric == 1.0, + AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + } + + val resultA = calculate( + AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( + SampleAnalyzer("att1"), anomalyAssertionFunctionA), df) + + assert(resultA.status == ConstraintStatus.Success) + assert(resultA.message.isEmpty) + assert(resultA.metric.isDefined) + + val anomalyAssertionFunctionB = (metric: Double) => { + AnomalyDetectionAssertionResult(metric != 1.0, + AnomalyDetectionExtendedResult( + Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, isAnomaly = true)))) + } + + // Analysis result should equal to 1.0 for an existing column + val resultB = calculate(AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( + SampleAnalyzer("att1"), anomalyAssertionFunctionB), df) + + assert(resultB.status == ConstraintStatus.Failure) + assert(resultB.message.contains( + "Value: 1.0 does not meet the constraint requirement, check the anomaly detection metadata!")) + assert(resultB.metric.isDefined) + + val anomalyAssertionFunctionC = anomalyAssertionFunctionA + + // Analysis should fail for a non existing column + val resultC = calculate(AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( + SampleAnalyzer("someMissingColumn"), anomalyAssertionFunctionC), df) + + assert(resultC.status == ConstraintStatus.Failure) + assert(resultC.message.contains("requirement failed: Missing column someMissingColumn")) + assert(resultC.metric.isDefined) + } + + "execute value picker on the analysis result value, if provided" in + withSparkSession { sparkSession => + + + val df = getDfMissing(sparkSession) + + val anomalyAssertionFunctionA = (metric: Double) => { + AnomalyDetectionAssertionResult(metric == 2.0, + AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)))) + } + + // Analysis result should equal to 100.0 for an existing column + assert(calculate(AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( + SampleAnalyzer("att1"), anomalyAssertionFunctionA, Some(valueDoubler)), df).status == + ConstraintStatus.Success) + + val anomalyAssertionFunctionB = (metric: Double) => { + AnomalyDetectionAssertionResult(metric != 2.0, + AnomalyDetectionExtendedResult( + Seq(AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0, isAnomaly = true)))) + } + + assert(calculate(AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( + SampleAnalyzer("att1"), anomalyAssertionFunctionB, Some(valueDoubler)), df).status == + ConstraintStatus.Failure) + + val anomalyAssertionFunctionC = anomalyAssertionFunctionA + + // Analysis should fail for a non existing column + assert(calculate(AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( + SampleAnalyzer("someMissingColumn"), anomalyAssertionFunctionC, Some(valueDoubler)), df).status == + ConstraintStatus.Failure) + } + + "get the analysis from the context, if provided" in withSparkSession { sparkSession => + val df = getDfMissing(sparkSession) + + val emptyResults = Map.empty[Analyzer[_, Metric[_]], Metric[_]] + + val validResults = Map[Analyzer[_, Metric[_]], Metric[_]]( + SampleAnalyzer("att1") -> SampleAnalyzer("att1").calculate(df), + SampleAnalyzer("someMissingColumn") -> SampleAnalyzer("someMissingColumn").calculate(df) + ) + + val anomalyAssertionFunctionA = (metric: Double) => { + AnomalyDetectionAssertionResult(metric == 1.0, + AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + } + val anomalyAssertionFunctionB = (metric: Double) => { + AnomalyDetectionAssertionResult(metric != 1.0, + AnomalyDetectionExtendedResult( + Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, isAnomaly = true)))) + } + + // Analysis result should equal to 1.0 for an existing column + assert(AnomalyExtendedResultsConstraint[NumMatches, Double, Double] + (SampleAnalyzer("att1"), anomalyAssertionFunctionA) + .evaluate(validResults).status == ConstraintStatus.Success) + assert(AnomalyExtendedResultsConstraint[NumMatches, Double, Double] + (SampleAnalyzer("att1"), anomalyAssertionFunctionB) + .evaluate(validResults).status == ConstraintStatus.Failure) + assert(AnomalyExtendedResultsConstraint[NumMatches, Double, Double] + (SampleAnalyzer("someMissingColumn"), anomalyAssertionFunctionA) + .evaluate(validResults).status == ConstraintStatus.Failure) + + // Although assertion would pass, since analysis result is missing, + // constraint fails with missing analysis message + AnomalyExtendedResultsConstraint[NumMatches, Double, Double](SampleAnalyzer("att1"), anomalyAssertionFunctionA) + .evaluate(emptyResults) match { + case result => + assert(result.status == ConstraintStatus.Failure) + assert(result.message.contains("Missing Analysis, can't run the constraint!")) + assert(result.metric.isEmpty) + } + } + + "execute value picker on the analysis result value retrieved from context, if provided" in + withSparkSession { sparkSession => + val df = getDfMissing(sparkSession) + val validResults = Map[Analyzer[_, Metric[_]], Metric[_]]( + SampleAnalyzer("att1") -> SampleAnalyzer("att1").calculate(df)) + + val anomalyAssertionFunction = (metric: Double) => { + AnomalyDetectionAssertionResult(metric == 2.0, + AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)))) + } + + assert(AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( + SampleAnalyzer("att1"), anomalyAssertionFunction, Some(valueDoubler)) + .evaluate(validResults).status == ConstraintStatus.Success) + } + + + "fail on analysis if value picker is provided but fails" in withSparkSession { sparkSession => + def problematicValuePicker(value: Double): Double = { + throw new RuntimeException("Something wrong with this picker") + } + + val df = getDfMissing(sparkSession) + + val emptyResults = Map.empty[Analyzer[_, Metric[_]], Metric[_]] + val validResults = Map[Analyzer[_, Metric[_]], Metric[_]]( + SampleAnalyzer("att1") -> SampleAnalyzer("att1").calculate(df)) + + val anomalyAssertionFunction = (metric: Double) => { + AnomalyDetectionAssertionResult(metric == 1.0, + AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + } + val constraint = AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( + SampleAnalyzer("att1"), anomalyAssertionFunction, Some(problematicValuePicker)) + + calculate(constraint, df) match { + case result => + assert(result.status == ConstraintStatus.Failure) + assert(result.message.get.contains("Can't retrieve the value to assert on")) + assert(result.metric.isDefined) + } + + constraint.evaluate(validResults) match { + case result => + assert(result.status == ConstraintStatus.Failure) + assert(result.message.isDefined) + assert(result.message.get.startsWith("Can't retrieve the value to assert on")) + assert(result.metric.isDefined) + } + + constraint.evaluate(emptyResults) match { + case result => + assert(result.status == ConstraintStatus.Failure) + assert(result.message.contains("Missing Analysis, can't run the constraint!")) + assert(result.metric.isEmpty) + } + + } + + "fail on failed assertion function with hint in exception message if provided" in + withSparkSession { sparkSession => + + val df = getDfMissing(sparkSession) + + val anomalyAssertionFunction = (metric: Double) => { + AnomalyDetectionAssertionResult(metric == 0.9, + AnomalyDetectionExtendedResult( + Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, isAnomaly = true)))) + } + + val failingConstraint = AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( + SampleAnalyzer("att1"), anomalyAssertionFunction, hint = Some("Value should be like ...!")) + + calculate(failingConstraint, df) match { + case result => + assert(result.status == ConstraintStatus.Failure) + assert(result.message.isDefined) + assert(result.message.get == "Value: 1.0 does not meet the constraint requirement, " + + "check the anomaly detection metadata! Value should be like ...!") + assert(result.metric.isDefined) + } + } + + "return failed constraint for a failing assertion" in withSparkSession { session => + val msg = "-test-" + val exception = new RuntimeException(msg) + val df = getDfMissing(session) + + def failingAssertion(value: Double): AnomalyDetectionAssertionResult = throw exception + + val constraintResult = calculate( + AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( + SampleAnalyzer("att1"), failingAssertion), df + ) + + assert(constraintResult.status == ConstraintStatus.Failure) + assert(constraintResult.metric.isDefined) + assert(constraintResult.message.contains(s"Can't execute the assertion: $msg!")) + } + + } +} diff --git a/src/test/scala/com/amazon/deequ/constraints/ConstraintUtils.scala b/src/test/scala/com/amazon/deequ/constraints/ConstraintUtils.scala index 5782bc18c..27065cafb 100644 --- a/src/test/scala/com/amazon/deequ/constraints/ConstraintUtils.scala +++ b/src/test/scala/com/amazon/deequ/constraints/ConstraintUtils.scala @@ -21,12 +21,15 @@ import org.apache.spark.sql.DataFrame object ConstraintUtils { def calculate(constraint: Constraint, df: DataFrame): ConstraintResult = { - - val analysisBasedConstraint = constraint match { - case nc: ConstraintDecorator => nc.inner - case c: Constraint => c + val finalConstraint = constraint match { + case nc: ConstraintDecorator => nc.inner + case c: Constraint => c + } + finalConstraint match { + case _: AnalysisBasedConstraint[_, _, _] => + finalConstraint.asInstanceOf[AnalysisBasedConstraint[_, _, _]].calculateAndEvaluate(df) + case _: AnomalyExtendedResultsConstraint[_, _, _] => + finalConstraint.asInstanceOf[AnomalyExtendedResultsConstraint[_, _, _]].calculateAndEvaluate(df) } - - analysisBasedConstraint.asInstanceOf[AnalysisBasedConstraint[_, _, _]].calculateAndEvaluate(df) } } diff --git a/src/test/scala/com/amazon/deequ/constraints/ConstraintsTest.scala b/src/test/scala/com/amazon/deequ/constraints/ConstraintsTest.scala index e4a8ba898..ac426ef55 100644 --- a/src/test/scala/com/amazon/deequ/constraints/ConstraintsTest.scala +++ b/src/test/scala/com/amazon/deequ/constraints/ConstraintsTest.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DoubleType, StringType} import Constraint._ import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.anomalydetection.{AnomalyDetectionAssertionResult, AnomalyDetectionDataPoint, AnomalyDetectionExtendedResult} class ConstraintsTest extends WordSpec with Matchers with SparkContextSpec with FixtureSupport { @@ -174,4 +175,35 @@ class ConstraintsTest extends WordSpec with Matchers with SparkContextSpec with Completeness("att2"), _ < 0.7), df).status == ConstraintStatus.Failure) } } + + "Anomaly constraint with Extended Results" should { + "assert on anomaly analyzer values" in withSparkSession { sparkSession => + val df = getDfMissing(sparkSession) + assert(calculate(Constraint.anomalyConstraintWithExtendedResults[NumMatchesAndCount]( + Completeness("att1"), (metric: Double) => { + AnomalyDetectionAssertionResult(metric > 0.4, + AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + } ), df) + .status == ConstraintStatus.Success) + assert(calculate(Constraint.anomalyConstraintWithExtendedResults[NumMatchesAndCount]( + Completeness("att1"), (metric: Double) => { + AnomalyDetectionAssertionResult(metric < 0.4, + AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + }), df) + .status == ConstraintStatus.Failure) + + assert(calculate(Constraint.anomalyConstraintWithExtendedResults[NumMatchesAndCount]( + Completeness("att2"), (metric: Double) => { + AnomalyDetectionAssertionResult(metric > 0.7, + AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + }), df) + .status == ConstraintStatus.Success) + assert(calculate(Constraint.anomalyConstraintWithExtendedResults[NumMatchesAndCount]( + Completeness("att2"), (metric: Double) => { + AnomalyDetectionAssertionResult(metric < 0.7, + AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + }), df) + .status == ConstraintStatus.Failure) + } + } } diff --git a/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryAnomalyDetectionIntegrationTest.scala b/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryAnomalyDetectionIntegrationTest.scala index c73ac95b0..2cd475ac6 100644 --- a/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryAnomalyDetectionIntegrationTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryAnomalyDetectionIntegrationTest.scala @@ -57,9 +57,29 @@ class MetricsRepositoryAnomalyDetectionIntegrationTest extends AnyWordSpec with } } + "Anomaly Detection with Extended Results" should { + + "work using the InMemoryMetricsRepository" in withSparkSession { session => + + val repository = new InMemoryMetricsRepository() + + testAnomalyDetection(session, repository, useExtendedResults = true) + + } + + "work using the FileSystemMetricsRepository" in withSparkSession { session => + + val tempDir = TempFileUtils.tempDir("fileSystemRepositoryTest") + val repository = new FileSystemMetricsRepository(session, tempDir + "repository-test.json") + + testAnomalyDetection(session, repository, useExtendedResults = true) + } + } + private[this] def testAnomalyDetection( session: SparkSession, - repository: MetricsRepository) + repository: MetricsRepository, + useExtendedResults: Boolean = false) : Unit = { val data = getTestData(session) @@ -71,8 +91,15 @@ class MetricsRepositoryAnomalyDetectionIntegrationTest extends AnyWordSpec with val (otherCheck, additionalRequiredAnalyzers) = getNormalCheckAndRequiredAnalyzers() // This method is where the interesting stuff happens - val verificationResult = createAnomalyChecksAndRunEverything(data, repository, otherCheck, - additionalRequiredAnalyzers) + val verificationResult = + if (useExtendedResults) { + createAnomalyChecksWithExtendedResultsAndRunEverything( + data, repository, otherCheck, additionalRequiredAnalyzers) + } + else + { + createAnomalyChecksAndRunEverything(data, repository, otherCheck, additionalRequiredAnalyzers) + } printConstraintResults(verificationResult) @@ -189,6 +216,56 @@ class MetricsRepositoryAnomalyDetectionIntegrationTest extends AnyWordSpec with .run() } + private[this] def createAnomalyChecksWithExtendedResultsAndRunEverything( + data: DataFrame, + repository: MetricsRepository, + otherCheck: Check, + additionalRequiredAnalyzers: Seq[Analyzer[_, Metric[_]]]) + : VerificationResult = { + + // We only want to use historic data with the EU tag for the anomaly checks since the new + // data point is from the EU marketplace + val filterEU = Map("marketplace" -> "EU") + + // We only want to use data points before the date time associated with the current + // data point and only ones that are from 2018 + val afterDateTime = createDate(2018, 1, 1) + val beforeDateTime = createDate(2018, 8, 1) + + // Config for the size anomaly check + val sizeAnomalyCheckConfig = AnomalyCheckConfig(CheckLevel.Error, "Size only increases", + filterEU, Some(afterDateTime), Some(beforeDateTime)) + val sizeAnomalyDetectionStrategy = AbsoluteChangeStrategy(Some(0)) + + // Config for the mean sales anomaly check + val meanSalesAnomalyCheckConfig = AnomalyCheckConfig( + CheckLevel.Warning, + "Sales mean within 2 standard deviations", + filterEU, + Some(afterDateTime), + Some(beforeDateTime) + ) + val meanSalesAnomalyDetectionStrategy = OnlineNormalStrategy(upperDeviationFactor = Some(2), + ignoreAnomalies = false) + + // ResultKey to be used when saving the results of this run + val currentRunResultKey = ResultKey(createDate(2018, 8, 1), Map("marketplace" -> "EU")) + + VerificationSuite() + .onData(data) + .addCheck(otherCheck) + .addRequiredAnalyzers(additionalRequiredAnalyzers) + .useRepository(repository) + // Add the Size anomaly check + .addAnomalyCheckWithExtendedResults(sizeAnomalyDetectionStrategy, Size(), Some(sizeAnomalyCheckConfig)) + // Add the Mean sales anomaly check + .addAnomalyCheckWithExtendedResults(meanSalesAnomalyDetectionStrategy, Mean("sales"), + Some(meanSalesAnomalyCheckConfig)) + // Save new data point in the repository after we calculated everything + .saveOrAppendResult(currentRunResultKey) + .run() + } + private[this] def assertAnomalyCheckResultsAreCorrect( verificationResult: VerificationResult) : Unit = { From 6040cef090fe21c6d35e6a69d237d1f6978a7629 Mon Sep 17 00:00:00 2001 From: Vincent Chee Jia Hong <33974196+jhchee@users.noreply.github.com> Date: Fri, 9 Feb 2024 01:40:20 +0800 Subject: [PATCH 02/24] Add Spark 3.5 support (#514) * Add Spark 3.5 support * Replace with DataTypeUtils.fromAttributes * Remove unintended new line --- pom.xml | 4 ++-- .../deequ/analyzers/catalyst/StatefulHyperloglogPlus.scala | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index ad5a8c582..ccf2acef5 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ com.amazon.deequ deequ - 2.0.6-spark-3.4 + 2.0.6-spark-3.5 1.8 @@ -18,7 +18,7 @@ ${scala.major.version} 4.8.1 - 3.4.1 + 3.5.0 deequ diff --git a/src/main/scala/com/amazon/deequ/analyzers/catalyst/StatefulHyperloglogPlus.scala b/src/main/scala/com/amazon/deequ/analyzers/catalyst/StatefulHyperloglogPlus.scala index 52e175b17..28105a99b 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/catalyst/StatefulHyperloglogPlus.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/catalyst/StatefulHyperloglogPlus.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.expressions.aggregate.HLLConstants._ import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.types.DataTypeUtils /** Adjusted version of org.apache.spark.sql.catalyst.expressions.aggregate.HyperloglogPlus */ private[sql] case class StatefulHyperloglogPlus( @@ -59,7 +60,7 @@ private[sql] case class StatefulHyperloglogPlus( override def dataType: DataType = BinaryType - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override def aggBufferSchema: StructType = DataTypeUtils.fromAttributes(aggBufferAttributes) /** Allocate enough words to store all registers. */ override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(NUM_WORDS) { i => From 02ed720fb1b972e3e0691bdcc6fe6abb8638f6fd Mon Sep 17 00:00:00 2001 From: Hubert Date: Fri, 1 Nov 2024 19:16:57 -0400 Subject: [PATCH 03/24] fix merge conflicts --- .../analyzers/DataSynchronizationState.scala | 48 ------------ ...lyzer.scala => DatasetMatchAnalyzer.scala} | 54 +++++++------ .../deequ/analyzers/DatasetMatchState.scala | 46 +++++++++++ .../scala/com/amazon/deequ/checks/Check.scala | 58 +++++++------- .../deequ/comparison/ComparisonResult.scala | 6 +- .../comparison/DataSynchronization.scala | 32 ++++---- .../amazon/deequ/constraints/Constraint.scala | 6 +- .../amazon/deequ/VerificationSuiteTest.scala | 62 +++++++++++---- .../com/amazon/deequ/checks/CheckTest.scala | 78 +++++++++++++------ .../comparison/DataSynchronizationTest.scala | 52 ++++++------- 10 files changed, 257 insertions(+), 185 deletions(-) delete mode 100644 src/main/scala/com/amazon/deequ/analyzers/DataSynchronizationState.scala rename src/main/scala/com/amazon/deequ/analyzers/{DataSynchronizationAnalyzer.scala => DatasetMatchAnalyzer.scala} (51%) create mode 100644 src/main/scala/com/amazon/deequ/analyzers/DatasetMatchState.scala diff --git a/src/main/scala/com/amazon/deequ/analyzers/DataSynchronizationState.scala b/src/main/scala/com/amazon/deequ/analyzers/DataSynchronizationState.scala deleted file mode 100644 index e0321df35..000000000 --- a/src/main/scala/com/amazon/deequ/analyzers/DataSynchronizationState.scala +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not - * use this file except in compliance with the License. A copy of the License - * is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - * - */ - -package com.amazon.deequ.analyzers - -/** - * Represents the state of data synchronization between two DataFrames in Deequ. - * This state keeps track of the count of synchronized record count and the total record count. - * It is used to calculate a ratio of synchronization, which is a measure of how well the data - * in the two DataFrames are synchronized. - * - * @param synchronizedDataCount The count of records that are considered synchronized between the two DataFrames. - * @param totalDataCount The total count of records for check. - * - * The `sum` method allows for aggregation of this state with another, combining the counts from both states. - * This is useful in distributed computations where states from different partitions need to be aggregated. - * - * The `metricValue` method computes the synchronization ratio. It is the ratio of `synchronizedDataCount` - * to `dataCount`. - * If `dataCount` is zero, which means no data points were examined, the method returns `Double.NaN` - * to indicate the undefined state. - * - */ -case class DataSynchronizationState(synchronizedDataCount: Long, totalDataCount: Long) - extends DoubleValuedState[DataSynchronizationState] { - override def sum(other: DataSynchronizationState): DataSynchronizationState = { - DataSynchronizationState(synchronizedDataCount + other.synchronizedDataCount, totalDataCount + other.totalDataCount) - } - - override def metricValue(): Double = { - if (totalDataCount == 0L) Double.NaN else synchronizedDataCount.toDouble / totalDataCount.toDouble - } -} - -object DataSynchronizationState diff --git a/src/main/scala/com/amazon/deequ/analyzers/DataSynchronizationAnalyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala similarity index 51% rename from src/main/scala/com/amazon/deequ/analyzers/DataSynchronizationAnalyzer.scala rename to src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala index 1d7e37533..cdf0e5061 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/DataSynchronizationAnalyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala @@ -18,8 +18,8 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.metricFromFailure import com.amazon.deequ.comparison.DataSynchronization -import com.amazon.deequ.comparison.DataSynchronizationFailed -import com.amazon.deequ.comparison.DataSynchronizationSucceeded +import com.amazon.deequ.comparison.DatasetMatchFailed +import com.amazon.deequ.comparison.DatasetMatchSucceeded import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.Entity import org.apache.spark.sql.DataFrame @@ -29,59 +29,67 @@ import scala.util.Try /** - * An Analyzer for Deequ that performs a data synchronization check between two DataFrames. - * It evaluates the degree of synchronization based on specified column mappings and an assertion function. + * An Analyzer for Deequ that performs a dataset match check between two DataFrames. + * It evaluates the degree of match based on specified column mappings and an assertion function. * - * The analyzer computes a ratio of synchronized data points to the total data points, represented as a DoubleMetric. - * Refer to [[com.amazon.deequ.comparison.DataSynchronization.columnMatch]] for DataSynchronization implementation + * The analyzer computes a ratio of matched data points to the total data points, represented as a DoubleMetric. + * Refer to [[com.amazon.deequ.comparison.DataSynchronization.columnMatch]] for dataset match implementation * * @param dfToCompare The DataFrame to compare with the primary DataFrame that is setup * during [[com.amazon.deequ.VerificationSuite.onData]] setup. * @param columnMappings A map where each key-value pair represents a column in the primary DataFrame * and its corresponding column in dfToCompare. + * @param matchColumnMappings A map defining the column correlations between the current DataFrame and otherDf. + * These are the columns which we will check for equality, post joining. + * It's an optional value with defaults to None. * @param assertion A function that takes a Double (the match ratio) and returns a Boolean. * It defines the condition for successful synchronization. * * Usage: - * This analyzer is used in Deequ's VerificationSuite based if `isDataSynchronized` check is defined or could be used + * This analyzer is used in Deequ's VerificationSuite based if `doesDatasetMatch` check is defined or could be used * manually as well. * * Example: - * val analyzer = DataSynchronizationAnalyzer(dfToCompare, Map("col1" -> "col2"), _ > 0.8) + * val analyzer = DatasetMatchAnalyzer(dfToCompare, Map("col1" -> "col2"), _ > 0.8) * val verificationResult = VerificationSuite().onData(df).addAnalyzer(analyzer).run() * * // or could do something like below - * val verificationResult = VerificationSuite().onData(df).isDataSynchronized(dfToCompare, Map("col1" -> "col2"), + * val verificationResult = VerificationSuite().onData(df).doesDatasetMatch(dfToCompare, Map("col1" -> "col2"), * _ > 0.8).run() * * - * The computeStateFrom method calculates the synchronization state by comparing the specified columns of the two + * The computeStateFrom method calculates the datasetmatch state by comparing the specified columns of the two * DataFrames. - * The computeMetricFrom method then converts this state into a DoubleMetric representing the synchronization ratio. + * The computeMetricFrom method then converts this state into a DoubleMetric representing the match ratio. * */ -case class DataSynchronizationAnalyzer(dfToCompare: DataFrame, - columnMappings: Map[String, String], - assertion: Double => Boolean) - extends Analyzer[DataSynchronizationState, DoubleMetric] { +case class DatasetMatchAnalyzer(dfToCompare: DataFrame, + columnMappings: Map[String, String], + assertion: Double => Boolean, + matchColumnMappings: Option[Map[String, String]] = None) + extends Analyzer[DatasetMatchState, DoubleMetric] { - override def computeStateFrom(data: DataFrame): Option[DataSynchronizationState] = { + override def computeStateFrom(data: DataFrame): Option[DatasetMatchState] = { - val result = DataSynchronization.columnMatch(data, dfToCompare, columnMappings, assertion) + val result = if (matchColumnMappings.isDefined) { + DataSynchronization.columnMatch(data, dfToCompare, columnMappings, matchColumnMappings.get, assertion) + } else { + DataSynchronization.columnMatch(data, dfToCompare, columnMappings, assertion) + } result match { - case succeeded: DataSynchronizationSucceeded => - Some(DataSynchronizationState(succeeded.passedCount, succeeded.totalCount)) - case failed: DataSynchronizationFailed => - Some(DataSynchronizationState(failed.passedCount.getOrElse(0), failed.totalCount.getOrElse(0))) + case succeeded: DatasetMatchSucceeded => + Some(DatasetMatchState(succeeded.passedCount, succeeded.totalCount)) + case failed: DatasetMatchFailed => + Some(DatasetMatchState(failed.passedCount.getOrElse(0), failed.totalCount.getOrElse(0))) case _ => None } } - override def computeMetricFrom(state: Option[DataSynchronizationState]): DoubleMetric = { + override def computeMetricFrom(state: Option[DatasetMatchState]): DoubleMetric = { val metric = state match { - case Some(s) => Try(s.synchronizedDataCount.toDouble / s.totalDataCount.toDouble) + case Some(s) => Try(s.matchedDataCount.toDouble / s.totalDataCount.toDouble) case _ => Failure(new IllegalStateException("No state available for DataSynchronizationAnalyzer")) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchState.scala b/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchState.scala new file mode 100644 index 000000000..9e1c45e9b --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchState.scala @@ -0,0 +1,46 @@ +/** + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.analyzers + +/** + * Represents the state of datasetMatch between two DataFrames in Deequ. + * This state keeps track of the count of matched record count and the total record count. + * It measures how well the data in the two DataFrames matches. + * + * @param matchedDataCount The count of records that are considered match between the two DataFrames. + * @param totalDataCount The total count of records for check. + * + * The `sum` method allows for aggregation of this state with another, combining the counts from both states. + * This is useful in distributed computations where states from different partitions need to be aggregated. + * + * The `metricValue` method computes the synchronization ratio. It is the ratio of `matchedDataCount` to `dataCount`. + * If `dataCount` is zero, which means no data points were examined, the method returns `Double.NaN` to indicate + * the undefined state. + * + */ +case class DatasetMatchState(matchedDataCount: Long, totalDataCount: Long) + extends DoubleValuedState[DatasetMatchState] { + override def sum(other: DatasetMatchState): DatasetMatchState = { + DatasetMatchState(matchedDataCount + other.matchedDataCount, totalDataCount + other.totalDataCount) + } + + override def metricValue(): Double = { + if (totalDataCount == 0L) Double.NaN else matchedDataCount.toDouble / totalDataCount.toDouble + } +} + +object DatasetMatchState diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index 954dc763d..c38ee0e0f 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -1,5 +1,5 @@ /** - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License @@ -19,8 +19,8 @@ package com.amazon.deequ.checks import com.amazon.deequ.analyzers.runners.AnalyzerContext import com.amazon.deequ.analyzers.Analyzer import com.amazon.deequ.analyzers.AnalyzerOptions -import com.amazon.deequ.analyzers.DataSynchronizationAnalyzer -import com.amazon.deequ.analyzers.DataSynchronizationState +import com.amazon.deequ.analyzers.DatasetMatchAnalyzer +import com.amazon.deequ.analyzers.DatasetMatchState import com.amazon.deequ.analyzers.Histogram import com.amazon.deequ.analyzers.KLLParameters import com.amazon.deequ.analyzers.Patterns @@ -348,13 +348,13 @@ case class Check( } /** - * Performs a data synchronization check between the base DataFrame supplied to + * Performs a dataset check between the base DataFrame supplied to * [[com.amazon.deequ.VerificationSuite.onData]] and other DataFrame supplied to this check using Deequ's * [[com.amazon.deequ.comparison.DataSynchronization.columnMatch]] framework. - * This method compares specified columns of both DataFrames and assesses synchronization based on a custom assertion. + * This method compares specified columns of both DataFrames and assesses match based on a custom assertion. * - * Utilizes [[com.amazon.deequ.analyzers.DataSynchronizationAnalyzer]] for comparing the data - * and Constraint [[com.amazon.deequ.constraints.DataSynchronizationConstraint]]. + * Utilizes [[com.amazon.deequ.analyzers.DatasetMatchAnalyzer]] for comparing the data + * and Constraint [[com.amazon.deequ.constraints.DatasetMatchConstraint]]. * * Usage: * To use this method, create a VerificationSuite and invoke this method as part of adding checks: @@ -365,7 +365,7 @@ case class Check( * val assertionFunction: Double => Boolean = _ > 0.7 * * val check = new Check(CheckLevel.Error, "Data Synchronization Check") - * .isDataSynchronized(otherDataFrame, columnMappings, assertionFunction) + * .doesDatasetMatch(otherDataFrame, columnMappings, assertionFunction) * * val verificationResult = VerificationSuite() * .onData(baseDataFrame) @@ -373,29 +373,33 @@ case class Check( * .run() * }}} * - * This will add a data synchronization check to the VerificationSuite, comparing the specified columns of + * This will add a dataset match check to the VerificationSuite, comparing the specified columns of * baseDataFrame and otherDataFrame based on the provided assertion function. * - * - * @param otherDf The DataFrame to be compared with the current one. Analyzed in conjunction with the - * current DataFrame to assess data synchronization. - * @param columnMappings A map defining the column correlations between the current DataFrame and otherDf. - * Keys represent column names in the current DataFrame, - * and values are corresponding column names in otherDf. - * @param assertion A function that takes a Double (result of the comparison) and returns a Boolean. - * Defines the condition under which the data in both DataFrames is considered synchronized. - * For example (_ > 0.7) denoting metric value > 0.7 or 70% of records. - * @param hint Optional. Additional context or information about the synchronization check. - * Helpful for understanding the intent or specifics of the check. Default is None. - * @return A [[com.amazon.deequ.checks.Check]] object representing the outcome - * of the synchronization check. This object can be used in Deequ's verification suite to - * assert data quality constraints. + * @param otherDataset The DataFrame to be compared with the current one. Analyzed in conjunction with the + * current DataFrame to assess data synchronization. + * @param keyColumnMappings A map defining the column correlations between the current DataFrame and otherDf. + * Keys represent column names in the current DataFrame, and values are corresponding + * column names in otherDf. + * @param assertion A function that takes a Double (result of the comparison) and returns a Boolean. Defines the + * condition under which the data in both DataFrames is considered synchronized. For example + * (_ > 0.7) denoting metric value > 0.7 or 70% of records. + * @param matchColumnMappings A map defining the column correlations between the current DataFrame and otherDf. + * These are the columns which we will check for equality, post joining. It's an optional + * value with defaults to None, which will be derived from `keyColumnMappings` if None. + * @param hint Optional. Additional context or information about the synchronization check. + * Helpful for understanding the intent or specifics of the check. Default is None. + * @return A [[com.amazon.deequ.checks.Check]] object representing the outcome of the dataset match check. + * This object can be used in Deequ's verification suite to assert data quality constraints. * */ - def isDataSynchronized(otherDf: DataFrame, columnMappings: Map[String, String], assertion: Double => Boolean, - hint: Option[String] = None): Check = { - val dataSyncAnalyzer = DataSynchronizationAnalyzer(otherDf, columnMappings, assertion) - val constraint = AnalysisBasedConstraint[DataSynchronizationState, Double, Double](dataSyncAnalyzer, assertion, + def doesDatasetMatch(otherDataset: DataFrame, + keyColumnMappings: Map[String, String], + assertion: Double => Boolean, + matchColumnMappings: Option[Map[String, String]] = None, + hint: Option[String] = None): Check = { + val dataMatchAnalyzer = DatasetMatchAnalyzer(otherDataset, keyColumnMappings, assertion, matchColumnMappings) + val constraint = AnalysisBasedConstraint[DatasetMatchState, Double, Double](dataMatchAnalyzer, assertion, hint = hint) addConstraint(constraint) } diff --git a/src/main/scala/com/amazon/deequ/comparison/ComparisonResult.scala b/src/main/scala/com/amazon/deequ/comparison/ComparisonResult.scala index 67b4d4b47..643fb0360 100644 --- a/src/main/scala/com/amazon/deequ/comparison/ComparisonResult.scala +++ b/src/main/scala/com/amazon/deequ/comparison/ComparisonResult.scala @@ -21,6 +21,6 @@ sealed trait ComparisonResult case class ComparisonFailed(errorMessage: String, ratio: Double = 0) extends ComparisonResult case class ComparisonSucceeded(ratio: Double = 0) extends ComparisonResult -case class DataSynchronizationFailed(errorMessage: String, passedCount: Option[Long] = None, - totalCount: Option[Long] = None) extends ComparisonResult -case class DataSynchronizationSucceeded(passedCount: Long, totalCount: Long) extends ComparisonResult +case class DatasetMatchFailed(errorMessage: String, passedCount: Option[Long] = None, + totalCount: Option[Long] = None) extends ComparisonResult +case class DatasetMatchSucceeded(passedCount: Long, totalCount: Long) extends ComparisonResult diff --git a/src/main/scala/com/amazon/deequ/comparison/DataSynchronization.scala b/src/main/scala/com/amazon/deequ/comparison/DataSynchronization.scala index 992dc48d0..de207823c 100644 --- a/src/main/scala/com/amazon/deequ/comparison/DataSynchronization.scala +++ b/src/main/scala/com/amazon/deequ/comparison/DataSynchronization.scala @@ -102,13 +102,13 @@ object DataSynchronization extends ComparisonBase { val nonKeyColsMatch = colsDS1.forall(columnExists(ds2, _)) if (!nonKeyColsMatch) { - DataSynchronizationFailed("Non key columns in the given data frames do not match.") + DatasetMatchFailed("Non key columns in the given data frames do not match.") } else { val mergedMaps = colKeyMap ++ colsDS1.map(x => x -> x).toMap finalAssertion(ds1, ds2, mergedMaps, assertion) } } else { - DataSynchronizationFailed(columnErrors.get) + DatasetMatchFailed(columnErrors.get) } } @@ -138,17 +138,17 @@ object DataSynchronization extends ComparisonBase { val nonKeyColumns2NotInDataset = compCols.values.filterNot(columnExists(ds2, _)) if (nonKeyColumns1NotInDataset.nonEmpty) { - DataSynchronizationFailed(s"The following columns were not found in the first dataset: " + + DatasetMatchFailed(s"The following columns were not found in the first dataset: " + s"${nonKeyColumns1NotInDataset.mkString(", ")}") } else if (nonKeyColumns2NotInDataset.nonEmpty) { - DataSynchronizationFailed(s"The following columns were not found in the second dataset: " + + DatasetMatchFailed(s"The following columns were not found in the second dataset: " + s"${nonKeyColumns2NotInDataset.mkString(", ")}") } else { val mergedMaps = colKeyMap ++ compCols finalAssertion(ds1, ds2, mergedMaps, assertion) } } else { - DataSynchronizationFailed(keyColumnErrors.get) + DatasetMatchFailed(keyColumnErrors.get) } } @@ -157,23 +157,23 @@ object DataSynchronization extends ComparisonBase { colKeyMap: Map[String, String], optionalCompCols: Option[Map[String, String]] = None, optionalOutcomeColumnName: Option[String] = None): - Either[DataSynchronizationFailed, DataFrame] = { + Either[DatasetMatchFailed, DataFrame] = { val columnErrors = areKeyColumnsValid(ds1, ds2, colKeyMap) if (columnErrors.isEmpty) { - val compColsEither: Either[DataSynchronizationFailed, Map[String, String]] = if (optionalCompCols.isDefined) { + val compColsEither: Either[DatasetMatchFailed, Map[String, String]] = if (optionalCompCols.isDefined) { optionalCompCols.get match { - case compCols if compCols.isEmpty => Left(DataSynchronizationFailed("Empty column comparison map provided.")) + case compCols if compCols.isEmpty => Left(DatasetMatchFailed("Empty column comparison map provided.")) case compCols => val ds1CompColsNotInDataset = compCols.keys.filterNot(columnExists(ds1, _)) val ds2CompColsNotInDataset = compCols.values.filterNot(columnExists(ds2, _)) if (ds1CompColsNotInDataset.nonEmpty) { Left( - DataSynchronizationFailed(s"The following columns were not found in the first dataset: " + + DatasetMatchFailed(s"The following columns were not found in the first dataset: " + s"${ds1CompColsNotInDataset.mkString(", ")}") ) } else if (ds2CompColsNotInDataset.nonEmpty) { Left( - DataSynchronizationFailed(s"The following columns were not found in the second dataset: " + + DatasetMatchFailed(s"The following columns were not found in the second dataset: " + s"${ds2CompColsNotInDataset.mkString(", ")}") ) } else { @@ -186,7 +186,7 @@ object DataSynchronization extends ComparisonBase { val nonKeyColsMatch = ds1NonKeyCols.forall(columnExists(ds2, _)) if (!nonKeyColsMatch) { - Left(DataSynchronizationFailed("Non key columns in the given data frames do not match.")) + Left(DatasetMatchFailed("Non key columns in the given data frames do not match.")) } else { Right(ds1NonKeyCols.map { c => c -> c}.toMap) } @@ -198,11 +198,11 @@ object DataSynchronization extends ComparisonBase { case Success(df) => Right(df) case Failure(ex) => ex.printStackTrace() - Left(DataSynchronizationFailed(s"Comparison failed due to ${ex.getCause.getClass}")) + Left(DatasetMatchFailed(s"Comparison failed due to ${ex.getCause.getClass}")) } } } else { - Left(DataSynchronizationFailed(columnErrors.get)) + Left(DatasetMatchFailed(columnErrors.get)) } } @@ -255,7 +255,7 @@ object DataSynchronization extends ComparisonBase { val ds2Count = ds2.count() if (ds1Count != ds2Count) { - DataSynchronizationFailed(s"The row counts of the two data frames do not match.") + DatasetMatchFailed(s"The row counts of the two data frames do not match.") } else { val joinExpression: Column = mergedMaps .map { case (col1, col2) => ds1(col1) === ds2(col2)} @@ -267,9 +267,9 @@ object DataSynchronization extends ComparisonBase { val ratio = passedCount.toDouble / totalCount.toDouble if (assertion(ratio)) { - DataSynchronizationSucceeded(passedCount, totalCount) + DatasetMatchSucceeded(passedCount, totalCount) } else { - DataSynchronizationFailed(s"Data Synchronization Comparison Metric Value: $ratio does not meet the constraint" + + DatasetMatchFailed(s"Data Synchronization Comparison Metric Value: $ratio does not meet the constraint" + s"requirement.", Some(passedCount), Some(totalCount)) } } diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index df020eb2f..c910325c9 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -938,17 +938,17 @@ object Constraint { } /** - * Data Synchronization Constraint + * DatasetMatch Constraint * @param analyzer Data Synchronization Analyzer * @param hint hint */ -case class DataSynchronizationConstraint(analyzer: DataSynchronizationAnalyzer, hint: Option[String]) +case class DatasetMatchConstraint(analyzer: DatasetMatchAnalyzer, hint: Option[String]) extends Constraint { override def evaluate(metrics: Map[Analyzer[_, Metric[_]], Metric[_]]): ConstraintResult = { metrics.collectFirst { - case (_: DataSynchronizationAnalyzer, metric: Metric[Double]) => metric + case (_: DatasetMatchAnalyzer, metric: Metric[Double]) => metric } match { case Some(metric) => val result = metric.value match { diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 25a623c2a..1cc09b811 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -1,5 +1,5 @@ /** - * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License @@ -35,6 +35,7 @@ import com.amazon.deequ.utils.TempFileUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.when import org.scalamock.scalatest.MockFactory import org.scalatest.Matchers @@ -943,7 +944,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .hasCompleteness("fake", x => x > 0) val checkHasDataInSyncTest = Check(CheckLevel.Error, "shouldSucceedForAge") - .isDataSynchronized(df, Map("age" -> "age"), _ > 0.99, Some("shouldPass")) + .doesDatasetMatch(df, Map("age" -> "age"), _ > 0.99, hint = Some("shouldPass")) val verificationResult = VerificationSuite() .onData(df) @@ -1125,30 +1126,30 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val dfColRenamed = df.withColumnRenamed("id", "id_renamed") val dataSyncCheckPass = Check(CheckLevel.Error, "data synchronization check pass") - .isDataSynchronized(dfModified, Map("id" -> "id"), _ > 0.7, Some("shouldPass")) + .doesDatasetMatch(dfModified, Map("id" -> "id"), _ > 0.7, hint = Some("shouldPass")) val dataSyncCheckFail = Check(CheckLevel.Error, "data synchronization check fail") - .isDataSynchronized(dfModified, Map("id" -> "id"), _ > 0.9, Some("shouldFail")) + .doesDatasetMatch(dfModified, Map("id" -> "id"), _ > 0.9, hint = Some("shouldFail")) val emptyDf = sparkSession.createDataFrame(sparkSession.sparkContext.emptyRDD[Row], df.schema) val dataSyncCheckEmpty = Check(CheckLevel.Error, "data synchronization check on empty DataFrame") - .isDataSynchronized(emptyDf, Map("id" -> "id"), _ < 0.5) + .doesDatasetMatch(emptyDf, Map("id" -> "id"), _ < 0.5) val dataSyncCheckColMismatchDestination = Check(CheckLevel.Error, "data synchronization check col mismatch in destination") - .isDataSynchronized(dfModified, Map("id" -> "id2"), _ < 0.5) + .doesDatasetMatch(dfModified, Map("id" -> "id2"), _ < 0.5) val dataSyncCheckColMismatchSource = Check(CheckLevel.Error, "data synchronization check col mismatch in source") - .isDataSynchronized(dfModified, Map("id2" -> "id"), _ < 0.5) + .doesDatasetMatch(dfModified, Map("id2" -> "id"), _ < 0.5) val dataSyncCheckColRenamed = Check(CheckLevel.Error, "data synchronization check col names renamed") - .isDataSynchronized(dfColRenamed, Map("id" -> "id_renamed"), _ == 1.0) + .doesDatasetMatch(dfColRenamed, Map("id" -> "id_renamed"), _ == 1.0) val dataSyncFullMatch = Check(CheckLevel.Error, "data synchronization check full match") - .isDataSynchronized(df, Map("id" -> "id"), _ == 1.0) + .doesDatasetMatch(df, Map("id" -> "id"), _ == 1.0) val verificationResult = VerificationSuite() @@ -1205,32 +1206,46 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val dfColRenamed = df.withColumnRenamed("id", "id_renamed") val colMap = Map("id" -> "id", "product" -> "product") + // Additional DataFrames for testing matchColumnMappings + val dfWithAdditionalColumns = df.withColumn("newColumn", lit(1)) + + val matchColMap = Map("product" -> "product") + val dataSyncCheckWithMatchColumns = Check(CheckLevel.Error, + "data synchronization check with matchColumnMappings") + .doesDatasetMatch(df, colMap, _ > 0.7, Some(matchColMap), + hint = Some("Check with matchColumnMappings")) + + val dataSyncCheckWithAdditionalCols = Check(CheckLevel.Error, + "data synchronization check with additional columns") + .doesDatasetMatch(dfWithAdditionalColumns, colMap, _ > 0.7, Some(matchColMap), + hint = Some("Check with additional columns and matchColumnMappings")) + val dataSyncCheckPass = Check(CheckLevel.Error, "data synchronization check") - .isDataSynchronized(dfModified, colMap, _ > 0.7, Some("shouldPass")) + .doesDatasetMatch(dfModified, colMap, _ > 0.7, hint = Some("shouldPass")) val dataSyncCheckFail = Check(CheckLevel.Error, "data synchronization check") - .isDataSynchronized(dfModified, colMap, _ > 0.9, Some("shouldFail")) + .doesDatasetMatch(dfModified, colMap, _ > 0.9, hint = Some("shouldFail")) val emptyDf = sparkSession.createDataFrame(sparkSession.sparkContext.emptyRDD[Row], df.schema) val dataSyncCheckEmpty = Check(CheckLevel.Error, "data synchronization check on empty DataFrame") - .isDataSynchronized(emptyDf, colMap, _ < 0.5) + .doesDatasetMatch(emptyDf, colMap, _ < 0.5) val dataSyncCheckColMismatchDestination = Check(CheckLevel.Error, "data synchronization check col mismatch in destination") - .isDataSynchronized(dfModified, colMap, _ > 0.9) + .doesDatasetMatch(dfModified, colMap, _ > 0.9) val dataSyncCheckColMismatchSource = Check(CheckLevel.Error, "data synchronization check col mismatch in source") - .isDataSynchronized(dfModified, Map("id2" -> "id", "product" -> "product"), _ < 0.5) + .doesDatasetMatch(dfModified, Map("id2" -> "id", "product" -> "product"), _ < 0.5) val dataSyncCheckColRenamed = Check(CheckLevel.Error, "data synchronization check col names renamed") - .isDataSynchronized(dfColRenamed, Map("id" -> "id_renamed", "product" -> "product"), _ == 1.0, - Some("shouldPass")) + .doesDatasetMatch(dfColRenamed, Map("id" -> "id_renamed", "product" -> "product"), _ == 1.0, + hint = Some("shouldPass")) val dataSyncFullMatch = Check(CheckLevel.Error, "data synchronization check col full match") - .isDataSynchronized(df, colMap, _ == 1, Some("shouldPass")) + .doesDatasetMatch(df, colMap, _ == 1, hint = Some("shouldPass")) val verificationResult = VerificationSuite() @@ -1242,6 +1257,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .addCheck(dataSyncCheckColMismatchSource) .addCheck(dataSyncCheckColRenamed) .addCheck(dataSyncFullMatch) + .addCheck(dataSyncCheckWithMatchColumns) + .addCheck(dataSyncCheckWithAdditionalCols) .run() val passResult = verificationResult.checkResults(dataSyncCheckPass) @@ -1279,6 +1296,17 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec List(None) assert(fullMatchResult.status == CheckStatus.Success) + // Assertions for the new checks + val matchColumnsResult = verificationResult.checkResults(dataSyncCheckWithMatchColumns) + matchColumnsResult.constraintResults.map(_.message) shouldBe + List(None) // or any expected result + assert(matchColumnsResult.status == CheckStatus.Success) // or expected status + + val additionalColsResult = verificationResult.checkResults(dataSyncCheckWithAdditionalCols) + additionalColsResult.constraintResults.map(_.message) shouldBe + List(None) // or any expected result + assert(additionalColsResult.status == CheckStatus.Success) // or expected status + } } diff --git a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala index bb565258c..43657d7ce 100644 --- a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala +++ b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala @@ -1,5 +1,5 @@ /** - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License @@ -18,22 +18,35 @@ package com.amazon.deequ package checks import com.amazon.deequ.analyzers._ -import com.amazon.deequ.analyzers.runners.{AnalysisRunner, AnalyzerContext} -import com.amazon.deequ.anomalydetection.{Anomaly, AnomalyDetectionAssertionResult, AnomalyDetectionDataPoint, AnomalyDetectionStrategy, AnomalyDetectionStrategyWithExtendedResults, ExtendedDetectionResult} +import com.amazon.deequ.analyzers.runners.AnalysisRunner +import com.amazon.deequ.analyzers.runners.AnalyzerContext +import com.amazon.deequ.anomalydetection.Anomaly +import com.amazon.deequ.anomalydetection.AnomalyDetectionAssertionResult +import com.amazon.deequ.anomalydetection.AnomalyDetectionDataPoint +import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategy +import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategyWithExtendedResults +import com.amazon.deequ.anomalydetection.ExtendedDetectionResult import com.amazon.deequ.checks.Check.getNewestPointAnomalyResults -import com.amazon.deequ.constraints.{ConstrainableDataTypes, ConstraintStatus} -import com.amazon.deequ.metrics.{DoubleMetric, Entity} +import com.amazon.deequ.constraints.ConstrainableDataTypes +import com.amazon.deequ.constraints.ConstraintStatus +import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.metrics.Entity import com.amazon.deequ.repository.memory.InMemoryMetricsRepository -import com.amazon.deequ.repository.{MetricsRepository, ResultKey} +import com.amazon.deequ.repository.MetricsRepository +import com.amazon.deequ.repository.ResultKey import com.amazon.deequ.utils.FixtureSupport -import org.apache.spark.sql.functions.{col, when} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.when import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.SparkSession import org.scalamock.scalatest.MockFactory import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -import scala.util.{Success, Try} +import scala.util.Success +import scala.util.Try class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport with MockFactory { @@ -1349,10 +1362,16 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix val dfInformative = getDfWithConditionallyInformativeColumns(sparkSession) val check = Check(CheckLevel.Error, "must have data in sync") - .isDataSynchronized(dfInformative, colMapAtt1, _ > 0.9, Some("show be in sync")) + .doesDatasetMatch(dfInformative, colMapAtt1, _ > 0.9, hint = Some("show be in sync")) val context = runChecks(dfInformative, check) assertSuccess(check, context) + + val check2 = Check(CheckLevel.Error, "must have data in sync") + .doesDatasetMatch(dfInformative, colMapAtt1, _ > 0.9, Some(colMapAtt1), Some("show be in sync with match col")) + val context2 = runChecks(dfInformative, check2) + + assertSuccess(check2, context2) } "yield failure when column doesnt exist in data sync test for 1 col" in withSparkSession { sparkSession => @@ -1360,10 +1379,11 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix val dfInformativeRenamed = dfInformative.withColumnRenamed("att1", "att1_renamed") val check = Check(CheckLevel.Error, "must fail as columns does not exist") - .isDataSynchronized(dfInformativeRenamed, colMapAtt1, _ > 0.9, Some("must fail as columns does not exist")) + .doesDatasetMatch(dfInformativeRenamed, colMapAtt1, _ > 0.9, + hint = Some("must fail as columns does not exist")) val context = runChecks(dfInformative, check) assertEvaluatesTo(check, context, CheckStatus.Error) - println(context) + } "yield failure when row count varies in data sync test for 1 col" in withSparkSession { sparkSession => @@ -1371,7 +1391,8 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix val dfInformativeFiltered = dfInformative.filter("att1 > 2") val check = Check(CheckLevel.Error, "must fail as columns does not exist") - .isDataSynchronized(dfInformativeFiltered, colMapAtt1, _ > 0.9, Some("must fail as columns does not exist")) + .doesDatasetMatch(dfInformativeFiltered, colMapAtt1, _ > 0.9, + hint = Some("must fail as columns does not exist")) val context = runChecks(dfInformative, check) assertEvaluatesTo(check, context, CheckStatus.Error) } @@ -1382,7 +1403,7 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix .otherwise(col("att1"))) val check = Check(CheckLevel.Error, "must fail as rows mismatches") - .isDataSynchronized(modifiedDf, colMapAtt1, _ > 0.9, Some("must fail as rows mismatches")) + .doesDatasetMatch(modifiedDf, colMapAtt1, _ > 0.9, hint = Some("must fail as rows mismatches")) val context = runChecks(df, check) assertEvaluatesTo(check, context, CheckStatus.Error) @@ -1394,8 +1415,8 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix .otherwise(col("att1"))) val check = Check(CheckLevel.Error, "must be success as rows count mismatches at assertion 0.6") - .isDataSynchronized(modifiedDf, colMapAtt1, _ > 0.6, - Some("must be success as rows count mismatches at assertion 0.6")) + .doesDatasetMatch(modifiedDf, colMapAtt1, _ > 0.6, + hint = Some("must be success as rows count mismatches at assertion 0.6")) val context = runChecks(df, check) assertSuccess(check, context) } @@ -1405,19 +1426,31 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix val dfInformative = getDfWithConditionallyInformativeColumns(sparkSession) val check = Check(CheckLevel.Error, "must have data in sync") - .isDataSynchronized(dfInformative, colMapTwoCols, _ > 0.9, Some("show be in sync")) + .doesDatasetMatch(dfInformative, colMapTwoCols, _ > 0.9, hint = Some("show be in sync")) val context = runChecks(dfInformative, check) assertSuccess(check, context) } + "yield success for basic data sync test for multiple columns and one col match" in + withSparkSession { sparkSession => + val dfInformative = getDfWithConditionallyInformativeColumns(sparkSession) + + val check = Check(CheckLevel.Error, "must have data in sync") + .doesDatasetMatch(dfInformative, colMapTwoCols, _ > 0.9, Some(colMapAtt1), hint = Some("show be in sync")) + val context = runChecks(dfInformative, check) + + assertSuccess(check, context) + } + "yield failure when column doesnt exist in data sync test for multiple columns" in withSparkSession { sparkSession => val dfInformative = getDfWithConditionallyInformativeColumns(sparkSession) val dfInformativeRenamed = dfInformative.withColumnRenamed("att1", "att1_renamed") val check = Check(CheckLevel.Error, "must fail as columns does not exist") - .isDataSynchronized(dfInformativeRenamed, colMapTwoCols, _ > 0.9, Some("must fail as columns does not exist")) + .doesDatasetMatch(dfInformativeRenamed, colMapTwoCols, _ > 0.9, + hint = Some("must fail as columns does not exist")) val context = runChecks(dfInformative, check) assertEvaluatesTo(check, context, CheckStatus.Error) @@ -1428,7 +1461,8 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix val dfInformativeFiltered = dfInformative.filter("att1 > 2") val check = Check(CheckLevel.Error, "must fail as columns does not exist") - .isDataSynchronized(dfInformativeFiltered, colMapTwoCols, _ > 0.9, Some("must fail as columns does not exist")) + .doesDatasetMatch(dfInformativeFiltered, colMapTwoCols, _ > 0.9, + hint = Some("must fail as columns does not exist")) val context = runChecks(dfInformative, check) assertEvaluatesTo(check, context, CheckStatus.Error) @@ -1440,7 +1474,7 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix .otherwise(col("att1"))) val check = Check(CheckLevel.Error, "must fail as rows mismatches") - .isDataSynchronized(modifiedDf, colMapTwoCols, _ > 0.9, Some("must fail as rows mismatches")) + .doesDatasetMatch(modifiedDf, colMapTwoCols, _ > 0.9, hint = Some("must fail as rows mismatches")) val context = runChecks(df, check) assertEvaluatesTo(check, context, CheckStatus.Error) @@ -1453,8 +1487,8 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix .otherwise(col("att1"))) val check = Check(CheckLevel.Error, "must be success as metric value is 0.66") - .isDataSynchronized(modifiedDf, colMapTwoCols, _ > 0.6, - Some("must be success as metric value is 0.66")) + .doesDatasetMatch(modifiedDf, colMapTwoCols, _ > 0.6, + hint = Some("must be success as metric value is 0.66")) val context = runChecks(df, check) assertSuccess(check, context) diff --git a/src/test/scala/com/amazon/deequ/comparison/DataSynchronizationTest.scala b/src/test/scala/com/amazon/deequ/comparison/DataSynchronizationTest.scala index dd3a002da..f7b7e30f1 100644 --- a/src/test/scala/com/amazon/deequ/comparison/DataSynchronizationTest.scala +++ b/src/test/scala/com/amazon/deequ/comparison/DataSynchronizationTest.scala @@ -57,7 +57,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.60 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compCols, assertion) - assert(result.isInstanceOf[DataSynchronizationSucceeded]) + assert(result.isInstanceOf[DatasetMatchSucceeded]) } "match == 0.83 when id is colKey and state is compCols" in withSparkSession { spark => @@ -88,7 +88,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.80 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compCols, assertion) - assert(result.isInstanceOf[DataSynchronizationSucceeded]) + assert(result.isInstanceOf[DatasetMatchSucceeded]) } "return false because col name isn't unique" in withSparkSession { spark => @@ -119,7 +119,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.66 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compCols, assertion) - assert(result.isInstanceOf[DataSynchronizationFailed]) + assert(result.isInstanceOf[DatasetMatchFailed]) } "match >= 0.66 when id is unique col, rest compCols" in withSparkSession { spark => @@ -150,7 +150,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.60 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compCols, assertion) - assert(result.isInstanceOf[DataSynchronizationSucceeded]) + assert(result.isInstanceOf[DatasetMatchSucceeded]) } "match >= 0.66 (same test as above only the data sets change)" in withSparkSession{ spark => @@ -181,7 +181,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.60 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compCols, assertion) - assert(result.isInstanceOf[DataSynchronizationSucceeded]) + assert(result.isInstanceOf[DatasetMatchSucceeded]) } "return false because the id col in ds1 isn't unique" in withSparkSession { spark => @@ -213,7 +213,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.40 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compCols, assertion) - assert(result.asInstanceOf[DataSynchronizationFailed].errorMessage == + assert(result.asInstanceOf[DatasetMatchFailed].errorMessage == "The selected columns are not comparable due to duplicates present in the dataset." + "Comparison keys must be unique, but in Dataframe 1, there are 6 unique records and 7 rows, " + "and in Dataframe 2, there are 6 unique records and 6 rows, based on the combination of keys " + @@ -249,7 +249,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.40 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compCols, assertion) - assert(result.isInstanceOf[DataSynchronizationFailed]) + assert(result.isInstanceOf[DatasetMatchFailed]) } "return false because col state isn't unique" in withSparkSession { spark => @@ -280,7 +280,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.66 val result = (DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compCols, assertion)) - assert(result.isInstanceOf[DataSynchronizationFailed]) + assert(result.isInstanceOf[DatasetMatchFailed]) } "check all columns and return an assertion of .66" in withSparkSession { spark => @@ -310,7 +310,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.66 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, assertion) - assert(result.isInstanceOf[DataSynchronizationSucceeded]) + assert(result.isInstanceOf[DatasetMatchSucceeded]) } "return false because state column isn't unique" in withSparkSession { spark => @@ -340,7 +340,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.66 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, assertion) - assert(result.isInstanceOf[DataSynchronizationFailed]) + assert(result.isInstanceOf[DatasetMatchFailed]) } "check all columns" in withSparkSession { spark => @@ -370,7 +370,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0.66 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, assertion) - assert(result.isInstanceOf[DataSynchronizationSucceeded]) + assert(result.isInstanceOf[DatasetMatchSucceeded]) } "cols exist but 0 matches" in withSparkSession { spark => import spark.implicits._ @@ -399,7 +399,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val assertion: Double => Boolean = _ >= 0 val result = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, assertion) - assert(result.isInstanceOf[DataSynchronizationSucceeded]) + assert(result.isInstanceOf[DatasetMatchSucceeded]) } } @@ -643,7 +643,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { // Overall val assertion: Double => Boolean = _ >= 0.6 // 4 out of 6 rows match val overallResult = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, assertion) - assert(overallResult.isInstanceOf[DataSynchronizationSucceeded]) + assert(overallResult.isInstanceOf[DatasetMatchSucceeded]) // Row Level val outcomeColName = "outcome" @@ -670,7 +670,7 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { // Overall val assertion: Double => Boolean = _ >= 0.6 // 4 out of 6 rows match val overallResult = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, assertion) - assert(overallResult.isInstanceOf[DataSynchronizationSucceeded]) + assert(overallResult.isInstanceOf[DatasetMatchSucceeded]) // Row Level val outcomeColName = "outcome" @@ -700,8 +700,8 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val colKeyMap1 = Map(nonExistCol1 -> nonExistCol2) val overallResult1 = DataSynchronization.columnMatch(ds1, ds2, colKeyMap1, assertion) - assert(overallResult1.isInstanceOf[DataSynchronizationFailed]) - val failedOverallResult1 = overallResult1.asInstanceOf[DataSynchronizationFailed] + assert(overallResult1.isInstanceOf[DatasetMatchFailed]) + val failedOverallResult1 = overallResult1.asInstanceOf[DatasetMatchFailed] assert(failedOverallResult1.errorMessage.contains("key columns were not found in the first dataset")) assert(failedOverallResult1.errorMessage.contains(nonExistCol1)) @@ -716,8 +716,8 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val colKeyMap2 = Map(nonExistCol1 -> idColumnName) val overallResult2 = DataSynchronization.columnMatch(ds1, ds2, colKeyMap2, assertion) - assert(overallResult2.isInstanceOf[DataSynchronizationFailed]) - val failedOverallResult2 = overallResult2.asInstanceOf[DataSynchronizationFailed] + assert(overallResult2.isInstanceOf[DatasetMatchFailed]) + val failedOverallResult2 = overallResult2.asInstanceOf[DatasetMatchFailed] assert(failedOverallResult2.errorMessage.contains("key columns were not found in the first dataset")) assert(failedOverallResult2.errorMessage.contains(nonExistCol1)) @@ -732,8 +732,8 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val colKeyMap3 = Map(idColumnName -> nonExistCol2) val overallResult3 = DataSynchronization.columnMatch(ds1, ds2, colKeyMap3, assertion) - assert(overallResult3.isInstanceOf[DataSynchronizationFailed]) - val failedOverallResult3 = overallResult3.asInstanceOf[DataSynchronizationFailed] + assert(overallResult3.isInstanceOf[DatasetMatchFailed]) + val failedOverallResult3 = overallResult3.asInstanceOf[DatasetMatchFailed] assert(failedOverallResult3.errorMessage.contains("key columns were not found in the second dataset")) assert(failedOverallResult3.errorMessage.contains(nonExistCol2)) @@ -759,8 +759,8 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val compColsMap1 = Map(nonExistCol1 -> nonExistCol2) val overallResult1 = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compColsMap1, assertion) - assert(overallResult1.isInstanceOf[DataSynchronizationFailed]) - val failedOverallResult1 = overallResult1.asInstanceOf[DataSynchronizationFailed] + assert(overallResult1.isInstanceOf[DatasetMatchFailed]) + val failedOverallResult1 = overallResult1.asInstanceOf[DatasetMatchFailed] assert(failedOverallResult1.errorMessage.contains( s"The following columns were not found in the first dataset: $nonExistCol1")) @@ -775,8 +775,8 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val compColsMap2 = Map(nonExistCol1 -> "State") val overallResult2 = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compColsMap2, assertion) - assert(overallResult2.isInstanceOf[DataSynchronizationFailed]) - val failedOverallResult2 = overallResult2.asInstanceOf[DataSynchronizationFailed] + assert(overallResult2.isInstanceOf[DatasetMatchFailed]) + val failedOverallResult2 = overallResult2.asInstanceOf[DatasetMatchFailed] assert(failedOverallResult2.errorMessage.contains( s"The following columns were not found in the first dataset: $nonExistCol1")) @@ -791,8 +791,8 @@ class DataSynchronizationTest extends AnyWordSpec with SparkContextSpec { val compColsMap3 = Map("state" -> nonExistCol2) val overallResult3 = DataSynchronization.columnMatch(ds1, ds2, colKeyMap, compColsMap3, assertion) - assert(overallResult3.isInstanceOf[DataSynchronizationFailed]) - val failedOverallResult3 = overallResult3.asInstanceOf[DataSynchronizationFailed] + assert(overallResult3.isInstanceOf[DatasetMatchFailed]) + val failedOverallResult3 = overallResult3.asInstanceOf[DatasetMatchFailed] assert(failedOverallResult3.errorMessage.contains( s"The following columns were not found in the second dataset: $nonExistCol2")) From a8780b79c4dc58c8916d27cf3ac59399541a6291 Mon Sep 17 00:00:00 2001 From: Edward Cho <114528615+eycho-am@users.noreply.github.com> Date: Thu, 15 Feb 2024 15:03:04 -0500 Subject: [PATCH 04/24] Feature: Add Row Level Result Treatment Options for Uniqueness and Completeness (#532) * Modified Completeness analyzer to label filtered rows as null for row-level results * Modified GroupingAnalyzers and Uniqueness analyzer to label filtered rows as null for row-level results * Adjustments for modifying the calculate method to take in a filterCondition * Add RowLevelFilterTreatement trait and object to determine how filtered rows will be labeled (default True) * Modify VerificationRunBuilder to have RowLevelFilterTreatment as variable instead of extending, create RowLevelAnalyzer trait * Do row-level filtering in AnalyzerOptions rather than with RowLevelFilterTreatment trait * Modify computeStateFrom to take in optional filterCondition --- .../amazon/deequ/VerificationRunBuilder.scala | 3 +- .../com/amazon/deequ/analyzers/Analyzer.scala | 31 +++- .../amazon/deequ/analyzers/Completeness.scala | 22 ++- .../amazon/deequ/analyzers/CustomSql.scala | 2 +- .../analyzers/DatasetMatchAnalyzer.scala | 2 +- .../deequ/analyzers/GroupingAnalyzers.scala | 16 +- .../amazon/deequ/analyzers/Histogram.scala | 3 +- .../deequ/analyzers/UniqueValueRatio.scala | 26 +++- .../amazon/deequ/analyzers/Uniqueness.scala | 29 +++- .../scala/com/amazon/deequ/checks/Check.scala | 60 +++++++- .../amazon/deequ/constraints/Constraint.scala | 14 +- .../amazon/deequ/VerificationResultTest.scala | 18 ++- .../amazon/deequ/VerificationSuiteTest.scala | 144 ++++++++++++++++++ .../deequ/analyzers/AnalyzerTests.scala | 4 +- .../deequ/analyzers/CompletenessTest.scala | 33 ++++ .../deequ/analyzers/UniquenessTest.scala | 64 ++++++++ .../runners/AnalysisRunnerTests.scala | 22 ++- .../runners/AnalyzerContextTest.scala | 5 +- .../com/amazon/deequ/checks/CheckTest.scala | 32 +++- .../AnalysisBasedConstraintTest.scala | 6 +- .../repository/AnalysisResultSerdeTest.scala | 4 +- .../deequ/repository/AnalysisResultTest.scala | 5 +- ...sRepositoryMultipleResultsLoaderTest.scala | 5 +- .../ConstraintSuggestionResultTest.scala | 32 ++-- .../amazon/deequ/utils/FixtureSupport.scala | 24 +++ 25 files changed, 528 insertions(+), 78 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala index cd4c89a49..929b2319b 100644 --- a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala +++ b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala @@ -25,7 +25,7 @@ import com.amazon.deequ.repository._ import org.apache.spark.sql.{DataFrame, SparkSession} /** A class to build a VerificationRun using a fluent API */ -class VerificationRunBuilder(val data: DataFrame) { +class VerificationRunBuilder(val data: DataFrame) { protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty @@ -159,7 +159,6 @@ class VerificationRunBuilder(val data: DataFrame) { new VerificationRunBuilderWithSparkSession(this, Option(sparkSession)) } - def run(): VerificationResult = { VerificationSuite().doVerificationRun( data, diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index a80405825..bc241fe72 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -17,6 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers._ +import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.amazon.deequ.analyzers.NullBehavior.NullBehavior import com.amazon.deequ.analyzers.runners._ import com.amazon.deequ.metrics.DoubleMetric @@ -69,7 +70,7 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { * @param data data frame * @return */ - def computeStateFrom(data: DataFrame): Option[S] + def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[S] /** * Compute the metric from the state (sufficient statistics) @@ -97,13 +98,14 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { def calculate( data: DataFrame, aggregateWith: Option[StateLoader] = None, - saveStatesWith: Option[StatePersister] = None) + saveStatesWith: Option[StatePersister] = None, + filterCondition: Option[String] = None) : M = { try { preconditions.foreach { condition => condition(data.schema) } - val state = computeStateFrom(data) + val state = computeStateFrom(data, filterCondition) calculateMetric(state, aggregateWith, saveStatesWith) } catch { @@ -170,7 +172,6 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = { source.load[S](this).foreach { state => target.persist(this, state) } } - } /** An analyzer that runs a set of aggregation functions over the data, @@ -184,7 +185,7 @@ trait ScanShareableAnalyzer[S <: State[_], +M <: Metric[_]] extends Analyzer[S, private[deequ] def fromAggregationResult(result: Row, offset: Int): Option[S] /** Runs aggregation functions directly, without scan sharing */ - override def computeStateFrom(data: DataFrame): Option[S] = { + override def computeStateFrom(data: DataFrame, where: Option[String] = None): Option[S] = { val aggregations = aggregationFunctions() val result = data.agg(aggregations.head, aggregations.tail: _*).collect().head fromAggregationResult(result, 0) @@ -255,12 +256,18 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo } } -case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore) +case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore, + filteredRow: FilteredRow = FilteredRow.TRUE) object NullBehavior extends Enumeration { type NullBehavior = Value val Ignore, EmptyString, Fail = Value } +object FilteredRow extends Enumeration { + type FilteredRow = Value + val NULL, TRUE = Value +} + /** Base class for analyzers that compute ratios of matching predicates */ abstract class PredicateMatchingAnalyzer( name: String, @@ -490,6 +497,18 @@ private[deequ] object Analyzers { conditionalSelectionFromColumns(selection, conditionColumn) } + def conditionalSelectionFilteredFromColumns( + selection: Column, + conditionColumn: Option[Column], + filterTreatment: String) + : Column = { + conditionColumn + .map { condition => { + when(not(condition), expr(filterTreatment)).when(condition, selection) + } } + .getOrElse(selection) + } + private[this] def conditionalSelectionFromColumns( selection: Column, conditionColumn: Option[Column]) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala index 5e80e2f6e..399cbb06a 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala @@ -20,19 +20,21 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested} import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.{IntegerType, StructType} import Analyzers._ +import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.google.common.annotations.VisibleForTesting -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.{Column, Row} /** Completeness is the fraction of non-null values in a column of a DataFrame. */ -case class Completeness(column: String, where: Option[String] = None) extends +case class Completeness(column: String, where: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Completeness", column) with FilterableAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { - ifNoNullsIn(result, offset, howMany = 2) { _ => - NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion)) + NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults)) } } @@ -51,4 +53,16 @@ case class Completeness(column: String, where: Option[String] = None) extends @VisibleForTesting // required by some tests that compare analyzer results to an expected state private[deequ] def criterion: Column = conditionalSelection(column, where).isNotNull + + @VisibleForTesting + private[deequ] def rowLevelResults: Column = { + val whereCondition = where.map { expression => expr(expression)} + conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment.toString) + } + + private def getRowLevelFilterTreatment: FilteredRow = { + analyzerOptions + .map { options => options.filteredRow } + .getOrElse(FilteredRow.TRUE) + } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala index b8dc2692a..e07e2d11f 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -33,7 +33,7 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double * @param data data frame * @return */ - override def computeStateFrom(data: DataFrame): Option[CustomSqlState] = { + override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[CustomSqlState] = { Try { data.sqlContext.sql(expression) diff --git a/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala index cdf0e5061..f2aefb57f 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/DatasetMatchAnalyzer.scala @@ -69,7 +69,7 @@ case class DatasetMatchAnalyzer(dfToCompare: DataFrame, matchColumnMappings: Option[Map[String, String]] = None) extends Analyzer[DatasetMatchState, DoubleMetric] { - override def computeStateFrom(data: DataFrame): Option[DatasetMatchState] = { + override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[DatasetMatchState] = { val result = if (matchColumnMappings.isDefined) { DataSynchronization.columnMatch(data, dfToCompare, columnMappings, matchColumnMappings.get, assertion) diff --git a/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala b/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala index 2090d8231..30bd89621 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.functions.count import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.functions.when /** Base class for all analyzers that operate the frequencies of groups in the data */ abstract class FrequencyBasedAnalyzer(columnsToGroupOn: Seq[String]) @@ -39,8 +40,9 @@ abstract class FrequencyBasedAnalyzer(columnsToGroupOn: Seq[String]) override def groupingColumns(): Seq[String] = { columnsToGroupOn } - override def computeStateFrom(data: DataFrame): Option[FrequenciesAndNumRows] = { - Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns())) + override def computeStateFrom(data: DataFrame, + filterCondition: Option[String] = None): Option[FrequenciesAndNumRows] = { + Some(FrequencyBasedAnalyzer.computeFrequencies(data, groupingColumns(), filterCondition)) } /** We need at least one grouping column, and all specified columns must exist */ @@ -88,7 +90,15 @@ object FrequencyBasedAnalyzer { .count() // Set rows with value count 1 to true, and otherwise false - val fullColumn: Column = count(UNIQUENESS_ID).over(Window.partitionBy(columnsToGroupBy: _*)) + val fullColumn: Column = { + val window = Window.partitionBy(columnsToGroupBy: _*) + where.map { + condition => { + count(when(expr(condition), UNIQUENESS_ID)).over(window) + } + }.getOrElse(count(UNIQUENESS_ID).over(window)) + } + FrequenciesAndNumRows(frequencies, numRows, Option(fullColumn)) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala index 42a7e72e5..742b2ba68 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala @@ -59,7 +59,8 @@ case class Histogram( } } - override def computeStateFrom(data: DataFrame): Option[FrequenciesAndNumRows] = { + override def computeStateFrom(data: DataFrame, + filterCondition: Option[String] = None): Option[FrequenciesAndNumRows] = { // TODO figure out a way to pass this in if its known before hand val totalCount = if (computeFrequenciesAsRatio) { diff --git a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala index d3c8aeb68..c2fce1f14 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala @@ -17,13 +17,17 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL +import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.amazon.deequ.metrics.DoubleMetric +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.not import org.apache.spark.sql.functions.when import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.functions.{col, count, lit, sum} import org.apache.spark.sql.types.DoubleType -case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None) +case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) extends ScanShareableFrequencyBasedAnalyzer("UniqueValueRatio", columns) with FilterableAnalyzer { @@ -34,11 +38,27 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None) override def fromAggregationResult(result: Row, offset: Int, fullColumn: Option[Column] = None): DoubleMetric = { val numUniqueValues = result.getDouble(offset) val numDistinctValues = result.getLong(offset + 1).toDouble - val fullColumnUniqueness = when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false) - toSuccessMetric(numUniqueValues / numDistinctValues, Option(fullColumnUniqueness)) + val conditionColumn = where.map { expression => expr(expression) } + val fullColumnUniqueness = fullColumn.map { + rowLevelColumn => { + conditionColumn.map { + condition => { + when(not(condition), expr(getRowLevelFilterTreatment.toString)) + .when(rowLevelColumn.equalTo(1), true).otherwise(false) + } + }.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false)) + } + } + toSuccessMetric(numUniqueValues / numDistinctValues, fullColumnUniqueness) } override def filterCondition: Option[String] = where + + private def getRowLevelFilterTreatment: FilteredRow = { + analyzerOptions + .map { options => options.filteredRow } + .getOrElse(FilteredRow.TRUE) + } } object UniqueValueRatio { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala index 959f4734c..78ba4c418 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala @@ -17,31 +17,52 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL +import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.amazon.deequ.metrics.DoubleMetric +import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.when import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.not +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.DoubleType /** Uniqueness is the fraction of unique values of a column(s), i.e., * values that occur exactly once. */ -case class Uniqueness(columns: Seq[String], where: Option[String] = None) +case class Uniqueness(columns: Seq[String], where: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) extends ScanShareableFrequencyBasedAnalyzer("Uniqueness", columns) with FilterableAnalyzer { override def aggregationFunctions(numRows: Long): Seq[Column] = { - (sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil + (sum(col(COUNT_COL).equalTo(lit(1)).cast(DoubleType)) / numRows) :: Nil } override def fromAggregationResult(result: Row, offset: Int, fullColumn: Option[Column]): DoubleMetric = { - val fullColumnUniqueness = when((fullColumn.getOrElse(null)).equalTo(1), true).otherwise(false) - super.fromAggregationResult(result, offset, Option(fullColumnUniqueness)) + val conditionColumn = where.map { expression => expr(expression) } + val fullColumnUniqueness = fullColumn.map { + rowLevelColumn => { + conditionColumn.map { + condition => { + when(not(condition), expr(getRowLevelFilterTreatment.toString)) + .when(rowLevelColumn.equalTo(1), true).otherwise(false) + } + }.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false)) + } + } + super.fromAggregationResult(result, offset, fullColumnUniqueness) } override def filterCondition: Option[String] = where + + private def getRowLevelFilterTreatment: FilteredRow = { + analyzerOptions + .map { options => options.filteredRow } + .getOrElse(FilteredRow.TRUE) + } } object Uniqueness { diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index c38ee0e0f..bdae62ab7 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -129,10 +129,12 @@ case class Check( * * @param column Column to run the assertion on * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ - def isComplete(column: String, hint: Option[String] = None): CheckWithLastConstraintFilterable = { - addFilterableConstraint { filter => completenessConstraint(column, Check.IsOne, filter, hint) } + def isComplete(column: String, hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None): CheckWithLastConstraintFilterable = { + addFilterableConstraint { filter => completenessConstraint(column, Check.IsOne, filter, hint, analyzerOptions) } } /** @@ -143,14 +145,16 @@ case class Check( * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasCompleteness( column: String, assertion: Double => Boolean, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { - addFilterableConstraint { filter => completenessConstraint(column, assertion, filter, hint) } + addFilterableConstraint { filter => completenessConstraint(column, assertion, filter, hint, analyzerOptions) } } /** @@ -218,11 +222,13 @@ case class Check( * * @param column Column to run the assertion on * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ - def isUnique(column: String, hint: Option[String] = None): CheckWithLastConstraintFilterable = { + def isUnique(column: String, hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None): CheckWithLastConstraintFilterable = { addFilterableConstraint { filter => - uniquenessConstraint(Seq(column), Check.IsOne, filter, hint) } + uniquenessConstraint(Seq(column), Check.IsOne, filter, hint, analyzerOptions) } } /** @@ -266,6 +272,24 @@ case class Check( addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter) } } + /** + * Creates a constraint that asserts on uniqueness in a single or combined set of key columns. + * + * @param columns Key columns + * @param assertion Function that receives a double input parameter and returns a boolean. + * Refers to the fraction of unique values + * @param hint A hint to provide additional context why a constraint could have failed + * @return + */ + def hasUniqueness( + columns: Seq[String], + assertion: Double => Boolean, + hint: Option[String]) + : CheckWithLastConstraintFilterable = { + + addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter, hint) } + } + /** * Creates a constraint that asserts on uniqueness in a single or combined set of key columns. * @@ -273,15 +297,17 @@ case class Check( * @param assertion Function that receives a double input parameter and returns a boolean. * Refers to the fraction of unique values * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasUniqueness( columns: Seq[String], assertion: Double => Boolean, - hint: Option[String]) + hint: Option[String], + analyzerOptions: Option[AnalyzerOptions]) : CheckWithLastConstraintFilterable = { - addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter, hint) } + addFilterableConstraint { filter => uniquenessConstraint(columns, assertion, filter, hint, analyzerOptions) } } /** @@ -311,6 +337,22 @@ case class Check( hasUniqueness(Seq(column), assertion, hint) } + /** + * Creates a constraint that asserts on the uniqueness of a key column. + * + * @param column Key column + * @param assertion Function that receives a double input parameter and returns a boolean. + * Refers to the fraction of unique values. + * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) + * @return + */ + def hasUniqueness(column: String, assertion: Double => Boolean, hint: Option[String], + analyzerOptions: Option[AnalyzerOptions]) + : CheckWithLastConstraintFilterable = { + hasUniqueness(Seq(column), assertion, hint, analyzerOptions) + } + /** * Creates a constraint on the distinctness in a single or combined set of key columns. * @@ -636,6 +678,7 @@ case class Check( * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasMinLength( @@ -654,6 +697,7 @@ case class Check( * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasMaxLength( diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index c910325c9..b9a15901b 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -205,15 +205,17 @@ object Constraint { * @param assertion Function that receives a double input parameter (since the metric is * double metric) and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def completenessConstraint( column: String, assertion: Double => Boolean, where: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val completeness = Completeness(column, where) + val completeness = Completeness(column, where, analyzerOptions) this.fromAnalyzer(completeness, assertion, hint) } @@ -277,15 +279,17 @@ object Constraint { * (since the metric is double metric) and returns a boolean * @param where Additional filter to apply before the analyzer is run. * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def uniquenessConstraint( columns: Seq[String], assertion: Double => Boolean, where: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val uniqueness = Uniqueness(columns, where) + val uniqueness = Uniqueness(columns, where, analyzerOptions) fromAnalyzer(uniqueness, assertion, hint) } @@ -563,6 +567,7 @@ object Constraint { * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def maxLengthConstraint( column: String, @@ -597,6 +602,7 @@ object Constraint { * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def minLengthConstraint( column: String, diff --git a/src/test/scala/com/amazon/deequ/VerificationResultTest.scala b/src/test/scala/com/amazon/deequ/VerificationResultTest.scala index 93aa73201..0a90c8f77 100644 --- a/src/test/scala/com/amazon/deequ/VerificationResultTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationResultTest.scala @@ -78,6 +78,13 @@ class VerificationResultTest extends WordSpec with Matchers with SparkContextSpe val successMetricsResultsJson = VerificationResult.successMetricsAsJson(results) + val expectedJsonSet = Set("""{"entity":"Column","instance":"item","name":"Distinctness","value":1.0}""", + """{"entity": "Column", "instance":"att2","name":"Completeness","value":1.0}""", + """{"entity":"Column","instance":"att1","name":"Completeness","value":1.0}""", + """{"entity":"Multicolumn","instance":"att1,att2", + "name":"Uniqueness","value":0.25}""", + """{"entity":"Dataset","instance":"*","name":"Size","value":4.0}""") + val expectedJson = """[{"entity":"Column","instance":"item","name":"Distinctness","value":1.0}, |{"entity": "Column", "instance":"att2","name":"Completeness","value":1.0}, @@ -123,11 +130,11 @@ class VerificationResultTest extends WordSpec with Matchers with SparkContextSpe import session.implicits._ val expected = Seq( - ("group-1", "Error", "Success", "CompletenessConstraint(Completeness(att1,None))", + ("group-1", "Error", "Success", "CompletenessConstraint(Completeness(att1,None,None))", "Success", ""), ("group-2-E", "Error", "Error", "SizeConstraint(Size(None))", "Failure", "Value: 4 does not meet the constraint requirement! Should be greater than 5!"), - ("group-2-E", "Error", "Error", "CompletenessConstraint(Completeness(att2,None))", + ("group-2-E", "Error", "Error", "CompletenessConstraint(Completeness(att2,None,None))", "Success", ""), ("group-2-W", "Warning", "Warning", "DistinctnessConstraint(Distinctness(List(item),None))", @@ -150,7 +157,7 @@ class VerificationResultTest extends WordSpec with Matchers with SparkContextSpe val expectedJson = """[{"check":"group-1","check_level":"Error","check_status":"Success", - |"constraint":"CompletenessConstraint(Completeness(att1,None))", + |"constraint":"CompletenessConstraint(Completeness(att1,None,None))", |"constraint_status":"Success","constraint_message":""}, | |{"check":"group-2-E","check_level":"Error","check_status":"Error", @@ -159,7 +166,7 @@ class VerificationResultTest extends WordSpec with Matchers with SparkContextSpe | Should be greater than 5!"}, | |{"check":"group-2-E","check_level":"Error","check_status":"Error", - |"constraint":"CompletenessConstraint(Completeness(att2,None))", + |"constraint":"CompletenessConstraint(Completeness(att2,None,None))", |"constraint_status":"Success","constraint_message":""}, | |{"check":"group-2-W","check_level":"Warning","check_status":"Warning", @@ -214,7 +221,6 @@ class VerificationResultTest extends WordSpec with Matchers with SparkContextSpe } private[this] def assertSameResultsJson(jsonA: String, jsonB: String): Unit = { - assert(SimpleResultSerde.deserialize(jsonA) == - SimpleResultSerde.deserialize(jsonB)) + assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet)) } } diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 1cc09b811..7588ee914 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -304,6 +304,91 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec assert(Seq(true, true, true, false, false, false).sameElements(rowLevel8)) } + "generate a result that contains row-level results with true for filtered rows" in withSparkSession { session => + val data = getDfCompleteAndInCompleteColumns(session) + + val completeness = new Check(CheckLevel.Error, "rule1") + .hasCompleteness("att2", _ > 0.7, None) + .where("att1 = \"a\"") + val uniqueness = new Check(CheckLevel.Error, "rule2") + .hasUniqueness("att1", _ > 0.5, None) + val uniquenessWhere = new Check(CheckLevel.Error, "rule3") + .isUnique("att1") + .where("item < 3") + val expectedColumn1 = completeness.description + val expectedColumn2 = uniqueness.description + val expectedColumn3 = uniquenessWhere.description + + + val suite = new VerificationSuite().onData(data) + .addCheck(completeness) + .addCheck(uniqueness) + .addCheck(uniquenessWhere) + + val result: VerificationResult = suite.run() + + assert(result.status == CheckStatus.Error) + + val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") + resultData.show(false) + val expectedColumns: Set[String] = + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + assert(resultData.columns.toSet == expectedColumns) + + val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, true, true).sameElements(rowLevel1)) + + val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) + assert(Seq(false, false, false, false, false, false).sameElements(rowLevel2)) + + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(rowLevel3)) + + } + + "generate a result that contains row-level results with null for filtered rows" in withSparkSession { session => + val data = getDfCompleteAndInCompleteColumns(session) + + val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRow.NULL)) + + val completeness = new Check(CheckLevel.Error, "rule1") + .hasCompleteness("att2", _ > 0.7, None, analyzerOptions) + .where("att1 = \"a\"") + val uniqueness = new Check(CheckLevel.Error, "rule2") + .hasUniqueness("att1", _ > 0.5, None, analyzerOptions) + val uniquenessWhere = new Check(CheckLevel.Error, "rule3") + .isUnique("att1", None, analyzerOptions) + .where("item < 3") + val expectedColumn1 = completeness.description + val expectedColumn2 = uniqueness.description + val expectedColumn3 = uniquenessWhere.description + + val suite = new VerificationSuite().onData(data) + .addCheck(completeness) + .addCheck(uniqueness) + .addCheck(uniquenessWhere) + + val result: VerificationResult = suite.run() + + assert(result.status == CheckStatus.Error) + + val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") + resultData.show(false) + val expectedColumns: Set[String] = + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + assert(resultData.columns.toSet == expectedColumns) + + val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, null, false, true, null, true).sameElements(rowLevel1)) + + val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) + assert(Seq(false, false, false, false, false, false).sameElements(rowLevel2)) + + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, null, null, null, null).sameElements(rowLevel3)) + + } + "generate a result that contains row-level results for null column values" in withSparkSession { session => val data = getDfCompleteAndInCompleteColumnsAndVarLengthStrings(session) @@ -459,6 +544,38 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } + "accept analysis config for mandatory analysis for checks with filters" in withSparkSession { sparkSession => + + import sparkSession.implicits._ + val df = getDfCompleteAndInCompleteColumns(sparkSession) + + val result = { + val checkToSucceed = Check(CheckLevel.Error, "group-1") + .hasCompleteness("att2", _ > 0.7, null) // 0.75 + .where("att1 = \"a\"") + val uniquenessCheck = Check(CheckLevel.Error, "group-2") + .isUnique("att1") + .where("item < 3") + + + VerificationSuite().onData(df).addCheck(checkToSucceed).addCheck(uniquenessCheck).run() + } + + assert(result.status == CheckStatus.Success) + + val analysisDf = AnalyzerContext.successMetricsAsDataFrame(sparkSession, + AnalyzerContext(result.metrics)) + + val expected = Seq( + ("Column", "att2", "Completeness (where: att1 = \"a\")", 0.75), + ("Column", "att1", "Uniqueness (where: item < 3)", 1.0)) + .toDF("entity", "instance", "name", "value") + + + assertSameRows(analysisDf, expected) + + } + "run the analysis even there are no constraints" in withSparkSession { sparkSession => import sparkSession.implicits._ @@ -918,6 +1035,33 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } } + "A well-defined check should pass even if an ill-defined check is also configured quotes" in withSparkSession { + sparkSession => + val df = getDfWithDistinctValuesQuotes(sparkSession) + + val rangeCheck = Check(CheckLevel.Error, "a") + .isContainedIn("att2", Array("can't", "help", "but", "wouldn't")) + + val reasonCheck = Check(CheckLevel.Error, "a") + .isContainedIn("reason", Array("Already Has ", " Can't Proceed")) + + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(rangeCheck) + .addCheck(reasonCheck) + .run() + + val checkSuccessResult = verificationResult.checkResults(rangeCheck) + checkSuccessResult.constraintResults.map(_.message) shouldBe List(None) + println(checkSuccessResult.constraintResults.map(_.message)) + assert(checkSuccessResult.status == CheckStatus.Success) + + val reasonResult = verificationResult.checkResults(reasonCheck) + checkSuccessResult.constraintResults.map(_.message) shouldBe List(None) + println(checkSuccessResult.constraintResults.map(_.message)) + assert(checkSuccessResult.status == CheckStatus.Success) + } + "A well-defined check should pass even if an ill-defined check is also configured" in withSparkSession { sparkSession => val df = getDfWithNameAndAge(sparkSession) diff --git a/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala index 03787b886..1c0b28d1a 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala @@ -63,7 +63,9 @@ class AnalyzerTests extends AnyWordSpec with Matchers with SparkContextSpec with val result2 = Completeness("att2").calculate(dfMissing) assert(result2 == DoubleMetric(Entity.Column, "Completeness", "att2", Success(0.75), result2.fullColumn)) - + val result3 = Completeness("att2", Option("att1 is NOT NULL")).calculate(dfMissing) + assert(result3 == DoubleMetric(Entity.Column, + "Completeness", "att2", Success(4.0/6.0), result3.fullColumn)) } "fail on wrong column input" in withSparkSession { sparkSession => diff --git a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala index b1cdf3014..54e26f867 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala @@ -23,6 +23,8 @@ import com.amazon.deequ.utils.FixtureSupport import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec +import scala.util.Success + class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { "Completeness" should { @@ -37,5 +39,36 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe Seq(true, true, true, true, false, true, true, false) } + + "return row-level results for columns filtered as null" in withSparkSession { session => + + val data = getDfCompleteAndInCompleteColumns(session) + + // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder + val completenessAtt2 = Completeness("att2", Option("att1 = \"a\""), + Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) + val state = completenessAtt2.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) + + val df = data.withColumn("new", metric.fullColumn.get) + df.show(false) + df.collect().map(_.getAs[Any]("new")).toSeq shouldBe + Seq(true, null, false, true, null, true) + } + + "return row-level results for columns filtered as true" in withSparkSession { session => + + val data = getDfCompleteAndInCompleteColumns(session) + + // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder + val completenessAtt2 = Completeness("att2", Option("att1 = \"a\"")) + val state = completenessAtt2.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) + + val df = data.withColumn("new", metric.fullColumn.get) + df.show(false) + df.collect().map(_.getAs[Any]("new")).toSeq shouldBe + Seq(true, true, false, true, true, true) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala index 5d6d6808f..d50995b55 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala @@ -117,4 +117,68 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit .withColumn("new", metric.fullColumn.get).orderBy("unique") .collect().map(_.getAs[Boolean]("new")) shouldBe Seq(true, true, true, true, true, true) } + + "return filtered row-level results for uniqueness with null" in withSparkSession { session => + + val data = getDfWithUniqueColumns(session) + + val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4"), + Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) + val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4")) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + // Adding column with UNIQUENESS_ID, since it's only added in VerificationResult.getRowLevelResults + val resultDf = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) + .withColumn("new", metric.fullColumn.get).orderBy("unique") + resultDf + .collect().map(_.getAs[Any]("new")) shouldBe Seq(true, true, true, null, null, null) + } + + "return filtered row-level results for uniqueness with null on multiple columns" in withSparkSession { session => + + val data = getDfWithUniqueColumns(session) + + val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2"), + Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) + val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2")) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + // Adding column with UNIQUENESS_ID, since it's only added in VerificationResult.getRowLevelResults + val resultDf = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) + .withColumn("new", metric.fullColumn.get).orderBy("unique") + resultDf + .collect().map(_.getAs[Any]("new")) shouldBe Seq(null, null, true, true, true, true) + } + + "return filtered row-level results for uniqueness true null" in withSparkSession { session => + + val data = getDfWithUniqueColumns(session) + + // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder + val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4")) + val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4")) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + // Adding column with UNIQUENESS_ID, since it's only added in VerificationResult.getRowLevelResults + val resultDf = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) + .withColumn("new", metric.fullColumn.get).orderBy("unique") + resultDf + .collect().map(_.getAs[Any]("new")) shouldBe Seq(true, true, true, true, true, true) + } + + "return filtered row-level results for uniqueness with true on multiple columns" in withSparkSession { session => + + val data = getDfWithUniqueColumns(session) + + // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder + val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2")) + val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2")) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + // Adding column with UNIQUENESS_ID, since it's only added in VerificationResult.getRowLevelResults + val resultDf = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) + .withColumn("new", metric.fullColumn.get).orderBy("unique") + resultDf + .collect().map(_.getAs[Any]("new")) shouldBe Seq(true, true, true, true, true, true) + } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala index 4ffc9eeb9..ce9bda69b 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala @@ -137,7 +137,8 @@ class AnalysisRunnerTests extends AnyWordSpec UniqueValueRatio(Seq("att1"), Some("att3 > 0")) :: Nil val (separateResults, numSeparateJobs) = sparkMonitor.withMonitoringSession { stat => - val results = analyzers.map { _.calculate(df) }.toSet + val results = analyzers.map { analyzer => + analyzer.calculate(df, filterCondition = analyzer.filterCondition) }.toSet (results, stat.jobCount) } @@ -160,7 +161,9 @@ class AnalysisRunnerTests extends AnyWordSpec UniqueValueRatio(Seq("att1", "att2"), Some("att3 > 0")) :: Nil val (separateResults, numSeparateJobs) = sparkMonitor.withMonitoringSession { stat => - val results = analyzers.map { _.calculate(df) }.toSet + val results = analyzers.map { analyzer => + analyzer.calculate(df, filterCondition = analyzer.filterCondition) + }.toSet (results, stat.jobCount) } @@ -184,7 +187,9 @@ class AnalysisRunnerTests extends AnyWordSpec Uniqueness("att1", Some("att3 = 0")) :: Nil val (separateResults, numSeparateJobs) = sparkMonitor.withMonitoringSession { stat => - val results = analyzers.map { _.calculate(df) }.toSet + val results = analyzers.map { analyzer => + analyzer.calculate(df, filterCondition = analyzer.filterCondition) + }.toSet (results, stat.jobCount) } @@ -195,7 +200,14 @@ class AnalysisRunnerTests extends AnyWordSpec assert(numSeparateJobs == analyzers.length * 2) assert(numCombinedJobs == analyzers.length * 2) - assert(separateResults.toString == runnerResults.toString) + // assert(separateResults == runnerResults.toString) + // Used to be tested with the above line, but adding filters changed the order of the results. + assert(separateResults.asInstanceOf[Set[DoubleMetric]].size == + runnerResults.asInstanceOf[Set[DoubleMetric]].size) + separateResults.asInstanceOf[Set[DoubleMetric]].foreach( result => { + assert(runnerResults.toString.contains(result.toString)) + } + ) } "reuse existing results" in @@ -272,7 +284,7 @@ class AnalysisRunnerTests extends AnyWordSpec assert(exception.getMessage == "Could not find all necessary results in the " + "MetricsRepository, the calculation of the metrics for these analyzers " + - "would be needed: Uniqueness(List(item, att2),None), Size(None)") + "would be needed: Uniqueness(List(item, att2),None,None), Size(None)") } "save results if specified" in diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala index 254fac9b4..9133d5ae4 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala @@ -145,7 +145,8 @@ class AnalyzerContextTest extends AnyWordSpec } private[this] def assertSameJson(jsonA: String, jsonB: String): Unit = { - assert(SimpleResultSerde.deserialize(jsonA) == - SimpleResultSerde.deserialize(jsonB)) + assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet)) + // assert(SimpleResultSerde.deserialize(jsonA) == + // SimpleResultSerde.deserialize(jsonB)) } } diff --git a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala index 43657d7ce..096e330b8 100644 --- a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala +++ b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala @@ -67,18 +67,39 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix val check3 = Check(CheckLevel.Warning, "group-2-W") .hasCompleteness("att2", _ > 0.8) // 0.75 + val check4 = Check(CheckLevel.Error, "group-3") + .isComplete("att2", None) // 1.0 with filter + .where("att2 is NOT NULL") + .hasCompleteness("att2", _ == 1.0, None) // 1.0 with filter + .where("att2 is NOT NULL") + val context = runChecks(getDfCompleteAndInCompleteColumns(sparkSession), - check1, check2, check3) + check1, check2, check3, check4) context.metricMap.foreach { println } assertEvaluatesTo(check1, context, CheckStatus.Success) assertEvaluatesTo(check2, context, CheckStatus.Error) assertEvaluatesTo(check3, context, CheckStatus.Warning) + assertEvaluatesTo(check4, context, CheckStatus.Success) assert(check1.getRowLevelConstraintColumnNames() == Seq("Completeness-att1", "Completeness-att1")) assert(check2.getRowLevelConstraintColumnNames() == Seq("Completeness-att2")) assert(check3.getRowLevelConstraintColumnNames() == Seq("Completeness-att2")) + assert(check4.getRowLevelConstraintColumnNames() == Seq("Completeness-att2", "Completeness-att2")) + } + + "return the correct check status for completeness with where filter" in withSparkSession { sparkSession => + + val check = Check(CheckLevel.Error, "group-3") + .hasCompleteness("ZipCode", _ > 0.6, None) // 1.0 with filter + .where("City is NOT NULL") + + val context = runChecks(getDfForWhereClause(sparkSession), check) + + assertEvaluatesTo(check, context, CheckStatus.Success) + + assert(check.getRowLevelConstraintColumnNames() == Seq("Completeness-ZipCode")) } "return the correct check status for combined completeness" in @@ -169,7 +190,6 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix assert(constraintStatuses.head == ConstraintStatus.Success) assert(constraintStatuses(1) == ConstraintStatus.Success) assert(constraintStatuses(2) == ConstraintStatus.Success) - assert(constraintStatuses(3) == ConstraintStatus.Failure) assert(constraintStatuses(4) == ConstraintStatus.Failure) } @@ -520,6 +540,14 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix assertEvaluatesTo(numericRangeCheck9, numericRangeResults, CheckStatus.Success) } + "correctly evaluate range constraints when values have single quote in string" in withSparkSession { sparkSession => + val rangeCheck = Check(CheckLevel.Error, "a") + .isContainedIn("att2", Array("can't", "help", "but", "wouldn't")) + + val rangeResults = runChecks(getDfWithDistinctValuesQuotes(sparkSession), rangeCheck) + assertEvaluatesTo(rangeCheck, rangeResults, CheckStatus.Success) + } + "return the correct check status for histogram constraints" in withSparkSession { sparkSession => diff --git a/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala b/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala index f8188165c..a7efbe180 100644 --- a/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala +++ b/src/test/scala/com/amazon/deequ/constraints/AnalysisBasedConstraintTest.scala @@ -58,7 +58,8 @@ class AnalysisBasedConstraintTest extends WordSpec with Matchers with SparkConte override def calculate( data: DataFrame, stateLoader: Option[StateLoader], - statePersister: Option[StatePersister]) + statePersister: Option[StatePersister], + filterCondition: Option[String]) : DoubleMetric = { val value: Try[Double] = Try { require(data.columns.contains(column), s"Missing column $column") @@ -67,11 +68,10 @@ class AnalysisBasedConstraintTest extends WordSpec with Matchers with SparkConte DoubleMetric(Entity.Column, "sample", column, value) } - override def computeStateFrom(data: DataFrame): Option[NumMatches] = { + override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[NumMatches] = { throw new NotImplementedError() } - override def computeMetricFrom(state: Option[NumMatches]): DoubleMetric = { throw new NotImplementedError() } diff --git a/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala b/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala index 6f1fa1874..05f4d47bd 100644 --- a/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala @@ -363,7 +363,7 @@ class SimpleResultSerdeTest extends WordSpec with Matchers with SparkContextSpec .stripMargin.replaceAll("\n", "") // ordering of map entries is not guaranteed, so comparing strings is not an option - assert(SimpleResultSerde.deserialize(sucessMetricsResultJson) == - SimpleResultSerde.deserialize(expected)) + assert(SimpleResultSerde.deserialize(sucessMetricsResultJson).toSet.sameElements( + SimpleResultSerde.deserialize(expected).toSet)) } } diff --git a/src/test/scala/com/amazon/deequ/repository/AnalysisResultTest.scala b/src/test/scala/com/amazon/deequ/repository/AnalysisResultTest.scala index 97d7a3c49..d4ce97fcb 100644 --- a/src/test/scala/com/amazon/deequ/repository/AnalysisResultTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/AnalysisResultTest.scala @@ -344,7 +344,8 @@ class AnalysisResultTest extends AnyWordSpec } private[this] def assertSameJson(jsonA: String, jsonB: String): Unit = { - assert(SimpleResultSerde.deserialize(jsonA) == - SimpleResultSerde.deserialize(jsonB)) + assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet)) +// assert(SimpleResultSerde.deserialize(jsonA) == +// SimpleResultSerde.deserialize(jsonB)) } } diff --git a/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryMultipleResultsLoaderTest.scala b/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryMultipleResultsLoaderTest.scala index 6e61b9385..592f27b0e 100644 --- a/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryMultipleResultsLoaderTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/MetricsRepositoryMultipleResultsLoaderTest.scala @@ -264,7 +264,8 @@ class MetricsRepositoryMultipleResultsLoaderTest extends AnyWordSpec with Matche } private[this] def assertSameJson(jsonA: String, jsonB: String): Unit = { - assert(SimpleResultSerde.deserialize(jsonA) == - SimpleResultSerde.deserialize(jsonB)) + assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet)) + // assert(SimpleResultSerde.deserialize(jsonA) == + // SimpleResultSerde.deserialize(jsonB)) } } diff --git a/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala b/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala index 6a98bf3c6..9a82903e8 100644 --- a/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala +++ b/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala @@ -212,7 +212,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo """{ | "constraint_suggestions": [ | { - | "constraint_name": "CompletenessConstraint(Completeness(att2,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att2,None,None))", | "column_name": "att2", | "current_value": "Completeness: 1.0", | "description": "'att2' is not null", @@ -222,7 +222,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isComplete(\"att2\")" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(att1,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att1,None,None))", | "column_name": "att1", | "current_value": "Completeness: 1.0", | "description": "'att1' is not null", @@ -232,7 +232,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isComplete(\"att1\")" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(item,None))", + | "constraint_name": "CompletenessConstraint(Completeness(item,None,None))", | "column_name": "item", | "current_value": "Completeness: 1.0", | "description": "'item' is not null", @@ -265,7 +265,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isNonNegative(\"item\")" | }, | { - | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None))", + | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None,None))", | "column_name": "item", | "current_value": "ApproxDistinctness: 1.0", | "description": "'item' is unique", @@ -294,7 +294,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo """{ | "constraint_suggestions": [ | { - | "constraint_name": "CompletenessConstraint(Completeness(att2,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att2,None,None))", | "column_name": "att2", | "current_value": "Completeness: 1.0", | "description": "\u0027att2\u0027 is not null", @@ -305,7 +305,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Failure" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(att1,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att1,None,None))", | "column_name": "att1", | "current_value": "Completeness: 1.0", | "description": "\u0027att1\u0027 is not null", @@ -316,7 +316,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Failure" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(item,None))", + | "constraint_name": "CompletenessConstraint(Completeness(item,None,None))", | "column_name": "item", | "current_value": "Completeness: 1.0", | "description": "\u0027item\u0027 is not null", @@ -352,7 +352,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Failure" | }, | { - | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None))", + | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None,None))", | "column_name": "item", | "current_value": "ApproxDistinctness: 1.0", | "description": "\u0027item\u0027 is unique", @@ -381,7 +381,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo """{ | "constraint_suggestions": [ | { - | "constraint_name": "CompletenessConstraint(Completeness(att2,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att2,None,None))", | "column_name": "att2", | "current_value": "Completeness: 1.0", | "description": "\u0027att2\u0027 is not null", @@ -392,7 +392,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Unknown" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(att1,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att1,None,None))", | "column_name": "att1", | "current_value": "Completeness: 1.0", | "description": "\u0027att1\u0027 is not null", @@ -403,7 +403,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Unknown" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(item,None))", + | "constraint_name": "CompletenessConstraint(Completeness(item,None,None))", | "column_name": "item", | "current_value": "Completeness: 1.0", | "description": "\u0027item\u0027 is not null", @@ -439,7 +439,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "constraint_result_on_test_set": "Unknown" | }, | { - | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None))", + | "constraint_name": "UniquenessConstraint(Uniqueness(List(item),None,None))", | "column_name": "item", | "current_value": "ApproxDistinctness: 1.0", | "description": "\u0027item\u0027 is unique", @@ -471,7 +471,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo """{ | "constraint_suggestions": [ | { - | "constraint_name": "CompletenessConstraint(Completeness(`item.one`,None))", + | "constraint_name": "CompletenessConstraint(Completeness(`item.one`,None,None))", | "column_name": "`item.one`", | "current_value": "Completeness: 1.0", | "description": "'`item.one`' is not null", @@ -504,7 +504,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isNonNegative(\"`item.one`\")" | }, | { - | "constraint_name": "UniquenessConstraint(Uniqueness(List(`item.one`),None))", + | "constraint_name": "UniquenessConstraint(Uniqueness(List(`item.one`),None,None))", | "column_name": "`item.one`", | "current_value": "ApproxDistinctness: 1.0", | "description": "'`item.one`' is unique", @@ -515,7 +515,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isUnique(\"`item.one`\")" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(att2,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att2,None,None))", | "column_name": "att2", | "current_value": "Completeness: 1.0", | "description": "'att2' is not null", @@ -525,7 +525,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | "code_for_constraint": ".isComplete(\"att2\")" | }, | { - | "constraint_name": "CompletenessConstraint(Completeness(att1,None))", + | "constraint_name": "CompletenessConstraint(Completeness(att1,None,None))", | "column_name": "att1", | "current_value": "Completeness: 1.0", | "description": "'att1' is not null", diff --git a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala index 9b6ad9d4e..601134a53 100644 --- a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala +++ b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala @@ -338,6 +338,19 @@ trait FixtureSupport { .toDF("att1", "att2") } + def getDfWithDistinctValuesQuotes(sparkSession: SparkSession): DataFrame = { + import sparkSession.implicits._ + + Seq( + ("a", null, "Already Has "), + ("a", null, " Can't Proceed"), + (null, "can't", "Already Has "), + ("b", "help", " Can't Proceed"), + ("b", "but", "Already Has "), + ("c", "wouldn't", " Can't Proceed")) + .toDF("att1", "att2", "reason") + } + def getDfWithConditionallyUninformativeColumns(sparkSession: SparkSession): DataFrame = { import sparkSession.implicits._ Seq( @@ -409,6 +422,17 @@ trait FixtureSupport { ).toDF("item.one", "att1", "att2") } + def getDfForWhereClause(sparkSession: SparkSession): DataFrame = { + import sparkSession.implicits._ + + Seq( + ("Acme", "90210", "CA", "Los Angeles"), + ("Acme", "90211", "CA", "Los Angeles"), + ("Robocorp", null, "NJ", null), + ("Robocorp", null, "NY", "New York") + ).toDF("Company", "ZipCode", "State", "City") + } + def getDfCompleteAndInCompleteColumnsWithPeriod(sparkSession: SparkSession): DataFrame = { import sparkSession.implicits._ From 185ce014ac0ae11a59bb90cfe1e59a48039f4f67 Mon Sep 17 00:00:00 2001 From: Edward Cho <114528615+eycho-am@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:49:02 -0500 Subject: [PATCH 05/24] Skip SparkTableMetricsRepositoryTest iceberg test when SupportsRowLevelOperations is not available (#536) --- .../SparkTableMetricsRepositoryTest.scala | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala index 667b5b502..8e0d0aac9 100644 --- a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala @@ -101,21 +101,24 @@ class SparkTableMetricsRepositoryTest extends AnyWordSpec } "save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => { - val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value")) - val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) - val context = AnalyzerContext(Map(analyzer -> metric)) - - val repository = new SparkTableMetricsRepository(spark, "local.metrics_table") - // Save the metric - repository.save(resultKey, context) - - // Load the metric - val loadedContext = repository.loadByKey(resultKey) - - assert(loadedContext.isDefined) - assert(loadedContext.get.metric(analyzer).contains(metric)) - } - - } + // The SupportsRowLevelOperations class is available from spark 3.3 + // We should skip this test for lower spark versions + val className = "org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations" + if (Try(Class.forName(className)).isSuccess) { + val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context = AnalyzerContext(Map(analyzer -> metric)) + + val repository = new SparkTableMetricsRepository(spark, "local.metrics_table") + // Save the metric + repository.save(resultKey, context) + + // Load the metric + val loadedContext = repository.loadByKey(resultKey) + + assert(loadedContext.isDefined) + assert(loadedContext.get.metric(analyzer).contains(metric)) + } + } } } } From e48f97aab7cc0db957eb7aeabcafed405aa10624 Mon Sep 17 00:00:00 2001 From: Edward Cho <114528615+eycho-am@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:29:16 -0500 Subject: [PATCH 06/24] Feature: Add Row Level Result Treatment Options for Miminum and Maximum (#535) * Address comments on PR #532 * Add filtered row-level result support for Minimum, Maximum, Compliance, PatternMatch, MinLength, MaxLength analyzers * Refactored criterion for MinLength and MaxLength analyzers to separate rowLevelResults logic --- .../amazon/deequ/VerificationRunBuilder.scala | 2 +- .../com/amazon/deequ/analyzers/Analyzer.scala | 32 ++- .../amazon/deequ/analyzers/Completeness.scala | 11 +- .../amazon/deequ/analyzers/Compliance.scala | 19 +- .../deequ/analyzers/GroupingAnalyzers.scala | 3 +- .../amazon/deequ/analyzers/MaxLength.scala | 33 ++- .../com/amazon/deequ/analyzers/Maximum.scala | 18 +- .../amazon/deequ/analyzers/MinLength.scala | 33 ++- .../com/amazon/deequ/analyzers/Minimum.scala | 23 +- .../amazon/deequ/analyzers/PatternMatch.scala | 26 +- .../deequ/analyzers/UniqueValueRatio.scala | 10 +- .../amazon/deequ/analyzers/Uniqueness.scala | 10 +- .../scala/com/amazon/deequ/checks/Check.scala | 64 +++-- .../amazon/deequ/constraints/Constraint.scala | 24 +- .../amazon/deequ/VerificationSuiteTest.scala | 243 ++++++++++++++++-- .../deequ/analyzers/CompletenessTest.scala | 2 +- .../deequ/analyzers/ComplianceTest.scala | 159 +++++++++++- .../deequ/analyzers/MaxLengthTest.scala | 78 ++++++ .../amazon/deequ/analyzers/MaximumTest.scala | 31 +++ .../deequ/analyzers/MinLengthTest.scala | 87 ++++++- .../amazon/deequ/analyzers/MinimumTest.scala | 33 +++ .../deequ/analyzers/PatternMatchTest.scala | 55 +++- .../deequ/analyzers/UniquenessTest.scala | 4 +- .../runners/AnalysisRunnerTests.scala | 7 +- .../runners/AnalyzerContextTest.scala | 2 - .../ConstraintSuggestionResultTest.scala | 8 +- .../amazon/deequ/utils/FixtureSupport.scala | 45 ++-- 27 files changed, 924 insertions(+), 138 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala index 929b2319b..f34b7f6ee 100644 --- a/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala +++ b/src/main/scala/com/amazon/deequ/VerificationRunBuilder.scala @@ -25,7 +25,7 @@ import com.amazon.deequ.repository._ import org.apache.spark.sql.{DataFrame, SparkSession} /** A class to build a VerificationRun using a fluent API */ -class VerificationRunBuilder(val data: DataFrame) { +class VerificationRunBuilder(val data: DataFrame) { protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index bc241fe72..dd5fb07e9 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -17,7 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers._ -import com.amazon.deequ.analyzers.FilteredRow.FilteredRow +import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome import com.amazon.deequ.analyzers.NullBehavior.NullBehavior import com.amazon.deequ.analyzers.runners._ import com.amazon.deequ.metrics.DoubleMetric @@ -172,6 +172,12 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable { private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = { source.load[S](this).foreach { state => target.persist(this, state) } } + + private[deequ] def getRowLevelFilterTreatment(analyzerOptions: Option[AnalyzerOptions]): FilteredRowOutcome = { + analyzerOptions + .map { options => options.filteredRow } + .getOrElse(FilteredRowOutcome.TRUE) + } } /** An analyzer that runs a set of aggregation functions over the data, @@ -257,15 +263,19 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo } case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore, - filteredRow: FilteredRow = FilteredRow.TRUE) + filteredRow: FilteredRowOutcome = FilteredRowOutcome.TRUE) object NullBehavior extends Enumeration { type NullBehavior = Value val Ignore, EmptyString, Fail = Value } -object FilteredRow extends Enumeration { - type FilteredRow = Value +object FilteredRowOutcome extends Enumeration { + type FilteredRowOutcome = Value val NULL, TRUE = Value + + implicit class FilteredRowOutcomeOps(value: FilteredRowOutcome) { + def getExpression: Column = expr(value.toString) + } } /** Base class for analyzers that compute ratios of matching predicates */ @@ -484,6 +494,12 @@ private[deequ] object Analyzers { .getOrElse(selection) } + def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Boolean): Column = { + where + .map { condition => when(condition, replaceWith).otherwise(selection) } + .getOrElse(selection) + } + def conditionalSelection(selection: Column, where: Option[String], replaceWith: Double): Column = { conditionSelectionGivenColumn(selection, where.map(expr), replaceWith) } @@ -500,12 +516,12 @@ private[deequ] object Analyzers { def conditionalSelectionFilteredFromColumns( selection: Column, conditionColumn: Option[Column], - filterTreatment: String) + filterTreatment: FilteredRowOutcome) : Column = { conditionColumn - .map { condition => { - when(not(condition), expr(filterTreatment)).when(condition, selection) - } } + .map { condition => + when(not(condition), filterTreatment.getExpression).when(condition, selection) + } .getOrElse(selection) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala index 399cbb06a..3a262d7cc 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Completeness.scala @@ -20,7 +20,6 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested} import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.{IntegerType, StructType} import Analyzers._ -import com.amazon.deequ.analyzers.FilteredRow.FilteredRow import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.expr @@ -54,15 +53,9 @@ case class Completeness(column: String, where: Option[String] = None, @VisibleForTesting // required by some tests that compare analyzer results to an expected state private[deequ] def criterion: Column = conditionalSelection(column, where).isNotNull - @VisibleForTesting private[deequ] def rowLevelResults: Column = { val whereCondition = where.map { expression => expr(expression)} - conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment.toString) - } - - private def getRowLevelFilterTreatment: FilteredRow = { - analyzerOptions - .map { options => options.filteredRow } - .getOrElse(FilteredRow.TRUE) + conditionalSelectionFilteredFromColumns( + col(column).isNotNull, whereCondition, getRowLevelFilterTreatment(analyzerOptions)) } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala b/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala index ec242fe6c..0edf01970 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.functions._ import Analyzers._ import com.amazon.deequ.analyzers.Preconditions.hasColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.types.DoubleType /** * Compliance is a measure of the fraction of rows that complies with the given column constraint. @@ -40,14 +41,15 @@ import com.google.common.annotations.VisibleForTesting case class Compliance(instance: String, predicate: String, where: Option[String] = None, - columns: List[String] = List.empty[String]) + columns: List[String] = List.empty[String], + analyzerOptions: Option[AnalyzerOptions] = None) extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance) with FilterableAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { ifNoNullsIn(result, offset, howMany = 2) { _ => - NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion)) + NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults)) } } @@ -65,6 +67,19 @@ case class Compliance(instance: String, conditionalSelection(expr(predicate), where).cast(IntegerType) } + private def rowLevelResults: Column = { + val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) + val whereNotCondition = where.map { expression => not(expr(expression)) } + + filteredRowOutcome match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true).cast(IntegerType) + case _ => + // The default behavior when using filtering for rows is to treat them as nulls. No special treatment needed. + criterion + } + } + override protected def additionalPreconditions(): Seq[StructType => Unit] = columns.map(hasColumn) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala b/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala index 30bd89621..c830d0189 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/GroupingAnalyzers.scala @@ -93,9 +93,8 @@ object FrequencyBasedAnalyzer { val fullColumn: Column = { val window = Window.partitionBy(columnsToGroupBy: _*) where.map { - condition => { + condition => count(when(expr(condition), UNIQUENESS_ID)).over(window) - } }.getOrElse(count(UNIQUENESS_ID).over(window)) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala index 47ed71a69..19c9ca9b7 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala @@ -23,8 +23,10 @@ import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.length import org.apache.spark.sql.functions.max +import org.apache.spark.sql.functions.not import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -33,12 +35,12 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - max(criterion(getNullBehavior)) :: Nil + max(criterion) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = { ifNoNullsIn(result, offset) { _ => - MaxState(result.getDouble(offset), Some(criterion(getNullBehavior))) + MaxState(result.getDouble(offset), Some(rowLevelResults)) } } @@ -48,15 +50,34 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where - private def criterion(nullBehavior: NullBehavior): Column = { + private[deequ] def criterion: Column = { + transformColForNullBehavior + } + + private[deequ] def rowLevelResults: Column = { + transformColForFilteredRow(criterion) + } + + private def transformColForFilteredRow(col: Column): Column = { + val whereNotCondition = where.map { expression => not(expr(expression)) } + getRowLevelFilterTreatment(analyzerOptions) match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MinValue) + case _ => + conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null) + } + } + + private def transformColForNullBehavior: Column = { val isNullCheck = col(column).isNull - nullBehavior match { + val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) + getNullBehavior match { case NullBehavior.Fail => - val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue) case NullBehavior.EmptyString => length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType) - case _ => length(conditionalSelection(column, where)).cast(DoubleType) + case _ => + colLengths } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala index 24a1ae965..c5cc33f94 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.types.{DoubleType, StructType} import Analyzers._ import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.not case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MaxState] with FullColumn { @@ -36,7 +38,7 @@ case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = } } -case class Maximum(column: String, where: Option[String] = None) +case class Maximum(column: String, where: Option[String] = None, analyzerOptions: Option[AnalyzerOptions] = None) extends StandardScanShareableAnalyzer[MaxState]("Maximum", column) with FilterableAnalyzer { @@ -47,7 +49,7 @@ case class Maximum(column: String, where: Option[String] = None) override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = { ifNoNullsIn(result, offset) { _ => - MaxState(result.getDouble(offset), Some(criterion)) + MaxState(result.getDouble(offset), Some(rowLevelResults)) } } @@ -60,5 +62,17 @@ case class Maximum(column: String, where: Option[String] = None) @VisibleForTesting private def criterion: Column = conditionalSelection(column, where).cast(DoubleType) + private[deequ] def rowLevelResults: Column = { + val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) + val whereNotCondition = where.map { expression => not(expr(expression)) } + + filteredRowOutcome match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MinValue).cast(DoubleType) + case _ => + criterion + } + } + } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala index b63c4b4be..c155cca94 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala @@ -23,8 +23,10 @@ import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.length import org.apache.spark.sql.functions.min +import org.apache.spark.sql.functions.not import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -33,12 +35,12 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - min(criterion(getNullBehavior)) :: Nil + min(criterion) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = { ifNoNullsIn(result, offset) { _ => - MinState(result.getDouble(offset), Some(criterion(getNullBehavior))) + MinState(result.getDouble(offset), Some(rowLevelResults)) } } @@ -48,15 +50,34 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where - private[deequ] def criterion(nullBehavior: NullBehavior): Column = { + private[deequ] def criterion: Column = { + transformColForNullBehavior + } + + private[deequ] def rowLevelResults: Column = { + transformColForFilteredRow(criterion) + } + + private def transformColForFilteredRow(col: Column): Column = { + val whereNotCondition = where.map { expression => not(expr(expression)) } + getRowLevelFilterTreatment(analyzerOptions) match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MaxValue) + case _ => + conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null) + } + } + + private def transformColForNullBehavior: Column = { val isNullCheck = col(column).isNull - nullBehavior match { + val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) + getNullBehavior match { case NullBehavior.Fail => - val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue) case NullBehavior.EmptyString => length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType) - case _ => length(conditionalSelection(column, where)).cast(DoubleType) + case _ => + colLengths } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala index feac13f88..18640dc12 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.types.{DoubleType, StructType} import Analyzers._ import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.not case class MinState(minValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MinState] with FullColumn { @@ -36,7 +38,7 @@ case class MinState(minValue: Double, override val fullColumn: Option[Column] = } } -case class Minimum(column: String, where: Option[String] = None) +case class Minimum(column: String, where: Option[String] = None, analyzerOptions: Option[AnalyzerOptions] = None) extends StandardScanShareableAnalyzer[MinState]("Minimum", column) with FilterableAnalyzer { @@ -45,9 +47,8 @@ case class Minimum(column: String, where: Option[String] = None) } override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = { - ifNoNullsIn(result, offset) { _ => - MinState(result.getDouble(offset), Some(criterion)) + MinState(result.getDouble(offset), Some(rowLevelResults)) } } @@ -58,5 +59,19 @@ case class Minimum(column: String, where: Option[String] = None) override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = conditionalSelection(column, where).cast(DoubleType) + private def criterion: Column = { + conditionalSelection(column, where).cast(DoubleType) + } + + private[deequ] def rowLevelResults: Column = { + val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) + val whereNotCondition = where.map { expression => not(expr(expression)) } + + filteredRowOutcome match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MaxValue).cast(DoubleType) + case _ => + criterion + } + } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/PatternMatch.scala b/src/main/scala/com/amazon/deequ/analyzers/PatternMatch.scala index 47fb08737..eb62f9675 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/PatternMatch.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/PatternMatch.scala @@ -19,6 +19,8 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers._ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isString} import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.not import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.functions.{col, lit, regexp_extract, sum, when} import org.apache.spark.sql.types.{BooleanType, IntegerType, StructType} @@ -36,13 +38,14 @@ import scala.util.matching.Regex * @param pattern The regular expression to check for * @param where Additional filter to apply before the analyzer is run. */ -case class PatternMatch(column: String, pattern: Regex, where: Option[String] = None) +case class PatternMatch(column: String, pattern: Regex, where: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) extends StandardScanShareableAnalyzer[NumMatchesAndCount]("PatternMatch", column) with FilterableAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { ifNoNullsIn(result, offset, howMany = 2) { _ => - NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion.cast(BooleanType))) + NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults.cast(BooleanType))) } } @@ -77,12 +80,25 @@ case class PatternMatch(column: String, pattern: Regex, where: Option[String] = @VisibleForTesting // required by some tests that compare analyzer results to an expected state private[deequ] def criterion: Column = { - val expression = when(regexp_extract(col(column), pattern.toString(), 0) =!= lit(""), 1) - .otherwise(0) - conditionalSelection(expression, where).cast(IntegerType) + conditionalSelection(getPatternMatchExpression, where).cast(IntegerType) } + private[deequ] def rowLevelResults: Column = { + val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) + val whereNotCondition = where.map { expression => not(expr(expression)) } + filteredRowOutcome match { + case FilteredRowOutcome.TRUE => + conditionSelectionGivenColumn(getPatternMatchExpression, whereNotCondition, replaceWith = 1).cast(IntegerType) + case _ => + // The default behavior when using filtering for rows is to treat them as nulls. No special treatment needed. + criterion + } + } + + private def getPatternMatchExpression: Column = { + when(regexp_extract(col(column), pattern.toString(), 0) =!= lit(""), 1).otherwise(0) + } } object Patterns { diff --git a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala index c2fce1f14..02b682b9d 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/UniqueValueRatio.scala @@ -17,7 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL -import com.amazon.deequ.analyzers.FilteredRow.FilteredRow +import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome import com.amazon.deequ.metrics.DoubleMetric import org.apache.spark.sql.functions.expr import org.apache.spark.sql.functions.not @@ -43,7 +43,7 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None, rowLevelColumn => { conditionColumn.map { condition => { - when(not(condition), expr(getRowLevelFilterTreatment.toString)) + when(not(condition), getRowLevelFilterTreatment(analyzerOptions).getExpression) .when(rowLevelColumn.equalTo(1), true).otherwise(false) } }.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false)) @@ -53,12 +53,6 @@ case class UniqueValueRatio(columns: Seq[String], where: Option[String] = None, } override def filterCondition: Option[String] = where - - private def getRowLevelFilterTreatment: FilteredRow = { - analyzerOptions - .map { options => options.filteredRow } - .getOrElse(FilteredRow.TRUE) - } } object UniqueValueRatio { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala index 78ba4c418..b46b6d324 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Uniqueness.scala @@ -17,7 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Analyzers.COUNT_COL -import com.amazon.deequ.analyzers.FilteredRow.FilteredRow +import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome import com.amazon.deequ.metrics.DoubleMetric import com.google.common.annotations.VisibleForTesting import org.apache.spark.sql.Column @@ -47,7 +47,7 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None, rowLevelColumn => { conditionColumn.map { condition => { - when(not(condition), expr(getRowLevelFilterTreatment.toString)) + when(not(condition), getRowLevelFilterTreatment(analyzerOptions).getExpression) .when(rowLevelColumn.equalTo(1), true).otherwise(false) } }.getOrElse(when(rowLevelColumn.equalTo(1), true).otherwise(false)) @@ -57,12 +57,6 @@ case class Uniqueness(columns: Seq[String], where: Option[String] = None, } override def filterCondition: Option[String] = where - - private def getRowLevelFilterTreatment: FilteredRow = { - analyzerOptions - .map { options => options.filteredRow } - .getOrElse(FilteredRow.TRUE) - } } object Uniqueness { diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index bdae62ab7..ccfd9badc 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -716,15 +716,17 @@ case class Check( * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasMin( column: String, assertion: Double => Boolean, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { - addFilterableConstraint { filter => minConstraint(column, assertion, filter, hint) } + addFilterableConstraint { filter => minConstraint(column, assertion, filter, hint, analyzerOptions) } } /** @@ -733,15 +735,17 @@ case class Check( * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasMax( column: String, assertion: Double => Boolean, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { - addFilterableConstraint { filter => maxConstraint(column, assertion, filter, hint) } + addFilterableConstraint { filter => maxConstraint(column, assertion, filter, hint, analyzerOptions) } } /** @@ -845,6 +849,7 @@ case class Check( * name the metrics for the analysis being done. * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def satisfies( @@ -852,11 +857,12 @@ case class Check( constraintName: String, assertion: Double => Boolean = Check.IsOne, hint: Option[String] = None, - columns: List[String] = List.empty[String]) + columns: List[String] = List.empty[String], + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { addFilterableConstraint { filter => - complianceConstraint(constraintName, columnCondition, assertion, filter, hint, columns) + complianceConstraint(constraintName, columnCondition, assertion, filter, hint, columns, analyzerOptions) } } @@ -868,6 +874,7 @@ case class Check( * @param pattern The columns values will be checked for a match against this pattern. * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasPattern( @@ -875,11 +882,12 @@ case class Check( pattern: Regex, assertion: Double => Boolean = Check.IsOne, name: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { addFilterableConstraint { filter => - Constraint.patternMatchConstraint(column, pattern, assertion, filter, name, hint) + Constraint.patternMatchConstraint(column, pattern, assertion, filter, name, hint, analyzerOptions) } } @@ -1118,8 +1126,7 @@ case class Check( allowedValues: Array[String]) : CheckWithLastConstraintFilterable = { - - isContainedIn(column, allowedValues, Check.IsOne, None) + isContainedIn(column, allowedValues, Check.IsOne, None, None) } // We can't use default values here as you can't combine default values and overloading in Scala @@ -1137,7 +1144,7 @@ case class Check( hint: Option[String]) : CheckWithLastConstraintFilterable = { - isContainedIn(column, allowedValues, Check.IsOne, hint) + isContainedIn(column, allowedValues, Check.IsOne, hint, None) } // We can't use default values here as you can't combine default values and overloading in Scala @@ -1155,8 +1162,27 @@ case class Check( assertion: Double => Boolean) : CheckWithLastConstraintFilterable = { + isContainedIn(column, allowedValues, assertion, None, None) + } + + // We can't use default values here as you can't combine default values and overloading in Scala + /** + * Asserts that every non-null value in a column is contained in a set of predefined values + * + * @param column Column to run the assertion on + * @param allowedValues Allowed values for the column + * @param assertion Function that receives a double input parameter and returns a boolean + * @param hint A hint to provide additional context why a constraint could have failed + * @return + */ + def isContainedIn( + column: String, + allowedValues: Array[String], + assertion: Double => Boolean, + hint: Option[String]) + : CheckWithLastConstraintFilterable = { - isContainedIn(column, allowedValues, assertion, None) + isContainedIn(column, allowedValues, assertion, hint, None) } // We can't use default values here as you can't combine default values and overloading in Scala @@ -1167,23 +1193,24 @@ case class Check( * @param allowedValues Allowed values for the column * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def isContainedIn( column: String, allowedValues: Array[String], assertion: Double => Boolean, - hint: Option[String]) + hint: Option[String], + analyzerOptions: Option[AnalyzerOptions]) : CheckWithLastConstraintFilterable = { - val valueList = allowedValues .map { _.replaceAll("'", "\\\\\'") } .mkString("'", "','", "'") val predicate = s"`$column` IS NULL OR `$column` IN ($valueList)" satisfies(predicate, s"$column contained in ${allowedValues.mkString(",")}", - assertion, hint, List(column)) + assertion, hint, List(column), analyzerOptions) } /** @@ -1195,6 +1222,7 @@ case class Check( * @param includeLowerBound is a value equal to the lower bound allows? * @param includeUpperBound is a value equal to the upper bound allowed? * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def isContainedIn( @@ -1203,7 +1231,8 @@ case class Check( upperBound: Double, includeLowerBound: Boolean = true, includeUpperBound: Boolean = true, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { val leftOperand = if (includeLowerBound) ">=" else ">" @@ -1212,7 +1241,8 @@ case class Check( val predicate = s"`$column` IS NULL OR " + s"(`$column` $leftOperand $lowerBound AND `$column` $rightOperand $upperBound)" - satisfies(predicate, s"$column between $lowerBound and $upperBound", hint = hint, columns = List(column)) + satisfies(predicate, s"$column between $lowerBound and $upperBound", hint = hint, + columns = List(column), analyzerOptions = analyzerOptions) } /** diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index b9a15901b..fec0842f7 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -371,6 +371,7 @@ object Constraint { * metrics for the analysis being done. * @param column Data frame column which is a combination of expression and the column name * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def complianceConstraint( name: String, @@ -378,10 +379,11 @@ object Constraint { assertion: Double => Boolean, where: Option[String] = None, hint: Option[String] = None, - columns: List[String] = List.empty[String]) + columns: List[String] = List.empty[String], + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val compliance = Compliance(name, column, where, columns) + val compliance = Compliance(name, column, where, columns, analyzerOptions) fromAnalyzer(compliance, assertion, hint) } @@ -406,6 +408,7 @@ object Constraint { * @param pattern The regex pattern to check compliance for * @param column Data frame column which is a combination of expression and the column name * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def patternMatchConstraint( column: String, @@ -413,10 +416,11 @@ object Constraint { assertion: Double => Boolean, where: Option[String] = None, name: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val patternMatch = PatternMatch(column, pattern, where) + val patternMatch = PatternMatch(column, pattern, where, analyzerOptions) fromAnalyzer(patternMatch, pattern, assertion, name, hint) } @@ -637,16 +641,18 @@ object Constraint { * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * */ def minConstraint( column: String, assertion: Double => Boolean, where: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val minimum = Minimum(column, where) + val minimum = Minimum(column, where, analyzerOptions) fromAnalyzer(minimum, assertion, hint) } @@ -670,15 +676,17 @@ object Constraint { * @param column Column to run the assertion on * @param assertion Function that receives a double input parameter and returns a boolean * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def maxConstraint( column: String, assertion: Double => Boolean, where: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val maximum = Maximum(column, where) + val maximum = Maximum(column, where, analyzerOptions) fromAnalyzer(maximum, assertion, hint) } diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 7588ee914..1fb8ab74d 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -305,7 +305,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } "generate a result that contains row-level results with true for filtered rows" in withSparkSession { session => - val data = getDfCompleteAndInCompleteColumns(session) + val data = getDfCompleteAndInCompleteColumnsWithIntId(session) val completeness = new Check(CheckLevel.Error, "rule1") .hasCompleteness("att2", _ > 0.7, None) @@ -315,15 +315,31 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val uniquenessWhere = new Check(CheckLevel.Error, "rule3") .isUnique("att1") .where("item < 3") + val min = new Check(CheckLevel.Error, "rule4") + .hasMin("item", _ > 3, None) + .where("item > 3") + val max = new Check(CheckLevel.Error, "rule5") + .hasMax("item", _ < 4, None) + .where("item < 4") + val patternMatch = new Check(CheckLevel.Error, "rule6") + .hasPattern("att2", """(^f)""".r) + .where("item < 4") + val expectedColumn1 = completeness.description val expectedColumn2 = uniqueness.description val expectedColumn3 = uniquenessWhere.description + val expectedColumn4 = min.description + val expectedColumn5 = max.description + val expectedColumn6 = patternMatch.description val suite = new VerificationSuite().onData(data) .addCheck(completeness) .addCheck(uniqueness) .addCheck(uniquenessWhere) + .addCheck(min) + .addCheck(max) + .addCheck(patternMatch) val result: VerificationResult = suite.run() @@ -332,24 +348,38 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") resultData.show(false) val expectedColumns: Set[String] = - data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + + expectedColumn4 + expectedColumn5 + expectedColumn6 assert(resultData.columns.toSet == expectedColumns) + // filtered rows 2,5 (where att1 = "a") val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) assert(Seq(true, true, false, true, true, true).sameElements(rowLevel1)) val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) assert(Seq(false, false, false, false, false, false).sameElements(rowLevel2)) + // filtered rows 3,4,5,6 (where item < 3) val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) assert(Seq(true, true, true, true, true, true).sameElements(rowLevel3)) + // filtered rows 1, 2, 3 (where item > 3) + val minRowLevel = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(minRowLevel)) + + // filtered rows 4, 5, 6 (where item < 4) + val maxRowLevel = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(maxRowLevel)) + + // filtered rows 4, 5, 6 (where item < 4) + val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, false, false, true, true, true).sameElements(rowLevel6)) } "generate a result that contains row-level results with null for filtered rows" in withSparkSession { session => - val data = getDfCompleteAndInCompleteColumns(session) + val data = getDfCompleteAndInCompleteColumnsWithIntId(session) - val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRow.NULL)) + val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)) val completeness = new Check(CheckLevel.Error, "rule1") .hasCompleteness("att2", _ > 0.7, None, analyzerOptions) @@ -359,14 +389,30 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val uniquenessWhere = new Check(CheckLevel.Error, "rule3") .isUnique("att1", None, analyzerOptions) .where("item < 3") + val min = new Check(CheckLevel.Error, "rule4") + .hasMin("item", _ > 3, None, analyzerOptions) + .where("item > 3") + val max = new Check(CheckLevel.Error, "rule5") + .hasMax("item", _ < 4, None, analyzerOptions) + .where("item < 4") + val patternMatch = new Check(CheckLevel.Error, "rule6") + .hasPattern("att2", """(^f)""".r, analyzerOptions = analyzerOptions) + .where("item < 4") + val expectedColumn1 = completeness.description val expectedColumn2 = uniqueness.description val expectedColumn3 = uniquenessWhere.description + val expectedColumn4 = min.description + val expectedColumn5 = max.description + val expectedColumn6 = patternMatch.description val suite = new VerificationSuite().onData(data) .addCheck(completeness) .addCheck(uniqueness) .addCheck(uniquenessWhere) + .addCheck(min) + .addCheck(max) + .addCheck(patternMatch) val result: VerificationResult = suite.run() @@ -375,7 +421,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") resultData.show(false) val expectedColumns: Set[String] = - data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + + expectedColumn4 + expectedColumn5 + expectedColumn6 assert(resultData.columns.toSet == expectedColumns) val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) @@ -384,9 +431,92 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) assert(Seq(false, false, false, false, false, false).sameElements(rowLevel2)) + // filtered rows 3,4,5,6 (where item < 3) val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) assert(Seq(true, true, null, null, null, null).sameElements(rowLevel3)) + // filtered rows 1, 2, 3 (where item > 3) + val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) + assert(Seq(null, null, null, true, true, true).sameElements(rowLevel4)) + + // filtered rows 4, 5, 6 (where item < 4) + val rowLevel5 = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, null, null, null).sameElements(rowLevel5)) + + // filtered rows 4, 5, 6 (where item < 4) + val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, false, false, null, null, null).sameElements(rowLevel6)) + } + + "generate a result that contains compliance row-level results " in withSparkSession { session => + val data = getDfWithNumericValues(session) + val analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)) + + val complianceRange = new Check(CheckLevel.Error, "rule1") + .isContainedIn("attNull", 0, 6, false, false) + val complianceFilteredRange = new Check(CheckLevel.Error, "rule2") + .isContainedIn("attNull", 0, 6, false, false) + .where("att1 < 4") + val complianceFilteredRangeNull = new Check(CheckLevel.Error, "rule3") + .isContainedIn("attNull", 0, 6, false, false, + analyzerOptions = analyzerOptions) + .where("att1 < 4") + val complianceInArray = new Check(CheckLevel.Error, "rule4") + .isContainedIn("att2", Array("5", "6", "7")) + val complianceInArrayFiltered = new Check(CheckLevel.Error, "rule5") + .isContainedIn("att2", Array("5", "6", "7")) + .where("att1 > 3") + val complianceInArrayFilteredNull = new Check(CheckLevel.Error, "rule6") + .isContainedIn("att2", Array("5", "6", "7"), Check.IsOne, None, analyzerOptions) + .where("att1 > 3") + + val expectedColumn1 = complianceRange.description + val expectedColumn2 = complianceFilteredRange.description + val expectedColumn3 = complianceFilteredRangeNull.description + val expectedColumn4 = complianceInArray.description + val expectedColumn5 = complianceInArrayFiltered.description + val expectedColumn6 = complianceInArrayFilteredNull.description + + val suite = new VerificationSuite().onData(data) + .addCheck(complianceRange) + .addCheck(complianceFilteredRange) + .addCheck(complianceFilteredRangeNull) + .addCheck(complianceInArray) + .addCheck(complianceInArrayFiltered) + .addCheck(complianceInArrayFilteredNull) + + val result: VerificationResult = suite.run() + + assert(result.status == CheckStatus.Error) + + val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item") + resultData.show(false) + val expectedColumns: Set[String] = + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + + expectedColumn4 + expectedColumn5 + expectedColumn6 + assert(resultData.columns.toSet == expectedColumns) + + val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, false, false).sameElements(rowLevel1)) + + // filtered rows 4, 5, 6 (where att1 < 4) as true + val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(rowLevel2)) + + // filtered rows 4, 5, 6 (where att1 < 4) as null + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, null, null, null).sameElements(rowLevel3)) + + val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) + assert(Seq(false, false, false, true, true, true).sameElements(rowLevel4)) + + // filtered rows 1,2,3 (where att1 > 3) as true + val rowLevel5 = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(rowLevel5)) + + // filtered rows 1,2,3 (where att1 > 3) as null + val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) + assert(Seq(null, null, null, true, true, true).sameElements(rowLevel6)) } "generate a result that contains row-level results for null column values" in withSparkSession { session => @@ -422,20 +552,16 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + expectedColumn4 assert(resultData.columns.toSet == expectedColumns) - val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => - if (r == null) null else r.getAs[Boolean](0)) + val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) assert(Seq(false, null, true, true, null, true).sameElements(rowLevel1)) - val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => - if (r == null) null else r.getAs[Boolean](0)) + val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) assert(Seq(true, null, true, false, null, false).sameElements(rowLevel2)) - val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => - if (r == null) null else r.getAs[Boolean](0)) + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) assert(Seq(true, true, false, true, false, true).sameElements(rowLevel3)) - val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => - if (r == null) null else r.getAs[Boolean](0)) + val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) assert(Seq(false, null, false, true, null, true).sameElements(rowLevel4)) } @@ -446,12 +572,37 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .hasMinLength("att2", _ >= 1, analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail))) val maxLength = new Check(CheckLevel.Error, "rule2") .hasMaxLength("att2", _ <= 1, analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail))) + // filtered rows as null + val minLengthFilterNull = new Check(CheckLevel.Error, "rule3") + .hasMinLength("att2", _ >= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) + .where("val1 < 5") + val maxLengthFilterNull = new Check(CheckLevel.Error, "rule4") + .hasMaxLength("att2", _ <= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) + .where("val1 < 5") + val minLengthFilterTrue = new Check(CheckLevel.Error, "rule5") + .hasMinLength("att2", _ >= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.TRUE))) + .where("val1 < 5") + val maxLengthFilterTrue = new Check(CheckLevel.Error, "rule6") + .hasMaxLength("att2", _ <= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.TRUE))) + .where("val1 < 5") val expectedColumn1 = minLength.description val expectedColumn2 = maxLength.description + val expectedColumn3 = minLengthFilterNull.description + val expectedColumn4 = maxLengthFilterNull.description + val expectedColumn5 = minLengthFilterTrue.description + val expectedColumn6 = maxLengthFilterTrue.description val suite = new VerificationSuite().onData(data) .addCheck(minLength) .addCheck(maxLength) + .addCheck(minLengthFilterNull) + .addCheck(maxLengthFilterNull) + .addCheck(minLengthFilterTrue) + .addCheck(maxLengthFilterTrue) val result: VerificationResult = suite.run() @@ -461,7 +612,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec resultData.show() val expectedColumns: Set[String] = - data.columns.toSet + expectedColumn1 + expectedColumn2 + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + + expectedColumn4 + expectedColumn5 + expectedColumn6 assert(resultData.columns.toSet == expectedColumns) val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getBoolean(0)) @@ -469,6 +621,22 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getBoolean(0)) assert(Seq(true, true, false, true, false, true).sameElements(rowLevel2)) + + // filtered last two rows where(val1 < 5) + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, null, null).sameElements(rowLevel3)) + + // filtered last two rows where(val1 < 5) + val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, null, null).sameElements(rowLevel4)) + + // filtered last two rows where(val1 < 5) + val rowLevel5 = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, true, true).sameElements(rowLevel5)) + + // filtered last two rows where(val1 < 5) + val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, true, true).sameElements(rowLevel6)) } "generate a result that contains length row-level results with nullBehavior empty" in withSparkSession { session => @@ -480,12 +648,38 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec // nulls should succeed since length 0 is < 2 val maxLength = new Check(CheckLevel.Error, "rule2") .hasMaxLength("att2", _ < 2, analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString))) + // filtered rows as null + val minLengthFilterNull = new Check(CheckLevel.Error, "rule3") + .hasMinLength("att2", _ >= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) + .where("val1 < 5") + val maxLengthFilterNull = new Check(CheckLevel.Error, "rule4") + .hasMaxLength("att2", _ < 2, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) + .where("val1 < 5") + val minLengthFilterTrue = new Check(CheckLevel.Error, "rule5") + .hasMinLength("att2", _ >= 1, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.TRUE))) + .where("val1 < 5") + val maxLengthFilterTrue = new Check(CheckLevel.Error, "rule6") + .hasMaxLength("att2", _ < 2, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.TRUE))) + .where("val1 < 5") + val expectedColumn1 = minLength.description val expectedColumn2 = maxLength.description + val expectedColumn3 = minLengthFilterNull.description + val expectedColumn4 = maxLengthFilterNull.description + val expectedColumn5 = minLengthFilterTrue.description + val expectedColumn6 = maxLengthFilterTrue.description val suite = new VerificationSuite().onData(data) .addCheck(minLength) .addCheck(maxLength) + .addCheck(minLengthFilterNull) + .addCheck(maxLengthFilterNull) + .addCheck(minLengthFilterTrue) + .addCheck(maxLengthFilterTrue) val result: VerificationResult = suite.run() @@ -495,7 +689,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec resultData.show() val expectedColumns: Set[String] = - data.columns.toSet + expectedColumn1 + expectedColumn2 + data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + + expectedColumn4 + expectedColumn5 + expectedColumn6 assert(resultData.columns.toSet == expectedColumns) @@ -504,6 +699,22 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getBoolean(0)) assert(Seq(true, true, true, true, true, true).sameElements(rowLevel2)) + + // filtered last two rows where(val1 < 5) + val rowLevel3 = resultData.select(expectedColumn3).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, null, null).sameElements(rowLevel3)) + + // filtered last two rows where(val1 < 5) + val rowLevel4 = resultData.select(expectedColumn4).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, null, null).sameElements(rowLevel4)) + + // filtered last two rows where(val1 < 5) + val rowLevel5 = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, false, true, true, true).sameElements(rowLevel5)) + + // filtered last two rows where(val1 < 5) + val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(rowLevel6)) } "accept analysis config for mandatory analysis" in withSparkSession { sparkSession => @@ -1124,7 +1335,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec checkFailedResultStringType.constraintResults.map(_.message) shouldBe List(Some("Empty state for analyzer Compliance(name between 1.0 and 3.0,`name`" + " IS NULL OR (`name` >= 1.0 AND `name` <= 3.0)," + - "None,List(name)), all input values were NULL.")) + "None,List(name),None), all input values were NULL.")) assert(checkFailedResultStringType.status == CheckStatus.Error) val checkFailedCompletenessResult = verificationResult.checkResults(complianceCheckThatShouldFailCompleteness) diff --git a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala index 54e26f867..b5b0d5094 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CompletenessTest.scala @@ -46,7 +46,7 @@ class CompletenessTest extends AnyWordSpec with Matchers with SparkContextSpec w // Explicitly setting RowLevelFilterTreatment for test purposes, this should be set at the VerificationRunBuilder val completenessAtt2 = Completeness("att2", Option("att1 = \"a\""), - Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) + Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) val state = completenessAtt2.computeStateFrom(data) val metric: DoubleMetric with FullColumn = completenessAtt2.computeMetricFrom(state) diff --git a/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala b/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala index c572a4bd8..5aa4033ba 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala @@ -35,7 +35,8 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val state = att1Compliance.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Int]("new")) shouldBe Seq(0, 0, 0, 1, 1, 1) + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Int]("new") + ) shouldBe Seq(0, 0, 0, 1, 1, 1) } "return row-level results for null columns" in withSparkSession { session => @@ -49,6 +50,162 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit data.withColumn("new", metric.fullColumn.get).collect().map(r => if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 1, 1, 1) } + + "return row-level results filtered with null" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 0, 1, 1) + } + + "return row-level results filtered with true" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(1, 1, 1, 0, 1, 1) + } + + "return row-level results for compliance in bounds" in withSparkSession { session => + val column = "att1" + val leftOperand = ">=" + val rightOperand = "<=" + val lowerBound = 2 + val upperBound = 5 + val predicate = s"`$column` IS NULL OR " + + s"(`$column` $leftOperand $lowerBound AND `$column` $rightOperand $upperBound)" + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance(predicate, s"$column between $lowerBound and $upperBound", columns = List("att3")) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 0) + } + + "return row-level results for compliance in bounds filtered as null" in withSparkSession { session => + val column = "att1" + val leftOperand = ">=" + val rightOperand = "<=" + val lowerBound = 2 + val upperBound = 5 + val predicate = s"`$column` IS NULL OR " + + s"(`$column` $leftOperand $lowerBound AND `$column` $rightOperand $upperBound)" + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance(predicate, s"$column between $lowerBound and $upperBound", + where = Option("att1 < 4"), columns = List("att3"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, null, null, null) + } + + "return row-level results for compliance in bounds filtered as true" in withSparkSession { session => + val column = "att1" + val leftOperand = ">=" + val rightOperand = "<=" + val lowerBound = 2 + val upperBound = 5 + val predicate = s"`$column` IS NULL OR " + + s"(`$column` $leftOperand $lowerBound AND `$column` $rightOperand $upperBound)" + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance(s"$column between $lowerBound and $upperBound", predicate, + where = Option("att1 < 4"), columns = List("att3"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 1) + } + + "return row-level results for compliance in array" in withSparkSession { session => + val column = "att1" + val allowedValues = Array("3", "4", "5") + val valueList = allowedValues + .map { + _.replaceAll("'", "\\\\\'") + } + .mkString("'", "','", "'") + + val predicate = s"`$column` IS NULL OR `$column` IN ($valueList)" + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance(s"$column contained in ${allowedValues.mkString(",")}", predicate, + columns = List("att3")) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 0) + } + + "return row-level results for compliance in array filtered as null" in withSparkSession { session => + val column = "att1" + val allowedValues = Array("3", "4", "5") + val valueList = allowedValues + .map { + _.replaceAll("'", "\\\\\'") + } + .mkString("'", "','", "'") + + val predicate = s"`$column` IS NULL OR `$column` IN ($valueList)" + + val data = getDfWithNumericValues(session) + + val att1Compliance = Compliance(s"$column contained in ${allowedValues.mkString(",")}", predicate, + where = Option("att1 < 5"), columns = List("att3"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, null, null) + } + + "return row-level results for compliance in array filtered as true" in withSparkSession { session => + val column = "att1" + val allowedValues = Array("3", "4", "5") + val valueList = allowedValues + .map { + _.replaceAll("'", "\\\\\'") + } + .mkString("'", "','", "'") + + val predicate = s"`$column` IS NULL OR `$column` IN ($valueList)" + + val data = getDfWithNumericValues(session) + + + val att1Compliance = Compliance(s"$column contained in ${allowedValues.mkString(",")}", predicate, + where = Option("att1 < 5"), columns = List("att3"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) + val state = att1Compliance.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(r => + if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 1) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala index f1fec85a4..456bd4c67 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala @@ -74,6 +74,84 @@ class MaxLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with data.withColumn("new", metric.fullColumn.get) .collect().map(_.getAs[Double]("new")) shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) } + + "return row-level results with NullBehavior fail and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Fail))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, Double.MaxValue, 1.0, Double.MinValue, Double.MinValue) + } + + "return row-level results with NullBehavior fail and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, Double.MaxValue, 1.0, null, null) + } + + "return row-level results with NullBehavior empty and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.EmptyString))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, 0.0, 1.0, Double.MinValue, Double.MinValue) + } + + "return row-level results with NullBehavior empty and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, null, null) + } + + "return row-level results with NullBehavior ignore and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Ignore))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, null, 1.0, Double.MinValue, Double.MinValue) + } + + "return row-level results with NullBehavior ignore and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MaxLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.Ignore, FilteredRowOutcome.NULL))) + val state: Option[MaxState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, null, 1.0, null, null) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala index d88b4b532..6ac90f735 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala @@ -51,5 +51,36 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F if (r == null) null else r.getAs[Double]("new")) shouldBe Seq(null, null, null, 5.0, 6.0, 7.0) } + + "return row-level results for columns with where clause filtered as true" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Maximum = Maximum("att1", Option("item < 4")) + val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4")) + val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) + + val result = data.withColumn("new", metric.fullColumn.get) + result.show(false) + result.collect().map(r => + if (r == null) null else r.getAs[Double]("new")) shouldBe + Seq(1.0, 2.0, 3.0, Double.MinValue, Double.MinValue, Double.MinValue) + } + + "return row-level results for columns with where clause filtered as null" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Maximum = Maximum("att1", Option("item < 4"), + Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4")) + val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) + + val result = data.withColumn("new", metric.fullColumn.get) + result.show(false) + result.collect().map(r => + if (r == null) null else r.getAs[Double]("new")) shouldBe + Seq(1.0, 2.0, 3.0, null, null, null) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala index 84228e7e7..0f88e377f 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala @@ -56,13 +56,13 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with val data = getEmptyColumnDataDf(session) // It's null in two rows - val addressLength = MinLength("att3") + val addressLength = MinLength("att3", None, Option(AnalyzerOptions(NullBehavior.Fail))) val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get) - .collect().map( r => if (r == null) null else r.getAs[Double]("new") - ) shouldBe Seq(1.0, 1.0, null, 1.0, null, 1.0) + .collect().map(_.getAs[Double]("new") + ) shouldBe Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MinValue, 1.0) } "return row-level results for null columns with NullBehavior empty option" in withSparkSession { session => @@ -70,7 +70,7 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with val data = getEmptyColumnDataDf(session) // It's null in two rows - val addressLength = MinLength("att3") + val addressLength = MinLength("att3", None, Option(AnalyzerOptions(NullBehavior.EmptyString))) val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) @@ -89,5 +89,84 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with data.withColumn("new", metric.fullColumn.get) .collect().map(_.getAs[Double]("new")) shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) } + + "return row-level results with NullBehavior fail and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Fail))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MaxValue, Double.MaxValue) + } + + "return row-level results with NullBehavior fail and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, Double.MinValue, 1.0, null, null) + } + + "return row-level results with NullBehavior empty and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.EmptyString))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, 0.0, 1.0, Double.MaxValue, Double.MaxValue) + } + + "return row-level results with NullBehavior empty and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, null, null) + } + + "return row-level results NullBehavior ignore and filtered as true" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Ignore))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe + Seq(1.0, 1.0, null, 1.0, Double.MaxValue, Double.MaxValue) + } + + "return row-level results with NullBehavior ignore and filtered as null" in withSparkSession { session => + + val data = getEmptyColumnDataDf(session) + + val addressLength = MinLength("att3", Option("id < 4"), + Option(AnalyzerOptions(NullBehavior.Ignore, FilteredRowOutcome.NULL))) + val state: Option[MinState] = addressLength.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get) + .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, null, 1.0, null, null) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala index 6d495aa0f..435542e8c 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala @@ -52,6 +52,39 @@ class MinimumTest extends AnyWordSpec with Matchers with SparkContextSpec with F if (r == null) null else r.getAs[Double]("new")) shouldBe Seq(null, null, null, 5.0, 6.0, 7.0) } + + "return row-level results for columns with where clause filtered as true" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Minimum = Minimum("att1", Option("item < 4")) + val state: Option[MinState] = att1Minimum.computeStateFrom(data, Option("item < 4")) + print(state) + val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) + + val result = data.withColumn("new", metric.fullColumn.get) + result.show(false) + result.collect().map(r => + if (r == null) null else r.getAs[Double]("new")) shouldBe + Seq(1.0, 2.0, 3.0, Double.MaxValue, Double.MaxValue, Double.MaxValue) + } + + "return row-level results for columns with where clause filtered as null" in withSparkSession { session => + + val data = getDfWithNumericValues(session) + + val att1Minimum = Minimum("att1", Option("item < 4"), + Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state: Option[MinState] = att1Minimum.computeStateFrom(data, Option("item < 4")) + print(state) + val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) + + val result = data.withColumn("new", metric.fullColumn.get) + result.show(false) + result.collect().map(r => + if (r == null) null else r.getAs[Double]("new")) shouldBe + Seq(1.0, 2.0, 3.0, null, null, null) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/PatternMatchTest.scala b/src/test/scala/com/amazon/deequ/analyzers/PatternMatchTest.scala index e01235597..94d439674 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/PatternMatchTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/PatternMatchTest.scala @@ -33,10 +33,35 @@ class PatternMatchTest extends AnyWordSpec with Matchers with SparkContextSpec w val state = patternMatchCountry.computeStateFrom(data) val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe Seq(true, true, true, true, true, true, true, true) } + "return row-level results for non-null columns starts with digit" in withSparkSession { session => + + val data = getDfWithStringColumns(session) + + val patternMatchCountry = PatternMatch("Address Line 1", """(^[0-4])""".r) + val state = patternMatchCountry.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe + Seq(false, false, true, true, false, false, true, true) + } + + "return row-level results for non-null columns starts with digit filtered as true" in withSparkSession { session => + + val data = getDfWithStringColumns(session) + + val patternMatchCountry = PatternMatch("Address Line 1", """(^[0-4])""".r, where = Option("id < 5"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) + val state = patternMatchCountry.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe + Seq(false, false, true, true, false, true, true, true) + } + "return row-level results for columns with nulls" in withSparkSession { session => val data = getDfWithStringColumns(session) @@ -45,8 +70,34 @@ class PatternMatchTest extends AnyWordSpec with Matchers with SparkContextSpec w val state = patternMatchCountry.computeStateFrom(data) val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe Seq(true, true, true, true, false, true, true, false) } + + "return row-level results for columns with nulls filtered as true" in withSparkSession { session => + + val data = getDfWithStringColumns(session) + + val patternMatchCountry = PatternMatch("Address Line 2", """\w""".r, where = Option("id < 5"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) + val state = patternMatchCountry.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe + Seq(true, true, true, true, false, true, true, true) + } + + "return row-level results for columns with nulls filtered as null" in withSparkSession { session => + + val data = getDfWithStringColumns(session) + + val patternMatchCountry = PatternMatch("Address Line 2", """\w""".r, where = Option("id < 5"), + analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) + val state = patternMatchCountry.computeStateFrom(data) + val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state) + + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe + Seq(true, true, true, true, false, null, null, null) + } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala index d50995b55..4aea6bb27 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/UniquenessTest.scala @@ -123,7 +123,7 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithUniqueColumns(session) val addressLength = Uniqueness(Seq("onlyUniqueWithOtherNonUnique"), Option("unique < 4"), - Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) + Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique < 4")) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) @@ -139,7 +139,7 @@ class UniquenessTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithUniqueColumns(session) val addressLength = Uniqueness(Seq("halfUniqueCombinedWithNonUnique", "nonUnique"), Option("unique > 2"), - Option(AnalyzerOptions(filteredRow = FilteredRow.NULL))) + Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) val state: Option[FrequenciesAndNumRows] = addressLength.computeStateFrom(data, Option("unique > 2")) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala index ce9bda69b..193dbaebe 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalysisRunnerTests.scala @@ -204,10 +204,9 @@ class AnalysisRunnerTests extends AnyWordSpec // Used to be tested with the above line, but adding filters changed the order of the results. assert(separateResults.asInstanceOf[Set[DoubleMetric]].size == runnerResults.asInstanceOf[Set[DoubleMetric]].size) - separateResults.asInstanceOf[Set[DoubleMetric]].foreach( result => { - assert(runnerResults.toString.contains(result.toString)) - } - ) + separateResults.asInstanceOf[Set[DoubleMetric]].foreach( result => + assert(runnerResults.toString.contains(result.toString)) + ) } "reuse existing results" in diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala index 9133d5ae4..3054d141c 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala @@ -146,7 +146,5 @@ class AnalyzerContextTest extends AnyWordSpec private[this] def assertSameJson(jsonA: String, jsonB: String): Unit = { assert(SimpleResultSerde.deserialize(jsonA).toSet.sameElements(SimpleResultSerde.deserialize(jsonB).toSet)) - // assert(SimpleResultSerde.deserialize(jsonA) == - // SimpleResultSerde.deserialize(jsonB)) } } diff --git a/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala b/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala index 9a82903e8..0cd76c8de 100644 --- a/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala +++ b/src/test/scala/com/amazon/deequ/suggestions/ConstraintSuggestionResultTest.scala @@ -255,7 +255,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | }, | { | "constraint_name": "ComplianceConstraint(Compliance(\u0027item\u0027 has no - | negative values,item \u003e\u003d 0,None,List(item)))", + | negative values,item \u003e\u003d 0,None,List(item),None))", | "column_name": "item", | "current_value": "Minimum: 1.0", | "description": "\u0027item\u0027 has no negative values", @@ -341,7 +341,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | }, | { | "constraint_name": "ComplianceConstraint(Compliance(\u0027item\u0027 has no - | negative values,item \u003e\u003d 0,None,List(item)))", + | negative values,item \u003e\u003d 0,None,List(item),None))", | "column_name": "item", | "current_value": "Minimum: 1.0", | "description": "\u0027item\u0027 has no negative values", @@ -428,7 +428,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | }, | { | "constraint_name": "ComplianceConstraint(Compliance(\u0027item\u0027 has no - | negative values,item \u003e\u003d 0,None,List(item)))", + | negative values,item \u003e\u003d 0,None,List(item),None))", | "column_name": "item", | "current_value": "Minimum: 1.0", | "description": "\u0027item\u0027 has no negative values", @@ -494,7 +494,7 @@ class ConstraintSuggestionResultTest extends WordSpec with Matchers with SparkCo | }, | { | "constraint_name": "ComplianceConstraint(Compliance(\u0027`item.one`\u0027 has no - | negative values,`item.one` \u003e\u003d 0,None,List(`item.one`)))", + | negative values,`item.one` \u003e\u003d 0,None,List(`item.one`),None))", | "column_name": "`item.one`", | "current_value": "Minimum: 1.0", | "description": "\u0027`item.one`\u0027 has no negative values", diff --git a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala index 601134a53..5c56ed4b0 100644 --- a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala +++ b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala @@ -32,13 +32,13 @@ trait FixtureSupport { import sparkSession.implicits._ Seq( - ("", "a", "f"), - ("", "b", "d"), - ("", "a", null), - ("", "a", "f"), - ("", "b", null), - ("", "a", "f") - ).toDF("att1", "att2", "att3") + (0, "", "a", "f"), + (1, "", "b", "d"), + (2, "", "a", null), + (3, "", "a", "f"), + (4, "", "b", null), + (5, "", "a", "f") + ).toDF("id", "att1", "att2", "att3") } def getDfEmpty(sparkSession: SparkSession): DataFrame = { @@ -159,6 +159,19 @@ trait FixtureSupport { ).toDF("item", "att1", "att2") } + def getDfCompleteAndInCompleteColumnsWithIntId(sparkSession: SparkSession): DataFrame = { + import sparkSession.implicits._ + + Seq( + (1, "a", "f"), + (2, "b", "d"), + (3, "a", null), + (4, "a", "f"), + (5, "b", null), + (6, "a", "f") + ).toDF("item", "att1", "att2") + } + def getDfCompleteAndInCompleteColumnsWithSpacesInNames(sparkSession: SparkSession): DataFrame = { import sparkSession.implicits._ @@ -399,16 +412,16 @@ trait FixtureSupport { import sparkSession.implicits._ Seq( - ("India", "Xavier House, 2nd Floor", "St. Peter Colony, Perry Road", "Bandra (West)"), - ("India", "503 Godavari", "Sir Pochkhanwala Road", "Worli"), - ("India", "4/4 Seema Society", "N Dutta Road, Four Bungalows", "Andheri"), - ("India", "1001D Abhishek Apartments", "Juhu Versova Road", "Andheri"), - ("India", "95, Hill Road", null, null), - ("India", "90 Cuffe Parade", "Taj President Hotel", "Cuffe Parade"), - ("India", "4, Seven PM", "Sir Pochkhanwala Rd", "Worli"), - ("India", "1453 Sahar Road", null, null) + (0, "India", "Xavier House, 2nd Floor", "St. Peter Colony, Perry Road", "Bandra (West)"), + (1, "India", "503 Godavari", "Sir Pochkhanwala Road", "Worli"), + (2, "India", "4/4 Seema Society", "N Dutta Road, Four Bungalows", "Andheri"), + (3, "India", "1001D Abhishek Apartments", "Juhu Versova Road", "Andheri"), + (4, "India", "95, Hill Road", null, null), + (5, "India", "90 Cuffe Parade", "Taj President Hotel", "Cuffe Parade"), + (6, "India", "4, Seven PM", "Sir Pochkhanwala Rd", "Worli"), + (7, "India", "1453 Sahar Road", null, null) ) - .toDF("Country", "Address Line 1", "Address Line 2", "Address Line 3") + .toDF("id", "Country", "Address Line 1", "Address Line 2", "Address Line 3") } def getDfWithPeriodInName(sparkSession: SparkSession): DataFrame = { From efeec970f66c2c54dc77a3644b7a7e8bbe5f9255 Mon Sep 17 00:00:00 2001 From: Edward Cho <114528615+eycho-am@users.noreply.github.com> Date: Fri, 23 Feb 2024 12:21:33 -0500 Subject: [PATCH 07/24] Add analyzerOption to add filteredRowOutcome for isPrimaryKey Check (#537) * Add analyzerOption to add filteredRowOutcome for isPrimaryKey Check * Add analyzerOption to add filteredRowOutcome for hasUniqueValueRatio Check --- .../scala/com/amazon/deequ/checks/Check.scala | 27 ++++++++++-- .../amazon/deequ/constraints/Constraint.scala | 6 ++- .../amazon/deequ/VerificationSuiteTest.scala | 41 +++++++++++++++++-- 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index ccfd9badc..c76187e7d 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -253,10 +253,27 @@ case class Check( * @param hint A hint to provide additional context why a constraint could have failed * @return */ - def isPrimaryKey(column: String, hint: Option[String], columns: String*) + def isPrimaryKey(column: String, hint: Option[String], + analyzerOptions: Option[AnalyzerOptions], columns: String*) : CheckWithLastConstraintFilterable = { addFilterableConstraint { filter => - uniquenessConstraint(column :: columns.toList, Check.IsOne, filter, hint) } + uniquenessConstraint(column :: columns.toList, Check.IsOne, filter, hint, analyzerOptions) } + } + + /** + * Creates a constraint that asserts on a column(s) primary key characteristics. + * Currently only checks uniqueness, but reserved for primary key checks if there is another + * assertion to run on primary key columns. + * + * @param column Columns to run the assertion on + * @param hint A hint to provide additional context why a constraint could have failed + * @return + */ + def isPrimaryKey(column: String, hint: Option[String], columns: String*) + : CheckWithLastConstraintFilterable = { + addFilterableConstraint { filter => + uniquenessConstraint(column :: columns.toList, Check.IsOne, filter, hint) + } } /** @@ -377,16 +394,18 @@ case class Check( * @param assertion Function that receives a double input parameter and returns a boolean. * Refers to the fraction of distinct values. * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) * @return */ def hasUniqueValueRatio( columns: Seq[String], assertion: Double => Boolean, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : CheckWithLastConstraintFilterable = { addFilterableConstraint { filter => - uniqueValueRatioConstraint(columns, assertion, filter, hint) } + uniqueValueRatioConstraint(columns, assertion, filter, hint, analyzerOptions) } } /** diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index fec0842f7..46a12332b 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -339,15 +339,17 @@ object Constraint { * (since the metric is double metric) and returns a boolean * @param where Additional filter to apply before the analyzer is run. * @param hint A hint to provide additional context why a constraint could have failed + * @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow) */ def uniqueValueRatioConstraint( columns: Seq[String], assertion: Double => Boolean, where: Option[String] = None, - hint: Option[String] = None) + hint: Option[String] = None, + analyzerOptions: Option[AnalyzerOptions] = None) : Constraint = { - val uniqueValueRatio = UniqueValueRatio(columns, where) + val uniqueValueRatio = UniqueValueRatio(columns, where, analyzerOptions) fromAnalyzer(uniqueValueRatio, assertion, hint) } diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 1fb8ab74d..e02ab0cf4 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -324,6 +324,12 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val patternMatch = new Check(CheckLevel.Error, "rule6") .hasPattern("att2", """(^f)""".r) .where("item < 4") + val isPrimaryKey = new Check(CheckLevel.Error, "rule7") + .isPrimaryKey("item") + .where("item < 3") + val uniqueValueRatio = new Check(CheckLevel.Error, "rule8") + .hasUniqueValueRatio(Seq("att1"), _ >= 0.5) + .where("item < 4") val expectedColumn1 = completeness.description val expectedColumn2 = uniqueness.description @@ -331,7 +337,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val expectedColumn4 = min.description val expectedColumn5 = max.description val expectedColumn6 = patternMatch.description - + val expectedColumn7 = isPrimaryKey.description + val expectedColumn8 = uniqueValueRatio.description val suite = new VerificationSuite().onData(data) .addCheck(completeness) @@ -340,6 +347,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .addCheck(min) .addCheck(max) .addCheck(patternMatch) + .addCheck(isPrimaryKey) + .addCheck(uniqueValueRatio) val result: VerificationResult = suite.run() @@ -349,7 +358,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec resultData.show(false) val expectedColumns: Set[String] = data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + - expectedColumn4 + expectedColumn5 + expectedColumn6 + expectedColumn4 + expectedColumn5 + expectedColumn6 + expectedColumn7 + expectedColumn8 assert(resultData.columns.toSet == expectedColumns) // filtered rows 2,5 (where att1 = "a") @@ -374,6 +383,14 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec // filtered rows 4, 5, 6 (where item < 4) val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) assert(Seq(true, false, false, true, true, true).sameElements(rowLevel6)) + + // filtered rows 4, 5, 6 (where item < 4) + val rowLevel7 = resultData.select(expectedColumn7).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, true, true, true).sameElements(rowLevel7)) + + // filtered rows 4, 5, 6 (where item < 4) row 1 and 3 are the same -> not unique + val rowLevel8 = resultData.select(expectedColumn8).collect().map(r => r.getAs[Any](0)) + assert(Seq(false, true, false, true, true, true).sameElements(rowLevel8)) } "generate a result that contains row-level results with null for filtered rows" in withSparkSession { session => @@ -398,6 +415,12 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val patternMatch = new Check(CheckLevel.Error, "rule6") .hasPattern("att2", """(^f)""".r, analyzerOptions = analyzerOptions) .where("item < 4") + val isPrimaryKey = new Check(CheckLevel.Error, "rule7") + .isPrimaryKey("item", None, analyzerOptions = analyzerOptions) + .where("item < 4") + val uniqueValueRatio = new Check(CheckLevel.Error, "rule8") + .hasUniqueValueRatio(Seq("att1"), _ >= 0.5, analyzerOptions = analyzerOptions) + .where("item < 4") val expectedColumn1 = completeness.description val expectedColumn2 = uniqueness.description @@ -405,6 +428,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val expectedColumn4 = min.description val expectedColumn5 = max.description val expectedColumn6 = patternMatch.description + val expectedColumn7 = isPrimaryKey.description + val expectedColumn8 = uniqueValueRatio.description val suite = new VerificationSuite().onData(data) .addCheck(completeness) @@ -413,6 +438,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .addCheck(min) .addCheck(max) .addCheck(patternMatch) + .addCheck(isPrimaryKey) + .addCheck(uniqueValueRatio) val result: VerificationResult = suite.run() @@ -422,7 +449,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec resultData.show(false) val expectedColumns: Set[String] = data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + - expectedColumn4 + expectedColumn5 + expectedColumn6 + expectedColumn4 + expectedColumn5 + expectedColumn6 + expectedColumn7 + expectedColumn8 assert(resultData.columns.toSet == expectedColumns) val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) @@ -446,6 +473,14 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec // filtered rows 4, 5, 6 (where item < 4) val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0)) assert(Seq(true, false, false, null, null, null).sameElements(rowLevel6)) + + // filtered rows 4, 5, 6 (where item < 4) + val rowLevel7 = resultData.select(expectedColumn7).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, null, null, null).sameElements(rowLevel7)) + + // filtered rows 4, 5, 6 (where item < 4) row 1 and 3 are the same -> not unique + val rowLevel8 = resultData.select(expectedColumn8).collect().map(r => r.getAs[Any](0)) + assert(Seq(false, true, false, null, null, null).sameElements(rowLevel8)) } "generate a result that contains compliance row-level results " in withSparkSession { session => From 9fa50964ae1c0ca83b97c0c3c115e06920466168 Mon Sep 17 00:00:00 2001 From: Edward Cho <114528615+eycho-am@users.noreply.github.com> Date: Mon, 26 Feb 2024 10:08:27 -0500 Subject: [PATCH 08/24] Fix bug in MinLength and MaxLength analyzers where given the NullBehavior.EmptyString option, the where filter wasn't properly applied (#538) --- .../amazon/deequ/analyzers/MaxLength.scala | 3 +- .../amazon/deequ/analyzers/MinLength.scala | 3 +- .../amazon/deequ/VerificationSuiteTest.scala | 41 +++++++++++++++++++ .../com/amazon/deequ/checks/CheckTest.scala | 15 +++++++ 4 files changed, 60 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala index 19c9ca9b7..3b55d4fa6 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala @@ -75,7 +75,8 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio case NullBehavior.Fail => conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue) case NullBehavior.EmptyString => - length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType) + // Empty String is 0 length string + conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType) case _ => colLengths } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala index c155cca94..a6627d2d2 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala @@ -75,7 +75,8 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio case NullBehavior.Fail => conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue) case NullBehavior.EmptyString => - length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType) + // Empty String is 0 length string + conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType) case _ => colLengths } diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index e02ab0cf4..580be29b2 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -600,6 +600,47 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec assert(Seq(false, null, false, true, null, true).sameElements(rowLevel4)) } + "confirm that minLength and maxLength properly filters with nullBehavior empty" in withSparkSession { session => + val data = getDfCompleteAndInCompleteColumnsAndVarLengthStrings(session) + + val minLength = new Check(CheckLevel.Error, "rule1") + .hasMinLength("item", _ > 3, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) + .where("val1 > 3") + val maxLength = new Check(CheckLevel.Error, "rule2") + .hasMaxLength("item", _ <= 3, + analyzerOptions = Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) + .where("val1 < 4") + + val expectedColumn1 = minLength.description + val expectedColumn2 = maxLength.description + + val suite = new VerificationSuite().onData(data) + .addCheck(minLength) + .addCheck(maxLength) + + val result: VerificationResult = suite.run() + + val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data) + + resultData.show(false) + + val expectedColumns: Set[String] = + data.columns.toSet + expectedColumn1 + expectedColumn2 + assert(resultData.columns.toSet == expectedColumns) + + // Unfiltered rows are all true - overall result should be Success + assert(result.status == CheckStatus.Success) + + // minLength > 3 would fail for the first three rows (length 1,2,3) + val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0)) + assert(Seq(null, null, null, true, true, true).sameElements(rowLevel1)) + + // maxLength <= 3 would fail for the last three rows (length 4,5,6) + val rowLevel2 = resultData.select(expectedColumn2).collect().map(r => r.getAs[Any](0)) + assert(Seq(true, true, true, null, null, null).sameElements(rowLevel2)) + } + "generate a result that contains length row-level results with nullBehavior fail" in withSparkSession { session => val data = getDfCompleteAndInCompleteColumnsAndVarLengthStrings(session) diff --git a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala index 096e330b8..e768d0f37 100644 --- a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala +++ b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala @@ -702,6 +702,21 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix assertSuccess(baseCheck.hasMaxLength("att1", _ == 4.0), context) } + "yield correct results for minimum and maximum length stats with where clause" in + withSparkSession { sparkSession => + val emptyNulLBehavior = Option(AnalyzerOptions(NullBehavior.EmptyString)) + val baseCheck = Check(CheckLevel.Error, description = "a description") + val df = getDfCompleteAndInCompleteColumnsAndVarLengthStrings(sparkSession) + val context = AnalysisRunner.onData(df) + .addAnalyzers(Seq(MinLength("item", Option("val1 > 3"), emptyNulLBehavior), + MaxLength("item", Option("val1 <= 3"), emptyNulLBehavior))).run() + + assertSuccess(baseCheck.hasMinLength("item", _ >= 4.0, analyzerOptions = emptyNulLBehavior) + .where("val1 > 3"), context) // 1 without where clause + assertSuccess(baseCheck.hasMaxLength("item", _ <= 3.0, analyzerOptions = emptyNulLBehavior) + .where("val1 <= 3"), context) // 6 without where clause + } + "work on regular expression patterns for E-Mails" in withSparkSession { sparkSession => val col = "some" val df = dataFrameWithColumn(col, StringType, sparkSession, Row("someone@somewhere.org"), From c7fa635483ee3318d957ee751e48c3dbc752be64 Mon Sep 17 00:00:00 2001 From: rdsharma26 <65777064+rdsharma26@users.noreply.github.com> Date: Fri, 8 Mar 2024 13:05:35 -0500 Subject: [PATCH 09/24] [Min/Max] Apply filtered row behavior at the row level evaluation (#543) * [Min/Max] Apply filtered row behavior at the row level evaluation - This changes from applying the behavior at the analyzer level. It allows us to prevent the usage of MinValue/MaxValue as placeholder values for filtered rows. * Improved the separation of null rows, based on their source - Whether the outcome for a row is null because of being filtered out or due to the target column being null, is now stored in the outcome column itself. - We could have reused the placeholder value to find out if a row was originally filtered out, but that would not work if the actual value in the row was the same originally. * Mark filtered rows as true We recently fixed the outcome of filtered rows and made them default to true instead of false, which was a bug earlier. This change maintains that behavior. * Added null behavior - empty string to match block Not having it can cause match error. --- .../com/amazon/deequ/analyzers/Analyzer.scala | 48 +++-- .../com/amazon/deequ/analyzers/Maximum.scala | 25 +-- .../com/amazon/deequ/analyzers/Minimum.scala | 24 +-- .../amazon/deequ/constraints/Constraint.scala | 61 +++++- .../amazon/deequ/VerificationSuiteTest.scala | 195 +++++++++++++++++- .../amazon/deequ/analyzers/MaximumTest.scala | 51 ++--- .../amazon/deequ/analyzers/MinimumTest.scala | 58 ++---- 7 files changed, 320 insertions(+), 142 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index dd5fb07e9..bc05adb51 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -262,8 +262,13 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo } } +sealed trait RowLevelStatusSource { def name: String } +case object InScopeData extends RowLevelStatusSource { val name = "InScopeData" } +case object FilteredData extends RowLevelStatusSource { val name = "FilteredData" } + case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore, filteredRow: FilteredRowOutcome = FilteredRowOutcome.TRUE) + object NullBehavior extends Enumeration { type NullBehavior = Value val Ignore, EmptyString, Fail = Value @@ -478,34 +483,34 @@ private[deequ] object Analyzers { if (columns.size == 1) Entity.Column else Entity.Multicolumn } - def conditionalSelection(selection: String, where: Option[String]): Column = { - conditionalSelection(col(selection), where) + def conditionalSelection(selection: String, condition: Option[String]): Column = { + conditionalSelection(col(selection), condition) } - def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Double): Column = { - where + def conditionSelectionGivenColumn(selection: Column, condition: Option[Column], replaceWith: Double): Column = { + condition .map { condition => when(condition, replaceWith).otherwise(selection) } .getOrElse(selection) } - def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: String): Column = { - where + def conditionSelectionGivenColumn(selection: Column, condition: Option[Column], replaceWith: String): Column = { + condition .map { condition => when(condition, replaceWith).otherwise(selection) } .getOrElse(selection) } - def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Boolean): Column = { - where + def conditionSelectionGivenColumn(selection: Column, condition: Option[Column], replaceWith: Boolean): Column = { + condition .map { condition => when(condition, replaceWith).otherwise(selection) } .getOrElse(selection) } - def conditionalSelection(selection: Column, where: Option[String], replaceWith: Double): Column = { - conditionSelectionGivenColumn(selection, where.map(expr), replaceWith) + def conditionalSelection(selection: Column, condition: Option[String], replaceWith: Double): Column = { + conditionSelectionGivenColumn(selection, condition.map(expr), replaceWith) } - def conditionalSelection(selection: Column, where: Option[String], replaceWith: String): Column = { - conditionSelectionGivenColumn(selection, where.map(expr), replaceWith) + def conditionalSelection(selection: Column, condition: Option[String], replaceWith: String): Column = { + conditionSelectionGivenColumn(selection, condition.map(expr), replaceWith) } def conditionalSelection(selection: Column, condition: Option[String]): Column = { @@ -513,11 +518,20 @@ private[deequ] object Analyzers { conditionalSelectionFromColumns(selection, conditionColumn) } - def conditionalSelectionFilteredFromColumns( - selection: Column, - conditionColumn: Option[Column], - filterTreatment: FilteredRowOutcome) - : Column = { + def conditionalSelectionWithAugmentedOutcome(selection: Column, + condition: Option[String], + replaceWith: Double): Column = { + val origSelection = array(lit(InScopeData.name).as("source"), selection.as("selection")) + val filteredSelection = array(lit(FilteredData.name).as("source"), lit(replaceWith).as("selection")) + + condition + .map { cond => when(not(expr(cond)), filteredSelection).otherwise(origSelection) } + .getOrElse(origSelection) + } + + def conditionalSelectionFilteredFromColumns(selection: Column, + conditionColumn: Option[Column], + filterTreatment: FilteredRowOutcome): Column = { conditionColumn .map { condition => when(not(condition), filterTreatment.getExpression).when(condition, selection) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala index c5cc33f94..abeee6d98 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala @@ -18,13 +18,11 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.{col, max} +import org.apache.spark.sql.functions.{col, element_at, max} import org.apache.spark.sql.types.{DoubleType, StructType} import Analyzers._ import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting -import org.apache.spark.sql.functions.expr -import org.apache.spark.sql.functions.not case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MaxState] with FullColumn { @@ -43,13 +41,12 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - max(criterion) :: Nil + max(element_at(criterion, 2).cast(DoubleType)) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = { - ifNoNullsIn(result, offset) { _ => - MaxState(result.getDouble(offset), Some(rowLevelResults)) + MaxState(result.getDouble(offset), Some(criterion)) } } @@ -60,19 +57,5 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = conditionalSelection(column, where).cast(DoubleType) - - private[deequ] def rowLevelResults: Column = { - val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) - val whereNotCondition = where.map { expression => not(expr(expression)) } - - filteredRowOutcome match { - case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MinValue).cast(DoubleType) - case _ => - criterion - } - } - + private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MinValue) } - diff --git a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala index 18640dc12..b17507fc5 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala @@ -18,13 +18,11 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.{col, min} +import org.apache.spark.sql.functions.{col, element_at, min} import org.apache.spark.sql.types.{DoubleType, StructType} import Analyzers._ import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting -import org.apache.spark.sql.functions.expr -import org.apache.spark.sql.functions.not case class MinState(minValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MinState] with FullColumn { @@ -43,12 +41,12 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - min(criterion) :: Nil + min(element_at(criterion, 2).cast(DoubleType)) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = { ifNoNullsIn(result, offset) { _ => - MinState(result.getDouble(offset), Some(rowLevelResults)) + MinState(result.getDouble(offset), Some(criterion)) } } @@ -59,19 +57,5 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = { - conditionalSelection(column, where).cast(DoubleType) - } - - private[deequ] def rowLevelResults: Column = { - val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) - val whereNotCondition = where.map { expression => not(expr(expression)) } - - filteredRowOutcome match { - case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MaxValue).cast(DoubleType) - case _ => - criterion - } - } + private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MaxValue) } diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index 46a12332b..ecf804b6d 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -664,7 +664,9 @@ object Constraint { val constraint = AnalysisBasedConstraint[MinState, Double, Double](minimum, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) + val updatedAssertion = getUpdatedRowLevelAssertion(assertion, minimum.analyzerOptions) + val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion) + new RowLevelAssertedConstraint( constraint, s"MinimumConstraint($minimum)", @@ -698,7 +700,9 @@ object Constraint { val constraint = AnalysisBasedConstraint[MaxState, Double, Double](maximum, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) + val updatedAssertion = getUpdatedRowLevelAssertion(assertion, maximum.analyzerOptions) + val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion) + new RowLevelAssertedConstraint( constraint, s"MaximumConstraint($maximum)", @@ -951,6 +955,59 @@ object Constraint { .getOrElse(0.0) } + + /* + * This function is used by Min/Max constraints and it creates a new assertion based on the provided assertion. + * Each value in the outcome column is an array of 2 elements. + * - The first element is a string that denotes whether the row is the filtered dataset or not. + * - The second element is the actual value of the constraint's target column. + * The result of the final assertion is one of 3 states: true, false or null. + * These values can be tuned using the analyzer options. + * Null outcome allows the consumer to decide how to treat filtered rows or rows that were originally null. + */ + private[this] def getUpdatedRowLevelAssertion(assertion: Double => Boolean, + analyzerOptions: Option[AnalyzerOptions]) + : Seq[String] => java.lang.Boolean = { + (d: Seq[String]) => { + val (scope, value) = (d.head, Option(d.last).map(_.toDouble)) + + def inScopeRowOutcome(value: Option[Double]): java.lang.Boolean = { + if (value.isDefined) { + // If value is defined, run it through the assertion. + assertion(value.get) + } else { + // If value is not defined (value is null), apply NullBehavior. + analyzerOptions match { + case Some(opts) => + opts.nullBehavior match { + case NullBehavior.Fail => false + case NullBehavior.Ignore | NullBehavior.EmptyString => null + } + case None => null + } + } + } + + def filteredRowOutcome: java.lang.Boolean = { + analyzerOptions match { + case Some(opts) => + opts.filteredRow match { + case FilteredRowOutcome.TRUE => true + case FilteredRowOutcome.NULL => null + } + // https://github.com/awslabs/deequ/issues/530 + // Filtered rows should be marked as true by default. + // They can be set to null using the FilteredRowOutcome option. + case None => true + } + } + + scope match { + case FilteredData.name => filteredRowOutcome + case InScopeData.name => inScopeRowOutcome(value) + } + } + } } /** diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 580be29b2..b1fd4596e 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -41,13 +41,10 @@ import org.scalamock.scalatest.MockFactory import org.scalatest.Matchers import org.scalatest.WordSpec - - class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec with FixtureSupport with MockFactory { "Verification Suite" should { - "return the correct verification status regardless of the order of checks" in withSparkSession { sparkSession => @@ -1741,6 +1738,189 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } } + "Verification Suite with == based Min/Max checks and filtered row behavior" should { + val col1 = "att1" + val col2 = "att2" + val col3 = "att3" + + val check1Description = "equality-check-1" + val check2Description = "equality-check-2" + val check3Description = "equality-check-3" + + val check1WhereClause = "att1 > 3" + val check2WhereClause = "att2 > 4" + val check3WhereClause = "att3 = 0" + + def mkEqualityCheck1(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check1Description) + .hasMin(col1, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause) + .hasMax(col1, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause) + + def mkEqualityCheck2(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check2Description) + .hasMin(col2, _ == 7, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause) + .hasMax(col2, _ == 7, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause) + + def mkEqualityCheck3(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check3Description) + .hasMin(col3, _ == 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause) + .hasMax(col3, _ == 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause) + + def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] = + df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq + + def assertCheckResults(verificationResult: VerificationResult): Unit = { + val passResult = verificationResult.checkResults + + val equalityCheck1Result = passResult.values.find(_.check.description == check1Description) + val equalityCheck2Result = passResult.values.find(_.check.description == check2Description) + val equalityCheck3Result = passResult.values.find(_.check.description == check3Description) + + assert(equalityCheck1Result.isDefined && equalityCheck1Result.get.status == CheckStatus.Error) + assert(equalityCheck2Result.isDefined && equalityCheck2Result.get.status == CheckStatus.Error) + assert(equalityCheck3Result.isDefined && equalityCheck3Result.get.status == CheckStatus.Success) + } + + def assertRowLevelResults(rowLevelResults: DataFrame, + analyzerOptions: AnalyzerOptions): Unit = { + val equalityCheck1Results = getRowLevelResults(rowLevelResults.select(check1Description)) + val equalityCheck2Results = getRowLevelResults(rowLevelResults.select(check2Description)) + val equalityCheck3Results = getRowLevelResults(rowLevelResults.select(check3Description)) + + val filteredOutcome: java.lang.Boolean = analyzerOptions.filteredRow match { + case FilteredRowOutcome.TRUE => true + case FilteredRowOutcome.NULL => null + } + + assert(equalityCheck1Results == Seq(filteredOutcome, filteredOutcome, filteredOutcome, true, false, false)) + assert(equalityCheck2Results == Seq(filteredOutcome, filteredOutcome, filteredOutcome, false, false, true)) + assert(equalityCheck3Results == Seq(true, true, true, filteredOutcome, filteredOutcome, filteredOutcome)) + } + + def assertMetrics(metricsDF: DataFrame): Unit = { + val metricsMap = getMetricsAsMap(metricsDF) + assert(metricsMap(s"$col1|Minimum (where: $check1WhereClause)") == 4.0) + assert(metricsMap(s"$col1|Maximum (where: $check1WhereClause)") == 6.0) + assert(metricsMap(s"$col2|Minimum (where: $check2WhereClause)") == 5.0) + assert(metricsMap(s"$col2|Maximum (where: $check2WhereClause)") == 7.0) + assert(metricsMap(s"$col3|Minimum (where: $check3WhereClause)") == 0.0) + assert(metricsMap(s"$col3|Maximum (where: $check3WhereClause)") == 0.0) + } + + "mark filtered rows as null" in withSparkSession { + sparkSession => + val df = getDfWithNumericValues(sparkSession) + val analyzerOptions = AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL) + + val equalityCheck1 = mkEqualityCheck1(analyzerOptions) + val equalityCheck2 = mkEqualityCheck2(analyzerOptions) + val equalityCheck3 = mkEqualityCheck3(analyzerOptions) + + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3)) + .run() + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + + assertCheckResults(verificationResult) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + assertMetrics(metricsDF) + } + + "mark filtered rows as true" in withSparkSession { + sparkSession => + val df = getDfWithNumericValues(sparkSession) + val analyzerOptions = AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE) + + val equalityCheck1 = mkEqualityCheck1(analyzerOptions) + val equalityCheck2 = mkEqualityCheck2(analyzerOptions) + val equalityCheck3 = mkEqualityCheck3(analyzerOptions) + + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3)) + .run() + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + + assertCheckResults(verificationResult) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + assertMetrics(metricsDF) + } + } + + "Verification Suite with == based Min/Max checks and null row behavior" should { + val col = "attNull" + val checkDescription = "equality-check" + def mkEqualityCheck(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, checkDescription) + .hasMin(col, _ == 6, analyzerOptions = Some(analyzerOptions)) + .hasMax(col, _ == 6, analyzerOptions = Some(analyzerOptions)) + + def assertCheckResults(verificationResult: VerificationResult, checkStatus: CheckStatus.Value): Unit = { + val passResult = verificationResult.checkResults + val equalityCheckResult = passResult.values.find(_.check.description == checkDescription) + assert(equalityCheckResult.isDefined && equalityCheckResult.get.status == checkStatus) + } + + def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] = + df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq + + def assertRowLevelResults(rowLevelResults: DataFrame, + analyzerOptions: AnalyzerOptions): Unit = { + val equalityCheckResults = getRowLevelResults(rowLevelResults.select(checkDescription)) + val nullOutcome: java.lang.Boolean = analyzerOptions.nullBehavior match { + case NullBehavior.Fail => false + case NullBehavior.Ignore => null + } + + assert(equalityCheckResults == Seq(nullOutcome, nullOutcome, nullOutcome, false, true, false)) + } + + def assertMetrics(metricsDF: DataFrame): Unit = { + val metricsMap = getMetricsAsMap(metricsDF) + assert(metricsMap(s"$col|Minimum") == 5.0) + assert(metricsMap(s"$col|Maximum") == 7.0) + } + + "keep non-filtered null rows as null" in withSparkSession { + sparkSession => + val df = getDfWithNumericValues(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.Ignore) + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(mkEqualityCheck(analyzerOptions))) + .run() + + val passResult = verificationResult.checkResults + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF) + } + + "mark non-filtered null rows as false" in withSparkSession { + sparkSession => + val df = getDfWithNumericValues(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.Fail) + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(mkEqualityCheck(analyzerOptions))) + .run() + + val passResult = verificationResult.checkResults + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF) + } + } + /** Run anomaly detection using a repository with some previous analysis results for testing */ private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = { @@ -1765,4 +1945,13 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec private[this] def assertSameRows(dataframeA: DataFrame, dataframeB: DataFrame): Unit = { assert(dataframeA.collect().toSet == dataframeB.collect().toSet) } + + private[this] def getMetricsAsMap(metricsDF: DataFrame): Map[String, Double] = { + metricsDF.collect().map { r => + val colName = r.getAs[String]("instance") + val metricName = r.getAs[String]("name") + val metricValue = r.getAs[Double]("value") + s"$colName|$metricName" -> metricValue + }.toMap + } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala index 6ac90f735..1d13a8dfe 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala @@ -21,10 +21,20 @@ import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.types.DoubleType import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + private val tempColName = "new" + + private def getValuesDF(df: DataFrame, outcomeColumn: Column): Seq[Row] = { + df.withColumn(tempColName, element_at(outcomeColumn, 2).cast(DoubleType)).collect() + } "Max" should { "return row-level results for columns" in withSparkSession { session => @@ -35,8 +45,8 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F val state: Option[MaxState] = att1Maximum.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0) } "return row-level results for columns with null" in withSparkSession { session => @@ -47,40 +57,9 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F val state: Option[MaxState] = att1Maximum.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(null, null, null, 5.0, 6.0, 7.0) - } - - "return row-level results for columns with where clause filtered as true" in withSparkSession { session => - - val data = getDfWithNumericValues(session) - - val att1Maximum = Maximum("att1", Option("item < 4")) - val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4")) - val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) - - val result = data.withColumn("new", metric.fullColumn.get) - result.show(false) - result.collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, Double.MinValue, Double.MinValue, Double.MinValue) - } - - "return row-level results for columns with where clause filtered as null" in withSparkSession { session => - - val data = getDfWithNumericValues(session) - - val att1Maximum = Maximum("att1", Option("item < 4"), - Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) - val state: Option[MaxState] = att1Maximum.computeStateFrom(data, Option("item < 4")) - val metric: DoubleMetric with FullColumn = att1Maximum.computeMetricFrom(state) - - val result = data.withColumn("new", metric.fullColumn.get) - result.show(false) - result.collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, null, null, null) + val values = getValuesDF(data, metric.fullColumn.get) + .map(r => if (r == null) null else r.getAs[Double](tempColName)) + values shouldBe Seq(null, null, null, 5.0, 6.0, 7.0) } } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala index 435542e8c..8d1d2dd63 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MinimumTest.scala @@ -14,77 +14,49 @@ * */ - package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.types.DoubleType import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class MinimumTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + private val tempColName = "new" + + private def getValuesDF(df: DataFrame, outcomeColumn: Column): Seq[Row] = { + df.withColumn(tempColName, element_at(outcomeColumn, 2).cast(DoubleType)).collect() + } "Min" should { "return row-level results for columns" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Minimum = Minimum("att1") val state: Option[MinState] = att1Minimum.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0) } "return row-level results for columns with null" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Minimum = Minimum("attNull") val state: Option[MinState] = att1Minimum.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(null, null, null, 5.0, 6.0, 7.0) - } - - "return row-level results for columns with where clause filtered as true" in withSparkSession { session => - - val data = getDfWithNumericValues(session) - - val att1Minimum = Minimum("att1", Option("item < 4")) - val state: Option[MinState] = att1Minimum.computeStateFrom(data, Option("item < 4")) - print(state) - val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) - - val result = data.withColumn("new", metric.fullColumn.get) - result.show(false) - result.collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, Double.MaxValue, Double.MaxValue, Double.MaxValue) - } - - "return row-level results for columns with where clause filtered as null" in withSparkSession { session => - - val data = getDfWithNumericValues(session) - - val att1Minimum = Minimum("att1", Option("item < 4"), - Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL))) - val state: Option[MinState] = att1Minimum.computeStateFrom(data, Option("item < 4")) - print(state) - val metric: DoubleMetric with FullColumn = att1Minimum.computeMetricFrom(state) - - val result = data.withColumn("new", metric.fullColumn.get) - result.show(false) - result.collect().map(r => - if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 2.0, 3.0, null, null, null) + val values = getValuesDF(data, metric.fullColumn.get) + .map(r => if (r == null) null else r.getAs[Double](tempColName)) + values shouldBe Seq(null, null, null, 5.0, 6.0, 7.0) } } - } From abb54bcf3e6675a13593943291f66693af740209 Mon Sep 17 00:00:00 2001 From: rdsharma26 <65777064+rdsharma26@users.noreply.github.com> Date: Sun, 10 Mar 2024 17:08:37 -0400 Subject: [PATCH 10/24] [MinLength/MaxLength] Apply filtered row behavior at the row level evaluation (#547) * [MinLength/MaxLength] Apply filtered row behavior at the row level evaluation - For certain scenarios, the filtered row behavior for MinLength and MaxLength was not working correctly. - For example, when using both minLength and maxLength constraints in a single check, and with both using == as an assertion. This was resulting in the row level outcome of the filtered rows to be false. This was because we were replacing values for filtered rows for Min to MaxValue and for Max to MinValue. But a number could not equal both at the same time. - Updated the logic of the row level assertion to MinLength/MaxLength to match what was done for Min/Max. --- .../com/amazon/deequ/analyzers/Analyzer.scala | 11 +- .../amazon/deequ/analyzers/MaxLength.scala | 47 ++-- .../com/amazon/deequ/analyzers/Maximum.scala | 20 +- .../amazon/deequ/analyzers/MinLength.scala | 47 ++-- .../com/amazon/deequ/analyzers/Minimum.scala | 20 +- .../amazon/deequ/constraints/Constraint.scala | 61 +++-- .../amazon/deequ/VerificationSuiteTest.scala | 208 +++++++++++++++++- .../deequ/analyzers/MaxLengthTest.scala | 114 ++-------- .../amazon/deequ/analyzers/MaximumTest.scala | 2 - .../deequ/analyzers/MinLengthTest.scala | 117 ++-------- .../amazon/deequ/utils/FixtureSupport.scala | 10 +- 11 files changed, 367 insertions(+), 290 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index bc05adb51..9367f31e1 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -519,13 +519,16 @@ private[deequ] object Analyzers { } def conditionalSelectionWithAugmentedOutcome(selection: Column, - condition: Option[String], - replaceWith: Double): Column = { + condition: Option[String]): Column = { val origSelection = array(lit(InScopeData.name).as("source"), selection.as("selection")) - val filteredSelection = array(lit(FilteredData.name).as("source"), lit(replaceWith).as("selection")) + + // The 2nd value in the array is set to null, but it can be set to anything. + // The value is not used to evaluate the row level outcome for filtered rows (to true/null). + // That decision is made using the 1st value which is set to "FilteredData" here. + val filteredSelection = array(lit(FilteredData.name).as("source"), lit(null).as("selection")) condition - .map { cond => when(not(expr(cond)), filteredSelection).otherwise(origSelection) } + .map { cond => when(expr(cond), origSelection).otherwise(filteredSelection) } .getOrElse(origSelection) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala index 3b55d4fa6..141d92fb5 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala @@ -23,10 +23,11 @@ import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.element_at import org.apache.spark.sql.functions.length +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.max -import org.apache.spark.sql.functions.not +import org.apache.spark.sql.functions.when import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -35,12 +36,15 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - max(criterion) :: Nil + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. + max(element_at(criterion, 2).cast(DoubleType)) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = { ifNoNullsIn(result, offset) { _ => - MaxState(result.getDouble(offset), Some(rowLevelResults)) + MaxState(result.getDouble(offset), Some(criterion)) } } @@ -51,35 +55,16 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where private[deequ] def criterion: Column = { - transformColForNullBehavior - } - - private[deequ] def rowLevelResults: Column = { - transformColForFilteredRow(criterion) - } - - private def transformColForFilteredRow(col: Column): Column = { - val whereNotCondition = where.map { expression => not(expr(expression)) } - getRowLevelFilterTreatment(analyzerOptions) match { - case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MinValue) - case _ => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null) - } - } - - private def transformColForNullBehavior: Column = { val isNullCheck = col(column).isNull - val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) - getNullBehavior match { - case NullBehavior.Fail => - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue) - case NullBehavior.EmptyString => - // Empty String is 0 length string - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType) - case _ => - colLengths + val colLength = length(col(column)).cast(DoubleType) + val updatedColumn = getNullBehavior match { + case NullBehavior.Fail => when(isNullCheck, Double.MaxValue).otherwise(colLength) + // Empty String is 0 length string + case NullBehavior.EmptyString => when(isNullCheck, lit(0.0)).otherwise(colLength) + case NullBehavior.Ignore => colLength } + + conditionalSelectionWithAugmentedOutcome(updatedColumn, where) } private def getNullBehavior: NullBehavior = { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala index abeee6d98..1e52a7ae4 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala @@ -16,13 +16,18 @@ package com.amazon.deequ.analyzers -import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} -import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.{col, element_at, max} -import org.apache.spark.sql.types.{DoubleType, StructType} -import Analyzers._ +import com.amazon.deequ.analyzers.Analyzers._ +import com.amazon.deequ.analyzers.Preconditions.hasColumn +import com.amazon.deequ.analyzers.Preconditions.isNumeric import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.functions.max +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.Column +import org.apache.spark.sql.Row case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MaxState] with FullColumn { @@ -41,6 +46,9 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. max(element_at(criterion, 2).cast(DoubleType)) :: Nil } @@ -57,5 +65,5 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MinValue) + private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where) } diff --git a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala index a6627d2d2..ddc4497b2 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/MinLength.scala @@ -23,10 +23,11 @@ import com.amazon.deequ.analyzers.Preconditions.isString import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.expr +import org.apache.spark.sql.functions.element_at import org.apache.spark.sql.functions.length +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.min -import org.apache.spark.sql.functions.not +import org.apache.spark.sql.functions.when import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.StructType @@ -35,12 +36,15 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { - min(criterion) :: Nil + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. + min(element_at(criterion, 2).cast(DoubleType)) :: Nil } override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = { ifNoNullsIn(result, offset) { _ => - MinState(result.getDouble(offset), Some(rowLevelResults)) + MinState(result.getDouble(offset), Some(criterion)) } } @@ -51,35 +55,16 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio override def filterCondition: Option[String] = where private[deequ] def criterion: Column = { - transformColForNullBehavior - } - - private[deequ] def rowLevelResults: Column = { - transformColForFilteredRow(criterion) - } - - private def transformColForFilteredRow(col: Column): Column = { - val whereNotCondition = where.map { expression => not(expr(expression)) } - getRowLevelFilterTreatment(analyzerOptions) match { - case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MaxValue) - case _ => - conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null) - } - } - - private def transformColForNullBehavior: Column = { val isNullCheck = col(column).isNull - val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType) - getNullBehavior match { - case NullBehavior.Fail => - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue) - case NullBehavior.EmptyString => - // Empty String is 0 length string - conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType) - case _ => - colLengths + val colLength = length(col(column)).cast(DoubleType) + val updatedColumn = getNullBehavior match { + case NullBehavior.Fail => when(isNullCheck, Double.MinValue).otherwise(colLength) + // Empty String is 0 length string + case NullBehavior.EmptyString => when(isNullCheck, lit(0.0)).otherwise(colLength) + case NullBehavior.Ignore => colLength } + + conditionalSelectionWithAugmentedOutcome(updatedColumn, where) } private def getNullBehavior: NullBehavior = { diff --git a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala index b17507fc5..701ae0f05 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala @@ -16,13 +16,18 @@ package com.amazon.deequ.analyzers -import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} -import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.{col, element_at, min} -import org.apache.spark.sql.types.{DoubleType, StructType} -import Analyzers._ +import com.amazon.deequ.analyzers.Analyzers._ +import com.amazon.deequ.analyzers.Preconditions.hasColumn +import com.amazon.deequ.analyzers.Preconditions.isNumeric import com.amazon.deequ.metrics.FullColumn import com.google.common.annotations.VisibleForTesting +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.functions.min +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.Column +import org.apache.spark.sql.Row case class MinState(minValue: Double, override val fullColumn: Option[Column] = None) extends DoubleValuedState[MinState] with FullColumn { @@ -41,6 +46,9 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions with FilterableAnalyzer { override def aggregationFunctions(): Seq[Column] = { + // The criterion returns a column where each row contains an array of 2 elements. + // The first element of the array is a string that indicates if the row is "in scope" or "filtered" out. + // The second element is the value used for calculating the metric. We use "element_at" to extract it. min(element_at(criterion, 2).cast(DoubleType)) :: Nil } @@ -57,5 +65,5 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MaxValue) + private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where) } diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index ecf804b6d..79695c878 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -593,7 +593,8 @@ object Constraint { val constraint = AnalysisBasedConstraint[MaxState, Double, Double](maxLength, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) + val updatedAssertion = getUpdatedRowLevelAssertionForLengthConstraint(assertion, maxLength.analyzerOptions) + val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion) new RowLevelAssertedConstraint( constraint, @@ -628,7 +629,8 @@ object Constraint { val constraint = AnalysisBasedConstraint[MinState, Double, Double](minLength, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) + val updatedAssertion = getUpdatedRowLevelAssertionForLengthConstraint(assertion, minLength.analyzerOptions) + val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion) new RowLevelAssertedConstraint( constraint, @@ -988,26 +990,57 @@ object Constraint { } } - def filteredRowOutcome: java.lang.Boolean = { - analyzerOptions match { - case Some(opts) => - opts.filteredRow match { - case FilteredRowOutcome.TRUE => true - case FilteredRowOutcome.NULL => null - } - // https://github.com/awslabs/deequ/issues/530 - // Filtered rows should be marked as true by default. - // They can be set to null using the FilteredRowOutcome option. - case None => true + scope match { + case FilteredData.name => filteredRowOutcome(analyzerOptions) + case InScopeData.name => inScopeRowOutcome(value) + } + } + } + + private[this] def getUpdatedRowLevelAssertionForLengthConstraint(assertion: Double => Boolean, + analyzerOptions: Option[AnalyzerOptions]) + : Seq[String] => java.lang.Boolean = { + (d: Seq[String]) => { + val (scope, value) = (d.head, Option(d.last).map(_.toDouble)) + + def inScopeRowOutcome(value: Option[Double]): java.lang.Boolean = { + if (value.isDefined) { + // If value is defined, run it through the assertion. + assertion(value.get) + } else { + // If value is not defined (value is null), apply NullBehavior. + analyzerOptions match { + case Some(opts) => + opts.nullBehavior match { + case NullBehavior.EmptyString => assertion(0.0) + case NullBehavior.Fail => false + case NullBehavior.Ignore => null + } + case None => null + } } } scope match { - case FilteredData.name => filteredRowOutcome + case FilteredData.name => filteredRowOutcome(analyzerOptions) case InScopeData.name => inScopeRowOutcome(value) } } } + + private def filteredRowOutcome(analyzerOptions: Option[AnalyzerOptions]): java.lang.Boolean = { + analyzerOptions match { + case Some(opts) => + opts.filteredRow match { + case FilteredRowOutcome.TRUE => true + case FilteredRowOutcome.NULL => null + } + // https://github.com/awslabs/deequ/issues/530 + // Filtered rows should be marked as true by default. + // They can be set to null using the FilteredRowOutcome option. + case None => true + } + } } /** diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index b1fd4596e..bdcddde71 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -1891,7 +1891,6 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .addChecks(Seq(mkEqualityCheck(analyzerOptions))) .run() - val passResult = verificationResult.checkResults assertCheckResults(verificationResult, CheckStatus.Error) val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) @@ -1910,7 +1909,6 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec .addChecks(Seq(mkEqualityCheck(analyzerOptions))) .run() - val passResult = verificationResult.checkResults assertCheckResults(verificationResult, CheckStatus.Error) val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) @@ -1921,6 +1919,212 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } } + "Verification Suite with ==/!= based MinLength/MaxLength checks and filtered row behavior" should { + val col1 = "Company" + val col2 = "ZipCode" + val col3 = "City" + + val check1Description = "length-equality-check-1" + val check2Description = "length-equality-check-2" + val check3Description = "length-equality-check-3" + + val check1WhereClause = "ID > 2" + val check2WhereClause = "ID in (1, 2, 3)" + val check3WhereClause = "ID <= 2" + + def mkLengthCheck1(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check1Description) + .hasMinLength(col1, _ == 8, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause) + .hasMaxLength(col1, _ == 8, analyzerOptions = Some(analyzerOptions)).where(check1WhereClause) + + def mkLengthCheck2(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check2Description) + .hasMinLength(col2, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause) + .hasMaxLength(col2, _ == 4, analyzerOptions = Some(analyzerOptions)).where(check2WhereClause) + + def mkLengthCheck3(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, check3Description) + .hasMinLength(col3, _ != 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause) + .hasMaxLength(col3, _ != 0, analyzerOptions = Some(analyzerOptions)).where(check3WhereClause) + + def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] = + df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq + + def assertCheckResults(verificationResult: VerificationResult): Unit = { + val passResult = verificationResult.checkResults + + val equalityCheck1Result = passResult.values.find(_.check.description == check1Description) + val equalityCheck2Result = passResult.values.find(_.check.description == check2Description) + val equalityCheck3Result = passResult.values.find(_.check.description == check3Description) + + assert(equalityCheck1Result.isDefined && equalityCheck1Result.get.status == CheckStatus.Success) + assert(equalityCheck2Result.isDefined && equalityCheck2Result.get.status == CheckStatus.Error) + assert(equalityCheck3Result.isDefined && equalityCheck3Result.get.status == CheckStatus.Success) + } + + def assertRowLevelResults(rowLevelResults: DataFrame, + analyzerOptions: AnalyzerOptions): Unit = { + val equalityCheck1Results = getRowLevelResults(rowLevelResults.select(check1Description)) + val equalityCheck2Results = getRowLevelResults(rowLevelResults.select(check2Description)) + val equalityCheck3Results = getRowLevelResults(rowLevelResults.select(check3Description)) + + val filteredOutcome: java.lang.Boolean = analyzerOptions.filteredRow match { + case FilteredRowOutcome.TRUE => true + case FilteredRowOutcome.NULL => null + } + + assert(equalityCheck1Results == Seq(filteredOutcome, filteredOutcome, true, true)) + assert(equalityCheck2Results == Seq(false, false, false, filteredOutcome)) + assert(equalityCheck3Results == Seq(true, true, filteredOutcome, filteredOutcome)) + } + + def assertMetrics(metricsDF: DataFrame): Unit = { + val metricsMap = getMetricsAsMap(metricsDF) + assert(metricsMap(s"$col1|MinLength (where: $check1WhereClause)") == 8.0) + assert(metricsMap(s"$col1|MaxLength (where: $check1WhereClause)") == 8.0) + assert(metricsMap(s"$col2|MinLength (where: $check2WhereClause)") == 0.0) + assert(metricsMap(s"$col2|MaxLength (where: $check2WhereClause)") == 5.0) + assert(metricsMap(s"$col3|MinLength (where: $check3WhereClause)") == 11.0) + assert(metricsMap(s"$col3|MaxLength (where: $check3WhereClause)") == 11.0) + } + + "mark filtered rows as null" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions( + nullBehavior = NullBehavior.EmptyString, filteredRow = FilteredRowOutcome.NULL + ) + + val equalityCheck1 = mkLengthCheck1(analyzerOptions) + val equalityCheck2 = mkLengthCheck2(analyzerOptions) + val equalityCheck3 = mkLengthCheck3(analyzerOptions) + + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3)) + .run() + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + + assertCheckResults(verificationResult) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + assertMetrics(metricsDF) + } + + "mark filtered rows as true" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions( + nullBehavior = NullBehavior.EmptyString, filteredRow = FilteredRowOutcome.TRUE + ) + + val equalityCheck1 = mkLengthCheck1(analyzerOptions) + val equalityCheck2 = mkLengthCheck2(analyzerOptions) + val equalityCheck3 = mkLengthCheck3(analyzerOptions) + + val verificationResult = VerificationSuite() + .onData(df) + .addChecks(Seq(equalityCheck1, equalityCheck2, equalityCheck3)) + .run() + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + + assertCheckResults(verificationResult) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + assertMetrics(metricsDF) + } + } + + "Verification Suite with ==/!= based MinLength/MaxLength checks and null row behavior" should { + val col = "City" + val checkDescription = "length-check" + val assertion = (d: Double) => d >= 0.0 && d <= 8.0 + + def mkLengthCheck(analyzerOptions: AnalyzerOptions): Check = new Check(CheckLevel.Error, checkDescription) + .hasMinLength(col, assertion, analyzerOptions = Some(analyzerOptions)) + .hasMaxLength(col, assertion, analyzerOptions = Some(analyzerOptions)) + + def assertCheckResults(verificationResult: VerificationResult, checkStatus: CheckStatus.Value): Unit = { + val passResult = verificationResult.checkResults + val equalityCheckResult = passResult.values.find(_.check.description == checkDescription) + assert(equalityCheckResult.isDefined && equalityCheckResult.get.status == checkStatus) + } + + def getRowLevelResults(df: DataFrame): Seq[java.lang.Boolean] = + df.collect().map { r => r.getAs[java.lang.Boolean](0) }.toSeq + + def assertRowLevelResults(rowLevelResults: DataFrame, + analyzerOptions: AnalyzerOptions): Unit = { + val equalityCheckResults = getRowLevelResults(rowLevelResults.select(checkDescription)) + val nullOutcome: java.lang.Boolean = analyzerOptions.nullBehavior match { + case NullBehavior.Fail => false + case NullBehavior.Ignore => null + case NullBehavior.EmptyString => true + } + + assert(equalityCheckResults == Seq(false, false, nullOutcome, true)) + } + + def assertMetrics(metricsDF: DataFrame, minLength: Double, maxLength: Double): Unit = { + val metricsMap = getMetricsAsMap(metricsDF) + assert(metricsMap(s"$col|MinLength") == minLength) + assert(metricsMap(s"$col|MaxLength") == maxLength) + } + + "keep non-filtered null rows as null" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.Ignore) + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(mkLengthCheck(analyzerOptions)) + .run() + + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF, 8.0, 11.0) + } + + "mark non-filtered null rows as false" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.Fail) + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(mkLengthCheck(analyzerOptions)) + .run() + + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF, Double.MinValue, Double.MaxValue) + } + + "mark non-filtered null rows as empty string" in withSparkSession { + sparkSession => + val df = getDfForWhereClause(sparkSession) + val analyzerOptions = AnalyzerOptions(nullBehavior = NullBehavior.EmptyString) + val verificationResult = VerificationSuite() + .onData(df) + .addCheck(mkLengthCheck(analyzerOptions)) + .run() + + assertCheckResults(verificationResult, CheckStatus.Error) + + val rowLevelResultsDF = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + assertRowLevelResults(rowLevelResultsDF, analyzerOptions) + + val metricsDF = VerificationResult.successMetricsAsDataFrame(sparkSession, verificationResult) + assertMetrics(metricsDF, 0.0, 11.0) + } + } + /** Run anomaly detection using a repository with some previous analysis results for testing */ private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = { diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala index 456bd4c67..fd302a4da 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaxLengthTest.scala @@ -13,145 +13,73 @@ * permissions and limitations under the License. * */ + package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.types.DoubleType import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class MaxLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + private val tempColName = "new" + + private def getValuesDF(df: DataFrame, outcomeColumn: Column): Seq[Row] = { + df.withColumn(tempColName, element_at(outcomeColumn, 2).cast(DoubleType)).collect() + } "MaxLength" should { "return row-level results for non-null columns" in withSparkSession { session => - val data = getDfWithStringColumns(session) val countryLength = MaxLength("Country") // It's "India" in every row val state: Option[MaxState] = countryLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = countryLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) } "return row-level results for null columns" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MaxLength("att3") // It's null in two rows val state: Option[MaxState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) + val values = getValuesDF(data, metric.fullColumn.get) + .map(r => if (r == null) null else r.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, null, 1.0, null, 1.0) } "return row-level results for null columns with NullBehavior fail option" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) // It's null in two rows - val addressLength = MaxLength("att3") + val addressLength = MaxLength("att3", None, Option(AnalyzerOptions(NullBehavior.Fail))) val state: Option[MaxState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(r => if (r == null) null else r.getAs[Double]("new")) shouldBe - Seq(1.0, 1.0, null, 1.0, null, 1.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, Double.MaxValue, 1.0, Double.MaxValue, 1.0) } "return row-level results for blank strings" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MaxLength("att1") // It's empty strings val state: Option[MaxState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) - } - - "return row-level results with NullBehavior fail and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Fail))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, Double.MaxValue, 1.0, Double.MinValue, Double.MinValue) - } - - "return row-level results with NullBehavior fail and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, Double.MaxValue, 1.0, null, null) - } - - "return row-level results with NullBehavior empty and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.EmptyString))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, 0.0, 1.0, Double.MinValue, Double.MinValue) - } - - "return row-level results with NullBehavior empty and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, null, null) - } - - "return row-level results with NullBehavior ignore and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Ignore))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, null, 1.0, Double.MinValue, Double.MinValue) - } - - "return row-level results with NullBehavior ignore and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MaxLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Ignore, FilteredRowOutcome.NULL))) - val state: Option[MaxState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, null, 1.0, null, null) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) } } - } diff --git a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala index 1d13a8dfe..983e6bca4 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MaximumTest.scala @@ -14,7 +14,6 @@ * */ - package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec @@ -50,7 +49,6 @@ class MaximumTest extends AnyWordSpec with Matchers with SparkContextSpec with F } "return row-level results for columns with null" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Maximum = Maximum("attNull") diff --git a/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala b/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala index 0f88e377f..23e995741 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/MinLengthTest.scala @@ -14,45 +14,52 @@ * */ - package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.FullColumn import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Column +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.element_at +import org.apache.spark.sql.types.DoubleType import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + private val tempColName = "new" + + private def getValuesDF(df: DataFrame, outcomeColumn: Column): Seq[Row] = { + df.withColumn(tempColName, element_at(outcomeColumn, 2).cast(DoubleType)).collect() + } "MinLength" should { "return row-level results for non-null columns" in withSparkSession { session => - val data = getDfWithStringColumns(session) val countryLength = MinLength("Country") // It's "India" in every row val state: Option[MinState] = countryLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = countryLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0) } "return row-level results for null columns" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MinLength("att3") // It's null in two rows val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) + val values = getValuesDF(data, metric.fullColumn.get) + .map(r => if (r == null) null else r.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, null, 1.0, null, 1.0) } "return row-level results for null columns with NullBehavior fail option" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) // It's null in two rows @@ -60,13 +67,11 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new") - ) shouldBe Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MinValue, 1.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MinValue, 1.0) } "return row-level results for null columns with NullBehavior empty option" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) // It's null in two rows @@ -74,99 +79,19 @@ class MinLengthTest extends AnyWordSpec with Matchers with SparkContextSpec with val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(1.0, 1.0, 0.0, 1.0, 0.0, 1.0) } "return row-level results for blank strings" in withSparkSession { session => - val data = getEmptyColumnDataDf(session) val addressLength = MinLength("att1") // It's empty strings val state: Option[MinState] = addressLength.computeStateFrom(data) val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Double]("new")) shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) - } - - "return row-level results with NullBehavior fail and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Fail))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, Double.MinValue, 1.0, Double.MaxValue, Double.MaxValue) - } - - "return row-level results with NullBehavior fail and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Fail, FilteredRowOutcome.NULL))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, Double.MinValue, 1.0, null, null) - } - - "return row-level results with NullBehavior empty and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.EmptyString))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, 0.0, 1.0, Double.MaxValue, Double.MaxValue) - } - - "return row-level results with NullBehavior empty and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.EmptyString, FilteredRowOutcome.NULL))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, 0.0, 1.0, null, null) - } - - "return row-level results NullBehavior ignore and filtered as true" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), Option(AnalyzerOptions(NullBehavior.Ignore))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe - Seq(1.0, 1.0, null, 1.0, Double.MaxValue, Double.MaxValue) - } - - "return row-level results with NullBehavior ignore and filtered as null" in withSparkSession { session => - - val data = getEmptyColumnDataDf(session) - - val addressLength = MinLength("att3", Option("id < 4"), - Option(AnalyzerOptions(NullBehavior.Ignore, FilteredRowOutcome.NULL))) - val state: Option[MinState] = addressLength.computeStateFrom(data) - val metric: DoubleMetric with FullColumn = addressLength.computeMetricFrom(state) - - data.withColumn("new", metric.fullColumn.get) - .collect().map(_.getAs[Any]("new")) shouldBe Seq(1.0, 1.0, null, 1.0, null, null) + val values = getValuesDF(data, metric.fullColumn.get).map(_.getAs[Double](tempColName)) + values shouldBe Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) } } } diff --git a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala index 5c56ed4b0..3a0866d2f 100644 --- a/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala +++ b/src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala @@ -439,11 +439,11 @@ trait FixtureSupport { import sparkSession.implicits._ Seq( - ("Acme", "90210", "CA", "Los Angeles"), - ("Acme", "90211", "CA", "Los Angeles"), - ("Robocorp", null, "NJ", null), - ("Robocorp", null, "NY", "New York") - ).toDF("Company", "ZipCode", "State", "City") + (1, "Acme", "90210", "CA", "Los Angeles"), + (2, "Acme", "90211", "CA", "Los Angeles"), + (3, "Robocorp", null, "NJ", null), + (4, "Robocorp", null, "NY", "New York") + ).toDF("ID", "Company", "ZipCode", "State", "City") } def getDfCompleteAndInCompleteColumnsWithPeriod(sparkSession: SparkSession): DataFrame = { From c2e862f08817d8f30f060554d3cfaf14ae5ddab4 Mon Sep 17 00:00:00 2001 From: Hubert Date: Fri, 1 Nov 2024 19:24:09 -0400 Subject: [PATCH 11/24] fix merge conflicts --- .../seasonal/HoltWinters.scala | 43 ++++++------ .../seasonal/HoltWintersTest.scala | 67 +++++++++++++++++++ 2 files changed, 89 insertions(+), 21 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala b/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala index ec7db67e4..3d837235a 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala @@ -25,11 +25,11 @@ import collection.mutable.ListBuffer object HoltWinters { object SeriesSeasonality extends Enumeration { - val Weekly, Yearly: Value = Value + val Daily, Weekly, Yearly: Value = Value } object MetricInterval extends Enumeration { - val Daily, Monthly: Value = Value + val Hourly, Daily, Monthly: Value = Value } private[seasonal] case class ModelResults( @@ -48,29 +48,30 @@ object HoltWinters { } -/** - * Detects anomalies based on additive Holt-Winters model. The methods has two - * parameters, one for the metric frequency, as in how often the metric of interest - * is computed (e.g. daily) and one for the expected metric seasonality which - * defines the longest cycle in series. This quantity is also referred to as periodicity. - * - * For example, if a metric is produced daily and repeats itself every Monday, then the - * model should be created with a Daily metric interval and a Weekly seasonality parameter. - * - * @param metricsInterval: How often a metric is available - * @param seasonality: Cycle length (or periodicity) of the metric - */ -class HoltWinters( - metricsInterval: HoltWinters.MetricInterval.Value, - seasonality: HoltWinters.SeriesSeasonality.Value) +class HoltWinters(seriesPeriodicity: Int) extends AnomalyDetectionStrategy with AnomalyDetectionStrategyWithExtendedResults { import HoltWinters._ - private val seriesPeriodicity = seasonality -> metricsInterval match { - case (SeriesSeasonality.Weekly, MetricInterval.Daily) => 7 - case (SeriesSeasonality.Yearly, MetricInterval.Monthly) => 12 - } + /** + * Detects anomalies based on additive Holt-Winters model. The methods has two + * parameters, one for the metric frequency, as in how often the metric of interest + * is computed (e.g. daily) and one for the expected metric seasonality which + * defines the longest cycle in series. This quantity is also referred to as periodicity. + * + * For example, if a metric is produced daily and repeats itself every Monday, then the + * model should be created with a Daily metric interval and a Weekly seasonality parameter. + * + * @param metricsInterval : How often a metric is available + * @param seasonality : Cycle length (or periodicity) of the metric + */ + def this(metricsInterval: HoltWinters.MetricInterval.Value, + seasonality: HoltWinters.SeriesSeasonality.Value) = + this(seasonality -> metricsInterval match { + case (HoltWinters.SeriesSeasonality.Daily, HoltWinters.MetricInterval.Hourly) => 24 + case (HoltWinters.SeriesSeasonality.Weekly, HoltWinters.MetricInterval.Daily) => 7 + case (HoltWinters.SeriesSeasonality.Yearly, HoltWinters.MetricInterval.Monthly) => 12 + }) /** * Triple exponential smoothing with additive trend and seasonality diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWintersTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWintersTest.scala index 8d8140366..af1854e07 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWintersTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWintersTest.scala @@ -206,8 +206,75 @@ class HoltWintersTest extends AnyWordSpec with Matchers { anomalies should have size 3 } + + "work on hourly data with daily seasonality" in { + // https://www.kaggle.com/datasets/fedesoriano/traffic-prediction-dataset + val hourlyTrafficData = Vector[Double]( + 15, 13, 10, 7, 9, 6, 9, 8, 11, 12, 15, 17, 16, 15, 16, 12, 12, 16, 17, 20, 17, 19, 20, 15, + 14, 12, 14, 12, 12, 11, 13, 14, 12, 22, 32, 31, 35, 26, 34, 30, 27, 27, 24, 26, 29, 32, 30, 27, + 21, 18, 19, 13, 11, 11, 11, 14, 15, 29, 33, 32, 32, 29, 27, 26, 28, 26, 25, 29, 26, 24, 25, 20, + 18, 18, 13, 13, 10, 12, 13, 11, 13, 22, 26, 27, 31, 24, 23, 26, 26, 24, 23, 25, 26, 24, 26, 24, + 19, 20, 18, 13, 13, 9, 12, 12, 15, 16, 23, 24, 25, 24, 26, 22, 20, 20, 22, 26, 22, 21, 21, 21, + 16, 18, 19, 14, 12, 13, 14, 14, 13, 20, 22, 26, 26, 21, 23, 23, 19, 19, 20, 24, 18, 19, 16, 17, + 16, 16, 10, 9, 8, 7, 9, 8, 12, 13, 17, 14, 14, 14, 14, 11, 15, 13, 12, 17, 18, 17, 16, 15, 13 + ) + + val strategy = new HoltWinters( + HoltWinters.MetricInterval.Hourly, + HoltWinters.SeriesSeasonality.Daily) + + val nDaysTrain = 6 + val nDaysTest = 1 + val trainSize = nDaysTrain * 24 + val testSize = nDaysTest * 24 + val nTotal = trainSize + testSize + + val anomalies = strategy.detect( + hourlyTrafficData.take(nTotal), + trainSize -> nTotal + ) + + anomalies should have size 2 + } + + "work on monthly data with yearly seasonality using custom seriesPeriodicity" in { + // https://datamarket.com/data/set/22ox/monthly-milk-production-pounds-per-cow-jan-62-dec-75 + val monthlyMilkProduction = Vector[Double]( + 589, 561, 640, 656, 727, 697, 640, 599, 568, 577, 553, 582, + 600, 566, 653, 673, 742, 716, 660, 617, 583, 587, 565, 598, + 628, 618, 688, 705, 770, 736, 678, 639, 604, 611, 594, 634, + 658, 622, 709, 722, 782, 756, 702, 653, 615, 621, 602, 635, + 677, 635, 736, 755, 811, 798, 735, 697, 661, 667, 645, 688, + 713, 667, 762, 784, 837, 817, 767, 722, 681, 687, 660, 698, + 717, 696, 775, 796, 858, 826, 783, 740, 701, 706, 677, 711, + 734, 690, 785, 805, 871, 845, 801, 764, 725, 723, 690, 734, + 750, 707, 807, 824, 886, 859, 819, 783, 740, 747, 711, 751, + 804, 756, 860, 878, 942, 913, 869, 834, 790, 800, 763, 800, + 826, 799, 890, 900, 961, 935, 894, 855, 809, 810, 766, 805, + 821, 773, 883, 898, 957, 924, 881, 837, 784, 791, 760, 802, + 828, 778, 889, 902, 969, 947, 908, 867, 815, 812, 773, 813, + 834, 782, 892, 903, 966, 937, 896, 858, 817, 827, 797, 843 + ) + + val strategy = new HoltWinters(12) + + val nYearsTrain = 3 + val nYearsTest = 1 + val trainSize = nYearsTrain * 12 + val testSize = nYearsTest * 12 + val nTotal = trainSize + testSize + + val anomalies = strategy.detect( + monthlyMilkProduction.take(nTotal), + trainSize -> nTotal + ) + + anomalies should have size 7 + } + } + "Additive Holt-Winters with Extended Results" should { val twoWeeksOfData = setupData() From 6538ef376e1881e0f6b4b6f42c901cc3bdac1ad6 Mon Sep 17 00:00:00 2001 From: rdsharma26 <65777064+rdsharma26@users.noreply.github.com> Date: Wed, 3 Apr 2024 10:51:25 -0400 Subject: [PATCH 12/24] Fix for satisfies row level results bug (#553) - The satisfies constraint was incorrectly using the provided assertion to evaluate the row level outcomes. The assertion should only be used to evaluate the final outcome. - As part of this change, we have updated the row level results to return a true/false. The cast to an integer happens as part of the aggregation result. - Added a test to verify the row level results using checks made up of different assertions. --- .../amazon/deequ/analyzers/Compliance.scala | 15 ++-- .../amazon/deequ/constraints/Constraint.scala | 6 +- .../amazon/deequ/VerificationSuiteTest.scala | 70 +++++++++++++++++++ .../deequ/analyzers/ComplianceTest.scala | 30 +++----- 4 files changed, 87 insertions(+), 34 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala b/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala index 0edf01970..247a02c14 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.functions._ import Analyzers._ import com.amazon.deequ.analyzers.Preconditions.hasColumn import com.google.common.annotations.VisibleForTesting -import org.apache.spark.sql.types.DoubleType /** * Compliance is a measure of the fraction of rows that complies with the given column constraint. @@ -43,29 +42,23 @@ case class Compliance(instance: String, where: Option[String] = None, columns: List[String] = List.empty[String], analyzerOptions: Option[AnalyzerOptions] = None) - extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance) - with FilterableAnalyzer { + extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance) with FilterableAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { - ifNoNullsIn(result, offset, howMany = 2) { _ => NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults)) } } override def aggregationFunctions(): Seq[Column] = { - - val summation = sum(criterion) - + val summation = sum(criterion.cast(IntegerType)) summation :: conditionalCount(where) :: Nil } override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = { - conditionalSelection(expr(predicate), where).cast(IntegerType) - } + private def criterion: Column = conditionalSelection(expr(predicate), where) private def rowLevelResults: Column = { val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) @@ -73,7 +66,7 @@ case class Compliance(instance: String, filteredRowOutcome match { case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true).cast(IntegerType) + conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true) case _ => // The default behavior when using filtering for rows is to treat them as nulls. No special treatment needed. criterion diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index 79695c878..c0523474c 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -394,12 +394,10 @@ object Constraint { val constraint = AnalysisBasedConstraint[NumMatchesAndCount, Double, Double]( compliance, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) - new RowLevelAssertedConstraint( + new RowLevelConstraint( constraint, s"ComplianceConstraint($compliance)", - s"ColumnsCompliance-${compliance.predicate}", - sparkAssertion) + s"ColumnsCompliance-${compliance.predicate}") } /** diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index bdcddde71..2e618bae2 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -26,6 +26,7 @@ import com.amazon.deequ.constraints.Constraint import com.amazon.deequ.io.DfsUtils import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.Entity +import com.amazon.deequ.metrics.Metric import com.amazon.deequ.repository.MetricsRepository import com.amazon.deequ.repository.ResultKey import com.amazon.deequ.repository.memory.InMemoryMetricsRepository @@ -2125,6 +2126,75 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } } + "Verification Suite's Row Level Results" should { + "yield correct results for satisfies check" in withSparkSession { sparkSession => + import sparkSession.implicits._ + val df = Seq( + (1, "blue"), + (2, "green"), + (3, "blue"), + (4, "red"), + (5, "purple") + ).toDF("id", "color") + + val columnCondition = "color in ('blue')" + val whereClause = "id <= 3" + + case class CheckConfig(checkName: String, + assertion: Double => Boolean, + checkStatus: CheckStatus.Value, + whereClause: Option[String] = None) + + val success = CheckStatus.Success + val error = CheckStatus.Error + + val checkConfigs = Seq( + // Without where clause: Expected compliance metric for full dataset for given condition is 0.4 + CheckConfig("check with >", (d: Double) => d > 0.5, error), + CheckConfig("check with >=", (d: Double) => d >= 0.35, success), + CheckConfig("check with <", (d: Double) => d < 0.3, error), + CheckConfig("check with <=", (d: Double) => d <= 0.4, success), + CheckConfig("check with =", (d: Double) => d == 0.4, success), + CheckConfig("check with > / <", (d: Double) => d > 0.0 && d < 0.5, success), + CheckConfig("check with >= / <=", (d: Double) => d >= 0.41 && d <= 1.1, error), + + // With where Clause: Expected compliance metric for full dataset for given condition with where clause is 0.67 + CheckConfig("check w/ where and with >", (d: Double) => d > 0.7, error, Some(whereClause)), + CheckConfig("check w/ where and with >=", (d: Double) => d >= 0.66, success, Some(whereClause)), + CheckConfig("check w/ where and with <", (d: Double) => d < 0.6, error, Some(whereClause)), + CheckConfig("check w/ where and with <=", (d: Double) => d <= 0.67, success, Some(whereClause)), + CheckConfig("check w/ where and with =", (d: Double) => d == 0.66, error, Some(whereClause)), + CheckConfig("check w/ where and with > / <", (d: Double) => d > 0.0 && d < 0.5, error, Some(whereClause)), + CheckConfig("check w/ where and with >= / <=", (d: Double) => d >= 0.41 && d <= 1.1, success, Some(whereClause)) + ) + + val checks = checkConfigs.map { checkConfig => + val constraintName = s"Constraint for check: ${checkConfig.checkName}" + val check = Check(CheckLevel.Error, checkConfig.checkName) + .satisfies(columnCondition, constraintName, checkConfig.assertion) + checkConfig.whereClause.map(check.where).getOrElse(check) + } + + val verificationResult = VerificationSuite().onData(df).addChecks(checks).run() + val actualResults = verificationResult.checkResults.map { case (c, r) => c.description -> r.status } + val expectedResults = checkConfigs.map { c => c.checkName -> c.checkStatus}.toMap + assert(actualResults == expectedResults) + + verificationResult.metrics.values.foreach { metric => + val metricValue = metric.asInstanceOf[Metric[Double]].value.toOption.getOrElse(0.0) + if (metric.instance.contains("where")) assert(math.abs(metricValue - 0.66) < 0.1) + else assert(metricValue == 0.4) + } + + val rowLevelResults = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + checkConfigs.foreach { checkConfig => + val results = rowLevelResults.select(checkConfig.checkName).collect().map { r => r.getAs[Boolean](0)}.toSeq + if (checkConfig.whereClause.isDefined) assert(results == Seq(true, false, true, true, true)) + else assert(results == Seq(true, false, true, false, false)) + } + } + } + /** Run anomaly detection using a repository with some previous analysis results for testing */ private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = { diff --git a/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala b/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala index 5aa4033ba..54fc225f3 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala @@ -14,7 +14,6 @@ * */ - package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec @@ -25,22 +24,19 @@ import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { - "Compliance" should { "return row-level results for columns" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Compliance = Compliance("rule1", "att1 > 3", columns = List("att1")) val state = att1Compliance.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Int]("new") - ) shouldBe Seq(0, 0, 0, 1, 1, 1) + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new") + ) shouldBe Seq(false, false, false, true, true, true) } "return row-level results for null columns" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Compliance = Compliance("rule1", "attNull > 3", columns = List("att1")) @@ -48,11 +44,10 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 1, 1, 1) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(null, null, null, true, true, true) } "return row-level results filtered with null" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"), @@ -61,11 +56,10 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 0, 1, 1) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(null, null, null, false, true, true) } "return row-level results filtered with true" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"), @@ -74,7 +68,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(1, 1, 1, 0, 1, 1) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(true, true, true, false, true, true) } "return row-level results for compliance in bounds" in withSparkSession { session => @@ -93,7 +87,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 0) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, true, true, false) } "return row-level results for compliance in bounds filtered as null" in withSparkSession { session => @@ -114,7 +108,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, null, null, null) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, null, null, null) } "return row-level results for compliance in bounds filtered as true" in withSparkSession { session => @@ -135,7 +129,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 1) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, true, true, true) } "return row-level results for compliance in array" in withSparkSession { session => @@ -157,7 +151,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 0) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, true, false) } "return row-level results for compliance in array filtered as null" in withSparkSession { session => @@ -180,7 +174,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, null, null) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, null, null) } "return row-level results for compliance in array filtered as true" in withSparkSession { session => @@ -196,7 +190,6 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithNumericValues(session) - val att1Compliance = Compliance(s"$column contained in ${allowedValues.mkString(",")}", predicate, where = Option("att1 < 5"), columns = List("att3"), analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) @@ -204,8 +197,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 1) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, true, true) } } - } From b69f2b865156b53ced9e7044b730f5b9a74e9040 Mon Sep 17 00:00:00 2001 From: scott-gunn <164559105+scott-gunn@users.noreply.github.com> Date: Thu, 11 Apr 2024 12:40:04 -0400 Subject: [PATCH 13/24] New analyzer, RatioOfSums (#552) * Added RatioOfSums analyzer and tests * Unit test for divide by zero and code cleanup. * More detailed Scaladoc * Fixed docs to include Double.NegativeInfinity * Add copyright to new file --- .../amazon/deequ/analyzers/RatioOfSums.scala | 92 +++++++++++++++++++ .../repository/AnalysisResultSerde.scala | 12 +++ .../deequ/analyzers/AnalyzerTests.scala | 17 ++++ .../repository/AnalysisResultSerdeTest.scala | 2 + 4 files changed, 123 insertions(+) create mode 100644 src/main/scala/com/amazon/deequ/analyzers/RatioOfSums.scala diff --git a/src/main/scala/com/amazon/deequ/analyzers/RatioOfSums.scala b/src/main/scala/com/amazon/deequ/analyzers/RatioOfSums.scala new file mode 100644 index 000000000..593d358d4 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/RatioOfSums.scala @@ -0,0 +1,92 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.analyzers + +import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} +import com.amazon.deequ.metrics.Entity +import org.apache.spark.sql.DeequFunctions.stateful_corr +import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType +import Analyzers._ + +import com.amazon.deequ.metrics.Entity +import com.amazon.deequ.repository.AnalysisResultSerde + +case class RatioOfSumsState( + numerator: Double, + denominator: Double +) extends DoubleValuedState[RatioOfSumsState] { + + override def sum(other: RatioOfSumsState): RatioOfSumsState = { + RatioOfSumsState(numerator + other.numerator, denominator + other.denominator) + } + + override def metricValue(): Double = { + numerator / denominator + } +} + +/** Sums up 2 columns and then divides the final values as a Double. The columns + * can contain a mix of positive and negative numbers. Dividing by zero is allowed + * and will result in a value of Double.PositiveInfinity or Double.NegativeInfinity. + * + * @param numerator + * First input column for computation + * @param denominator + * Second input column for computation + */ +case class RatioOfSums( + numerator: String, + denominator: String, + where: Option[String] = None +) extends StandardScanShareableAnalyzer[RatioOfSumsState]( + "RatioOfSums", + s"$numerator,$denominator", + Entity.Multicolumn + ) + with FilterableAnalyzer { + + override def aggregationFunctions(): Seq[Column] = { + val firstSelection = conditionalSelection(numerator, where) + val secondSelection = conditionalSelection(denominator, where) + sum(firstSelection).cast(DoubleType) :: sum(secondSelection).cast(DoubleType) :: Nil + } + + override def fromAggregationResult( + result: Row, + offset: Int + ): Option[RatioOfSumsState] = { + if (result.isNullAt(offset)) { + None + } else { + Some( + RatioOfSumsState( + result.getDouble(0), + result.getDouble(1) + ) + ) + } + } + + override protected def additionalPreconditions(): Seq[StructType => Unit] = { + hasColumn(numerator) :: isNumeric(numerator) :: hasColumn(denominator) :: isNumeric(denominator) :: Nil + } + + override def filterCondition: Option[String] = where +} diff --git a/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala b/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala index e9bb4f7df..eb0db5361 100644 --- a/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala +++ b/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala @@ -256,6 +256,12 @@ private[deequ] object AnalyzerSerializer result.addProperty(COLUMN_FIELD, sum.column) result.addProperty(WHERE_FIELD, sum.where.orNull) + case ratioOfSums: RatioOfSums => + result.addProperty(ANALYZER_NAME_FIELD, "RatioOfSums") + result.addProperty("numerator", ratioOfSums.numerator) + result.addProperty("denominator", ratioOfSums.denominator) + result.addProperty(WHERE_FIELD, ratioOfSums.where.orNull) + case mean: Mean => result.addProperty(ANALYZER_NAME_FIELD, "Mean") result.addProperty(COLUMN_FIELD, mean.column) @@ -412,6 +418,12 @@ private[deequ] object AnalyzerDeserializer json.get(COLUMN_FIELD).getAsString, getOptionalWhereParam(json)) + case "RatioOfSums" => + RatioOfSums( + json.get("numerator").getAsString, + json.get("denominator").getAsString, + getOptionalWhereParam(json)) + case "Mean" => Mean( json.get(COLUMN_FIELD).getAsString, diff --git a/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala index 1c0b28d1a..be5bdc5a6 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala @@ -847,6 +847,23 @@ class AnalyzerTests extends AnyWordSpec with Matchers with SparkContextSpec with analyzer.calculate(df).value shouldBe Success(2.0 / 8.0) assert(analyzer.calculate(df).fullColumn.isDefined) } + + "compute ratio of sums correctly for numeric data" in withSparkSession { sparkSession => + val df = getDfWithNumericValues(sparkSession) + RatioOfSums("att1", "att2").calculate(df).value shouldBe Success(21.0 / 18.0) + } + + "fail to compute ratio of sums for non numeric type" in withSparkSession { sparkSession => + val df = getDfFull(sparkSession) + assert(RatioOfSums("att1", "att2").calculate(df).value.isFailure) + } + + "divide by zero" in withSparkSession { sparkSession => + val df = getDfWithNumericValues(sparkSession) + val testVal = RatioOfSums("att1", "att2", Some("item IN ('1', '2')")).calculate(df) + assert(testVal.value.isSuccess) + assert(testVal.value.toOption.get.isInfinite) + } } } diff --git a/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala b/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala index 05f4d47bd..1000ff8e6 100644 --- a/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala @@ -76,6 +76,8 @@ class AnalysisResultSerdeTest extends FlatSpec with Matchers { DoubleMetric(Entity.Column, "Completeness", "ColumnA", Success(5.0)), Sum("ColumnA") -> DoubleMetric(Entity.Column, "Completeness", "ColumnA", Success(5.0)), + RatioOfSums("ColumnA", "ColumnB") -> + DoubleMetric(Entity.Column, "RatioOfSums", "ColumnA", Success(5.0)), StandardDeviation("ColumnA") -> DoubleMetric(Entity.Column, "Completeness", "ColumnA", Success(5.0)), DataType("ColumnA") -> From efd33f0600852c47b1b20ec2d37448e82c2e23c7 Mon Sep 17 00:00:00 2001 From: Yannis Mentekidis Date: Mon, 15 Apr 2024 14:19:23 -0400 Subject: [PATCH 14/24] Column Count Analyzer and Check (#555) * Fix flaky KLL test * Move CustomSql state to CustomSql analyzer * Implement new Analyzer to count columns * Improve documentation, remove unused parameter, replace if/else with map --------- Co-authored-by: Yannis Mentekidis --- .../amazon/deequ/analyzers/ColumnCount.scala | 63 +++++++++++++++++++ .../amazon/deequ/analyzers/CustomSql.scala | 11 ++++ .../com/amazon/deequ/analyzers/Size.scala | 11 ---- .../scala/com/amazon/deequ/checks/Check.scala | 7 +++ .../amazon/deequ/constraints/Constraint.scala | 12 ++++ .../amazon/deequ/KLL/KLLDistanceTest.scala | 5 +- .../amazon/deequ/VerificationSuiteTest.scala | 1 + .../deequ/analyzers/ColumnCountTest.scala | 45 +++++++++++++ 8 files changed, 142 insertions(+), 13 deletions(-) create mode 100644 src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala create mode 100644 src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala diff --git a/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala b/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala new file mode 100644 index 000000000..9eff89b6d --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala @@ -0,0 +1,63 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + * + */ + +package com.amazon.deequ.analyzers + +import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.metrics.Entity +import org.apache.spark.sql.DataFrame + +case class ColumnCount() extends Analyzer[NumMatches, DoubleMetric] { + + val name = "ColumnCount" + val instance = "*" + val entity = Entity.Dataset + + /** + * Compute the state (sufficient statistics) from the data + * + * @param data the input dataframe + * @return the number of columns in the input + */ + override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[NumMatches] = { + if (filterCondition.isDefined) { + throw new IllegalArgumentException("ColumnCount does not accept a filter condition") + } else { + val numColumns = data.columns.size + Some(NumMatches(numColumns)) + } + } + + /** + * Compute the metric from the state (sufficient statistics) + * + * @param state the computed state from [[computeStateFrom]] + * @return a double metric indicating the number of columns for this analyzer + */ + override def computeMetricFrom(state: Option[NumMatches]): DoubleMetric = { + state + .map(v => Analyzers.metricFromValue(v.metricValue(), name, instance, entity)) + .getOrElse(Analyzers.metricFromEmpty(this, name, instance, entity)) + } + + /** + * Compute the metric from a failure - reports the exception thrown while trying to count columns + */ + override private[deequ] def toFailureMetric(failure: Exception): DoubleMetric = { + Analyzers.metricFromFailure(failure, name, instance, entity) + } +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala index e07e2d11f..edd4f8e97 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -26,6 +26,17 @@ import scala.util.Failure import scala.util.Success import scala.util.Try +case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] { + lazy val state = stateOrError.left.get + lazy val error = stateOrError.right.get + + override def sum(other: CustomSqlState): CustomSqlState = { + CustomSqlState(Left(state + other.state)) + } + + override def metricValue(): Double = state +} + case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] { /** * Compute the state (sufficient statistics) from the data diff --git a/src/main/scala/com/amazon/deequ/analyzers/Size.scala b/src/main/scala/com/amazon/deequ/analyzers/Size.scala index a5080084a..c56083abe 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Size.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Size.scala @@ -20,17 +20,6 @@ import com.amazon.deequ.metrics.Entity import org.apache.spark.sql.{Column, Row} import Analyzers._ -case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] { - lazy val state = stateOrError.left.get - lazy val error = stateOrError.right.get - - override def sum(other: CustomSqlState): CustomSqlState = { - CustomSqlState(Left(state + other.state)) - } - - override def metricValue(): Double = state -} - case class NumMatches(numMatches: Long) extends DoubleValuedState[NumMatches] { override def sum(other: NumMatches): NumMatches = { diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index c76187e7d..8d4ffa1fb 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -124,6 +124,13 @@ case class Check( addFilterableConstraint { filter => Constraint.sizeConstraint(assertion, filter, hint) } } + def hasColumnCount(assertion: Long => Boolean, hint: Option[String] = None) + : CheckWithLastConstraintFilterable = { + addFilterableConstraint { + filter => Constraint.columnCountConstraint(assertion, hint) + } + } + /** * Creates a constraint that asserts on a column completion. * diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index c0523474c..413e384ca 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -141,6 +141,18 @@ object Constraint { new NamedConstraint(constraint, s"SizeConstraint($size)") } + def columnCountConstraint(assertion: Long => Boolean, hint: Option[String] = None): Constraint = { + val colCount = ColumnCount() + fromAnalyzer(colCount, assertion, hint) + } + + + def fromAnalyzer(colCount: ColumnCount, assertion: Long => Boolean, hint: Option[String]): Constraint = { + val constraint = AnalysisBasedConstraint[NumMatches, Double, Long](colCount, assertion, Some(_.toLong), hint) + + new NamedConstraint(constraint, name = s"ColumnCountConstraint($colCount)") + } + /** * Runs Histogram analysis on the given column and executes the assertion * diff --git a/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala b/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala index 20017fa71..728ce866c 100644 --- a/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala +++ b/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala @@ -22,7 +22,8 @@ import com.amazon.deequ.analyzers.{Distance, QuantileNonSample} import com.amazon.deequ.metrics.BucketValue import com.amazon.deequ.utils.FixtureSupport import org.scalatest.WordSpec -import com.amazon.deequ.metrics.{BucketValue} +import com.amazon.deequ.metrics.BucketValue +import org.scalactic.Tolerance.convertNumericToPlusOrMinusWrapper class KLLDistanceTest extends WordSpec with SparkContextSpec with FixtureSupport{ @@ -88,7 +89,7 @@ class KLLDistanceTest extends WordSpec with SparkContextSpec val sample2 = scala.collection.mutable.Map( "a" -> 22L, "b" -> 20L, "c" -> 25L, "d" -> 12L, "e" -> 13L, "f" -> 15L) val distance = Distance.categoricalDistance(sample1, sample2, method = LInfinityMethod(alpha = Some(0.003))) - assert(distance == 0.2726338046550349) + assert(distance === 0.2726338046550349 +- 1E-14) } "Categorial distance should compute correct linf_robust with different alpha value .1" in { diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 2e618bae2..f410ea821 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -61,6 +61,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val checkToSucceed = Check(CheckLevel.Error, "group-1") .isComplete("att1") + .hasColumnCount(_ == 3) .hasCompleteness("att1", _ == 1.0) val checkToErrorOut = Check(CheckLevel.Error, "group-2-E") diff --git a/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala b/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala new file mode 100644 index 000000000..00df2758c --- /dev/null +++ b/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala @@ -0,0 +1,45 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + * + */ + +package com.amazon.deequ.analyzers + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.util.Failure +import scala.util.Success + +class ColumnCountTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + "ColumnCount" should { + "return column count for a dataset" in withSparkSession { session => + val data = getDfWithStringColumns(session) + val colCount = ColumnCount() + + val state = colCount.computeStateFrom(data) + state.isDefined shouldBe true + state.get.metricValue() shouldBe 5.0 + + val metric = colCount.computeMetricFrom(state) + metric.fullColumn shouldBe None + metric.value shouldBe Success(5.0) + } + } +} From a4a8aa6d7b1f7cbaa2ac8e4af3adb8e9d6e27a0a Mon Sep 17 00:00:00 2001 From: zeotuan <48720253+zeotuan@users.noreply.github.com> Date: Thu, 18 Apr 2024 00:04:00 +1000 Subject: [PATCH 15/24] Update breeze to match spark 3.5 breeze version (#545) --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index ccf2acef5..0652f5008 100644 --- a/pom.xml +++ b/pom.xml @@ -103,7 +103,7 @@ org.scalanlp breeze_${scala.major.version} - 0.13.2 + 2.1.0 From 572d776b85c9044411d1c46bbb7038f0370ad4fa Mon Sep 17 00:00:00 2001 From: zeotuan <48720253+zeotuan@users.noreply.github.com> Date: Tue, 7 May 2024 04:26:51 +1000 Subject: [PATCH 16/24] Configurable RetainCompletenessRule (#564) * Configurable RetainCompletenessRule * Add doc string * Add default completeness const --- .../rules/RetainCompletenessRule.scala | 17 ++++++++++--- .../rules/ConstraintRulesTest.scala | 25 +++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala index 67ae61f92..9f995a112 100644 --- a/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala @@ -20,17 +20,23 @@ import com.amazon.deequ.constraints.Constraint.completenessConstraint import com.amazon.deequ.profiles.ColumnProfile import com.amazon.deequ.suggestions.CommonConstraintSuggestion import com.amazon.deequ.suggestions.ConstraintSuggestion +import com.amazon.deequ.suggestions.rules.RetainCompletenessRule._ import scala.math.BigDecimal.RoundingMode /** * If a column is incomplete in the sample, we model its completeness as a binomial variable, * estimate a confidence interval and use this to define a lower bound for the completeness + * + * @param minCompleteness : minimum completeness threshold to determine if rule should be applied + * @param maxCompleteness : maximum completeness threshold to determine if rule should be applied */ -case class RetainCompletenessRule() extends ConstraintRule[ColumnProfile] { - +case class RetainCompletenessRule( + minCompleteness: Double = defaultMinCompleteness, + maxCompleteness: Double = defaultMaxCompleteness +) extends ConstraintRule[ColumnProfile] { override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = { - profile.completeness > 0.2 && profile.completeness < 1.0 + profile.completeness > minCompleteness && profile.completeness < maxCompleteness } override def candidate(profile: ColumnProfile, numRecords: Long): ConstraintSuggestion = { @@ -65,3 +71,8 @@ case class RetainCompletenessRule() extends ConstraintRule[ColumnProfile] { "we model its completeness as a binomial variable, estimate a confidence interval " + "and use this to define a lower bound for the completeness" } + +object RetainCompletenessRule { + private val defaultMinCompleteness: Double = 0.2 + private val defaultMaxCompleteness: Double = 1.0 +} diff --git a/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala b/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala index 075247932..701a5d983 100644 --- a/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala +++ b/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala @@ -130,9 +130,14 @@ class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContext "be applied correctly" in { val complete = StandardColumnProfile("col1", 1.0, 100, String, false, Map.empty, None) + val tenPercent = StandardColumnProfile("col1", 0.1, 100, String, false, Map.empty, None) val incomplete = StandardColumnProfile("col1", .25, 100, String, false, Map.empty, None) assert(!RetainCompletenessRule().shouldBeApplied(complete, 1000)) + assert(!RetainCompletenessRule(0.05, 0.9).shouldBeApplied(complete, 1000)) + assert(RetainCompletenessRule(0.05, 0.9).shouldBeApplied(tenPercent, 1000)) + assert(RetainCompletenessRule(0.0).shouldBeApplied(tenPercent, 1000)) + assert(RetainCompletenessRule(0.0).shouldBeApplied(incomplete, 1000)) assert(RetainCompletenessRule().shouldBeApplied(incomplete, 1000)) } @@ -183,6 +188,26 @@ class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContext assert(metricResult.value.isSuccess) } + + "return evaluable constraint candidates with custom min/max completeness" in + withSparkSession { session => + + val dfWithColumnCandidate = getDfFull(session) + + val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) + + val check = Check(CheckLevel.Warning, "some") + .addConstraint(RetainCompletenessRule(0.4, 0.6).candidate(fakeColumnProfile, 100).constraint) + + val verificationResult = VerificationSuite() + .onData(dfWithColumnCandidate) + .addCheck(check) + .run() + + val metricResult = verificationResult.metrics.head._2 + + assert(metricResult.value.isSuccess) + } } "UniqueIfApproximatelyUniqueRule" should { From dc9ba7e322920748d26a83e3fda347f56c8c4899 Mon Sep 17 00:00:00 2001 From: tylermcdaniel0 <144386264+tylermcdaniel0@users.noreply.github.com> Date: Fri, 24 May 2024 11:06:44 -0400 Subject: [PATCH 17/24] Optional specification of instance name in CustomSQL analyzer metric. (#569) Co-authored-by: Tyler Mcdaniel --- .../amazon/deequ/analyzers/CustomSql.scala | 14 +++++++++----- .../deequ/analyzers/CustomSqlTest.scala | 19 ++++++++++++++++++- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala index edd4f8e97..8e2e351b9 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -37,7 +37,7 @@ case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleVa override def metricValue(): Double = state } -case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] { +case class CustomSql(expression: String, disambiguator: String = "*") extends Analyzer[CustomSqlState, DoubleMetric] { /** * Compute the state (sufficient statistics) from the data * @@ -76,15 +76,19 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double state match { // The returned state may case Some(theState) => theState.stateOrError match { - case Left(value) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Success(value)) - case Right(error) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException(error))) + case Left(value) => DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Success(value)) + case Right(error) => DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Failure(new RuntimeException(error))) } case None => - DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run"))) + DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Failure(new RuntimeException("CustomSql Failed To Run"))) } } override private[deequ] def toFailureMetric(failure: Exception) = { - DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run"))) + DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Failure(new RuntimeException("CustomSql Failed To Run"))) } } diff --git a/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala index 7e6e96c30..e6e23c40b 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala @@ -5,7 +5,7 @@ * use this file except in compliance with the License. A copy of the License * is located at * - * http://aws.amazon.com/apache2.0/ + * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either @@ -17,6 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.metrics.Entity import com.amazon.deequ.utils.FixtureSupport import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -84,5 +85,21 @@ class CustomSqlTest extends AnyWordSpec with Matchers with SparkContextSpec with case Failure(exception) => exception.getMessage should include("foo") } } + + "apply metric disambiguation string to returned metric" in withSparkSession { session => + val data = getDfWithStringColumns(session) + data.createOrReplaceTempView("primary") + + val disambiguator = "statement1" + val sql = CustomSql("SELECT COUNT(*) FROM primary WHERE `Address Line 2` IS NOT NULL", disambiguator) + val state = sql.computeStateFrom(data) + val metric: DoubleMetric = sql.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + metric.value.get shouldBe 6.0 + metric.name shouldBe "CustomSQL" + metric.entity shouldBe Entity.Dataset + metric.instance shouldBe "statement1" + } } } From 2a02afe59d5addf509facd7904194ab6de5b7d96 Mon Sep 17 00:00:00 2001 From: zeotuan <48720253+zeotuan@users.noreply.github.com> Date: Sat, 25 May 2024 02:08:22 +1000 Subject: [PATCH 18/24] Adding Wilson Score Confidence Interval Strategy (#567) * Configurable RetainCompletenessRule * Add doc string * Add default completeness const * Add ConfidenceIntervalStrategy * Add Separate Wilson and Wald Interval Test * Add License information, Fix formatting * Add License information * formatting fix * Update documentation * Make WaldInterval the default strategy for now * Formatting import to per line * Separate group import to per line import --- .../ConstraintSuggestionExample.scala | 6 ++ .../examples/constraint_suggestion_example.md | 13 +++ .../FractionalCategoricalRangeRule.scala | 12 +-- .../rules/RetainCompletenessRule.scala | 19 ++-- .../interval/ConfidenceIntervalStrategy.scala | 55 +++++++++++ .../rules/interval/WaldIntervalStrategy.scala | 47 +++++++++ .../WilsonScoreIntervalStrategy.scala | 47 +++++++++ .../rules/ConstraintRulesTest.scala | 95 ++++++++++++------- .../rules/interval/IntervalStrategyTest.scala | 59 ++++++++++++ 9 files changed, 299 insertions(+), 54 deletions(-) create mode 100644 src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala create mode 100644 src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala create mode 100644 src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala create mode 100644 src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala diff --git a/src/main/scala/com/amazon/deequ/examples/ConstraintSuggestionExample.scala b/src/main/scala/com/amazon/deequ/examples/ConstraintSuggestionExample.scala index 8aa0fb6c5..fc8f458bf 100644 --- a/src/main/scala/com/amazon/deequ/examples/ConstraintSuggestionExample.scala +++ b/src/main/scala/com/amazon/deequ/examples/ConstraintSuggestionExample.scala @@ -17,6 +17,8 @@ package com.amazon.deequ.examples import com.amazon.deequ.examples.ExampleUtils.withSpark +import com.amazon.deequ.suggestions.rules.RetainCompletenessRule +import com.amazon.deequ.suggestions.rules.interval.WilsonScoreIntervalStrategy import com.amazon.deequ.suggestions.{ConstraintSuggestionRunner, Rules} private[examples] object ConstraintSuggestionExample extends App { @@ -51,6 +53,10 @@ private[examples] object ConstraintSuggestionExample extends App { val suggestionResult = ConstraintSuggestionRunner() .onData(data) .addConstraintRules(Rules.EXTENDED) + // We can also add our own constraint and customize constraint parameters + .addConstraintRule( + RetainCompletenessRule(intervalStrategy = WilsonScoreIntervalStrategy()) + ) .run() // We can now investigate the constraints that deequ suggested. We get a textual description diff --git a/src/main/scala/com/amazon/deequ/examples/constraint_suggestion_example.md b/src/main/scala/com/amazon/deequ/examples/constraint_suggestion_example.md index df159a9c9..472f63c7d 100644 --- a/src/main/scala/com/amazon/deequ/examples/constraint_suggestion_example.md +++ b/src/main/scala/com/amazon/deequ/examples/constraint_suggestion_example.md @@ -43,6 +43,17 @@ val suggestionResult = ConstraintSuggestionRunner() .run() ``` +Alternatively, we also support customizing and adding individual constraint rule using `addConstraintRule()` +```scala +val suggestionResult = ConstraintSuggestionRunner() + .onData(data) + + .addConstraintRule( + RetainCompletenessRule(intervalStrategy = WilsonScoreIntervalStrategy()) + ) + .run() +``` + We can now investigate the constraints that deequ suggested. We get a textual description and the corresponding scala code for each suggested constraint. Note that the constraint suggestion is based on heuristic rules and assumes that the data it is shown is 'static' and correct, which might often not be the case in the real world. Therefore the suggestions should always be manually reviewed before being applied in real deployments. ```scala suggestionResult.constraintSuggestions.foreach { case (column, suggestions) => @@ -92,3 +103,5 @@ The corresponding scala code is .isContainedIn("status", Array("DELAYED", "UNKNO Currently, we leave it up to the user to decide whether they want to apply the suggested constraints or not, and provide the corresponding Scala code for convenience. For larger datasets, it makes sense to evaluate the suggested constraints on some held-out portion of the data to see whether they hold or not. You can test this by adding an invocation of `.useTrainTestSplitWithTestsetRatio(0.1)` to the `ConstraintSuggestionRunner`. With this configuration, it would compute constraint suggestions on 90% of the data and evaluate the suggested constraints on the remaining 10%. Finally, we would also like to note that the constraint suggestion code provides access to the underlying [column profiles](https://github.com/awslabs/deequ/blob/master/src/main/scala/com/amazon/deequ/examples/data_profiling_example.md) that it computed via `suggestionResult.columnProfiles`. + +An [executable and extended version of this example](https://github.com/awslabs/deequ/blob/master/src/main/scala/com/amazon/deequ/examples/.scala) is part of our code base. diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala index 55e410f33..f9dd192e8 100644 --- a/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala @@ -23,16 +23,17 @@ import com.amazon.deequ.metrics.DistributionValue import com.amazon.deequ.profiles.ColumnProfile import com.amazon.deequ.suggestions.ConstraintSuggestion import com.amazon.deequ.suggestions.ConstraintSuggestionWithValue +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.defaultIntervalStrategy +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy import org.apache.commons.lang3.StringEscapeUtils -import scala.math.BigDecimal.RoundingMode - /** If we see a categorical range for most values in a column, we suggest an IS IN (...) * constraint that should hold for most values */ case class FractionalCategoricalRangeRule( targetDataCoverageFraction: Double = 0.9, categorySorter: Array[(String, DistributionValue)] => Array[(String, DistributionValue)] = - categories => categories.sortBy({ case (_, value) => value.absolute }).reverse + categories => categories.sortBy({ case (_, value) => value.absolute }).reverse, + intervalStrategy: ConfidenceIntervalStrategy = defaultIntervalStrategy ) extends ConstraintRule[ColumnProfile] { override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = { @@ -79,11 +80,8 @@ case class FractionalCategoricalRangeRule( val p = ratioSums val n = numRecords - val z = 1.96 - // TODO this needs to be more robust for p's close to 0 or 1 - val targetCompliance = BigDecimal(p - z * math.sqrt(p * (1 - p) / n)) - .setScale(2, RoundingMode.DOWN).toDouble + val targetCompliance = intervalStrategy.calculateTargetConfidenceInterval(p, n).lowerBound val description = s"'${profile.column}' has value range $categoriesSql for at least " + s"${targetCompliance * 100}% of values" diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala index 9f995a112..be5bd101f 100644 --- a/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala @@ -21,8 +21,8 @@ import com.amazon.deequ.profiles.ColumnProfile import com.amazon.deequ.suggestions.CommonConstraintSuggestion import com.amazon.deequ.suggestions.ConstraintSuggestion import com.amazon.deequ.suggestions.rules.RetainCompletenessRule._ - -import scala.math.BigDecimal.RoundingMode +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.defaultIntervalStrategy +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy /** * If a column is incomplete in the sample, we model its completeness as a binomial variable, @@ -33,21 +33,18 @@ import scala.math.BigDecimal.RoundingMode */ case class RetainCompletenessRule( minCompleteness: Double = defaultMinCompleteness, - maxCompleteness: Double = defaultMaxCompleteness + maxCompleteness: Double = defaultMaxCompleteness, + intervalStrategy: ConfidenceIntervalStrategy = defaultIntervalStrategy ) extends ConstraintRule[ColumnProfile] { override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = { profile.completeness > minCompleteness && profile.completeness < maxCompleteness } override def candidate(profile: ColumnProfile, numRecords: Long): ConstraintSuggestion = { - - val p = profile.completeness - val n = numRecords - val z = 1.96 - - // TODO this needs to be more robust for p's close to 0 or 1 - val targetCompleteness = BigDecimal(p - z * math.sqrt(p * (1 - p) / n)) - .setScale(2, RoundingMode.DOWN).toDouble + val targetCompleteness = intervalStrategy.calculateTargetConfidenceInterval( + profile.completeness, + numRecords + ).lowerBound val constraint = completenessConstraint(profile.column, _ >= targetCompleteness) diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala new file mode 100644 index 000000000..0c12e03a5 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala @@ -0,0 +1,55 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.suggestions.rules.interval + +import breeze.stats.distributions.{Gaussian, Rand} +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy._ + +/** + * Strategy for calculate confidence interval + * */ +trait ConfidenceIntervalStrategy { + + /** + * Generated confidence interval interval + * @param pHat sample of the population that share a trait + * @param numRecords overall number of records + * @param confidence confidence level of method used to estimate the interval. + * @return + */ + def calculateTargetConfidenceInterval( + pHat: Double, + numRecords: Long, + confidence: Double = defaultConfidence + ): ConfidenceInterval + + def validateInput(pHat: Double, confidence: Double): Unit = { + require(0.0 <= pHat && pHat <= 1.0, "pHat must be between 0.0 and 1.0") + require(0.0 <= confidence && confidence <= 1.0, "confidence must be between 0.0 and 1.0") + } + + def calculateZScore(confidence: Double): Double = Gaussian(0, 1)(Rand).inverseCdf(1 - ((1.0 - confidence)/ 2.0)) +} + +object ConfidenceIntervalStrategy { + val defaultConfidence = 0.95 + val defaultIntervalStrategy: ConfidenceIntervalStrategy = WaldIntervalStrategy() + + case class ConfidenceInterval(lowerBound: Double, upperBound: Double) +} + + diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala new file mode 100644 index 000000000..154d8ebfe --- /dev/null +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala @@ -0,0 +1,47 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.suggestions.rules.interval + +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.defaultConfidence + +import scala.math.BigDecimal.RoundingMode + +/** + * Implements the Wald Interval method for creating a binomial proportion confidence interval. Provided for backwards + * compatibility. using [[WaldIntervalStrategy]] for calculating confidence interval can be problematic when dealing + * with small sample sizes or proportions close to 0 or 1. It also have poorer coverage and might produce confidence + * limit outside the range of [0,1] + * @see + * Normal approximation interval (Wikipedia) + */ +@deprecated("WilsonScoreIntervalStrategy is recommended for calculating confidence interval") +case class WaldIntervalStrategy() extends ConfidenceIntervalStrategy { + def calculateTargetConfidenceInterval( + pHat: Double, + numRecords: Long, + confidence: Double = defaultConfidence + ): ConfidenceInterval = { + validateInput(pHat, confidence) + val successRatio = BigDecimal(pHat) + val marginOfError = BigDecimal(calculateZScore(confidence) * math.sqrt(pHat * (1 - pHat) / numRecords)) + val lowerBound = (successRatio - marginOfError).setScale(2, RoundingMode.DOWN).toDouble + val upperBound = (successRatio + marginOfError).setScale(2, RoundingMode.UP).toDouble + ConfidenceInterval(lowerBound, upperBound) + } +} diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala new file mode 100644 index 000000000..6e8371ea5 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala @@ -0,0 +1,47 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.suggestions.rules.interval + +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.defaultConfidence + +import scala.math.BigDecimal.RoundingMode + +/** + * Using Wilson score method for creating a binomial proportion confidence interval. + * + * @see + * Wilson score interval (Wikipedia) + */ +case class WilsonScoreIntervalStrategy() extends ConfidenceIntervalStrategy { + + def calculateTargetConfidenceInterval( + pHat: Double, numRecords: Long, + confidence: Double = defaultConfidence + ): ConfidenceInterval = { + validateInput(pHat, confidence) + val zScore = calculateZScore(confidence) + val zSquareOverN = math.pow(zScore, 2) / numRecords + val factor = 1.0 / (1 + zSquareOverN) + val adjustedSuccessRatio = pHat + zSquareOverN/2 + val marginOfError = zScore * math.sqrt(pHat * (1 - pHat)/numRecords + zSquareOverN/(4 * numRecords)) + val lowerBound = BigDecimal(factor * (adjustedSuccessRatio - marginOfError)).setScale(2, RoundingMode.DOWN).toDouble + val upperBound = BigDecimal(factor * (adjustedSuccessRatio + marginOfError)).setScale(2, RoundingMode.UP).toDouble + ConfidenceInterval(lowerBound, upperBound) + } +} diff --git a/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala b/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala index 701a5d983..7b56e3938 100644 --- a/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala +++ b/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala @@ -22,10 +22,14 @@ import com.amazon.deequ.checks.{Check, CheckLevel} import com.amazon.deequ.constraints.ConstrainableDataTypes import com.amazon.deequ.metrics.{Distribution, DistributionValue} import com.amazon.deequ.profiles._ +import com.amazon.deequ.suggestions.rules.interval.WaldIntervalStrategy +import com.amazon.deequ.suggestions.rules.interval.WilsonScoreIntervalStrategy import com.amazon.deequ.utils.FixtureSupport import com.amazon.deequ.{SparkContextSpec, VerificationSuite} import org.scalamock.scalatest.MockFactory +import org.scalatest.Inspectors.forAll import org.scalatest.WordSpec +import org.scalatest.prop.Tables.Table class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContextSpec with MockFactory{ @@ -132,6 +136,7 @@ class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContext val complete = StandardColumnProfile("col1", 1.0, 100, String, false, Map.empty, None) val tenPercent = StandardColumnProfile("col1", 0.1, 100, String, false, Map.empty, None) val incomplete = StandardColumnProfile("col1", .25, 100, String, false, Map.empty, None) + val waldIntervalStrategy = WaldIntervalStrategy() assert(!RetainCompletenessRule().shouldBeApplied(complete, 1000)) assert(!RetainCompletenessRule(0.05, 0.9).shouldBeApplied(complete, 1000)) @@ -139,74 +144,92 @@ class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContext assert(RetainCompletenessRule(0.0).shouldBeApplied(tenPercent, 1000)) assert(RetainCompletenessRule(0.0).shouldBeApplied(incomplete, 1000)) assert(RetainCompletenessRule().shouldBeApplied(incomplete, 1000)) + assert(!RetainCompletenessRule(intervalStrategy = waldIntervalStrategy).shouldBeApplied(complete, 1000)) + assert(!RetainCompletenessRule(0.05, 0.9, waldIntervalStrategy).shouldBeApplied(complete, 1000)) + assert(RetainCompletenessRule(0.05, 0.9, waldIntervalStrategy).shouldBeApplied(tenPercent, 1000)) } "return evaluable constraint candidates" in withSparkSession { session => + val table = Table(("strategy", "result"), (WaldIntervalStrategy(), true), (WilsonScoreIntervalStrategy(), true)) + forAll(table) { case (strategy, result) => + val dfWithColumnCandidate = getDfFull(session) - val dfWithColumnCandidate = getDfFull(session) + val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) - val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) + val check = Check(CheckLevel.Warning, "some") + .addConstraint( + RetainCompletenessRule(intervalStrategy = strategy).candidate(fakeColumnProfile, 100).constraint + ) - val check = Check(CheckLevel.Warning, "some") - .addConstraint(RetainCompletenessRule().candidate(fakeColumnProfile, 100).constraint) + val verificationResult = VerificationSuite() + .onData(dfWithColumnCandidate) + .addCheck(check) + .run() - val verificationResult = VerificationSuite() - .onData(dfWithColumnCandidate) - .addCheck(check) - .run() + val metricResult = verificationResult.metrics.head._2 - val metricResult = verificationResult.metrics.head._2 + assert(metricResult.value.isSuccess == result) + } - assert(metricResult.value.isSuccess) } "return working code to add constraint to check" in withSparkSession { session => + val table = Table( + ("strategy", "colCompleteness", "targetCompleteness", "result"), + (WaldIntervalStrategy(), 0.5, 0.4, true), + (WilsonScoreIntervalStrategy(), 0.4, 0.3, true) + ) + forAll(table) { case (strategy, colCompleteness, targetCompleteness, result) => - val dfWithColumnCandidate = getDfFull(session) + val dfWithColumnCandidate = getDfFull(session) - val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) + val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", colCompleteness) - val codeForConstraint = RetainCompletenessRule().candidate(fakeColumnProfile, 100) - .codeForConstraint + val codeForConstraint = RetainCompletenessRule(intervalStrategy = strategy).candidate(fakeColumnProfile, 100) + .codeForConstraint - val expectedCodeForConstraint = """.hasCompleteness("att1", _ >= 0.4, - | Some("It should be above 0.4!"))""".stripMargin.replaceAll("\n", "") + val expectedCodeForConstraint = s""".hasCompleteness("att1", _ >= $targetCompleteness, + | Some("It should be above $targetCompleteness!"))""".stripMargin.replaceAll("\n", "") - assert(expectedCodeForConstraint == codeForConstraint) + assert(expectedCodeForConstraint == codeForConstraint) - val check = Check(CheckLevel.Warning, "some") - .hasCompleteness("att1", _ >= 0.4, Some("It should be above 0.4!")) + val check = Check(CheckLevel.Warning, "some") + .hasCompleteness("att1", _ >= targetCompleteness, Some(s"It should be above $targetCompleteness")) - val verificationResult = VerificationSuite() - .onData(dfWithColumnCandidate) - .addCheck(check) - .run() + val verificationResult = VerificationSuite() + .onData(dfWithColumnCandidate) + .addCheck(check) + .run() - val metricResult = verificationResult.metrics.head._2 + val metricResult = verificationResult.metrics.head._2 + + assert(metricResult.value.isSuccess == result) + } - assert(metricResult.value.isSuccess) } "return evaluable constraint candidates with custom min/max completeness" in withSparkSession { session => + val table = Table(("strategy", "result"), (WaldIntervalStrategy(), true), (WilsonScoreIntervalStrategy(), true)) + forAll(table) { case (strategy, result) => + val dfWithColumnCandidate = getDfFull(session) - val dfWithColumnCandidate = getDfFull(session) - - val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) + val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) - val check = Check(CheckLevel.Warning, "some") - .addConstraint(RetainCompletenessRule(0.4, 0.6).candidate(fakeColumnProfile, 100).constraint) + val check = Check(CheckLevel.Warning, "some") + .addConstraint(RetainCompletenessRule(0.4, 0.6, strategy).candidate(fakeColumnProfile, 100).constraint) - val verificationResult = VerificationSuite() - .onData(dfWithColumnCandidate) - .addCheck(check) - .run() + val verificationResult = VerificationSuite() + .onData(dfWithColumnCandidate) + .addCheck(check) + .run() - val metricResult = verificationResult.metrics.head._2 + val metricResult = verificationResult.metrics.head._2 - assert(metricResult.value.isSuccess) + assert(metricResult.value.isSuccess == result) + } } } diff --git a/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala b/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala new file mode 100644 index 000000000..54e6cd1e1 --- /dev/null +++ b/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala @@ -0,0 +1,59 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.suggestions.rules.interval + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval +import com.amazon.deequ.utils.FixtureSupport +import org.scalamock.scalatest.MockFactory +import org.scalatest.Inspectors.forAll +import org.scalatest.prop.Tables.Table +import org.scalatest.wordspec.AnyWordSpec + +class IntervalStrategyTest extends AnyWordSpec with FixtureSupport with SparkContextSpec + with MockFactory { + + "ConfidenceIntervalStrategy" should { + "be calculated correctly" in { + val waldStrategy = WaldIntervalStrategy() + val wilsonStrategy = WilsonScoreIntervalStrategy() + + val table = Table( + ("strategy", "pHat", "numRecord", "lowerBound", "upperBound"), + (waldStrategy, 1.0, 20L, 1.0, 1.0), + (waldStrategy, 0.5, 100L, 0.4, 0.6), + (waldStrategy, 0.4, 100L, 0.3, 0.5), + (waldStrategy, 0.6, 100L, 0.5, 0.7), + (waldStrategy, 0.9, 100L, 0.84, 0.96), + (waldStrategy, 1.0, 100L, 1.0, 1.0), + + (wilsonStrategy, 0.01, 20L, 0.00, 0.18), + (wilsonStrategy, 1.0, 20L, 0.83, 1.0), + (wilsonStrategy, 0.5, 100L, 0.4, 0.6), + (wilsonStrategy, 0.4, 100L, 0.3, 0.5), + (wilsonStrategy, 0.6, 100L, 0.5, 0.7), + (wilsonStrategy, 0.9, 100L, 0.82, 0.95), + (wilsonStrategy, 1.0, 100L, 0.96, 1.0) + ) + + forAll(table) { case (strategy, pHat, numRecords, lowerBound, upperBound) => + val actualInterval = strategy.calculateTargetConfidenceInterval(pHat, numRecords) + assert(actualInterval == ConfidenceInterval(lowerBound, upperBound)) + } + } + } +} From ee26d1c447309ab6d1bfc0e7ce648729eccda91d Mon Sep 17 00:00:00 2001 From: Joshua Zexter <67130377+joshuazexter@users.noreply.github.com> Date: Wed, 31 Jul 2024 13:35:10 -0400 Subject: [PATCH 19/24] CustomAggregator (#572) * Add support for EntityTypes dqdl rule * Add support for Conditional Aggregation Analyzer --------- Co-authored-by: Joshua Zexter --- .../deequ/analyzers/CustomAggregator.scala | 69 +++++ .../com/amazon/deequ/metrics/Metric.scala | 17 ++ .../analyzers/CustomAggregatorTest.scala | 244 ++++++++++++++++++ 3 files changed, 330 insertions(+) create mode 100644 src/main/scala/com/amazon/deequ/analyzers/CustomAggregator.scala create mode 100644 src/test/scala/com/amazon/deequ/analyzers/CustomAggregatorTest.scala diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomAggregator.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomAggregator.scala new file mode 100644 index 000000000..d82c09312 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomAggregator.scala @@ -0,0 +1,69 @@ +/** + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.analyzers + +import com.amazon.deequ.metrics.AttributeDoubleMetric +import com.amazon.deequ.metrics.Entity +import org.apache.spark.sql.DataFrame + +import scala.util.Failure +import scala.util.Success +import scala.util.Try + +// Define a custom state to hold aggregation results +case class AggregatedMetricState(counts: Map[String, Int], total: Int) + extends DoubleValuedState[AggregatedMetricState] { + + override def sum(other: AggregatedMetricState): AggregatedMetricState = { + val combinedCounts = counts ++ other + .counts + .map { case (k, v) => k -> (v + counts.getOrElse(k, 0)) } + AggregatedMetricState(combinedCounts, total + other.total) + } + + override def metricValue(): Double = counts.values.sum.toDouble / total +} + +// Define the analyzer +case class CustomAggregator(aggregatorFunc: DataFrame => AggregatedMetricState, + metricName: String, + instance: String = "Dataset") + extends Analyzer[AggregatedMetricState, AttributeDoubleMetric] { + + override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None) + : Option[AggregatedMetricState] = { + Try(aggregatorFunc(data)) match { + case Success(state) => Some(state) + case Failure(_) => None + } + } + + override def computeMetricFrom(state: Option[AggregatedMetricState]): AttributeDoubleMetric = { + state match { + case Some(detState) => + val metrics = detState.counts.map { case (key, count) => + key -> (count.toDouble / detState.total) + } + AttributeDoubleMetric(Entity.Column, metricName, instance, Success(metrics)) + case None => + AttributeDoubleMetric(Entity.Column, metricName, instance, + Failure(new RuntimeException("Metric computation failed"))) + } + } + + override private[deequ] def toFailureMetric(failure: Exception): AttributeDoubleMetric = { + AttributeDoubleMetric(Entity.Column, metricName, instance, Failure(failure)) + } +} diff --git a/src/main/scala/com/amazon/deequ/metrics/Metric.scala b/src/main/scala/com/amazon/deequ/metrics/Metric.scala index 30225e246..307b278d1 100644 --- a/src/main/scala/com/amazon/deequ/metrics/Metric.scala +++ b/src/main/scala/com/amazon/deequ/metrics/Metric.scala @@ -89,3 +89,20 @@ case class KeyedDoubleMetric( } } } + +case class AttributeDoubleMetric( + entity: Entity.Value, + name: String, + instance: String, + value: Try[Map[String, Double]]) + extends Metric[Map[String, Double]] { + + override def flatten(): Seq[DoubleMetric] = { + value match { + case Success(valuesMap) => valuesMap.map { case (key, metricValue) => + DoubleMetric(entity, s"$name.$key", instance, Success(metricValue)) + }.toSeq + case Failure(ex) => Seq(DoubleMetric(entity, name, instance, Failure(ex))) + } + } +} diff --git a/src/test/scala/com/amazon/deequ/analyzers/CustomAggregatorTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CustomAggregatorTest.scala new file mode 100644 index 000000000..4f21cc64e --- /dev/null +++ b/src/test/scala/com/amazon/deequ/analyzers/CustomAggregatorTest.scala @@ -0,0 +1,244 @@ +/** + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.analyzers + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.utils.FixtureSupport +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import com.amazon.deequ.analyzers._ +import com.amazon.deequ.metrics.AttributeDoubleMetric +import com.amazon.deequ.profiles.ColumnProfilerRunner +import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.{sum, count} +import scala.util.Failure +import scala.util.Success +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.DataFrame +import com.amazon.deequ.metrics.AttributeDoubleMetric +import com.amazon.deequ.profiles.NumericColumnProfile + +class CustomAggregatorTest + extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + + "CustomAggregatorTest" should { + + """Example use: return correct counts + |for product sales in different categories""".stripMargin in withSparkSession + { session => + val data = getDfWithIdColumn(session) + val mockLambda: DataFrame => AggregatedMetricState = _ => + AggregatedMetricState(Map("ProductA" -> 50, "ProductB" -> 45), 100) + + val analyzer = CustomAggregator(mockLambda, "ProductSales", "category") + + val state = analyzer.computeStateFrom(data) + val metric: AttributeDoubleMetric = analyzer.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + metric.value.get should contain ("ProductA" -> 0.5) + metric.value.get should contain ("ProductB" -> 0.45) + } + + "handle scenarios with no data points effectively" in withSparkSession { session => + val data = getDfWithIdColumn(session) + val mockLambda: DataFrame => AggregatedMetricState = _ => + AggregatedMetricState(Map.empty[String, Int], 100) + + val analyzer = CustomAggregator(mockLambda, "WebsiteTraffic", "page") + + val state = analyzer.computeStateFrom(data) + val metric: AttributeDoubleMetric = analyzer.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + metric.value.get shouldBe empty + } + + "return a failure metric when the lambda function fails" in withSparkSession { session => + val data = getDfWithIdColumn(session) + val failingLambda: DataFrame => AggregatedMetricState = + _ => throw new RuntimeException("Test failure") + + val analyzer = CustomAggregator(failingLambda, "ProductSales", "category") + + val state = analyzer.computeStateFrom(data) + val metric = analyzer.computeMetricFrom(state) + + metric.value.isFailure shouldBe true + metric.value match { + case Success(_) => fail("Should have failed due to lambda function failure") + case Failure(exception) => exception.getMessage shouldBe "Metric computation failed" + } + } + + "return a failure metric if there are no rows in DataFrame" in withSparkSession { session => + val emptyData = session.createDataFrame( + session.sparkContext.emptyRDD[org.apache.spark.sql.Row], + getDfWithIdColumn(session).schema) + val mockLambda: DataFrame => AggregatedMetricState = df => + if (df.isEmpty) throw new RuntimeException("No data to analyze") + else AggregatedMetricState(Map("ProductA" -> 0, "ProductB" -> 0), 0) + + val analyzer = CustomAggregator(mockLambda, + "ProductSales", + "category") + + val state = analyzer.computeStateFrom(emptyData) + val metric = analyzer.computeMetricFrom(state) + + metric.value.isFailure shouldBe true + metric.value match { + case Success(_) => fail("Should have failed due to no data") + case Failure(exception) => exception.getMessage should include("Metric computation failed") + } + } + } + + "Combined Analysis with CustomAggregator and ColumnProfilerRunner" should { + "provide aggregated data and column profiles" in withSparkSession { session => + import session.implicits._ + + // Define the dataset + val rawData = Seq( + ("thingA", "13.0", "IN_TRANSIT", "true"), + ("thingA", "5", "DELAYED", "false"), + ("thingB", null, "DELAYED", null), + ("thingC", null, "IN_TRANSIT", "false"), + ("thingD", "1.0", "DELAYED", "true"), + ("thingC", "7.0", "UNKNOWN", null), + ("thingC", "20", "UNKNOWN", null), + ("thingE", "20", "DELAYED", "false") + ).toDF("productName", "totalNumber", "status", "valuable") + + val statusCountLambda: DataFrame => AggregatedMetricState = df => + AggregatedMetricState(df.groupBy("status").count().rdd + .map(r => r.getString(0) -> r.getLong(1).toInt).collect().toMap, df.count().toInt) + + val statusAnalyzer = CustomAggregator(statusCountLambda, "ProductStatus") + val statusMetric = statusAnalyzer.computeMetricFrom(statusAnalyzer.computeStateFrom(rawData)) + + val result = ColumnProfilerRunner().onData(rawData).run() + + statusMetric.value.isSuccess shouldBe true + statusMetric.value.get("IN_TRANSIT") shouldBe 0.25 + statusMetric.value.get("DELAYED") shouldBe 0.5 + + val totalNumberProfile = result.profiles("totalNumber").asInstanceOf[NumericColumnProfile] + totalNumberProfile.completeness shouldBe 0.75 + totalNumberProfile.dataType shouldBe DataTypeInstances.Fractional + + result.profiles.foreach { case (colName, profile) => + println(s"Column '$colName': completeness: ${profile.completeness}, " + + s"approximate number of distinct values: ${profile.approximateNumDistinctValues}") + } + } + } + + "accurately compute percentage occurrences and total engagements for content types" in withSparkSession { session => + val data = getContentEngagementDataFrame(session) + val contentEngagementLambda: DataFrame => AggregatedMetricState = df => { + + // Calculate the total engagements for each content type + val counts = df + .groupBy("content_type") + .agg( + (sum("views") + sum("likes") + sum("shares")).cast("int").alias("totalEngagements") + ) + .collect() + .map(row => + row.getString(0) -> row.getInt(1) + ) + .toMap + val totalEngagements = counts.values.sum + AggregatedMetricState(counts, totalEngagements) + } + + val analyzer = CustomAggregator(contentEngagementLambda, "ContentEngagement", "AllTypes") + + val state = analyzer.computeStateFrom(data) + val metric = analyzer.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + // Counts: Map(Video -> 5300, Article -> 1170) + // total engagement: 6470 + (metric.value.get("Video") * 100).toInt shouldBe 81 + (metric.value.get("Article") * 100).toInt shouldBe 18 + println(metric.value.get) + } + + "accurately compute total aggregated resources for cloud services" in withSparkSession { session => + val data = getResourceUtilizationDataFrame(session) + val resourceUtilizationLambda: DataFrame => AggregatedMetricState = df => { + val counts = df.groupBy("service_type") + .agg( + (sum("cpu_hours") + sum("memory_gbs") + sum("storage_gbs")).cast("int").alias("totalResources") + ) + .collect() + .map(row => + row.getString(0) -> row.getInt(1) + ) + .toMap + val totalResources = counts.values.sum + AggregatedMetricState(counts, totalResources) + } + val analyzer = CustomAggregator(resourceUtilizationLambda, "ResourceUtilization", "CloudServices") + + val state = analyzer.computeStateFrom(data) + val metric = analyzer.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + println("Resource Utilization Metrics: " + metric.value.get) +// Resource Utilization Metrics: Map(Compute -> 0.5076142131979695, + // Database -> 0.27918781725888325, + // Storage -> 0.2131979695431472) + (metric.value.get("Compute") * 100).toInt shouldBe 50 // Expected percentage for Compute + (metric.value.get("Database") * 100).toInt shouldBe 27 // Expected percentage for Database + (metric.value.get("Storage") * 100).toInt shouldBe 21 // 430 CPU + 175 Memory + 140 Storage from mock data + } + + def getDfWithIdColumn(session: SparkSession): DataFrame = { + import session.implicits._ + Seq( + ("ProductA", "North"), + ("ProductA", "South"), + ("ProductB", "East"), + ("ProductA", "West") + ).toDF("product", "region") + } + + def getContentEngagementDataFrame(session: SparkSession): DataFrame = { + import session.implicits._ + Seq( + ("Video", 1000, 150, 300), + ("Article", 500, 100, 150), + ("Video", 1500, 200, 450), + ("Article", 300, 50, 70), + ("Video", 1200, 180, 320) + ).toDF("content_type", "views", "likes", "shares") + } + + def getResourceUtilizationDataFrame(session: SparkSession): DataFrame = { + import session.implicits._ + Seq( + ("Compute", 400, 120, 150), + ("Storage", 100, 30, 500), + ("Database", 200, 80, 100), + ("Compute", 450, 130, 250), + ("Database", 230, 95, 120) + ).toDF("service_type", "cpu_hours", "memory_gbs", "storage_gbs") + } +} From 97f7a3ede2ed36cc3e5ea562d285774dfb5aea30 Mon Sep 17 00:00:00 2001 From: bojackli <478378663@qq.com> Date: Thu, 29 Aug 2024 22:52:25 +0800 Subject: [PATCH 20/24] fix typo (#574) --- .../com/amazon/deequ/suggestions/ConstraintSuggestion.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/suggestions/ConstraintSuggestion.scala b/src/main/scala/com/amazon/deequ/suggestions/ConstraintSuggestion.scala index 57ac9aeaa..96b7ab9b2 100644 --- a/src/main/scala/com/amazon/deequ/suggestions/ConstraintSuggestion.scala +++ b/src/main/scala/com/amazon/deequ/suggestions/ConstraintSuggestion.scala @@ -52,7 +52,7 @@ case class ConstraintSuggestionWithValue[T]( object ConstraintSuggestions { - private[this] val CONSTRANT_SUGGESTIONS_FIELD = "constraint_suggestions" + private[this] val CONSTRAINT_SUGGESTIONS_FIELD = "constraint_suggestions" private[suggestions] def toJson(constraintSuggestions: Seq[ConstraintSuggestion]): String = { @@ -68,7 +68,7 @@ object ConstraintSuggestions { constraintsJson.add(constraintJson) } - json.add(CONSTRANT_SUGGESTIONS_FIELD, constraintsJson) + json.add(CONSTRAINT_SUGGESTIONS_FIELD, constraintsJson) val gson = new GsonBuilder() .setPrettyPrinting() @@ -109,7 +109,7 @@ object ConstraintSuggestions { constraintEvaluations.add(constraintEvaluation) } - json.add(CONSTRANT_SUGGESTIONS_FIELD, constraintEvaluations) + json.add(CONSTRAINT_SUGGESTIONS_FIELD, constraintEvaluations) val gson = new GsonBuilder() .setPrettyPrinting() From 9d92d94c12056ea0bfe8cdba1c5c7d8a1a5b2db7 Mon Sep 17 00:00:00 2001 From: Josh <5685731+marcantony@users.noreply.github.com> Date: Sat, 31 Aug 2024 12:40:11 -0400 Subject: [PATCH 21/24] Fix performance of building row-level results (#577) * Generate row-level results with withColumns Iteratively using withColumn (singular) causes performance issues when iterating over a large sequence of columns. * Add back UNIQUENESS_ID --- src/main/scala/com/amazon/deequ/VerificationResult.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/VerificationResult.scala b/src/main/scala/com/amazon/deequ/VerificationResult.scala index 6390db821..418a622e6 100644 --- a/src/main/scala/com/amazon/deequ/VerificationResult.scala +++ b/src/main/scala/com/amazon/deequ/VerificationResult.scala @@ -98,9 +98,7 @@ object VerificationResult { val columnNamesToMetrics: Map[String, Column] = verificationResultToColumn(verificationResult) val dataWithID = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) - columnNamesToMetrics.foldLeft(dataWithID)( - (dataWithID, newColumn: (String, Column)) => - dataWithID.withColumn(newColumn._1, newColumn._2)).drop(UNIQUENESS_ID) + dataWithID.withColumns(columnNamesToMetrics).drop(UNIQUENESS_ID) } def checkResultsAsJson(verificationResult: VerificationResult, From fdebce5bbc84197e0f53261cee27abcdb48977aa Mon Sep 17 00:00:00 2001 From: Hubert Date: Mon, 4 Nov 2024 09:17:20 -0500 Subject: [PATCH 22/24] updating anomaly check bounds to not have defaults and require inputs for the bound value and isThresholdInclusive, also adding anomaly detection with extended results README, and adding anomaly detection test with 2 anomaly checks on the same suite --- .../anomalydetection/BaseChangeStrategy.scala | 3 +- .../BatchNormalStrategy.scala | 3 +- .../ExtendedDetectionResult.scala | 48 ++++----- .../OnlineNormalStrategy.scala | 4 +- .../SimpleThresholdStrategy.scala | 4 +- .../seasonal/HoltWinters.scala | 14 ++- .../scala/com/amazon/deequ/checks/Check.scala | 24 +++-- .../AnomalyExtendedResultsConstraint.scala | 8 +- ...yDetectionWithExtendedResultsExample.scala | 7 +- .../examples/anomaly_detection_example.md | 2 + ...detection_with_extended_results_example.md | 75 +++++++++++++ .../amazon/deequ/VerificationSuiteTest.scala | 86 +++++++++++++-- .../AbsoluteChangeStrategyTest.scala | 94 ++++++++-------- .../AnomalyDetectorTest.scala | 54 ++++++---- .../BatchNormalStrategyTest.scala | 56 +++++----- .../OnlineNormalStrategyTest.scala | 91 ++++++++++------ .../RateOfChangeStrategyTest.scala | 5 +- .../RelativeRateOfChangeStrategyTest.scala | 88 ++++++++------- .../SimpleThresholdStrategyTest.scala | 22 ++-- .../com/amazon/deequ/checks/CheckTest.scala | 102 +++++++++--------- ...AnomalyExtendedResultsConstraintTest.scala | 85 +++++++++------ .../deequ/constraints/ConstraintsTest.scala | 46 ++++---- 22 files changed, 584 insertions(+), 337 deletions(-) create mode 100644 src/main/scala/com/amazon/deequ/examples/anomaly_detection_with_extended_results_example.md diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/BaseChangeStrategy.scala b/src/main/scala/com/amazon/deequ/anomalydetection/BaseChangeStrategy.scala index 0ac353223..2d0cf3948 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/BaseChangeStrategy.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/BaseChangeStrategy.scala @@ -131,7 +131,8 @@ trait BaseChangeStrategy (None, false) } (outputSequenceIndex, AnomalyDetectionDataPoint(value, change, - Threshold(lowerBound = Bound(lowerBound), upperBound = Bound(upperBound)), isAnomaly, 1.0, detail)) + BoundedRange(lowerBound = Bound(lowerBound, inclusive = true), + upperBound = Bound(upperBound, inclusive = true)), isAnomaly, 1.0, detail)) } } } diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategy.scala b/src/main/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategy.scala index 7d4bb6304..41a7bad43 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategy.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategy.scala @@ -113,7 +113,8 @@ case class BatchNormalStrategy( (None, false) } (index, AnomalyDetectionDataPoint(value, value, - Threshold(lowerBound = Bound(lowerBound), upperBound = Bound(upperBound)), isAnomaly, 1.0, detail)) + BoundedRange(lowerBound = Bound(lowerBound, inclusive = true), + upperBound = Bound(upperBound, inclusive = true)), isAnomaly, 1.0, detail)) } } diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/ExtendedDetectionResult.scala b/src/main/scala/com/amazon/deequ/anomalydetection/ExtendedDetectionResult.scala index 7e024b2cf..c966d738a 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/ExtendedDetectionResult.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/ExtendedDetectionResult.scala @@ -26,26 +26,26 @@ package com.amazon.deequ.anomalydetection * Anomaly Detection Data Point class * This class is different from the Anomaly Class in that this class * wraps around all data points, not just anomalies, and provides extended results including - * if the data point is an anomaly, and the thresholds used in the anomaly calculation. + * if the data point is an anomaly, and the range with bounds used in the anomaly calculation. * * @param dataMetricValue The metric value that is the data point. * @param anomalyMetricValue The metric value that is being used in the anomaly calculation. * This usually aligns with dataMetricValue but not always, * like in a rate of change strategy where the rate of change is the anomaly metric * which may not equal the actual data point value. - * @param anomalyThreshold The thresholds used in the anomaly check, the anomalyMetricValue is - * compared to this threshold. + * @param anomalyCheckRange The range of bounds used in the anomaly check, the anomalyMetricValue is + * compared to this range. * @param isAnomaly If the data point is an anomaly. * @param confidence Confidence of anomaly detection. * @param detail Detailed error message. */ class AnomalyDetectionDataPoint( - val dataMetricValue: Double, - val anomalyMetricValue: Double, - val anomalyThreshold: Threshold, - val isAnomaly: Boolean, - val confidence: Double, - val detail: Option[String]) + val dataMetricValue: Double, + val anomalyMetricValue: Double, + val anomalyCheckRange: BoundedRange, + val isAnomaly: Boolean, + val confidence: Double, + val detail: Option[String]) { def canEqual(that: Any): Boolean = { @@ -64,7 +64,7 @@ class AnomalyDetectionDataPoint( case anomaly: AnomalyDetectionDataPoint => anomaly.dataMetricValue == dataMetricValue && anomaly.anomalyMetricValue == anomalyMetricValue && - anomaly.anomalyThreshold == anomalyThreshold && + anomaly.anomalyCheckRange == anomalyCheckRange && anomaly.isAnomaly == isAnomaly && anomaly.confidence == confidence case _ => false @@ -76,7 +76,7 @@ class AnomalyDetectionDataPoint( var result = 1 result = prime * result + dataMetricValue.hashCode() result = prime * result + anomalyMetricValue.hashCode() - result = prime * result + anomalyThreshold.hashCode() + result = prime * result + anomalyCheckRange.hashCode() result = prime * result + isAnomaly.hashCode() result = prime * result + confidence.hashCode() result @@ -86,21 +86,21 @@ class AnomalyDetectionDataPoint( object AnomalyDetectionDataPoint { def apply(dataMetricValue: Double, anomalyMetricValue: Double, - anomalyThreshold: Threshold = Threshold(), isAnomaly: Boolean = false, + anomalyCheckRange: BoundedRange, isAnomaly: Boolean, confidence: Double, detail: Option[String] = None ): AnomalyDetectionDataPoint = { - new AnomalyDetectionDataPoint(dataMetricValue, anomalyMetricValue, anomalyThreshold, isAnomaly, confidence, detail) + new AnomalyDetectionDataPoint(dataMetricValue, anomalyMetricValue, anomalyCheckRange, isAnomaly, confidence, detail) } } /** - * Threshold class - * Defines threshold for the anomaly detection, defaults to inclusive bounds of Double.Min and Double.Max. + * BoundedRange class + * Defines range for the anomaly detection. * @param upperBound The upper bound or threshold. * @param lowerBound The lower bound or threshold. */ -case class Threshold(lowerBound: Bound = Bound(Double.MinValue), upperBound: Bound = Bound(Double.MaxValue)) +case class BoundedRange(lowerBound: Bound, upperBound: Bound) /** * Bound Class @@ -108,7 +108,7 @@ case class Threshold(lowerBound: Bound = Bound(Double.MinValue), upperBound: Bou * @param value The value of the bound as a Double. * @param inclusive Boolean indicating if the Bound is inclusive or not. */ -case class Bound(value: Double, inclusive: Boolean = true) +case class Bound(value: Double, inclusive: Boolean) @@ -123,23 +123,21 @@ case class ExtendedDetectionResult(anomalyDetectionDataPointSequence: /** * AnomalyDetectionExtendedResult Class - * This class contains anomaly detection extended results through a Sequence of AnomalyDetectionDataPoints. + * This class contains anomaly detection extended results through an AnomalyDetectionDataPoint. * This is currently an optional field in the ConstraintResult class that is exposed to users. * * Currently, anomaly detection only runs on "newest" data point (referring to the dataframe being - * run on by the verification suite) and not multiple data points, so the returned sequence will contain + * run on by the verification suite) and not multiple data points, so this will contain that * one AnomalyDetectionDataPoint. - * In the future, if we allow the anomaly check to detect multiple points, the returned sequence - * may be more than one AnomalyDetectionDataPoints. - * @param anomalyDetectionDataPoints Sequence of AnomalyDetectionDataPoints. + * @param anomalyDetectionDataPoint AnomalyDetectionDataPoint of newest data point generated from check. */ -case class AnomalyDetectionExtendedResult(anomalyDetectionDataPoints: Seq[AnomalyDetectionDataPoint]) +case class AnomalyDetectionExtendedResult(anomalyDetectionDataPoint: AnomalyDetectionDataPoint) /** * AnomalyDetectionAssertionResult Class * This class is returned by the assertion function Check.isNewestPointNonAnomalousWithExtendedResults. - * @param hasNoAnomaly Boolean indicating if there was no anomaly detected. + * @param hasAnomaly Boolean indicating if there was an anomaly detected. * @param anomalyDetectionExtendedResult AnomalyDetectionExtendedResults class. */ -case class AnomalyDetectionAssertionResult(hasNoAnomaly: Boolean, +case class AnomalyDetectionAssertionResult(hasAnomaly: Boolean, anomalyDetectionExtendedResult: AnomalyDetectionExtendedResult) diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategy.scala b/src/main/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategy.scala index 3955eae16..aa9c91276 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategy.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategy.scala @@ -173,8 +173,8 @@ case class OnlineNormalStrategy( val value = dataSeries(index) (index, AnomalyDetectionDataPoint(value, value, - Threshold(lowerBound = Bound(lowerBound), upperBound = Bound(upperBound)), - calcRes.isAnomaly, 1.0, detail)) + BoundedRange(lowerBound = Bound(lowerBound, inclusive = true), + upperBound = Bound(upperBound, inclusive = true)), calcRes.isAnomaly, 1.0, detail)) } } } diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategy.scala b/src/main/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategy.scala index 03d30c7c7..5e5fe72e8 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategy.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategy.scala @@ -78,8 +78,8 @@ case class SimpleThresholdStrategy( } (index, AnomalyDetectionDataPoint(value, value, - Threshold(lowerBound = Bound(lowerBound), upperBound = Bound(upperBound)), - isAnomaly, 1.0, detail)) + BoundedRange(lowerBound = Bound(lowerBound, inclusive = true), + upperBound = Bound(upperBound, inclusive = true)), isAnomaly, 1.0, detail)) } } } diff --git a/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala b/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala index 3d837235a..082911b1c 100644 --- a/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala +++ b/src/main/scala/com/amazon/deequ/anomalydetection/seasonal/HoltWinters.scala @@ -17,8 +17,15 @@ package com.amazon.deequ.anomalydetection.seasonal import breeze.linalg.DenseVector -import breeze.optimize.{ApproximateGradientFunction, DiffFunction, LBFGSB} -import com.amazon.deequ.anomalydetection.{Anomaly, AnomalyDetectionDataPoint, AnomalyDetectionStrategy, AnomalyDetectionStrategyWithExtendedResults, Threshold, Bound} +import breeze.optimize.ApproximateGradientFunction +import breeze.optimize.DiffFunction +import breeze.optimize.LBFGSB +import com.amazon.deequ.anomalydetection.Anomaly +import com.amazon.deequ.anomalydetection.AnomalyDetectionDataPoint +import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategy +import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategyWithExtendedResults +import com.amazon.deequ.anomalydetection.BoundedRange +import com.amazon.deequ.anomalydetection.Bound import collection.mutable.ListBuffer @@ -202,7 +209,8 @@ class HoltWinters(seriesPeriodicity: Int) detectionIndex + startIndex -> AnomalyDetectionDataPoint( dataMetricValue = inputValue, anomalyMetricValue = anomalyMetricValue, - anomalyThreshold = Threshold(upperBound = Bound(upperBound)), + anomalyCheckRange = BoundedRange(lowerBound = Bound(Double.MinValue, inclusive = true), + upperBound = Bound(upperBound, inclusive = true)), isAnomaly = isAnomaly, confidence = 1.0, detail = detail diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index 8d4ffa1fb..446c2022d 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -25,7 +25,15 @@ import com.amazon.deequ.analyzers.Histogram import com.amazon.deequ.analyzers.KLLParameters import com.amazon.deequ.analyzers.Patterns import com.amazon.deequ.analyzers.State -import com.amazon.deequ.anomalydetection.{AnomalyDetectionAssertionResult, AnomalyDetectionExtendedResult, ExtendedDetectionResult, AnomalyDetectionStrategy, AnomalyDetectionStrategyWithExtendedResults, AnomalyDetector, AnomalyDetectorWithExtendedResults, DataPoint, HistoryUtils} +import com.amazon.deequ.anomalydetection.AnomalyDetectionAssertionResult +import com.amazon.deequ.anomalydetection.AnomalyDetectionExtendedResult +import com.amazon.deequ.anomalydetection.ExtendedDetectionResult +import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategy +import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategyWithExtendedResults +import com.amazon.deequ.anomalydetection.AnomalyDetector +import com.amazon.deequ.anomalydetection.AnomalyDetectorWithExtendedResults +import com.amazon.deequ.anomalydetection.DataPoint +import com.amazon.deequ.anomalydetection.HistoryUtils import com.amazon.deequ.checks.ColumnCondition.isAnyNotNull import com.amazon.deequ.checks.ColumnCondition.isEachNotNull import com.amazon.deequ.constraints.Constraint._ @@ -1487,21 +1495,21 @@ object Check { */ private[deequ] def getNewestPointAnomalyResults(extendedDetectionResult: ExtendedDetectionResult): AnomalyDetectionAssertionResult = { - val (hasNoAnomaly, anomalyDetectionExtendedResults): (Boolean, AnomalyDetectionExtendedResult) = { + val (hasAnomaly, anomalyDetectionExtendedResults): (Boolean, AnomalyDetectionExtendedResult) = { - // Based on upstream code, this anomaly detection data point sequence should never be empty + // Based on upstream code, this anomaly detection data point sequence should never be empty. require(extendedDetectionResult.anomalyDetectionDataPointSequence != Nil, "anomalyDetectionDataPoints from AnomalyDetectionExtendedResult cannot be empty") - // get the last anomaly detection data point of sequence (there should only be one element for now) - // and check the isAnomaly boolean, also return the last anomaly detection data point - // wrapped in the anomaly detection extended result class + // Get the last anomaly detection data point of sequence (there should only be one element for now). + // Check the isAnomaly boolean, also return the last anomaly detection data point + // wrapped in the anomaly detection extended result class. extendedDetectionResult.anomalyDetectionDataPointSequence match { case _ :+ lastAnomalyDataPointPair => - (!lastAnomalyDataPointPair._2.isAnomaly, AnomalyDetectionExtendedResult(Seq(lastAnomalyDataPointPair._2))) + (lastAnomalyDataPointPair._2.isAnomaly, AnomalyDetectionExtendedResult(lastAnomalyDataPointPair._2)) } } AnomalyDetectionAssertionResult( - hasNoAnomaly = hasNoAnomaly, anomalyDetectionExtendedResult = anomalyDetectionExtendedResults) + hasAnomaly = hasAnomaly, anomalyDetectionExtendedResult = anomalyDetectionExtendedResults) } } diff --git a/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala b/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala index c55736ddd..3305aaa4f 100644 --- a/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala @@ -16,12 +16,14 @@ package com.amazon.deequ.constraints -import com.amazon.deequ.analyzers.{Analyzer, State} +import com.amazon.deequ.analyzers.Analyzer +import com.amazon.deequ.analyzers.State import com.amazon.deequ.anomalydetection.AnomalyDetectionAssertionResult import com.amazon.deequ.metrics.Metric import org.apache.spark.sql.DataFrame -import scala.util.{Failure, Success} +import scala.util.Success +import scala.util.Success /** * Case class for anomaly with extended results constraints that provides unified way to access @@ -76,7 +78,7 @@ private[deequ] case class AnomalyExtendedResultsConstraint[S <: State[S], M, V]( val assertOn = runPickerOnMetric(metricValue) val anomalyAssertionResult = runAssertion(assertOn) - if (anomalyAssertionResult.hasNoAnomaly) { + if (!anomalyAssertionResult.hasAnomaly) { ConstraintResult(this, ConstraintStatus.Success, metric = Some(metric), anomalyDetectionExtendedResultOption = Some(anomalyAssertionResult.anomalyDetectionExtendedResult)) } else { diff --git a/src/main/scala/com/amazon/deequ/examples/AnomalyDetectionWithExtendedResultsExample.scala b/src/main/scala/com/amazon/deequ/examples/AnomalyDetectionWithExtendedResultsExample.scala index dd73b006b..6666b9171 100644 --- a/src/main/scala/com/amazon/deequ/examples/AnomalyDetectionWithExtendedResultsExample.scala +++ b/src/main/scala/com/amazon/deequ/examples/AnomalyDetectionWithExtendedResultsExample.scala @@ -20,7 +20,8 @@ import com.amazon.deequ.VerificationSuite import com.amazon.deequ.analyzers.Size import com.amazon.deequ.anomalydetection.RelativeRateOfChangeStrategy import com.amazon.deequ.checks.CheckStatus._ -import com.amazon.deequ.examples.ExampleUtils.{itemsAsDataframe, withSpark} +import com.amazon.deequ.examples.ExampleUtils.itemsAsDataframe +import com.amazon.deequ.examples.ExampleUtils.withSpark import com.amazon.deequ.repository.ResultKey import com.amazon.deequ.repository.memory.InMemoryMetricsRepository @@ -82,9 +83,9 @@ private[examples] object AnomalyDetectionWithExtendedResultsExample extends App if (verificationResult.status != Success) { println("Anomaly detected in the Size() metric!") val anomalyDetectionDataPoint = verificationResult.checkResults.head._2.constraintResults. - head.anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + head.anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoint println(s"Rate of change of ${anomalyDetectionDataPoint.anomalyMetricValue} was not in " + - s"${anomalyDetectionDataPoint.anomalyThreshold}") + s"${anomalyDetectionDataPoint.anomalyCheckRange}") /* Lets have a look at the actual metrics. */ metricsRepository diff --git a/src/main/scala/com/amazon/deequ/examples/anomaly_detection_example.md b/src/main/scala/com/amazon/deequ/examples/anomaly_detection_example.md index 9acf7d83d..d72f5e951 100644 --- a/src/main/scala/com/amazon/deequ/examples/anomaly_detection_example.md +++ b/src/main/scala/com/amazon/deequ/examples/anomaly_detection_example.md @@ -1,5 +1,7 @@ # Anomaly detection +*After reading this page, check out [anomaly checks with extended results](https://github.com/awslabs/deequ/blob/master/src/main/scala/com/amazon/deequ/examples/anomaly_detection_with_extended_results_example.md) for how to access more details about the anomaly check such as the upper and lower bounds used in the check. This requires using a different method that has the same signature.* + Very often, it is hard to exactly define what constraints we want to evaluate on our data. However, we often have a better understanding of how much change we expect in certain metrics of our data. Therefore, **deequ** supports anomaly detection for data quality metrics. The idea is that we regularly store the metrics of our data in a [MetricsRepository](https://github.com/awslabs/deequ/blob/master/src/main/scala/com/amazon/deequ/examples/metrics_repository_example.md). Once we do that, we can run anomaly checks that compare the current value of the metric to its values in the past and allow us to detect anomalous changes. In this simple example, we assume that we compute the size of a dataset every day and we want to ensure that it does not change drastically: the number of rows on a given day should not be more than double of what we have seen on the day before. diff --git a/src/main/scala/com/amazon/deequ/examples/anomaly_detection_with_extended_results_example.md b/src/main/scala/com/amazon/deequ/examples/anomaly_detection_with_extended_results_example.md new file mode 100644 index 000000000..a061f0e75 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/examples/anomaly_detection_with_extended_results_example.md @@ -0,0 +1,75 @@ +# Anomaly detection with extended results + +Using the `addAnomalyCheckWithExtendedResults` method instead of the original `addAnomalyCheck`method, you can get more +detailed results about the anomaly detection result from the newly created metric. You can get details such as: + +- dataMetricValue: The metric value that is the data point. +- anomalyMetricValue: The value of the metric that is being checked, which isn't always equal to the dataMetricValue. +- anomalyCheckRange: The range of bounds used in the anomaly check, the anomalyMetricValue is compared to this range. +- isAnomaly: If the anomalyMetricValue is outside the anomalyCheckRange, this is true. +- confidence: The confidence of the anomaly detection. +- detail: An optional detail message. + +These are contained within the AnomalyDetectionDataPoint class. +```scala +class AnomalyDetectionDataPoint( +val dataMetricValue: Double, +val anomalyMetricValue: Double, +val anomalyCheckRange: BoundedRange, +val isAnomaly: Boolean, +val confidence: Double, +val detail: Option[String]) + +case class BoundedRange(lowerBound: Bound, upperBound: Bound) + +case class Bound(value: Double, inclusive: Boolean) +``` + +In terms of accessing the result, the AnomalyDetectionDataPoint is wrapped in an AnomalyDetectionExtendedResult class +that is an optional field in the ConstraintResult class. The ConstraintResult class is a class that contains the +results of a constraint check. + +```scala +case class ConstraintResult( + constraint: Constraint, + status: ConstraintStatus.Value, + message: Option[String] = None, + metric: Option[Metric[_]] = None, + anomalyDetectionExtendedResultOption: Option[AnomalyDetectionExtendedResult] = None) + +case class AnomalyDetectionExtendedResult(anomalyDetectionDataPoint: AnomalyDetectionDataPoint) +``` + + +In order to get extended results you need to run your verification suite with +the `addAnomalyCheckWithExtendedResults` method, which has the same method signature as the original `addAnomalyCheck` +method. + +```scala +val result = VerificationSuite() + .onData(yesterdaysDataset) + .useRepository(metricsRepository) + .saveOrAppendResult(yesterdaysKey) + .addAnomalyCheckWithExtendedResults( + RelativeRateOfChangeStrategy(maxRateIncrease = Some(2.0)), + Size()) + .run() + +val anomalyDetectionExtendedResult: AnomalyDetectionExtendedResult = result.checkResults.head._2.constraintResults.head + .anomalyDetectionExtendedResultOption.getOrElse("placeholder to do something else") + +val anomalyDetectionDataPoint: AnomalyDetectionDataPoint = anomalyDetectionExtendedResult.anomalyDetectionDataPoint +``` + +You can access the values of the anomaly detection extended results like the anomalyMetricValue and anomalyCheckRange. +```scala +println(s"Anomaly check range: ${anomalyDetectionDataPoint.anomalyCheckRange}") +println(s"Anomaly metric value: ${anomalyDetectionDataPoint.anomalyMetricValue}") +``` + +``` +Anomaly check range: BoundedRange(Bound(-2.0,true),Bound(2.0,true)) +Anomaly metric value: 4.5 +``` + +An [executable version of this example with extended results](https://github.com/awslabs/deequ/blob/master/src/main/scala/com/amazon/deequ/examples/AnomalyDetectionWithExtendedResultsExample.scala) is available as part of our code base. diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index f410ea821..54d9040a4 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -18,7 +18,10 @@ package com.amazon.deequ import com.amazon.deequ.analyzers._ import com.amazon.deequ.analyzers.runners.AnalyzerContext -import com.amazon.deequ.anomalydetection.{AbsoluteChangeStrategy, AnomalyDetectionDataPoint, AnomalyDetectionExtendedResult, Bound, Threshold} +import com.amazon.deequ.anomalydetection.AbsoluteChangeStrategy +import com.amazon.deequ.anomalydetection.AnomalyDetectionDataPoint +import com.amazon.deequ.anomalydetection.Bound +import com.amazon.deequ.anomalydetection.BoundedRange import com.amazon.deequ.checks.Check import com.amazon.deequ.checks.CheckLevel import com.amazon.deequ.checks.CheckStatus @@ -1132,23 +1135,26 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val checkResultsOne = verificationResultOne.checkResults.head._2.status val actualResultsOneAnomalyDetectionDataPoint = verificationResultOne.checkResults.head._2.constraintResults.head - .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoint val expectedResultsOneAnomalyDetectionDataPoint = - AnomalyDetectionDataPoint(11.0, 7.0, Threshold(Bound(-2.0), Bound(2.0)), isAnomaly = true, 1.0) + AnomalyDetectionDataPoint(11.0, 7.0, BoundedRange(Bound(-2.0, inclusive = true), + Bound(2.0, inclusive = true)), isAnomaly = true, 1.0) val checkResultsTwo = verificationResultTwo.checkResults.head._2.status val actualResultsTwoAnomalyDetectionDataPoint = verificationResultTwo.checkResults.head._2.constraintResults.head - .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoint val expectedResultsTwoAnomalyDetectionDataPoint = - AnomalyDetectionDataPoint(11.0, 0.0, Threshold(Bound(-7.0), Bound(7.0)), isAnomaly = false, 1.0) + AnomalyDetectionDataPoint(11.0, 0.0, BoundedRange(Bound(-7.0, inclusive = true), + Bound(7.0, inclusive = true)), isAnomaly = false, 1.0) val checkResultsThree = verificationResultThree.checkResults.head._2.status val actualResultsThreeAnomalyDetectionDataPoint = verificationResultThree.checkResults.head._2.constraintResults.head - .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoint val expectedResultsThreeAnomalyDetectionDataPoint = - AnomalyDetectionDataPoint(11.0, 0.0, Threshold(Bound(-7.0), Bound(7.0)), isAnomaly = false, 1.0) + AnomalyDetectionDataPoint(11.0, 0.0, BoundedRange(Bound(-7.0, inclusive = true), + Bound(7.0, inclusive = true)), isAnomaly = false, 1.0) assert(checkResultsOne == CheckStatus.Warning) assert(checkResultsTwo == CheckStatus.Success) @@ -1198,16 +1204,74 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val checkResultsOne = verificationResultOne.checkResults.values.toSeq(1).status val actualResultsOneAnomalyDetectionDataPoint = verificationResultOne.checkResults.values.toSeq(1).constraintResults.head - .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoint val expectedResultsOneAnomalyDetectionDataPoint = - AnomalyDetectionDataPoint(11.0, 7.0, Threshold(Bound(-2.0), Bound(2.0)), isAnomaly = true, 1.0) + AnomalyDetectionDataPoint(11.0, 7.0, BoundedRange(Bound(-2.0, inclusive = true), + Bound(2.0, inclusive = true)), isAnomaly = true, 1.0) val checkResultsTwo = verificationResultTwo.checkResults.head._2.status val actualResultsTwoAnomalyDetectionDataPoint = verificationResultTwo.checkResults.head._2.constraintResults.head - .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoints.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoint val expectedResultsTwoAnomalyDetectionDataPoint = - AnomalyDetectionDataPoint(11.0, 0.0, Threshold(Bound(-7.0), Bound(7.0)), isAnomaly = false, 1.0) + AnomalyDetectionDataPoint(11.0, 0.0, BoundedRange(Bound(-7.0, inclusive = true), + Bound(7.0, inclusive = true)), isAnomaly = false, 1.0) + + assert(checkResultsOne == CheckStatus.Warning) + assert(checkResultsTwo == CheckStatus.Success) + + assert(actualResultsOneAnomalyDetectionDataPoint == expectedResultsOneAnomalyDetectionDataPoint) + assert(actualResultsTwoAnomalyDetectionDataPoint == expectedResultsTwoAnomalyDetectionDataPoint) + } + } + + + "addAnomalyCheckWithExtendedResults with two anomaly checks on the same suite should work and " + + "output extended results" in + withSparkSession { sparkSession => + evaluateWithRepositoryWithHistory { repository => + + val df = getDfWithNRows(sparkSession, 11) + val saveResultsWithKey = ResultKey(5, Map.empty) + + val analyzers = Completeness("item") :: Nil + + val verificationResultOne = VerificationSuite() + .onData(df) + .addCheck(Check(CheckLevel.Error, "group-1").hasSize(_ == 11)) + .useRepository(repository) + .addRequiredAnalyzers(analyzers) + .saveOrAppendResult(saveResultsWithKey) + .addAnomalyCheckWithExtendedResults( + AbsoluteChangeStrategy(Some(-2.0), Some(2.0)), + Size(), + Some(AnomalyCheckConfig(CheckLevel.Warning, "Anomaly check to fail")) + ) + .addAnomalyCheckWithExtendedResults( + AbsoluteChangeStrategy(Some(-7.0), Some(7.0)), + Size(), + Some(AnomalyCheckConfig(CheckLevel.Error, "Anomaly check to succeed", + Map.empty, Some(0), Some(11))) + ) + .run() + + + val checkResultsOne = verificationResultOne.checkResults.values.toSeq(1).status + val actualResultsOneAnomalyDetectionDataPoint = + verificationResultOne.checkResults.values.toSeq(1).constraintResults.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoint + val expectedResultsOneAnomalyDetectionDataPoint = + AnomalyDetectionDataPoint(11.0, 7.0, BoundedRange(Bound(-2.0, inclusive = true), + Bound(2.0, inclusive = true)), isAnomaly = true, 1.0) + + val checkResultsTwo = verificationResultOne.checkResults.values.toSeq(2).status + val actualResultsTwoAnomalyDetectionDataPoint = + verificationResultOne.checkResults.values.toSeq(2).constraintResults.head + .anomalyDetectionExtendedResultOption.get.anomalyDetectionDataPoint + val expectedResultsTwoAnomalyDetectionDataPoint = + AnomalyDetectionDataPoint(11.0, 7.0, BoundedRange(Bound(-7.0, inclusive = true), + Bound(7.0, inclusive = true)), isAnomaly = false, 1.0) + assert(checkResultsOne == CheckStatus.Warning) assert(checkResultsTwo == CheckStatus.Success) diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/AbsoluteChangeStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/AbsoluteChangeStrategyTest.scala index f970f9812..1c435b546 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/AbsoluteChangeStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/AbsoluteChangeStrategyTest.scala @@ -17,7 +17,8 @@ package com.amazon.deequ.anomalydetection import breeze.linalg.DenseVector -import org.scalatest.{Matchers, WordSpec} +import org.scalatest.Matchers +import org.scalatest.WordSpec class AbsoluteChangeStrategyTest extends WordSpec with Matchers { @@ -158,35 +159,35 @@ class AbsoluteChangeStrategyTest extends WordSpec with Matchers { "detect all anomalies if no interval specified" in { val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) - val expectedAnomalyThreshold = Threshold(Bound(-2.0), Bound(2.0)) + val expectedAnomalyCheckRange = BoundedRange(Bound(-2.0, inclusive = true), Bound(2.0, inclusive = true)) val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (20, AnomalyDetectionDataPoint(20, 19, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (21, AnomalyDetectionDataPoint(-21, -41, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (22, AnomalyDetectionDataPoint(22, 43, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (23, AnomalyDetectionDataPoint(-23, -45, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (24, AnomalyDetectionDataPoint(24, 47, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (25, AnomalyDetectionDataPoint(-25, -49, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (26, AnomalyDetectionDataPoint(26, 51, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (27, AnomalyDetectionDataPoint(-27, -53, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (28, AnomalyDetectionDataPoint(28, 55, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (29, AnomalyDetectionDataPoint(-29, -57, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (30, AnomalyDetectionDataPoint(30, 59, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (31, AnomalyDetectionDataPoint(1, -29, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (20, AnomalyDetectionDataPoint(20, 19, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (21, AnomalyDetectionDataPoint(-21, -41, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(22, 43, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(-23, -45, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(24, 47, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(-25, -49, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 51, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(-27, -53, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 55, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(-29, -57, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 59, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, -29, expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } "only detect anomalies in interval" in { val anomalyResult = strategy.detectWithExtendedResults(data, (25, 50)).filter({case (_, anom) => anom.isAnomaly}) - val expectedAnomalyThreshold = Threshold(Bound(-2.0), Bound(2.0)) + val expectedAnomalyCheckRange = BoundedRange(Bound(-2.0, inclusive = true), Bound(2.0, inclusive = true)) val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (25, AnomalyDetectionDataPoint(-25, -49, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (26, AnomalyDetectionDataPoint(26, 51, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (27, AnomalyDetectionDataPoint(-27, -53, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (28, AnomalyDetectionDataPoint(28, 55, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (29, AnomalyDetectionDataPoint(-29, -57, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (30, AnomalyDetectionDataPoint(30, 59, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (31, AnomalyDetectionDataPoint(1, -29, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (25, AnomalyDetectionDataPoint(-25, -49, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 51, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(-27, -53, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 55, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(-29, -57, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 59, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, -29, expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -194,15 +195,16 @@ class AbsoluteChangeStrategyTest extends WordSpec with Matchers { "ignore min rate if none is given" in { val strategy = AbsoluteChangeStrategy(None, Some(1.0)) val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) - val expectedAnomalyThreshold = Threshold(upperBound = Bound(1.0)) + val expectedAnomalyCheckRange = BoundedRange(lowerBound = Bound(Double.MinValue, inclusive = true), + upperBound = Bound(1.0, inclusive = true)) // Anomalies with positive values only val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (20, AnomalyDetectionDataPoint(20, 19, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (22, AnomalyDetectionDataPoint(22, 43, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (24, AnomalyDetectionDataPoint(24, 47, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (26, AnomalyDetectionDataPoint(26, 51, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (28, AnomalyDetectionDataPoint(28, 55, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (30, AnomalyDetectionDataPoint(30, 59, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (20, AnomalyDetectionDataPoint(20, 19, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(22, 43, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(24, 47, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 51, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 55, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 59, expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) @@ -211,16 +213,17 @@ class AbsoluteChangeStrategyTest extends WordSpec with Matchers { "ignore max rate if none is given" in { val strategy = AbsoluteChangeStrategy(Some(-1.0), None) val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) - val expectedAnomalyThreshold = Threshold(lowerBound = Bound(-1.0)) + val expectedAnomalyCheckRange = BoundedRange(lowerBound = Bound(-1.0, inclusive = true), + upperBound = Bound(Double.MaxValue, inclusive = true)) // Anomalies with negative values only val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (21, AnomalyDetectionDataPoint(-21, -41, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (23, AnomalyDetectionDataPoint(-23, -45, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (25, AnomalyDetectionDataPoint(-25, -49, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (27, AnomalyDetectionDataPoint(-27, -53, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (29, AnomalyDetectionDataPoint(-29, -57, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (31, AnomalyDetectionDataPoint(1, -29, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (21, AnomalyDetectionDataPoint(-21, -41, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(-23, -45, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(-25, -49, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(-27, -53, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(-29, -57, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, -29, expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -239,8 +242,10 @@ class AbsoluteChangeStrategyTest extends WordSpec with Matchers { val result = strategy.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) val expectedResult = Seq( - (4, AnomalyDetectionDataPoint(18.0, 9.0, Threshold(upperBound = Bound(8.0)), isAnomaly = true, 1.0)), - (5, AnomalyDetectionDataPoint(72.0, 42.0, Threshold(upperBound = Bound(8.0)), isAnomaly = true, 1.0)) + (4, AnomalyDetectionDataPoint(18.0, 9.0, BoundedRange(lowerBound = Bound(Double.MinValue, inclusive = true), + upperBound = Bound(8.0, inclusive = true)), isAnomaly = true, 1.0)), + (5, AnomalyDetectionDataPoint(72.0, 42.0, BoundedRange(lowerBound = Bound(Double.MinValue, inclusive = true), + upperBound = Bound(8.0, inclusive = true)), isAnomaly = true, 1.0)) ) assert(result == expectedResult) } @@ -251,7 +256,8 @@ class AbsoluteChangeStrategyTest extends WordSpec with Matchers { val result = strategy.detectWithExtendedResults(data, (5, 6)).filter({case (_, anom) => anom.isAnomaly}) val expectedResult = Seq( - (5, AnomalyDetectionDataPoint(72.0, 42.0, Threshold(upperBound = Bound(8.0)), isAnomaly = true, 1.0)) + (5, AnomalyDetectionDataPoint(72.0, 42.0, BoundedRange(lowerBound = Bound(Double.MinValue, inclusive = true), + upperBound = Bound(8.0, inclusive = true)), isAnomaly = true, 1.0)) ) assert(result == expectedResult) } @@ -261,8 +267,10 @@ class AbsoluteChangeStrategyTest extends WordSpec with Matchers { val result = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) val expectedResult = Seq( - (2, AnomalyDetectionDataPoint(4.0, 5.0, Threshold(Bound(-2.0), Bound(2.0)), isAnomaly = true, 1.0)), - (3, AnomalyDetectionDataPoint(-7.0, -11.0, Threshold(Bound(-2.0), Bound(2.0)), isAnomaly = true, 1.0)) + (2, AnomalyDetectionDataPoint(4.0, 5.0, BoundedRange(Bound(-2.0, inclusive = true), + Bound(2.0, inclusive = true)), isAnomaly = true, 1.0)), + (3, AnomalyDetectionDataPoint(-7.0, -11.0, BoundedRange(Bound(-2.0, inclusive = true), + Bound(2.0, inclusive = true)), isAnomaly = true, 1.0)) ) assert(result == expectedResult) } @@ -292,8 +300,8 @@ class AbsoluteChangeStrategyTest extends WordSpec with Matchers { result.foreach { case (_, anom) => val value = anom.anomalyMetricValue - val upperBound = anom.anomalyThreshold.upperBound.value - val lowerBound = anom.anomalyThreshold.lowerBound.value + val upperBound = anom.anomalyCheckRange.upperBound.value + val lowerBound = anom.anomalyCheckRange.lowerBound.value assert(value < lowerBound || value > upperBound) } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/AnomalyDetectorTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/AnomalyDetectorTest.scala index 6068d111b..cdb87b763 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/AnomalyDetectorTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/AnomalyDetectorTest.scala @@ -17,7 +17,9 @@ package com.amazon.deequ.anomalydetection import org.scalamock.scalatest.MockFactory -import org.scalatest.{Matchers, PrivateMethodTester, WordSpec} +import org.scalatest.Matchers +import org.scalatest.PrivateMethodTester +import org.scalatest.WordSpec class AnomalyDetectorTest extends WordSpec with Matchers with MockFactory with PrivateMethodTester { @@ -111,6 +113,10 @@ class AnomalyDetectorTest extends WordSpec with Matchers with MockFactory with P val fakeAnomalyDetector = stub[AnomalyDetectionStrategyWithExtendedResults] + // This is used as a default bounded range value for anomaly detection + val defaultBoundedRange = BoundedRange(lowerBound = Bound(0.0, inclusive = true), + upperBound = Bound(1.0, inclusive = true)) + val aD = AnomalyDetectorWithExtendedResults(fakeAnomalyDetector) val data = Seq((0L, -1.0), (1L, 2.0), (2L, 3.0), (3L, 0.5)).map { case (t, v) => DataPoint[Double](t, Option(v)) @@ -121,20 +127,24 @@ class AnomalyDetectorTest extends WordSpec with Matchers with MockFactory with P DataPoint[Double](2L, None), DataPoint[Double](3L, Option(1.0))) (fakeAnomalyDetector.detectWithExtendedResults _ when(Vector(1.0, 2.0, 1.0), (0, 3))) - .returns(Seq((1, AnomalyDetectionDataPoint(2.0, 2.0, Threshold(), confidence = 1.0)))) + .returns(Seq((1, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, + isAnomaly = true)))) val anomalyResult = aD.detectAnomaliesInHistoryWithExtendedResults(data, (0L, 4L)) - assert(anomalyResult == ExtendedDetectionResult(Seq((1L, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0))))) + assert(anomalyResult == ExtendedDetectionResult(Seq((1L, AnomalyDetectionDataPoint(2.0, 2.0, + defaultBoundedRange, confidence = 1.0, isAnomaly = true))))) } "only detect values in range" in { (fakeAnomalyDetector.detectWithExtendedResults _ when(Vector(-1.0, 2.0, 3.0, 0.5), (2, 4))) - .returns(Seq((2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)))) + .returns(Seq((2, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, + isAnomaly = true)))) val anomalyResult = aD.detectAnomaliesInHistoryWithExtendedResults(data, (2L, 4L)) - assert(anomalyResult == ExtendedDetectionResult(Seq((2L, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0))))) + assert(anomalyResult == ExtendedDetectionResult(Seq((2L, AnomalyDetectionDataPoint(3.0, 3.0, + defaultBoundedRange, confidence = 1.0, isAnomaly = true))))) } "throw an error when intervals are not ordered" in { @@ -153,16 +163,17 @@ class AnomalyDetectorTest extends WordSpec with Matchers with MockFactory with P (fakeAnomalyDetector.detectWithExtendedResults _ when(data.map(_.metricValue.get).toVector, (0, 2))) .returns ( Seq( - (0, AnomalyDetectionDataPoint(5.0, 5.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(5.0, 5.0, confidence = 1.0)) + (0, AnomalyDetectionDataPoint(5.0, 5.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)), + (1, AnomalyDetectionDataPoint(5.0, 5.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)) ) ) val anomalyResult = aD.detectAnomaliesInHistoryWithExtendedResults(data, (200L, 401L)) assert(anomalyResult == ExtendedDetectionResult(Seq( - (200L, AnomalyDetectionDataPoint(5.0, 5.0, confidence = 1.0)), - (400L, AnomalyDetectionDataPoint(5.0, 5.0, confidence = 1.0))))) + (200L, AnomalyDetectionDataPoint(5.0, 5.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)), + (400L, AnomalyDetectionDataPoint(5.0, 5.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true))) + )) } "treat unordered values with time gaps correctly" in { @@ -174,18 +185,18 @@ class AnomalyDetectorTest extends WordSpec with Matchers with MockFactory with P (fakeAnomalyDetector.detectWithExtendedResults _ when(Vector(0.5, -1.0, 3.0, 2.0), (0, 4))) .returns( Seq( - (1, AnomalyDetectionDataPoint(-1.0, -1.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), - (3, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)) + (1, AnomalyDetectionDataPoint(-1.0, -1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)), + (3, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)) ) ) val anomalyResult = aD.detectAnomaliesInHistoryWithExtendedResults(data) assert(anomalyResult == ExtendedDetectionResult( - Seq((10L, AnomalyDetectionDataPoint(-1.0, -1.0, confidence = 1.0)), - (11L, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), - (25L, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0))))) + Seq((10L, AnomalyDetectionDataPoint(-1.0, -1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)), + (11L, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)), + (25L, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true))))) } "treat unordered values without time gaps correctly" in { @@ -194,16 +205,17 @@ class AnomalyDetectorTest extends WordSpec with Matchers with MockFactory with P } (fakeAnomalyDetector.detectWithExtendedResults _ when(Vector(0.5, -1.0, 3.0, 2.0), (0, 4))) - .returns(Seq((1, AnomalyDetectionDataPoint(-1.0, -1.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), - (3, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)))) + .returns(Seq((1, AnomalyDetectionDataPoint(-1.0, -1.0, defaultBoundedRange, confidence = 1.0, + isAnomaly = true)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)), + (3, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)))) val anomalyResult = aD.detectAnomaliesInHistoryWithExtendedResults(data) assert(anomalyResult == ExtendedDetectionResult(Seq( - (1L, AnomalyDetectionDataPoint(-1.0, -1.0, confidence = 1.0)), - (2L, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), - (3L, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0))))) + (1L, AnomalyDetectionDataPoint(-1.0, -1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)), + (2L, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)), + (3L, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true))))) } } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategyTest.scala index 0575ad3f7..1634053eb 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/BatchNormalStrategyTest.scala @@ -16,7 +16,8 @@ package com.amazon.deequ.anomalydetection -import org.scalatest.{Matchers, WordSpec} +import org.scalatest.Matchers +import org.scalatest.WordSpec import scala.util.Random @@ -117,14 +118,15 @@ class BatchNormalStrategyTest extends WordSpec with Matchers { val anomalyResult = strategy.detectWithExtendedResults(data, (25, 50)).filter({ case (_, anom) => anom.isAnomaly }) - val expectedAnomalyThreshold = Threshold(Bound(-9.280850004177061), Bound(10.639954755150061)) + val expectedAnomalyCheckRange = BoundedRange(Bound(-9.280850004177061, inclusive = true), + Bound(10.639954755150061, inclusive = true)) val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (25, AnomalyDetectionDataPoint(data(25), data(25), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (26, AnomalyDetectionDataPoint(data(26), data(26), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (27, AnomalyDetectionDataPoint(data(27), data(27), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (28, AnomalyDetectionDataPoint(data(28), data(28), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (29, AnomalyDetectionDataPoint(data(29), data(29), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (30, AnomalyDetectionDataPoint(data(30), data(30), expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (25, AnomalyDetectionDataPoint(data(25), data(25), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(data(26), data(26), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(data(27), data(27), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(data(28), data(28), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(data(29), data(29), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(data(30), data(30), expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -134,15 +136,16 @@ class BatchNormalStrategyTest extends WordSpec with Matchers { val anomalyResult = strategy.detectWithExtendedResults(data, (20, 31)).filter({ case (_, anom) => anom.isAnomaly }) - val expectedAnomalyThreshold = Threshold(Bound(Double.NegativeInfinity), Bound(0.7781496015857838)) + val expectedAnomalyCheckRange = BoundedRange(Bound(Double.NegativeInfinity, inclusive = true), + Bound(0.7781496015857838, inclusive = true)) // Anomalies with positive values only val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (20, AnomalyDetectionDataPoint(data(20), data(20), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (22, AnomalyDetectionDataPoint(data(22), data(22), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (24, AnomalyDetectionDataPoint(data(24), data(24), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (26, AnomalyDetectionDataPoint(data(26), data(26), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (28, AnomalyDetectionDataPoint(data(28), data(28), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (30, AnomalyDetectionDataPoint(data(30), data(30), expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (20, AnomalyDetectionDataPoint(data(20), data(20), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(data(22), data(22), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(data(24), data(24), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(data(26), data(26), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(data(28), data(28), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(data(30), data(30), expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -151,15 +154,16 @@ class BatchNormalStrategyTest extends WordSpec with Matchers { val strategy = BatchNormalStrategy(Some(1.0), None) val anomalyResult = strategy.detectWithExtendedResults(data, (10, 30)).filter({ case (_, anom) => anom.isAnomaly }) - val expectedAnomalyThreshold = Threshold(Bound(-5.063730045618394), Bound(Double.PositiveInfinity)) + val expectedAnomalyCheckRange = BoundedRange(Bound(-5.063730045618394, inclusive = true), + Bound(Double.PositiveInfinity, inclusive = true)) // Anomalies with negative values only val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (21, AnomalyDetectionDataPoint(data(21), data(21), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (23, AnomalyDetectionDataPoint(data(23), data(23), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (25, AnomalyDetectionDataPoint(data(25), data(25), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (27, AnomalyDetectionDataPoint(data(27), data(27), expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (29, AnomalyDetectionDataPoint(data(29), data(29), expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (21, AnomalyDetectionDataPoint(data(21), data(21), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(data(23), data(23), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(data(25), data(25), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(data(27), data(27), expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(data(29), data(29), expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -171,8 +175,10 @@ class BatchNormalStrategyTest extends WordSpec with Matchers { strategy.detectWithExtendedResults(data, (3, 5)).filter({ case (_, anom) => anom.isAnomaly }) val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (3, AnomalyDetectionDataPoint(1000, 1000, Threshold(Bound(1.0), Bound(1.0)), isAnomaly = true, 1.0)), - (4, AnomalyDetectionDataPoint(500, 500, Threshold(Bound(1.0), Bound(1.0)), isAnomaly = true, 1.0)) + (3, AnomalyDetectionDataPoint(1000, 1000, BoundedRange(Bound(1.0, inclusive = true), + Bound(1.0, inclusive = true)), isAnomaly = true, 1.0)), + (4, AnomalyDetectionDataPoint(500, 500, BoundedRange(Bound(1.0, inclusive = true), + Bound(1.0, inclusive = true)), isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -209,8 +215,8 @@ class BatchNormalStrategyTest extends WordSpec with Matchers { result.foreach { case (_, anom) => val value = anom.anomalyMetricValue - val upperBound = anom.anomalyThreshold.upperBound.value - val lowerBound = anom.anomalyThreshold.lowerBound.value + val upperBound = anom.anomalyCheckRange.upperBound.value + val lowerBound = anom.anomalyCheckRange.lowerBound.value assert(value < lowerBound || value > upperBound) } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategyTest.scala index d9fdd4ebc..28f8ebdf7 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/OnlineNormalStrategyTest.scala @@ -16,7 +16,8 @@ package com.amazon.deequ.anomalydetection -import org.scalatest.{Matchers, WordSpec} +import org.scalatest.Matchers +import org.scalatest.WordSpec import breeze.stats.meanAndVariance import scala.util.Random @@ -169,27 +170,38 @@ class OnlineNormalStrategyTest extends WordSpec with Matchers { val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( (20, AnomalyDetectionDataPoint(data(20), data(20), - Threshold(Bound(-14.868489924421404), Bound(14.255383455388895)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-14.868489924421404, inclusive = true), Bound(14.255383455388895, inclusive = true)), + isAnomaly = true, 1.0)), (21, AnomalyDetectionDataPoint(data(21), data(21), - Threshold(Bound(-13.6338479733374), Bound(13.02074150430489)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-13.6338479733374, inclusive = true), Bound(13.02074150430489, inclusive = true)), + isAnomaly = true, 1.0)), (22, AnomalyDetectionDataPoint(data(22), data(22), - Threshold(Bound(-16.71733585267535), Bound(16.104229383642842)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-16.71733585267535, inclusive = true), Bound(16.104229383642842, inclusive = true)), + isAnomaly = true, 1.0)), (23, AnomalyDetectionDataPoint(data(23), data(23), - Threshold(Bound(-17.346915620547467), Bound(16.733809151514958)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-17.346915620547467, inclusive = true), Bound(16.733809151514958, inclusive = true)), + isAnomaly = true, 1.0)), (24, AnomalyDetectionDataPoint(data(24), data(24), - Threshold(Bound(-17.496117397890874), Bound(16.883010928858365)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-17.496117397890874, inclusive = true), Bound(16.883010928858365, inclusive = true)), + isAnomaly = true, 1.0)), (25, AnomalyDetectionDataPoint(data(25), data(25), - Threshold(Bound(-17.90391150851199), Bound(17.29080503947948)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-17.90391150851199, inclusive = true), Bound(17.29080503947948, inclusive = true)), + isAnomaly = true, 1.0)), (26, AnomalyDetectionDataPoint(data(26), data(26), - Threshold(Bound(-17.028892797350824), Bound(16.415786328318315)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-17.028892797350824, inclusive = true), Bound(16.415786328318315, inclusive = true)), + isAnomaly = true, 1.0)), (27, AnomalyDetectionDataPoint(data(27), data(27), - Threshold(Bound(-17.720100310354653), Bound(17.106993841322144)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-17.720100310354653, inclusive = true), Bound(17.106993841322144, inclusive = true)), + isAnomaly = true, 1.0)), (28, AnomalyDetectionDataPoint(data(28), data(28), - Threshold(Bound(-18.23663168508628), Bound(17.62352521605377)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-18.23663168508628, inclusive = true), Bound(17.62352521605377, inclusive = true)), + isAnomaly = true, 1.0)), (29, AnomalyDetectionDataPoint(data(29), data(29), - Threshold(Bound(-19.32641622778204), Bound(18.71330975874953)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-19.32641622778204, inclusive = true), Bound(18.71330975874953, inclusive = true)), + isAnomaly = true, 1.0)), (30, AnomalyDetectionDataPoint(data(30), data(30), - Threshold(Bound(-18.96540323993527), Bound(18.35229677090276)), isAnomaly = true, 1.0)) + BoundedRange(Bound(-18.96540323993527, inclusive = true), Bound(18.35229677090276, inclusive = true)), + isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -199,17 +211,23 @@ class OnlineNormalStrategyTest extends WordSpec with Matchers { val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( (25, AnomalyDetectionDataPoint(data(25), data(25), - Threshold(Bound(-15.630116599125694), Bound(16.989221350098695)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-15.630116599125694, inclusive = true), Bound(16.989221350098695, inclusive = true)), + isAnomaly = true, 1.0)), (26, AnomalyDetectionDataPoint(data(26), data(26), - Threshold(Bound(-14.963376676338362), Bound(16.322481427311363)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-14.963376676338362, inclusive = true), Bound(16.322481427311363, inclusive = true)), + isAnomaly = true, 1.0)), (27, AnomalyDetectionDataPoint(data(27), data(27), - Threshold(Bound(-15.131834814393196), Bound(16.490939565366197)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-15.131834814393196, inclusive = true), Bound(16.490939565366197, inclusive = true)), + isAnomaly = true, 1.0)), (28, AnomalyDetectionDataPoint(data(28), data(28), - Threshold(Bound(-14.76810451038132), Bound(16.12720926135432)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-14.76810451038132, inclusive = true), Bound(16.12720926135432, inclusive = true)), + isAnomaly = true, 1.0)), (29, AnomalyDetectionDataPoint(data(29), data(29), - Threshold(Bound(-15.078145049879462), Bound(16.437249800852463)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-15.078145049879462, inclusive = true), Bound(16.437249800852463, inclusive = true)), + isAnomaly = true, 1.0)), (30, AnomalyDetectionDataPoint(data(30), data(30), - Threshold(Bound(-14.540171084298914), Bound(15.899275835271913)), isAnomaly = true, 1.0)) + BoundedRange(Bound(-14.540171084298914, inclusive = true), Bound(15.899275835271913, inclusive = true)), + isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -222,17 +240,23 @@ class OnlineNormalStrategyTest extends WordSpec with Matchers { // Anomalies with positive values only val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( (20, AnomalyDetectionDataPoint(data(20), data(20), - Threshold(Bound(Double.NegativeInfinity), Bound(5.934276775443095)), isAnomaly = true, 1.0)), + BoundedRange(Bound(Double.NegativeInfinity, inclusive = true), Bound(5.934276775443095, inclusive = true)), + isAnomaly = true, 1.0)), (22, AnomalyDetectionDataPoint(data(22), data(22), - Threshold(Bound(Double.NegativeInfinity), Bound(7.979098353666404)), isAnomaly = true, 1.0)), + BoundedRange(Bound(Double.NegativeInfinity, inclusive = true), Bound(7.979098353666404, inclusive = true)), + isAnomaly = true, 1.0)), (24, AnomalyDetectionDataPoint(data(24), data(24), - Threshold(Bound(Double.NegativeInfinity), Bound(9.582136909647211)), isAnomaly = true, 1.0)), + BoundedRange(Bound(Double.NegativeInfinity, inclusive = true), Bound(9.582136909647211, inclusive = true)), + isAnomaly = true, 1.0)), (26, AnomalyDetectionDataPoint(data(26), data(26), - Threshold(Bound(Double.NegativeInfinity), Bound(10.320400087389258)), isAnomaly = true, 1.0)), + BoundedRange(Bound(Double.NegativeInfinity, inclusive = true), Bound(10.320400087389258, inclusive = true)), + isAnomaly = true, 1.0)), (28, AnomalyDetectionDataPoint(data(28), data(28), - Threshold(Bound(Double.NegativeInfinity), Bound(11.113502213504855)), isAnomaly = true, 1.0)), + BoundedRange(Bound(Double.NegativeInfinity, inclusive = true), Bound(11.113502213504855, inclusive = true)), + isAnomaly = true, 1.0)), (30, AnomalyDetectionDataPoint(data(30), data(30), - Threshold(Bound(Double.NegativeInfinity), Bound(11.776810456746686)), isAnomaly = true, 1.0)) + BoundedRange(Bound(Double.NegativeInfinity, inclusive = true), Bound(11.776810456746686, inclusive = true)), + isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -245,15 +269,20 @@ class OnlineNormalStrategyTest extends WordSpec with Matchers { // Anomalies with negative values only val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( (21, AnomalyDetectionDataPoint(data(21), data(21), - Threshold(Bound(-7.855820681098751), Bound(Double.PositiveInfinity)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-7.855820681098751, inclusive = true), Bound(Double.PositiveInfinity, inclusive = true)), + isAnomaly = true, 1.0)), (23, AnomalyDetectionDataPoint(data(23), data(23), - Threshold(Bound(-10.14631437278386), Bound(Double.PositiveInfinity)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-10.14631437278386, inclusive = true), Bound(Double.PositiveInfinity, inclusive = true)), + isAnomaly = true, 1.0)), (25, AnomalyDetectionDataPoint(data(25), data(25), - Threshold(Bound(-11.038751996286909), Bound(Double.PositiveInfinity)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-11.038751996286909, inclusive = true), Bound(Double.PositiveInfinity, inclusive = true)), + isAnomaly = true, 1.0)), (27, AnomalyDetectionDataPoint(data(27), data(27), - Threshold(Bound(-11.359107787232386), Bound(Double.PositiveInfinity)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-11.359107787232386, inclusive = true), Bound(Double.PositiveInfinity, inclusive = true)), + isAnomaly = true, 1.0)), (29, AnomalyDetectionDataPoint(data(29), data(29), - Threshold(Bound(-12.097995027317015), Bound(Double.PositiveInfinity)), isAnomaly = true, 1.0)) + BoundedRange(Bound(-12.097995027317015, inclusive = true), Bound(Double.PositiveInfinity, inclusive = true)), + isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -290,8 +319,8 @@ class OnlineNormalStrategyTest extends WordSpec with Matchers { result.foreach { case (_, anom) => val value = anom.anomalyMetricValue - val upperBound = anom.anomalyThreshold.upperBound.value - val lowerBound = anom.anomalyThreshold.lowerBound.value + val upperBound = anom.anomalyCheckRange.upperBound.value + val lowerBound = anom.anomalyCheckRange.lowerBound.value assert(value < lowerBound || value > upperBound) } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/RateOfChangeStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/RateOfChangeStrategyTest.scala index d0e6ccba9..7c87b85ee 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/RateOfChangeStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/RateOfChangeStrategyTest.scala @@ -16,7 +16,8 @@ package com.amazon.deequ.anomalydetection -import org.scalatest.{Matchers, WordSpec} +import org.scalatest.Matchers +import org.scalatest.WordSpec /** * The tested class RateOfChangeStrategy is deprecated. @@ -44,7 +45,7 @@ class RateOfChangeStrategyTest extends WordSpec with Matchers { "detect all anomalies if no interval specified" in { val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) - val expectedAnomalyThreshold = Threshold(Bound(-2.0), Bound(2.0)) + val expectedAnomalyThreshold = BoundedRange(Bound(-2.0, inclusive = true), Bound(2.0, inclusive = true)) val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( (20, AnomalyDetectionDataPoint(20, 19, expectedAnomalyThreshold, isAnomaly = true, 1.0)), (21, AnomalyDetectionDataPoint(-21, -41, expectedAnomalyThreshold, isAnomaly = true, 1.0)), diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/RelativeRateOfChangeStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/RelativeRateOfChangeStrategyTest.scala index c6da5ae2b..bd09d0e97 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/RelativeRateOfChangeStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/RelativeRateOfChangeStrategyTest.scala @@ -17,7 +17,8 @@ package com.amazon.deequ.anomalydetection import breeze.linalg.DenseVector -import org.scalatest.{Matchers, WordSpec} +import org.scalatest.Matchers +import org.scalatest.WordSpec class RelativeRateOfChangeStrategyTest extends WordSpec with Matchers { @@ -151,20 +152,20 @@ class RelativeRateOfChangeStrategyTest extends WordSpec with Matchers { "detect all anomalies if no interval specified" in { val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) - val expectedAnomalyThreshold = Threshold(Bound(0.5), Bound(2.0)) + val expectedAnomalyCheckRange = BoundedRange(Bound(0.5, inclusive = true), Bound(2.0, inclusive = true)) val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (20, AnomalyDetectionDataPoint(20, 20, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (21, AnomalyDetectionDataPoint(1, 0.05, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (22, AnomalyDetectionDataPoint(22, 22, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (23, AnomalyDetectionDataPoint(1, 0.045454545454545456, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (24, AnomalyDetectionDataPoint(24, 24, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (25, AnomalyDetectionDataPoint(1, 0.041666666666666664, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (26, AnomalyDetectionDataPoint(26, 26, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (27, AnomalyDetectionDataPoint(1, 0.038461538461538464, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (28, AnomalyDetectionDataPoint(28, 28, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (29, AnomalyDetectionDataPoint(1, 0.03571428571428571, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (30, AnomalyDetectionDataPoint(30, 30, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (31, AnomalyDetectionDataPoint(1, 0.03333333333333333, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (20, AnomalyDetectionDataPoint(20, 20, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (21, AnomalyDetectionDataPoint(1, 0.05, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(22, 22, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(1, 0.045454545454545456, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(24, 24, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(1, 0.041666666666666664, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 26, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(1, 0.038461538461538464, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 28, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(1, 0.03571428571428571, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 30, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, 0.03333333333333333, expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -172,15 +173,15 @@ class RelativeRateOfChangeStrategyTest extends WordSpec with Matchers { "only detect anomalies in interval" in { val anomalyResult = strategy.detectWithExtendedResults(data, (25, 50)).filter({case (_, anom) => anom.isAnomaly}) - val expectedAnomalyThreshold = Threshold(Bound(0.5), Bound(2.0)) + val expectedAnomalyCheckRange = BoundedRange(Bound(0.5, inclusive = true), Bound(2.0, inclusive = true)) val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (25, AnomalyDetectionDataPoint(1, 0.041666666666666664, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (26, AnomalyDetectionDataPoint(26, 26, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (27, AnomalyDetectionDataPoint(1, 0.038461538461538464, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (28, AnomalyDetectionDataPoint(28, 28, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (29, AnomalyDetectionDataPoint(1, 0.03571428571428571, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (30, AnomalyDetectionDataPoint(30, 30, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (31, AnomalyDetectionDataPoint(1, 0.03333333333333333, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (25, AnomalyDetectionDataPoint(1, 0.041666666666666664, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 26, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(1, 0.038461538461538464, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 28, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(1, 0.03571428571428571, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 30, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, 0.03333333333333333, expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -190,14 +191,15 @@ class RelativeRateOfChangeStrategyTest extends WordSpec with Matchers { val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) // Anomalies with positive values only - val expectedAnomalyThreshold = Threshold(Bound(-1.7976931348623157E308), Bound(1.0)) + val expectedAnomalyCheckRange = BoundedRange(Bound(-1.7976931348623157E308, inclusive = true), + Bound(1.0, inclusive = true)) val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (20, AnomalyDetectionDataPoint(20, 20, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (22, AnomalyDetectionDataPoint(22, 22, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (24, AnomalyDetectionDataPoint(24, 24, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (26, AnomalyDetectionDataPoint(26, 26, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (28, AnomalyDetectionDataPoint(28, 28, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (30, AnomalyDetectionDataPoint(30, 30, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (20, AnomalyDetectionDataPoint(20, 20, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (22, AnomalyDetectionDataPoint(22, 22, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (24, AnomalyDetectionDataPoint(24, 24, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (26, AnomalyDetectionDataPoint(26, 26, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (28, AnomalyDetectionDataPoint(28, 28, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (30, AnomalyDetectionDataPoint(30, 30, expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -207,14 +209,15 @@ class RelativeRateOfChangeStrategyTest extends WordSpec with Matchers { val anomalyResult = strategy.detectWithExtendedResults(data).filter({case (_, anom) => anom.isAnomaly}) // Anomalies with negative values only - val expectedAnomalyThreshold = Threshold(Bound(0.5), Bound(1.7976931348623157E308)) + val expectedAnomalyCheckRange = BoundedRange(Bound(0.5, inclusive = true), + Bound(1.7976931348623157E308, inclusive = true)) val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( - (21, AnomalyDetectionDataPoint(1, 0.05, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (23, AnomalyDetectionDataPoint(1, 0.045454545454545456, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (25, AnomalyDetectionDataPoint(1, 0.041666666666666664, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (27, AnomalyDetectionDataPoint(1, 0.038461538461538464, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (29, AnomalyDetectionDataPoint(1, 0.03571428571428571, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (31, AnomalyDetectionDataPoint(1, 0.03333333333333333, expectedAnomalyThreshold, isAnomaly = true, 1.0)) + (21, AnomalyDetectionDataPoint(1, 0.05, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (23, AnomalyDetectionDataPoint(1, 0.045454545454545456, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (25, AnomalyDetectionDataPoint(1, 0.041666666666666664, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (27, AnomalyDetectionDataPoint(1, 0.038461538461538464, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (29, AnomalyDetectionDataPoint(1, 0.03571428571428571, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (31, AnomalyDetectionDataPoint(1, 0.03333333333333333, expectedAnomalyCheckRange, isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -234,9 +237,11 @@ class RelativeRateOfChangeStrategyTest extends WordSpec with Matchers { val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( (2, AnomalyDetectionDataPoint(3, Double.PositiveInfinity, - Threshold(Bound(-1.7976931348623157E308), Bound(8.0)), isAnomaly = true, 1.0)), + BoundedRange(Bound(-1.7976931348623157E308, inclusive = true), Bound(8.0, inclusive = true)), + isAnomaly = true, 1.0)), (5, AnomalyDetectionDataPoint(72, 12, - Threshold(Bound(-1.7976931348623157E308), Bound(8.0)), isAnomaly = true, 1.0)) + BoundedRange(Bound(-1.7976931348623157E308, inclusive = true), Bound(8.0, inclusive = true)), + isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -248,7 +253,8 @@ class RelativeRateOfChangeStrategyTest extends WordSpec with Matchers { val expectedResult: Seq[(Int, AnomalyDetectionDataPoint)] = Seq( (5, AnomalyDetectionDataPoint(72, 12, - Threshold(Bound(-1.7976931348623157E308), Bound(8.0)), isAnomaly = true, 1.0)) + BoundedRange(Bound(-1.7976931348623157E308, inclusive = true), Bound(8.0, inclusive = true)), + isAnomaly = true, 1.0)) ) assert(anomalyResult == expectedResult) } @@ -277,8 +283,8 @@ class RelativeRateOfChangeStrategyTest extends WordSpec with Matchers { result.foreach { case (_, anom) => val value = anom.anomalyMetricValue - val upperBound = anom.anomalyThreshold.upperBound.value - val lowerBound = anom.anomalyThreshold.lowerBound.value + val upperBound = anom.anomalyCheckRange.upperBound.value + val lowerBound = anom.anomalyCheckRange.lowerBound.value assert(value < lowerBound || value > upperBound) } diff --git a/src/test/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategyTest.scala b/src/test/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategyTest.scala index f8396c677..28d49d4c2 100644 --- a/src/test/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategyTest.scala +++ b/src/test/scala/com/amazon/deequ/anomalydetection/SimpleThresholdStrategyTest.scala @@ -16,7 +16,8 @@ package com.amazon.deequ.anomalydetection -import org.scalatest.{Matchers, WordSpec} +import org.scalatest.Matchers +import org.scalatest.WordSpec class SimpleThresholdStrategyTest extends WordSpec with Matchers { @@ -73,10 +74,11 @@ class SimpleThresholdStrategyTest extends WordSpec with Matchers { "Simple Threshold Strategy with Extended Results" should { val (strategy, data) = setupDefaultStrategyAndData() - val expectedAnomalyThreshold = Threshold(upperBound = Bound(1.0)) + val expectedAnomalyCheckRange = BoundedRange(lowerBound = Bound(Double.MinValue, inclusive = true), + upperBound = Bound(1.0, inclusive = true)) val expectedResult = Seq( - (1, AnomalyDetectionDataPoint(2.0, 2.0, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (2, AnomalyDetectionDataPoint(3.0, 3.0, expectedAnomalyThreshold, isAnomaly = true, 1.0))) + (1, AnomalyDetectionDataPoint(2.0, 2.0, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, expectedAnomalyCheckRange, isAnomaly = true, 1.0))) "detect values above threshold" in { val anomalyResult = @@ -102,11 +104,11 @@ class SimpleThresholdStrategyTest extends WordSpec with Matchers { "work with upper and lower threshold" in { val tS = SimpleThresholdStrategy(lowerBound = -0.5, upperBound = 1.0) val anomalyResult = tS.detectWithExtendedResults(data).filter({ case (_, anom) => anom.isAnomaly }) - val expectedAnomalyThreshold = Threshold(Bound(-0.5), Bound(1.0)) + val expectedAnomalyCheckRange = BoundedRange(Bound(-0.5, inclusive = true), Bound(1.0, inclusive = true)) val expectedResult = Seq( - (0, AnomalyDetectionDataPoint(-1.0, -1.0, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (1, AnomalyDetectionDataPoint(2.0, 2.0, expectedAnomalyThreshold, isAnomaly = true, 1.0)), - (2, AnomalyDetectionDataPoint(3.0, 3.0, expectedAnomalyThreshold, isAnomaly = true, 1.0))) + (0, AnomalyDetectionDataPoint(-1.0, -1.0, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, expectedAnomalyCheckRange, isAnomaly = true, 1.0)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, expectedAnomalyCheckRange, isAnomaly = true, 1.0))) assert(anomalyResult == expectedResult) } @@ -128,8 +130,8 @@ class SimpleThresholdStrategyTest extends WordSpec with Matchers { result.foreach { case (_, anom) => val value = anom.anomalyMetricValue - val upperBound = anom.anomalyThreshold.upperBound.value - val lowerBound = anom.anomalyThreshold.lowerBound.value + val upperBound = anom.anomalyCheckRange.upperBound.value + val lowerBound = anom.anomalyCheckRange.lowerBound.value assert(value < lowerBound || value > upperBound) } diff --git a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala index e768d0f37..31c5209d7 100644 --- a/src/test/scala/com/amazon/deequ/checks/CheckTest.scala +++ b/src/test/scala/com/amazon/deequ/checks/CheckTest.scala @@ -25,6 +25,8 @@ import com.amazon.deequ.anomalydetection.AnomalyDetectionAssertionResult import com.amazon.deequ.anomalydetection.AnomalyDetectionDataPoint import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategy import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategyWithExtendedResults +import com.amazon.deequ.anomalydetection.Bound +import com.amazon.deequ.anomalydetection.BoundedRange import com.amazon.deequ.anomalydetection.ExtendedDetectionResult import com.amazon.deequ.checks.Check.getNewestPointAnomalyResults import com.amazon.deequ.constraints.ConstrainableDataTypes @@ -53,6 +55,10 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix import CheckTest._ + // This is used as a default bounded range value for anomaly detection tests. + private[this] val defaultBoundedRange = BoundedRange(lowerBound = Bound(0.0, inclusive = true), + upperBound = Bound(1.0, inclusive = true)) + "Check" should { "return the correct check status for completeness" in withSparkSession { sparkSession => @@ -1177,38 +1183,38 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix (fakeAnomalyDetector.detectWithExtendedResults _) .expects(Vector(1.0, 2.0, 3.0, 4.0, 11.0), (4, 5)) .returns(Seq( - (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), - (3, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), - (4, AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0)))) + (0, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (3, AnomalyDetectionDataPoint(4.0, 4.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (4, AnomalyDetectionDataPoint(11.0, 11.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)))) .once() (fakeAnomalyDetector.detectWithExtendedResults _).expects(Vector(1.0, 2.0, 3.0, 4.0, 4.0), (4, 5)) .returns(Seq( - (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), - (3, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), - (4, AnomalyDetectionDataPoint(4.0, 4.0, isAnomaly = true, confidence = 1.0)))) + (0, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (3, AnomalyDetectionDataPoint(4.0, 4.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (4, AnomalyDetectionDataPoint(4.0, 4.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)))) .once() // Distinctness results (fakeAnomalyDetector.detectWithExtendedResults _) .expects(Vector(1.0, 2.0, 3.0, 4.0, 1), (4, 5)) .returns(Seq( - (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), - (3, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), - (4, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + (0, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (3, AnomalyDetectionDataPoint(4.0, 4.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (4, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)))) .once() (fakeAnomalyDetector.detectWithExtendedResults _) .expects(Vector(1.0, 2.0, 3.0, 4.0, 1), (4, 5)) .returns(Seq( - (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), - (3, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), - (4, AnomalyDetectionDataPoint(1.0, 1.0, isAnomaly = true, confidence = 1.0)))) + (0, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (3, AnomalyDetectionDataPoint(4.0, 4.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (4, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)))) .once() } @@ -1249,15 +1255,15 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix (fakeAnomalyDetector.detectWithExtendedResults _) .expects(Vector(1.0, 2.0, 11.0), (2, 3)) .returns(Seq( - (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0)))) + (0, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(11.0, 11.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)))) .once() (fakeAnomalyDetector.detectWithExtendedResults _).expects(Vector(1.0, 2.0, 4.0), (2, 3)) .returns(Seq( - (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(4.0, 4.0, isAnomaly = true, confidence = 1.0)))) + (0, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(4.0, 4.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)))) .once() } @@ -1289,15 +1295,15 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix (fakeAnomalyDetector.detectWithExtendedResults _) .expects(Vector(3.0, 4.0, 11.0), (2, 3)) .returns(Seq( - (0, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0)))) + (0, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(4.0, 4.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(11.0, 11.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)))) .once() (fakeAnomalyDetector.detectWithExtendedResults _).expects(Vector(3.0, 4.0, 4.0), (2, 3)) .returns(Seq( - (0, AnomalyDetectionDataPoint(3.0, 3.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(4.0, 4.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(4.0, 4.0, isAnomaly = true, confidence = 1.0)))) + (0, AnomalyDetectionDataPoint(3.0, 3.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(4.0, 4.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(4.0, 4.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)))) .once() } @@ -1329,15 +1335,15 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix (fakeAnomalyDetector.detectWithExtendedResults _) .expects(Vector(1.0, 2.0, 11.0), (2, 3)) .returns(Seq( - (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0)))) + (0, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(11.0, 11.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)))) .once() (fakeAnomalyDetector.detectWithExtendedResults _).expects(Vector(1.0, 2.0, 4.0), (2, 3)) .returns(Seq( - (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(4.0, 4.0, isAnomaly = true, confidence = 1.0)))) + (0, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(4.0, 4.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)))) .once() } @@ -1364,26 +1370,26 @@ class CheckTest extends AnyWordSpec with Matchers with SparkContextSpec with Fix "with multiple data points" in { val anomalySequence: Seq[(Long, AnomalyDetectionDataPoint)] = Seq( - (0, AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)), - (1, AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)), - (2, AnomalyDetectionDataPoint(11.0, 11.0, isAnomaly = true, confidence = 1.0))) + (0, AnomalyDetectionDataPoint(1.0, 1.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (1, AnomalyDetectionDataPoint(2.0, 2.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)), + (2, AnomalyDetectionDataPoint(11.0, 11.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true))) val result: AnomalyDetectionAssertionResult = getNewestPointAnomalyResults(ExtendedDetectionResult(anomalySequence)) - assert(!result.hasNoAnomaly) - assert(result.anomalyDetectionExtendedResult.anomalyDetectionDataPoints.head == - AnomalyDetectionDataPoint(11.0, 11.0, isAnomaly = true, confidence = 1.0)) + assert(result.hasAnomaly) + assert(result.anomalyDetectionExtendedResult.anomalyDetectionDataPoint == + AnomalyDetectionDataPoint(11.0, 11.0, defaultBoundedRange, confidence = 1.0, isAnomaly = true)) } "getNewestPointAnomalyResults returns correct assertion result from anomaly detection data point sequence " + "with one data point" in { val anomalySequence: Seq[(Long, AnomalyDetectionDataPoint)] = Seq( - (0, AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0))) + (0, AnomalyDetectionDataPoint(11.0, 11.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false))) val result: AnomalyDetectionAssertionResult = getNewestPointAnomalyResults(ExtendedDetectionResult(anomalySequence)) - assert(result.hasNoAnomaly) - assert(result.anomalyDetectionExtendedResult.anomalyDetectionDataPoints.head == - AnomalyDetectionDataPoint(11.0, 11.0, confidence = 1.0)) + assert(!result.hasAnomaly) + assert(result.anomalyDetectionExtendedResult.anomalyDetectionDataPoint == + AnomalyDetectionDataPoint(11.0, 11.0, defaultBoundedRange, confidence = 1.0, isAnomaly = false)) } "assert getNewestPointAnomalyResults throws exception from empty anomaly detection sequence" in { diff --git a/src/test/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraintTest.scala b/src/test/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraintTest.scala index 213123a74..606b1966b 100644 --- a/src/test/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraintTest.scala +++ b/src/test/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraintTest.scala @@ -19,7 +19,11 @@ package com.amazon.deequ.constraints import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.analyzers._ import com.amazon.deequ.analyzers.runners.MetricCalculationException -import com.amazon.deequ.anomalydetection.{AnomalyDetectionAssertionResult, AnomalyDetectionDataPoint, AnomalyDetectionExtendedResult} +import com.amazon.deequ.anomalydetection.AnomalyDetectionAssertionResult +import com.amazon.deequ.anomalydetection.AnomalyDetectionDataPoint +import com.amazon.deequ.anomalydetection.AnomalyDetectionExtendedResult +import com.amazon.deequ.anomalydetection.Bound +import com.amazon.deequ.anomalydetection.BoundedRange import com.amazon.deequ.constraints.ConstraintUtils.calculate import com.amazon.deequ.metrics.{DoubleMetric, Entity, Metric} import com.amazon.deequ.utils.FixtureSupport @@ -54,7 +58,8 @@ class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with S override def calculate( data: DataFrame, stateLoader: Option[StateLoader], - statePersister: Option[StatePersister]) + statePersister: Option[StatePersister], + filterCondition: Option[String]) : DoubleMetric = { val value: Try[Double] = Try { require(data.columns.contains(column), s"Missing column $column") @@ -63,11 +68,11 @@ class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with S DoubleMetric(Entity.Column, "sample", column, value) } - override def computeStateFrom(data: DataFrame): Option[NumMatches] = { + override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None) + : Option[NumMatches] = { throw new NotImplementedError() } - override def computeMetricFrom(state: Option[NumMatches]): DoubleMetric = { throw new NotImplementedError() } @@ -75,15 +80,20 @@ class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with S "Anomaly extended results constraint" should { + val defaultBoundedRange = BoundedRange(lowerBound = Bound(0.0, inclusive = true), + upperBound = Bound(1.0, inclusive = true)) + "assert correctly on values if analysis is successful" in withSparkSession { sparkSession => val df = getDfMissing(sparkSession) // Analysis result should equal to 1.0 for an existing column - val anomalyAssertionFunctionA = (metric: Double) => { - AnomalyDetectionAssertionResult(metric == 1.0, - AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + val anomalyAssertionFunctionA = (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = false, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = false)) + ) } val resultA = calculate( @@ -94,10 +104,10 @@ class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with S assert(resultA.message.isEmpty) assert(resultA.metric.isDefined) - val anomalyAssertionFunctionB = (metric: Double) => { - AnomalyDetectionAssertionResult(metric != 1.0, - AnomalyDetectionExtendedResult( - Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, isAnomaly = true)))) + val anomalyAssertionFunctionB = (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = true, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = true))) } // Analysis result should equal to 1.0 for an existing column @@ -126,9 +136,11 @@ class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with S val df = getDfMissing(sparkSession) - val anomalyAssertionFunctionA = (metric: Double) => { - AnomalyDetectionAssertionResult(metric == 2.0, - AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)))) + val anomalyAssertionFunctionA = (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = false, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = false)) + ) } // Analysis result should equal to 100.0 for an existing column @@ -136,10 +148,10 @@ class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with S SampleAnalyzer("att1"), anomalyAssertionFunctionA, Some(valueDoubler)), df).status == ConstraintStatus.Success) - val anomalyAssertionFunctionB = (metric: Double) => { - AnomalyDetectionAssertionResult(metric != 2.0, - AnomalyDetectionExtendedResult( - Seq(AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0, isAnomaly = true)))) + val anomalyAssertionFunctionB = (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = true, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = true))) } assert(calculate(AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( @@ -164,14 +176,15 @@ class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with S SampleAnalyzer("someMissingColumn") -> SampleAnalyzer("someMissingColumn").calculate(df) ) - val anomalyAssertionFunctionA = (metric: Double) => { - AnomalyDetectionAssertionResult(metric == 1.0, - AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + val anomalyAssertionFunctionA = (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = false, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = false))) } - val anomalyAssertionFunctionB = (metric: Double) => { - AnomalyDetectionAssertionResult(metric != 1.0, - AnomalyDetectionExtendedResult( - Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, isAnomaly = true)))) + val anomalyAssertionFunctionB = (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = true, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = true))) } // Analysis result should equal to 1.0 for an existing column @@ -202,9 +215,10 @@ class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with S val validResults = Map[Analyzer[_, Metric[_]], Metric[_]]( SampleAnalyzer("att1") -> SampleAnalyzer("att1").calculate(df)) - val anomalyAssertionFunction = (metric: Double) => { - AnomalyDetectionAssertionResult(metric == 2.0, - AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0)))) + val anomalyAssertionFunction = (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = false, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(2.0, 2.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = false))) } assert(AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( @@ -224,9 +238,10 @@ class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with S val validResults = Map[Analyzer[_, Metric[_]], Metric[_]]( SampleAnalyzer("att1") -> SampleAnalyzer("att1").calculate(df)) - val anomalyAssertionFunction = (metric: Double) => { - AnomalyDetectionAssertionResult(metric == 1.0, - AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + val anomalyAssertionFunction = (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = false, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = false))) } val constraint = AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( SampleAnalyzer("att1"), anomalyAssertionFunction, Some(problematicValuePicker)) @@ -260,10 +275,10 @@ class AnomalyExtendedResultsConstraintTest extends WordSpec with Matchers with S val df = getDfMissing(sparkSession) - val anomalyAssertionFunction = (metric: Double) => { - AnomalyDetectionAssertionResult(metric == 0.9, - AnomalyDetectionExtendedResult( - Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, isAnomaly = true)))) + val anomalyAssertionFunction = (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = true, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = true))) } val failingConstraint = AnomalyExtendedResultsConstraint[NumMatches, Double, Double]( diff --git a/src/test/scala/com/amazon/deequ/constraints/ConstraintsTest.scala b/src/test/scala/com/amazon/deequ/constraints/ConstraintsTest.scala index ac426ef55..cd8d91d91 100644 --- a/src/test/scala/com/amazon/deequ/constraints/ConstraintsTest.scala +++ b/src/test/scala/com/amazon/deequ/constraints/ConstraintsTest.scala @@ -18,14 +18,21 @@ package com.amazon.deequ package constraints import com.amazon.deequ.utils.FixtureSupport -import org.scalatest.{Matchers, WordSpec} +import org.scalatest.Matchers +import org.scalatest.WordSpec import ConstraintUtils.calculate -import com.amazon.deequ.analyzers.{Completeness, NumMatchesAndCount} +import com.amazon.deequ.analyzers.Completeness +import com.amazon.deequ.analyzers.NumMatchesAndCount import org.apache.spark.sql.Row -import org.apache.spark.sql.types.{DoubleType, StringType} +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.StringType import Constraint._ import com.amazon.deequ.SparkContextSpec -import com.amazon.deequ.anomalydetection.{AnomalyDetectionAssertionResult, AnomalyDetectionDataPoint, AnomalyDetectionExtendedResult} +import com.amazon.deequ.anomalydetection.AnomalyDetectionAssertionResult +import com.amazon.deequ.anomalydetection.AnomalyDetectionDataPoint +import com.amazon.deequ.anomalydetection.AnomalyDetectionExtendedResult +import com.amazon.deequ.anomalydetection.Bound +import com.amazon.deequ.anomalydetection.BoundedRange class ConstraintsTest extends WordSpec with Matchers with SparkContextSpec with FixtureSupport { @@ -179,31 +186,26 @@ class ConstraintsTest extends WordSpec with Matchers with SparkContextSpec with "Anomaly constraint with Extended Results" should { "assert on anomaly analyzer values" in withSparkSession { sparkSession => val df = getDfMissing(sparkSession) + val defaultBoundedRange = BoundedRange(lowerBound = Bound(0.0, inclusive = true), + upperBound = Bound(1.0, inclusive = true)) + assert(calculate(Constraint.anomalyConstraintWithExtendedResults[NumMatchesAndCount]( - Completeness("att1"), (metric: Double) => { - AnomalyDetectionAssertionResult(metric > 0.4, - AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + Completeness("att1"), (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = false, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = false))) } ), df) .status == ConstraintStatus.Success) - assert(calculate(Constraint.anomalyConstraintWithExtendedResults[NumMatchesAndCount]( - Completeness("att1"), (metric: Double) => { - AnomalyDetectionAssertionResult(metric < 0.4, - AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) - }), df) - .status == ConstraintStatus.Failure) assert(calculate(Constraint.anomalyConstraintWithExtendedResults[NumMatchesAndCount]( - Completeness("att2"), (metric: Double) => { - AnomalyDetectionAssertionResult(metric > 0.7, - AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) - }), df) - .status == ConstraintStatus.Success) - assert(calculate(Constraint.anomalyConstraintWithExtendedResults[NumMatchesAndCount]( - Completeness("att2"), (metric: Double) => { - AnomalyDetectionAssertionResult(metric < 0.7, - AnomalyDetectionExtendedResult(Seq(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0)))) + Completeness("att1"), (_: Double) => { + AnomalyDetectionAssertionResult(hasAnomaly = true, + AnomalyDetectionExtendedResult(AnomalyDetectionDataPoint(1.0, 1.0, confidence = 1.0, + anomalyCheckRange = defaultBoundedRange, isAnomaly = true) + )) }), df) .status == ConstraintStatus.Failure) + } } } From 5da25c4d3738a5ac32f6cb2f1edeb9262bec6b2a Mon Sep 17 00:00:00 2001 From: Hubert Date: Mon, 4 Nov 2024 09:38:21 -0500 Subject: [PATCH 23/24] add accidentally removed import --- .../deequ/constraints/AnomalyExtendedResultsConstraint.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala b/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala index 3305aaa4f..03c374565 100644 --- a/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/AnomalyExtendedResultsConstraint.scala @@ -23,7 +23,7 @@ import com.amazon.deequ.metrics.Metric import org.apache.spark.sql.DataFrame import scala.util.Success -import scala.util.Success +import scala.util.Failure /** * Case class for anomaly with extended results constraints that provides unified way to access From 198a41fb368535deb9b01c74d349f1a010c9c14e Mon Sep 17 00:00:00 2001 From: Hubert Date: Fri, 8 Nov 2024 17:52:27 -0500 Subject: [PATCH 24/24] update readme to be more clear about the anomalyMetricValue --- .../examples/anomaly_detection_with_extended_results_example.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/com/amazon/deequ/examples/anomaly_detection_with_extended_results_example.md b/src/main/scala/com/amazon/deequ/examples/anomaly_detection_with_extended_results_example.md index a061f0e75..6b89b5d06 100644 --- a/src/main/scala/com/amazon/deequ/examples/anomaly_detection_with_extended_results_example.md +++ b/src/main/scala/com/amazon/deequ/examples/anomaly_detection_with_extended_results_example.md @@ -4,7 +4,7 @@ Using the `addAnomalyCheckWithExtendedResults` method instead of the original `a detailed results about the anomaly detection result from the newly created metric. You can get details such as: - dataMetricValue: The metric value that is the data point. -- anomalyMetricValue: The value of the metric that is being checked, which isn't always equal to the dataMetricValue. +- anomalyMetricValue: The metric value that is being checked for the anomaly detection strategy, which isn't always equal to the dataMetricValue. - anomalyCheckRange: The range of bounds used in the anomaly check, the anomalyMetricValue is compared to this range. - isAnomaly: If the anomalyMetricValue is outside the anomalyCheckRange, this is true. - confidence: The confidence of the anomaly detection.