From e033c2be2a7ec50b36218c9cbb3d5d96b153470d Mon Sep 17 00:00:00 2001 From: Edward Cho <114528615+eycho-am@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:40:50 -0400 Subject: [PATCH] Add commits from master branch to release/2.0.8-spark-3.3 (#589) * Configurable RetainCompletenessRule (#564) * Configurable RetainCompletenessRule * Add doc string * Add default completeness const * Optional specification of instance name in CustomSQL analyzer metric. (#569) Co-authored-by: Tyler Mcdaniel * Adding Wilson Score Confidence Interval Strategy (#567) * Configurable RetainCompletenessRule * Add doc string * Add default completeness const * Add ConfidenceIntervalStrategy * Add Separate Wilson and Wald Interval Test * Add License information, Fix formatting * Add License information * formatting fix * Update documentation * Make WaldInterval the default strategy for now * Formatting import to per line * Separate group import to per line import * CustomAggregator (#572) * Add support for EntityTypes dqdl rule * Add support for Conditional Aggregation Analyzer --------- Co-authored-by: Joshua Zexter * fix typo (#574) * Fix performance of building row-level results (#577) * Generate row-level results with withColumns Iteratively using withColumn (singular) causes performance issues when iterating over a large sequence of columns. * Add back UNIQUENESS_ID * Replace 'withColumns' with 'select' (#582) 'withColumns' was introduced in Spark 3.3, so it won't work for Deequ's <3.3 builds. * Replace rdd with dataframe functions in Histogram analyzer (#586) Co-authored-by: Shriya Vanvari * Match Breeze version with spark 3.3 (#562) * Updated version in pom.xml to 2.0.8-spark-3.3 --------- Co-authored-by: zeotuan <48720253+zeotuan@users.noreply.github.com> Co-authored-by: tylermcdaniel0 <144386264+tylermcdaniel0@users.noreply.github.com> Co-authored-by: Tyler Mcdaniel Co-authored-by: Joshua Zexter <67130377+joshuazexter@users.noreply.github.com> Co-authored-by: Joshua Zexter Co-authored-by: bojackli <478378663@qq.com> Co-authored-by: Josh <5685731+marcantony@users.noreply.github.com> Co-authored-by: Shriya Vanvari Co-authored-by: Shriya Vanvari --- pom.xml | 4 +- .../com/amazon/deequ/VerificationResult.scala | 7 +- .../deequ/analyzers/CustomAggregator.scala | 69 +++++ .../amazon/deequ/analyzers/CustomSql.scala | 14 +- .../amazon/deequ/analyzers/Histogram.scala | 21 +- .../ConstraintSuggestionExample.scala | 6 + .../examples/constraint_suggestion_example.md | 13 + .../com/amazon/deequ/metrics/Metric.scala | 17 ++ .../suggestions/ConstraintSuggestion.scala | 6 +- .../FractionalCategoricalRangeRule.scala | 12 +- .../rules/RetainCompletenessRule.scala | 34 ++- .../interval/ConfidenceIntervalStrategy.scala | 55 ++++ .../rules/interval/WaldIntervalStrategy.scala | 47 ++++ .../WilsonScoreIntervalStrategy.scala | 47 ++++ .../analyzers/CustomAggregatorTest.scala | 244 ++++++++++++++++++ .../deequ/analyzers/CustomSqlTest.scala | 19 +- .../rules/ConstraintRulesTest.scala | 98 +++++-- .../rules/interval/IntervalStrategyTest.scala | 59 +++++ 18 files changed, 709 insertions(+), 63 deletions(-) create mode 100644 src/main/scala/com/amazon/deequ/analyzers/CustomAggregator.scala create mode 100644 src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala create mode 100644 src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala create mode 100644 src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala create mode 100644 src/test/scala/com/amazon/deequ/analyzers/CustomAggregatorTest.scala create mode 100644 src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala diff --git a/pom.xml b/pom.xml index 0a0b41109..817099bb5 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ com.amazon.deequ deequ - 2.0.7-spark-3.3 + 2.0.8-spark-3.3 1.8 @@ -103,7 +103,7 @@ org.scalanlp breeze_${scala.major.version} - 0.13.2 + 1.2 diff --git a/src/main/scala/com/amazon/deequ/VerificationResult.scala b/src/main/scala/com/amazon/deequ/VerificationResult.scala index 6390db821..b9b450f2d 100644 --- a/src/main/scala/com/amazon/deequ/VerificationResult.scala +++ b/src/main/scala/com/amazon/deequ/VerificationResult.scala @@ -31,7 +31,7 @@ import com.amazon.deequ.repository.SimpleResultSerde import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.functions.{col, monotonically_increasing_id} import java.util.UUID @@ -96,11 +96,10 @@ object VerificationResult { data: DataFrame): DataFrame = { val columnNamesToMetrics: Map[String, Column] = verificationResultToColumn(verificationResult) + val columnsAliased = columnNamesToMetrics.toSeq.map { case (name, col) => col.as(name) } val dataWithID = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) - columnNamesToMetrics.foldLeft(dataWithID)( - (dataWithID, newColumn: (String, Column)) => - dataWithID.withColumn(newColumn._1, newColumn._2)).drop(UNIQUENESS_ID) + dataWithID.select(col("*") +: columnsAliased: _*).drop(UNIQUENESS_ID) } def checkResultsAsJson(verificationResult: VerificationResult, diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomAggregator.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomAggregator.scala new file mode 100644 index 000000000..d82c09312 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomAggregator.scala @@ -0,0 +1,69 @@ +/** + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.analyzers + +import com.amazon.deequ.metrics.AttributeDoubleMetric +import com.amazon.deequ.metrics.Entity +import org.apache.spark.sql.DataFrame + +import scala.util.Failure +import scala.util.Success +import scala.util.Try + +// Define a custom state to hold aggregation results +case class AggregatedMetricState(counts: Map[String, Int], total: Int) + extends DoubleValuedState[AggregatedMetricState] { + + override def sum(other: AggregatedMetricState): AggregatedMetricState = { + val combinedCounts = counts ++ other + .counts + .map { case (k, v) => k -> (v + counts.getOrElse(k, 0)) } + AggregatedMetricState(combinedCounts, total + other.total) + } + + override def metricValue(): Double = counts.values.sum.toDouble / total +} + +// Define the analyzer +case class CustomAggregator(aggregatorFunc: DataFrame => AggregatedMetricState, + metricName: String, + instance: String = "Dataset") + extends Analyzer[AggregatedMetricState, AttributeDoubleMetric] { + + override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None) + : Option[AggregatedMetricState] = { + Try(aggregatorFunc(data)) match { + case Success(state) => Some(state) + case Failure(_) => None + } + } + + override def computeMetricFrom(state: Option[AggregatedMetricState]): AttributeDoubleMetric = { + state match { + case Some(detState) => + val metrics = detState.counts.map { case (key, count) => + key -> (count.toDouble / detState.total) + } + AttributeDoubleMetric(Entity.Column, metricName, instance, Success(metrics)) + case None => + AttributeDoubleMetric(Entity.Column, metricName, instance, + Failure(new RuntimeException("Metric computation failed"))) + } + } + + override private[deequ] def toFailureMetric(failure: Exception): AttributeDoubleMetric = { + AttributeDoubleMetric(Entity.Column, metricName, instance, Failure(failure)) + } +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala index edd4f8e97..8e2e351b9 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -37,7 +37,7 @@ case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleVa override def metricValue(): Double = state } -case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] { +case class CustomSql(expression: String, disambiguator: String = "*") extends Analyzer[CustomSqlState, DoubleMetric] { /** * Compute the state (sufficient statistics) from the data * @@ -76,15 +76,19 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double state match { // The returned state may case Some(theState) => theState.stateOrError match { - case Left(value) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Success(value)) - case Right(error) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException(error))) + case Left(value) => DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Success(value)) + case Right(error) => DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Failure(new RuntimeException(error))) } case None => - DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run"))) + DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Failure(new RuntimeException("CustomSql Failed To Run"))) } } override private[deequ] def toFailureMetric(failure: Exception) = { - DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run"))) + DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator, + Failure(new RuntimeException("CustomSql Failed To Run"))) } } diff --git a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala index 742b2ba68..fbdb5b2b0 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala @@ -84,11 +84,26 @@ case class Histogram( case Some(theState) => val value: Try[Distribution] = Try { - val topNRows = theState.frequencies.rdd.top(maxDetailBins)(OrderByAbsoluteCount) + val countColumnName = theState.frequencies.schema.fields + .find(field => field.dataType == LongType && field.name != column) + .map(_.name) + .getOrElse(throw new IllegalStateException(s"Count column not found in the frequencies DataFrame")) + + val topNRowsDF = theState.frequencies + .orderBy(col(countColumnName).desc) + .limit(maxDetailBins) + .collect() + val binCount = theState.frequencies.count() - val histogramDetails = topNRows - .map { case Row(discreteValue: String, absolute: Long) => + val columnName = theState.frequencies.columns + .find(_ == column) + .getOrElse(throw new IllegalStateException(s"Column $column not found")) + + val histogramDetails = topNRowsDF + .map { row => + val discreteValue = row.getAs[String](columnName) + val absolute = row.getAs[Long](countColumnName) val ratio = absolute.toDouble / theState.numRows discreteValue -> DistributionValue(absolute, ratio) } diff --git a/src/main/scala/com/amazon/deequ/examples/ConstraintSuggestionExample.scala b/src/main/scala/com/amazon/deequ/examples/ConstraintSuggestionExample.scala index 8aa0fb6c5..fc8f458bf 100644 --- a/src/main/scala/com/amazon/deequ/examples/ConstraintSuggestionExample.scala +++ b/src/main/scala/com/amazon/deequ/examples/ConstraintSuggestionExample.scala @@ -17,6 +17,8 @@ package com.amazon.deequ.examples import com.amazon.deequ.examples.ExampleUtils.withSpark +import com.amazon.deequ.suggestions.rules.RetainCompletenessRule +import com.amazon.deequ.suggestions.rules.interval.WilsonScoreIntervalStrategy import com.amazon.deequ.suggestions.{ConstraintSuggestionRunner, Rules} private[examples] object ConstraintSuggestionExample extends App { @@ -51,6 +53,10 @@ private[examples] object ConstraintSuggestionExample extends App { val suggestionResult = ConstraintSuggestionRunner() .onData(data) .addConstraintRules(Rules.EXTENDED) + // We can also add our own constraint and customize constraint parameters + .addConstraintRule( + RetainCompletenessRule(intervalStrategy = WilsonScoreIntervalStrategy()) + ) .run() // We can now investigate the constraints that deequ suggested. We get a textual description diff --git a/src/main/scala/com/amazon/deequ/examples/constraint_suggestion_example.md b/src/main/scala/com/amazon/deequ/examples/constraint_suggestion_example.md index df159a9c9..472f63c7d 100644 --- a/src/main/scala/com/amazon/deequ/examples/constraint_suggestion_example.md +++ b/src/main/scala/com/amazon/deequ/examples/constraint_suggestion_example.md @@ -43,6 +43,17 @@ val suggestionResult = ConstraintSuggestionRunner() .run() ``` +Alternatively, we also support customizing and adding individual constraint rule using `addConstraintRule()` +```scala +val suggestionResult = ConstraintSuggestionRunner() + .onData(data) + + .addConstraintRule( + RetainCompletenessRule(intervalStrategy = WilsonScoreIntervalStrategy()) + ) + .run() +``` + We can now investigate the constraints that deequ suggested. We get a textual description and the corresponding scala code for each suggested constraint. Note that the constraint suggestion is based on heuristic rules and assumes that the data it is shown is 'static' and correct, which might often not be the case in the real world. Therefore the suggestions should always be manually reviewed before being applied in real deployments. ```scala suggestionResult.constraintSuggestions.foreach { case (column, suggestions) => @@ -92,3 +103,5 @@ The corresponding scala code is .isContainedIn("status", Array("DELAYED", "UNKNO Currently, we leave it up to the user to decide whether they want to apply the suggested constraints or not, and provide the corresponding Scala code for convenience. For larger datasets, it makes sense to evaluate the suggested constraints on some held-out portion of the data to see whether they hold or not. You can test this by adding an invocation of `.useTrainTestSplitWithTestsetRatio(0.1)` to the `ConstraintSuggestionRunner`. With this configuration, it would compute constraint suggestions on 90% of the data and evaluate the suggested constraints on the remaining 10%. Finally, we would also like to note that the constraint suggestion code provides access to the underlying [column profiles](https://github.com/awslabs/deequ/blob/master/src/main/scala/com/amazon/deequ/examples/data_profiling_example.md) that it computed via `suggestionResult.columnProfiles`. + +An [executable and extended version of this example](https://github.com/awslabs/deequ/blob/master/src/main/scala/com/amazon/deequ/examples/.scala) is part of our code base. diff --git a/src/main/scala/com/amazon/deequ/metrics/Metric.scala b/src/main/scala/com/amazon/deequ/metrics/Metric.scala index 30225e246..307b278d1 100644 --- a/src/main/scala/com/amazon/deequ/metrics/Metric.scala +++ b/src/main/scala/com/amazon/deequ/metrics/Metric.scala @@ -89,3 +89,20 @@ case class KeyedDoubleMetric( } } } + +case class AttributeDoubleMetric( + entity: Entity.Value, + name: String, + instance: String, + value: Try[Map[String, Double]]) + extends Metric[Map[String, Double]] { + + override def flatten(): Seq[DoubleMetric] = { + value match { + case Success(valuesMap) => valuesMap.map { case (key, metricValue) => + DoubleMetric(entity, s"$name.$key", instance, Success(metricValue)) + }.toSeq + case Failure(ex) => Seq(DoubleMetric(entity, name, instance, Failure(ex))) + } + } +} diff --git a/src/main/scala/com/amazon/deequ/suggestions/ConstraintSuggestion.scala b/src/main/scala/com/amazon/deequ/suggestions/ConstraintSuggestion.scala index 57ac9aeaa..96b7ab9b2 100644 --- a/src/main/scala/com/amazon/deequ/suggestions/ConstraintSuggestion.scala +++ b/src/main/scala/com/amazon/deequ/suggestions/ConstraintSuggestion.scala @@ -52,7 +52,7 @@ case class ConstraintSuggestionWithValue[T]( object ConstraintSuggestions { - private[this] val CONSTRANT_SUGGESTIONS_FIELD = "constraint_suggestions" + private[this] val CONSTRAINT_SUGGESTIONS_FIELD = "constraint_suggestions" private[suggestions] def toJson(constraintSuggestions: Seq[ConstraintSuggestion]): String = { @@ -68,7 +68,7 @@ object ConstraintSuggestions { constraintsJson.add(constraintJson) } - json.add(CONSTRANT_SUGGESTIONS_FIELD, constraintsJson) + json.add(CONSTRAINT_SUGGESTIONS_FIELD, constraintsJson) val gson = new GsonBuilder() .setPrettyPrinting() @@ -109,7 +109,7 @@ object ConstraintSuggestions { constraintEvaluations.add(constraintEvaluation) } - json.add(CONSTRANT_SUGGESTIONS_FIELD, constraintEvaluations) + json.add(CONSTRAINT_SUGGESTIONS_FIELD, constraintEvaluations) val gson = new GsonBuilder() .setPrettyPrinting() diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala index 55e410f33..f9dd192e8 100644 --- a/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala @@ -23,16 +23,17 @@ import com.amazon.deequ.metrics.DistributionValue import com.amazon.deequ.profiles.ColumnProfile import com.amazon.deequ.suggestions.ConstraintSuggestion import com.amazon.deequ.suggestions.ConstraintSuggestionWithValue +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.defaultIntervalStrategy +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy import org.apache.commons.lang3.StringEscapeUtils -import scala.math.BigDecimal.RoundingMode - /** If we see a categorical range for most values in a column, we suggest an IS IN (...) * constraint that should hold for most values */ case class FractionalCategoricalRangeRule( targetDataCoverageFraction: Double = 0.9, categorySorter: Array[(String, DistributionValue)] => Array[(String, DistributionValue)] = - categories => categories.sortBy({ case (_, value) => value.absolute }).reverse + categories => categories.sortBy({ case (_, value) => value.absolute }).reverse, + intervalStrategy: ConfidenceIntervalStrategy = defaultIntervalStrategy ) extends ConstraintRule[ColumnProfile] { override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = { @@ -79,11 +80,8 @@ case class FractionalCategoricalRangeRule( val p = ratioSums val n = numRecords - val z = 1.96 - // TODO this needs to be more robust for p's close to 0 or 1 - val targetCompliance = BigDecimal(p - z * math.sqrt(p * (1 - p) / n)) - .setScale(2, RoundingMode.DOWN).toDouble + val targetCompliance = intervalStrategy.calculateTargetConfidenceInterval(p, n).lowerBound val description = s"'${profile.column}' has value range $categoriesSql for at least " + s"${targetCompliance * 100}% of values" diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala index 67ae61f92..be5bd101f 100644 --- a/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala @@ -20,28 +20,31 @@ import com.amazon.deequ.constraints.Constraint.completenessConstraint import com.amazon.deequ.profiles.ColumnProfile import com.amazon.deequ.suggestions.CommonConstraintSuggestion import com.amazon.deequ.suggestions.ConstraintSuggestion - -import scala.math.BigDecimal.RoundingMode +import com.amazon.deequ.suggestions.rules.RetainCompletenessRule._ +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.defaultIntervalStrategy +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy /** * If a column is incomplete in the sample, we model its completeness as a binomial variable, * estimate a confidence interval and use this to define a lower bound for the completeness + * + * @param minCompleteness : minimum completeness threshold to determine if rule should be applied + * @param maxCompleteness : maximum completeness threshold to determine if rule should be applied */ -case class RetainCompletenessRule() extends ConstraintRule[ColumnProfile] { - +case class RetainCompletenessRule( + minCompleteness: Double = defaultMinCompleteness, + maxCompleteness: Double = defaultMaxCompleteness, + intervalStrategy: ConfidenceIntervalStrategy = defaultIntervalStrategy +) extends ConstraintRule[ColumnProfile] { override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = { - profile.completeness > 0.2 && profile.completeness < 1.0 + profile.completeness > minCompleteness && profile.completeness < maxCompleteness } override def candidate(profile: ColumnProfile, numRecords: Long): ConstraintSuggestion = { - - val p = profile.completeness - val n = numRecords - val z = 1.96 - - // TODO this needs to be more robust for p's close to 0 or 1 - val targetCompleteness = BigDecimal(p - z * math.sqrt(p * (1 - p) / n)) - .setScale(2, RoundingMode.DOWN).toDouble + val targetCompleteness = intervalStrategy.calculateTargetConfidenceInterval( + profile.completeness, + numRecords + ).lowerBound val constraint = completenessConstraint(profile.column, _ >= targetCompleteness) @@ -65,3 +68,8 @@ case class RetainCompletenessRule() extends ConstraintRule[ColumnProfile] { "we model its completeness as a binomial variable, estimate a confidence interval " + "and use this to define a lower bound for the completeness" } + +object RetainCompletenessRule { + private val defaultMinCompleteness: Double = 0.2 + private val defaultMaxCompleteness: Double = 1.0 +} diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala new file mode 100644 index 000000000..0c12e03a5 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala @@ -0,0 +1,55 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.suggestions.rules.interval + +import breeze.stats.distributions.{Gaussian, Rand} +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy._ + +/** + * Strategy for calculate confidence interval + * */ +trait ConfidenceIntervalStrategy { + + /** + * Generated confidence interval interval + * @param pHat sample of the population that share a trait + * @param numRecords overall number of records + * @param confidence confidence level of method used to estimate the interval. + * @return + */ + def calculateTargetConfidenceInterval( + pHat: Double, + numRecords: Long, + confidence: Double = defaultConfidence + ): ConfidenceInterval + + def validateInput(pHat: Double, confidence: Double): Unit = { + require(0.0 <= pHat && pHat <= 1.0, "pHat must be between 0.0 and 1.0") + require(0.0 <= confidence && confidence <= 1.0, "confidence must be between 0.0 and 1.0") + } + + def calculateZScore(confidence: Double): Double = Gaussian(0, 1)(Rand).inverseCdf(1 - ((1.0 - confidence)/ 2.0)) +} + +object ConfidenceIntervalStrategy { + val defaultConfidence = 0.95 + val defaultIntervalStrategy: ConfidenceIntervalStrategy = WaldIntervalStrategy() + + case class ConfidenceInterval(lowerBound: Double, upperBound: Double) +} + + diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala new file mode 100644 index 000000000..154d8ebfe --- /dev/null +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala @@ -0,0 +1,47 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.suggestions.rules.interval + +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.defaultConfidence + +import scala.math.BigDecimal.RoundingMode + +/** + * Implements the Wald Interval method for creating a binomial proportion confidence interval. Provided for backwards + * compatibility. using [[WaldIntervalStrategy]] for calculating confidence interval can be problematic when dealing + * with small sample sizes or proportions close to 0 or 1. It also have poorer coverage and might produce confidence + * limit outside the range of [0,1] + * @see + * Normal approximation interval (Wikipedia) + */ +@deprecated("WilsonScoreIntervalStrategy is recommended for calculating confidence interval") +case class WaldIntervalStrategy() extends ConfidenceIntervalStrategy { + def calculateTargetConfidenceInterval( + pHat: Double, + numRecords: Long, + confidence: Double = defaultConfidence + ): ConfidenceInterval = { + validateInput(pHat, confidence) + val successRatio = BigDecimal(pHat) + val marginOfError = BigDecimal(calculateZScore(confidence) * math.sqrt(pHat * (1 - pHat) / numRecords)) + val lowerBound = (successRatio - marginOfError).setScale(2, RoundingMode.DOWN).toDouble + val upperBound = (successRatio + marginOfError).setScale(2, RoundingMode.UP).toDouble + ConfidenceInterval(lowerBound, upperBound) + } +} diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala new file mode 100644 index 000000000..6e8371ea5 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala @@ -0,0 +1,47 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.suggestions.rules.interval + +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.defaultConfidence + +import scala.math.BigDecimal.RoundingMode + +/** + * Using Wilson score method for creating a binomial proportion confidence interval. + * + * @see + * Wilson score interval (Wikipedia) + */ +case class WilsonScoreIntervalStrategy() extends ConfidenceIntervalStrategy { + + def calculateTargetConfidenceInterval( + pHat: Double, numRecords: Long, + confidence: Double = defaultConfidence + ): ConfidenceInterval = { + validateInput(pHat, confidence) + val zScore = calculateZScore(confidence) + val zSquareOverN = math.pow(zScore, 2) / numRecords + val factor = 1.0 / (1 + zSquareOverN) + val adjustedSuccessRatio = pHat + zSquareOverN/2 + val marginOfError = zScore * math.sqrt(pHat * (1 - pHat)/numRecords + zSquareOverN/(4 * numRecords)) + val lowerBound = BigDecimal(factor * (adjustedSuccessRatio - marginOfError)).setScale(2, RoundingMode.DOWN).toDouble + val upperBound = BigDecimal(factor * (adjustedSuccessRatio + marginOfError)).setScale(2, RoundingMode.UP).toDouble + ConfidenceInterval(lowerBound, upperBound) + } +} diff --git a/src/test/scala/com/amazon/deequ/analyzers/CustomAggregatorTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CustomAggregatorTest.scala new file mode 100644 index 000000000..4f21cc64e --- /dev/null +++ b/src/test/scala/com/amazon/deequ/analyzers/CustomAggregatorTest.scala @@ -0,0 +1,244 @@ +/** + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.analyzers + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.utils.FixtureSupport +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import com.amazon.deequ.analyzers._ +import com.amazon.deequ.metrics.AttributeDoubleMetric +import com.amazon.deequ.profiles.ColumnProfilerRunner +import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.{sum, count} +import scala.util.Failure +import scala.util.Success +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.DataFrame +import com.amazon.deequ.metrics.AttributeDoubleMetric +import com.amazon.deequ.profiles.NumericColumnProfile + +class CustomAggregatorTest + extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + + "CustomAggregatorTest" should { + + """Example use: return correct counts + |for product sales in different categories""".stripMargin in withSparkSession + { session => + val data = getDfWithIdColumn(session) + val mockLambda: DataFrame => AggregatedMetricState = _ => + AggregatedMetricState(Map("ProductA" -> 50, "ProductB" -> 45), 100) + + val analyzer = CustomAggregator(mockLambda, "ProductSales", "category") + + val state = analyzer.computeStateFrom(data) + val metric: AttributeDoubleMetric = analyzer.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + metric.value.get should contain ("ProductA" -> 0.5) + metric.value.get should contain ("ProductB" -> 0.45) + } + + "handle scenarios with no data points effectively" in withSparkSession { session => + val data = getDfWithIdColumn(session) + val mockLambda: DataFrame => AggregatedMetricState = _ => + AggregatedMetricState(Map.empty[String, Int], 100) + + val analyzer = CustomAggregator(mockLambda, "WebsiteTraffic", "page") + + val state = analyzer.computeStateFrom(data) + val metric: AttributeDoubleMetric = analyzer.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + metric.value.get shouldBe empty + } + + "return a failure metric when the lambda function fails" in withSparkSession { session => + val data = getDfWithIdColumn(session) + val failingLambda: DataFrame => AggregatedMetricState = + _ => throw new RuntimeException("Test failure") + + val analyzer = CustomAggregator(failingLambda, "ProductSales", "category") + + val state = analyzer.computeStateFrom(data) + val metric = analyzer.computeMetricFrom(state) + + metric.value.isFailure shouldBe true + metric.value match { + case Success(_) => fail("Should have failed due to lambda function failure") + case Failure(exception) => exception.getMessage shouldBe "Metric computation failed" + } + } + + "return a failure metric if there are no rows in DataFrame" in withSparkSession { session => + val emptyData = session.createDataFrame( + session.sparkContext.emptyRDD[org.apache.spark.sql.Row], + getDfWithIdColumn(session).schema) + val mockLambda: DataFrame => AggregatedMetricState = df => + if (df.isEmpty) throw new RuntimeException("No data to analyze") + else AggregatedMetricState(Map("ProductA" -> 0, "ProductB" -> 0), 0) + + val analyzer = CustomAggregator(mockLambda, + "ProductSales", + "category") + + val state = analyzer.computeStateFrom(emptyData) + val metric = analyzer.computeMetricFrom(state) + + metric.value.isFailure shouldBe true + metric.value match { + case Success(_) => fail("Should have failed due to no data") + case Failure(exception) => exception.getMessage should include("Metric computation failed") + } + } + } + + "Combined Analysis with CustomAggregator and ColumnProfilerRunner" should { + "provide aggregated data and column profiles" in withSparkSession { session => + import session.implicits._ + + // Define the dataset + val rawData = Seq( + ("thingA", "13.0", "IN_TRANSIT", "true"), + ("thingA", "5", "DELAYED", "false"), + ("thingB", null, "DELAYED", null), + ("thingC", null, "IN_TRANSIT", "false"), + ("thingD", "1.0", "DELAYED", "true"), + ("thingC", "7.0", "UNKNOWN", null), + ("thingC", "20", "UNKNOWN", null), + ("thingE", "20", "DELAYED", "false") + ).toDF("productName", "totalNumber", "status", "valuable") + + val statusCountLambda: DataFrame => AggregatedMetricState = df => + AggregatedMetricState(df.groupBy("status").count().rdd + .map(r => r.getString(0) -> r.getLong(1).toInt).collect().toMap, df.count().toInt) + + val statusAnalyzer = CustomAggregator(statusCountLambda, "ProductStatus") + val statusMetric = statusAnalyzer.computeMetricFrom(statusAnalyzer.computeStateFrom(rawData)) + + val result = ColumnProfilerRunner().onData(rawData).run() + + statusMetric.value.isSuccess shouldBe true + statusMetric.value.get("IN_TRANSIT") shouldBe 0.25 + statusMetric.value.get("DELAYED") shouldBe 0.5 + + val totalNumberProfile = result.profiles("totalNumber").asInstanceOf[NumericColumnProfile] + totalNumberProfile.completeness shouldBe 0.75 + totalNumberProfile.dataType shouldBe DataTypeInstances.Fractional + + result.profiles.foreach { case (colName, profile) => + println(s"Column '$colName': completeness: ${profile.completeness}, " + + s"approximate number of distinct values: ${profile.approximateNumDistinctValues}") + } + } + } + + "accurately compute percentage occurrences and total engagements for content types" in withSparkSession { session => + val data = getContentEngagementDataFrame(session) + val contentEngagementLambda: DataFrame => AggregatedMetricState = df => { + + // Calculate the total engagements for each content type + val counts = df + .groupBy("content_type") + .agg( + (sum("views") + sum("likes") + sum("shares")).cast("int").alias("totalEngagements") + ) + .collect() + .map(row => + row.getString(0) -> row.getInt(1) + ) + .toMap + val totalEngagements = counts.values.sum + AggregatedMetricState(counts, totalEngagements) + } + + val analyzer = CustomAggregator(contentEngagementLambda, "ContentEngagement", "AllTypes") + + val state = analyzer.computeStateFrom(data) + val metric = analyzer.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + // Counts: Map(Video -> 5300, Article -> 1170) + // total engagement: 6470 + (metric.value.get("Video") * 100).toInt shouldBe 81 + (metric.value.get("Article") * 100).toInt shouldBe 18 + println(metric.value.get) + } + + "accurately compute total aggregated resources for cloud services" in withSparkSession { session => + val data = getResourceUtilizationDataFrame(session) + val resourceUtilizationLambda: DataFrame => AggregatedMetricState = df => { + val counts = df.groupBy("service_type") + .agg( + (sum("cpu_hours") + sum("memory_gbs") + sum("storage_gbs")).cast("int").alias("totalResources") + ) + .collect() + .map(row => + row.getString(0) -> row.getInt(1) + ) + .toMap + val totalResources = counts.values.sum + AggregatedMetricState(counts, totalResources) + } + val analyzer = CustomAggregator(resourceUtilizationLambda, "ResourceUtilization", "CloudServices") + + val state = analyzer.computeStateFrom(data) + val metric = analyzer.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + println("Resource Utilization Metrics: " + metric.value.get) +// Resource Utilization Metrics: Map(Compute -> 0.5076142131979695, + // Database -> 0.27918781725888325, + // Storage -> 0.2131979695431472) + (metric.value.get("Compute") * 100).toInt shouldBe 50 // Expected percentage for Compute + (metric.value.get("Database") * 100).toInt shouldBe 27 // Expected percentage for Database + (metric.value.get("Storage") * 100).toInt shouldBe 21 // 430 CPU + 175 Memory + 140 Storage from mock data + } + + def getDfWithIdColumn(session: SparkSession): DataFrame = { + import session.implicits._ + Seq( + ("ProductA", "North"), + ("ProductA", "South"), + ("ProductB", "East"), + ("ProductA", "West") + ).toDF("product", "region") + } + + def getContentEngagementDataFrame(session: SparkSession): DataFrame = { + import session.implicits._ + Seq( + ("Video", 1000, 150, 300), + ("Article", 500, 100, 150), + ("Video", 1500, 200, 450), + ("Article", 300, 50, 70), + ("Video", 1200, 180, 320) + ).toDF("content_type", "views", "likes", "shares") + } + + def getResourceUtilizationDataFrame(session: SparkSession): DataFrame = { + import session.implicits._ + Seq( + ("Compute", 400, 120, 150), + ("Storage", 100, 30, 500), + ("Database", 200, 80, 100), + ("Compute", 450, 130, 250), + ("Database", 230, 95, 120) + ).toDF("service_type", "cpu_hours", "memory_gbs", "storage_gbs") + } +} diff --git a/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala index 7e6e96c30..e6e23c40b 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala @@ -5,7 +5,7 @@ * use this file except in compliance with the License. A copy of the License * is located at * - * http://aws.amazon.com/apache2.0/ + * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either @@ -17,6 +17,7 @@ package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.metrics.Entity import com.amazon.deequ.utils.FixtureSupport import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -84,5 +85,21 @@ class CustomSqlTest extends AnyWordSpec with Matchers with SparkContextSpec with case Failure(exception) => exception.getMessage should include("foo") } } + + "apply metric disambiguation string to returned metric" in withSparkSession { session => + val data = getDfWithStringColumns(session) + data.createOrReplaceTempView("primary") + + val disambiguator = "statement1" + val sql = CustomSql("SELECT COUNT(*) FROM primary WHERE `Address Line 2` IS NOT NULL", disambiguator) + val state = sql.computeStateFrom(data) + val metric: DoubleMetric = sql.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + metric.value.get shouldBe 6.0 + metric.name shouldBe "CustomSQL" + metric.entity shouldBe Entity.Dataset + metric.instance shouldBe "statement1" + } } } diff --git a/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala b/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala index 075247932..7b56e3938 100644 --- a/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala +++ b/src/test/scala/com/amazon/deequ/suggestions/rules/ConstraintRulesTest.scala @@ -22,10 +22,14 @@ import com.amazon.deequ.checks.{Check, CheckLevel} import com.amazon.deequ.constraints.ConstrainableDataTypes import com.amazon.deequ.metrics.{Distribution, DistributionValue} import com.amazon.deequ.profiles._ +import com.amazon.deequ.suggestions.rules.interval.WaldIntervalStrategy +import com.amazon.deequ.suggestions.rules.interval.WilsonScoreIntervalStrategy import com.amazon.deequ.utils.FixtureSupport import com.amazon.deequ.{SparkContextSpec, VerificationSuite} import org.scalamock.scalatest.MockFactory +import org.scalatest.Inspectors.forAll import org.scalatest.WordSpec +import org.scalatest.prop.Tables.Table class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContextSpec with MockFactory{ @@ -130,59 +134,103 @@ class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContext "be applied correctly" in { val complete = StandardColumnProfile("col1", 1.0, 100, String, false, Map.empty, None) + val tenPercent = StandardColumnProfile("col1", 0.1, 100, String, false, Map.empty, None) val incomplete = StandardColumnProfile("col1", .25, 100, String, false, Map.empty, None) + val waldIntervalStrategy = WaldIntervalStrategy() assert(!RetainCompletenessRule().shouldBeApplied(complete, 1000)) + assert(!RetainCompletenessRule(0.05, 0.9).shouldBeApplied(complete, 1000)) + assert(RetainCompletenessRule(0.05, 0.9).shouldBeApplied(tenPercent, 1000)) + assert(RetainCompletenessRule(0.0).shouldBeApplied(tenPercent, 1000)) + assert(RetainCompletenessRule(0.0).shouldBeApplied(incomplete, 1000)) assert(RetainCompletenessRule().shouldBeApplied(incomplete, 1000)) + assert(!RetainCompletenessRule(intervalStrategy = waldIntervalStrategy).shouldBeApplied(complete, 1000)) + assert(!RetainCompletenessRule(0.05, 0.9, waldIntervalStrategy).shouldBeApplied(complete, 1000)) + assert(RetainCompletenessRule(0.05, 0.9, waldIntervalStrategy).shouldBeApplied(tenPercent, 1000)) } "return evaluable constraint candidates" in withSparkSession { session => + val table = Table(("strategy", "result"), (WaldIntervalStrategy(), true), (WilsonScoreIntervalStrategy(), true)) + forAll(table) { case (strategy, result) => + val dfWithColumnCandidate = getDfFull(session) - val dfWithColumnCandidate = getDfFull(session) + val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) - val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) + val check = Check(CheckLevel.Warning, "some") + .addConstraint( + RetainCompletenessRule(intervalStrategy = strategy).candidate(fakeColumnProfile, 100).constraint + ) - val check = Check(CheckLevel.Warning, "some") - .addConstraint(RetainCompletenessRule().candidate(fakeColumnProfile, 100).constraint) + val verificationResult = VerificationSuite() + .onData(dfWithColumnCandidate) + .addCheck(check) + .run() - val verificationResult = VerificationSuite() - .onData(dfWithColumnCandidate) - .addCheck(check) - .run() + val metricResult = verificationResult.metrics.head._2 - val metricResult = verificationResult.metrics.head._2 + assert(metricResult.value.isSuccess == result) + } - assert(metricResult.value.isSuccess) } "return working code to add constraint to check" in withSparkSession { session => + val table = Table( + ("strategy", "colCompleteness", "targetCompleteness", "result"), + (WaldIntervalStrategy(), 0.5, 0.4, true), + (WilsonScoreIntervalStrategy(), 0.4, 0.3, true) + ) + forAll(table) { case (strategy, colCompleteness, targetCompleteness, result) => - val dfWithColumnCandidate = getDfFull(session) + val dfWithColumnCandidate = getDfFull(session) - val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) + val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", colCompleteness) - val codeForConstraint = RetainCompletenessRule().candidate(fakeColumnProfile, 100) - .codeForConstraint + val codeForConstraint = RetainCompletenessRule(intervalStrategy = strategy).candidate(fakeColumnProfile, 100) + .codeForConstraint - val expectedCodeForConstraint = """.hasCompleteness("att1", _ >= 0.4, - | Some("It should be above 0.4!"))""".stripMargin.replaceAll("\n", "") + val expectedCodeForConstraint = s""".hasCompleteness("att1", _ >= $targetCompleteness, + | Some("It should be above $targetCompleteness!"))""".stripMargin.replaceAll("\n", "") - assert(expectedCodeForConstraint == codeForConstraint) + assert(expectedCodeForConstraint == codeForConstraint) - val check = Check(CheckLevel.Warning, "some") - .hasCompleteness("att1", _ >= 0.4, Some("It should be above 0.4!")) + val check = Check(CheckLevel.Warning, "some") + .hasCompleteness("att1", _ >= targetCompleteness, Some(s"It should be above $targetCompleteness")) - val verificationResult = VerificationSuite() - .onData(dfWithColumnCandidate) - .addCheck(check) - .run() + val verificationResult = VerificationSuite() + .onData(dfWithColumnCandidate) + .addCheck(check) + .run() - val metricResult = verificationResult.metrics.head._2 + val metricResult = verificationResult.metrics.head._2 + + assert(metricResult.value.isSuccess == result) + } - assert(metricResult.value.isSuccess) } + + "return evaluable constraint candidates with custom min/max completeness" in + withSparkSession { session => + val table = Table(("strategy", "result"), (WaldIntervalStrategy(), true), (WilsonScoreIntervalStrategy(), true)) + forAll(table) { case (strategy, result) => + val dfWithColumnCandidate = getDfFull(session) + + val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5) + + val check = Check(CheckLevel.Warning, "some") + .addConstraint(RetainCompletenessRule(0.4, 0.6, strategy).candidate(fakeColumnProfile, 100).constraint) + + val verificationResult = VerificationSuite() + .onData(dfWithColumnCandidate) + .addCheck(check) + .run() + + val metricResult = verificationResult.metrics.head._2 + + assert(metricResult.value.isSuccess == result) + } + } } "UniqueIfApproximatelyUniqueRule" should { diff --git a/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala b/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala new file mode 100644 index 000000000..54e6cd1e1 --- /dev/null +++ b/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala @@ -0,0 +1,59 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.suggestions.rules.interval + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval +import com.amazon.deequ.utils.FixtureSupport +import org.scalamock.scalatest.MockFactory +import org.scalatest.Inspectors.forAll +import org.scalatest.prop.Tables.Table +import org.scalatest.wordspec.AnyWordSpec + +class IntervalStrategyTest extends AnyWordSpec with FixtureSupport with SparkContextSpec + with MockFactory { + + "ConfidenceIntervalStrategy" should { + "be calculated correctly" in { + val waldStrategy = WaldIntervalStrategy() + val wilsonStrategy = WilsonScoreIntervalStrategy() + + val table = Table( + ("strategy", "pHat", "numRecord", "lowerBound", "upperBound"), + (waldStrategy, 1.0, 20L, 1.0, 1.0), + (waldStrategy, 0.5, 100L, 0.4, 0.6), + (waldStrategy, 0.4, 100L, 0.3, 0.5), + (waldStrategy, 0.6, 100L, 0.5, 0.7), + (waldStrategy, 0.9, 100L, 0.84, 0.96), + (waldStrategy, 1.0, 100L, 1.0, 1.0), + + (wilsonStrategy, 0.01, 20L, 0.00, 0.18), + (wilsonStrategy, 1.0, 20L, 0.83, 1.0), + (wilsonStrategy, 0.5, 100L, 0.4, 0.6), + (wilsonStrategy, 0.4, 100L, 0.3, 0.5), + (wilsonStrategy, 0.6, 100L, 0.5, 0.7), + (wilsonStrategy, 0.9, 100L, 0.82, 0.95), + (wilsonStrategy, 1.0, 100L, 0.96, 1.0) + ) + + forAll(table) { case (strategy, pHat, numRecords, lowerBound, upperBound) => + val actualInterval = strategy.calculateTargetConfidenceInterval(pHat, numRecords) + assert(actualInterval == ConfidenceInterval(lowerBound, upperBound)) + } + } + } +}