Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
VenkataKarthikP committed Jan 4, 2024
1 parent 6461856 commit 08a71ab
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 99 deletions.
66 changes: 11 additions & 55 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
Expand All @@ -19,15 +19,21 @@ package com.amazon.deequ.analyzers
import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.NullBehavior.NullBehavior
import com.amazon.deequ.analyzers.runners._
import com.amazon.deequ.comparison.{DataSynchronization, DataSynchronizationFailed, DataSynchronizationSucceeded}
import com.amazon.deequ.metrics.{DoubleMetric, Entity, FullColumn, Metric}
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import com.amazon.deequ.metrics.FullColumn
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.utilities.ColumnUtil.removeEscapeColumn
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

import scala.language.existentials
import scala.util.{Failure, Success, Try}
import scala.util.Failure
import scala.util.Success

/**
* A state (sufficient statistic) computed from data, from which we can compute a metric.
Expand Down Expand Up @@ -293,56 +299,6 @@ abstract class GroupingAnalyzer[S <: State[_], +M <: Metric[_]] extends Analyzer
}
}

/**
* Data Synchronization Analyzer
*
* @param dfToCompare DataFrame to compare
* @param columnMappings columns mappings
* @param assertion assertion logic
*/
case class DataSynchronizationAnalyzer(dfToCompare: DataFrame,
columnMappings: Map[String, String],
assertion: Double => Boolean)
extends Analyzer[DataSynchronizationState, DoubleMetric] {

override def computeStateFrom(data: DataFrame): Option[DataSynchronizationState] = {

val result = DataSynchronization.columnMatch(data, dfToCompare, columnMappings, assertion)

result match {
case succeeded: DataSynchronizationSucceeded =>
Some(DataSynchronizationState(succeeded.passedCount.getOrElse(0), succeeded.totalCount.getOrElse(0)))
case failed: DataSynchronizationFailed =>
Some(DataSynchronizationState(failed.passedCount.getOrElse(0), failed.totalCount.getOrElse(0)))
case _ => None
}
}

override def computeMetricFrom(state: Option[DataSynchronizationState]): DoubleMetric = {

state match {
case Some(s) => DoubleMetric(
Entity.Dataset,
"DataSynchronization",
"",
Try(s.synchronizedDataCount.toDouble / s.dataCount.toDouble),
None
)
case None => DoubleMetric(
Entity.Dataset,
"DataSynchronization",
"",
Try(0.0),
None
)
}
}

override private[deequ] def toFailureMetric(failure: Exception) =
metricFromFailure(failure, "DataSynchronization", "", Entity.Dataset)
}


/** Helper method to check conditions on the schema of the data */
object Preconditions {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/**
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/

package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers.metricFromFailure
import com.amazon.deequ.comparison.DataSynchronization
import com.amazon.deequ.comparison.DataSynchronizationFailed
import com.amazon.deequ.comparison.DataSynchronizationSucceeded
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import org.apache.spark.sql.DataFrame

import scala.util.Failure
import scala.util.Try


/**
* An Analyzer for Deequ that performs a data synchronization check between two DataFrames.
* It evaluates the degree of synchronization based on specified column mappings and an assertion function.
*
* The analyzer computes a ratio of synchronized data points to the total data points, represented as a DoubleMetric.
* Refer to [[com.amazon.deequ.comparison.DataSynchronization.columnMatch]] for DataSynchronization implementation
*
* @param dfToCompare The DataFrame to compare with the primary DataFrame that is setup
* during [[com.amazon.deequ.VerificationSuite.onData]] setup.
* @param columnMappings A map where each key-value pair represents a column in the primary DataFrame
* and its corresponding column in dfToCompare.
* @param assertion A function that takes a Double (the match ratio) and returns a Boolean.
* It defines the condition for successful synchronization.
*
* Usage:
* This analyzer is used in Deequ's VerificationSuite based if `isDataSynchronized` check is defined or could be used
* manually as well.
*
* Example:
* val analyzer = DataSynchronizationAnalyzer(dfToCompare, Map("col1" -> "col2"), _ > 0.8)
* val verificationResult = VerificationSuite().onData(df).addAnalyzer(analyzer).run()
*
* // or could do something like below
* val verificationResult = VerificationSuite().onData(df).isDataSynchronized(dfToCompare, Map("col1" -> "col2"),
* _ > 0.8).run()
*
*
* The computeStateFrom method calculates the synchronization state by comparing the specified columns of the two
* DataFrames.
* The computeMetricFrom method then converts this state into a DoubleMetric representing the synchronization ratio.
*
*/
case class DataSynchronizationAnalyzer(dfToCompare: DataFrame,
columnMappings: Map[String, String],
assertion: Double => Boolean)
extends Analyzer[DataSynchronizationState, DoubleMetric] {

override def computeStateFrom(data: DataFrame): Option[DataSynchronizationState] = {

val result = DataSynchronization.columnMatch(data, dfToCompare, columnMappings, assertion)

result match {
case succeeded: DataSynchronizationSucceeded =>
Some(DataSynchronizationState(succeeded.passedCount, succeeded.totalCount))
case failed: DataSynchronizationFailed =>
Some(DataSynchronizationState(failed.passedCount.getOrElse(0), failed.totalCount.getOrElse(0)))
case _ => None
}
}

override def computeMetricFrom(state: Option[DataSynchronizationState]): DoubleMetric = {

val metric = state match {
case Some(s) => Try(s.synchronizedDataCount.toDouble / s.dataCount.toDouble)
case _ => Failure(new IllegalStateException("No state available for DataSynchronizationAnalyzer"))
}

DoubleMetric(Entity.Dataset, "DataSynchronization", "", metric, None)
}

override private[deequ] def toFailureMetric(failure: Exception) =
metricFromFailure(failure, "DataSynchronization", "", Entity.Dataset)
}

30 changes: 20 additions & 10 deletions src/main/scala/com/amazon/deequ/checks/Check.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,29 @@

package com.amazon.deequ.checks

import com.amazon.deequ.analyzers.{Analyzer, AnalyzerOptions, DataSynchronizationState, DataSynchronizationAnalyzer, Histogram, KLLParameters, Patterns, State}
import com.amazon.deequ.anomalydetection.{AnomalyDetectionStrategy, AnomalyDetector, DataPoint}
import com.amazon.deequ.analyzers.runners.AnalyzerContext
import com.amazon.deequ.analyzers.Analyzer
import com.amazon.deequ.analyzers.AnalyzerOptions
import com.amazon.deequ.analyzers.DataSynchronizationAnalyzer
import com.amazon.deequ.analyzers.DataSynchronizationState
import com.amazon.deequ.analyzers.Histogram
import com.amazon.deequ.analyzers.KLLParameters
import com.amazon.deequ.analyzers.Patterns
import com.amazon.deequ.analyzers.State
import com.amazon.deequ.anomalydetection.HistoryUtils
import com.amazon.deequ.anomalydetection.AnomalyDetectionStrategy
import com.amazon.deequ.anomalydetection.AnomalyDetector
import com.amazon.deequ.anomalydetection.DataPoint
import com.amazon.deequ.checks.ColumnCondition.isAnyNotNull
import com.amazon.deequ.checks.ColumnCondition.isEachNotNull
import com.amazon.deequ.constraints.Constraint._
import com.amazon.deequ.constraints._
import com.amazon.deequ.metrics.{BucketDistribution, Distribution, Metric}
import com.amazon.deequ.metrics.BucketDistribution
import com.amazon.deequ.metrics.Distribution
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.repository.MetricsRepository
import org.apache.spark.sql.expressions.UserDefinedFunction
import com.amazon.deequ.anomalydetection.HistoryUtils
import com.amazon.deequ.checks.ColumnCondition.{isAnyNotNull, isEachNotNull}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.UserDefinedFunction

import scala.util.matching.Regex

Expand Down Expand Up @@ -363,11 +375,10 @@ case class Check(
*/
def isDataSynchronized(otherDf: DataFrame, columnMappings: Map[String, String], assertion: Double => Boolean,
hint: Option[String] = None): Check = {

val dataSyncAnalyzer = DataSynchronizationAnalyzer(otherDf, columnMappings, assertion)
val constraint = DataSynchronizationConstraint(dataSyncAnalyzer, hint)
val constraint = AnalysisBasedConstraint[DataSynchronizationState, Double, Double](dataSyncAnalyzer, assertion,
hint = hint)
addConstraint(constraint)

}

/**
Expand Down Expand Up @@ -1126,7 +1137,6 @@ case class Check(
}
.collect {
case constraint: AnalysisBasedConstraint[_, _, _] => constraint.analyzer
case constraint: DataSynchronizationConstraint => constraint.analyzer
}
.map { _.asInstanceOf[Analyzer[_, Metric[_]]] }
.toSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,4 @@ case class ComparisonSucceeded(ratio: Double = 0) extends ComparisonResult

case class DataSynchronizationFailed(errorMessage: String, passedCount: Option[Long] = None,
totalCount: Option[Long] = None) extends ComparisonResult
case class DataSynchronizationSucceeded(passedCount: Option[Long] = None, totalCount: Option[Long] = None)
extends ComparisonResult
case class DataSynchronizationSucceeded(passedCount: Long, totalCount: Long) extends ComparisonResult
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
Expand All @@ -16,7 +16,8 @@

package com.amazon.deequ.comparison

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.hash
import org.apache.spark.sql.functions.lit
Expand Down Expand Up @@ -101,7 +102,7 @@ object DataSynchronization extends ComparisonBase {
val nonKeyColsMatch = colsDS1.forall(columnExists(ds2, _))

if (!nonKeyColsMatch) {
ComparisonFailed("Non key columns in the given data frames do not match.")
DataSynchronizationFailed("Non key columns in the given data frames do not match.")
} else {
val mergedMaps = colKeyMap ++ colsDS1.map(x => x -> x).toMap
finalAssertion(ds1, ds2, mergedMaps, assertion)
Expand Down Expand Up @@ -137,10 +138,10 @@ object DataSynchronization extends ComparisonBase {
val nonKeyColumns2NotInDataset = compCols.values.filterNot(columnExists(ds2, _))

if (nonKeyColumns1NotInDataset.nonEmpty) {
ComparisonFailed(s"The following columns were not found in the first dataset: " +
DataSynchronizationFailed(s"The following columns were not found in the first dataset: " +
s"${nonKeyColumns1NotInDataset.mkString(", ")}")
} else if (nonKeyColumns2NotInDataset.nonEmpty) {
ComparisonFailed(s"The following columns were not found in the second dataset: " +
DataSynchronizationFailed(s"The following columns were not found in the second dataset: " +
s"${nonKeyColumns2NotInDataset.mkString(", ")}")
} else {
val mergedMaps = colKeyMap ++ compCols
Expand All @@ -155,23 +156,24 @@ object DataSynchronization extends ComparisonBase {
ds2: DataFrame,
colKeyMap: Map[String, String],
optionalCompCols: Option[Map[String, String]] = None,
optionalOutcomeColumnName: Option[String] = None): Either[ComparisonFailed, DataFrame] = {
optionalOutcomeColumnName: Option[String] = None):
Either[DataSynchronizationFailed, DataFrame] = {
val columnErrors = areKeyColumnsValid(ds1, ds2, colKeyMap)
if (columnErrors.isEmpty) {
val compColsEither: Either[ComparisonFailed, Map[String, String]] = if (optionalCompCols.isDefined) {
val compColsEither: Either[DataSynchronizationFailed, Map[String, String]] = if (optionalCompCols.isDefined) {
optionalCompCols.get match {
case compCols if compCols.isEmpty => Left(ComparisonFailed("Empty column comparison map provided."))
case compCols if compCols.isEmpty => Left(DataSynchronizationFailed("Empty column comparison map provided."))
case compCols =>
val ds1CompColsNotInDataset = compCols.keys.filterNot(columnExists(ds1, _))
val ds2CompColsNotInDataset = compCols.values.filterNot(columnExists(ds2, _))
if (ds1CompColsNotInDataset.nonEmpty) {
Left(
ComparisonFailed(s"The following columns were not found in the first dataset: " +
DataSynchronizationFailed(s"The following columns were not found in the first dataset: " +
s"${ds1CompColsNotInDataset.mkString(", ")}")
)
} else if (ds2CompColsNotInDataset.nonEmpty) {
Left(
ComparisonFailed(s"The following columns were not found in the second dataset: " +
DataSynchronizationFailed(s"The following columns were not found in the second dataset: " +
s"${ds2CompColsNotInDataset.mkString(", ")}")
)
} else {
Expand All @@ -184,7 +186,7 @@ object DataSynchronization extends ComparisonBase {
val nonKeyColsMatch = ds1NonKeyCols.forall(columnExists(ds2, _))

if (!nonKeyColsMatch) {
Left(ComparisonFailed("Non key columns in the given data frames do not match."))
Left(DataSynchronizationFailed("Non key columns in the given data frames do not match."))
} else {
Right(ds1NonKeyCols.map { c => c -> c}.toMap)
}
Expand All @@ -196,11 +198,11 @@ object DataSynchronization extends ComparisonBase {
case Success(df) => Right(df)
case Failure(ex) =>
ex.printStackTrace()
Left(ComparisonFailed(s"Comparison failed due to ${ex.getCause.getClass}"))
Left(DataSynchronizationFailed(s"Comparison failed due to ${ex.getCause.getClass}"))
}
}
} else {
Left(ComparisonFailed(columnErrors.get))
Left(DataSynchronizationFailed(columnErrors.get))
}
}

Expand Down Expand Up @@ -253,7 +255,7 @@ object DataSynchronization extends ComparisonBase {
val ds2Count = ds2.count()

if (ds1Count != ds2Count) {
ComparisonFailed(s"The row counts of the two data frames do not match.")
DataSynchronizationFailed(s"The row counts of the two data frames do not match.")
} else {
val joinExpression: Column = mergedMaps
.map { case (col1, col2) => ds1(col1) === ds2(col2)}
Expand All @@ -265,7 +267,7 @@ object DataSynchronization extends ComparisonBase {
val ratio = passedCount.toDouble / totalCount.toDouble

if (assertion(ratio)) {
DataSynchronizationSucceeded(Some(passedCount), Some(totalCount))
DataSynchronizationSucceeded(passedCount, totalCount)
} else {
DataSynchronizationFailed(s"Data Synchronization Comparison Metric Value: $ratio does not meet the constraint" +
s"requirement.", Some(passedCount), Some(totalCount))
Expand Down
Loading

0 comments on commit 08a71ab

Please sign in to comment.