From 6e89feb4c655d6333ddc68757c9870ac7c8df9e1 Mon Sep 17 00:00:00 2001 From: jochen Date: Wed, 29 May 2024 20:09:54 +0200 Subject: [PATCH] Test data contract against dataframes or views Resolves #175 --- CHANGELOG.md | 3 + README.md | 36 +++++++++++ datacontract/data_contract.py | 7 ++- .../engines/soda/check_soda_execute.py | 9 +++ tests/fixtures/dataframe/datacontract.yaml | 23 +++++++ tests/test_test_dataframe.py | 62 +++++++++++++++++++ 6 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 tests/fixtures/dataframe/datacontract.yaml create mode 100644 tests/test_test_dataframe.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3aba5c79..a5347f45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Test data contract against dataframes / temporary views (#175) + ## [0.10.6] - 2024-05-29 ### Fixed diff --git a/README.md b/README.md index 03cd82a2..34822d25 100644 --- a/README.md +++ b/README.md @@ -278,6 +278,7 @@ Supported server types: - [azure](#azure) - [databricks](#databricks) - [databricks (programmatic)](#databricks-programmatic) +- [dataframr (programmatic)](#dataframe-programmatic) - [snowflake](#snowflake) - [kafka](#kafka) - [postgres](#postgres) @@ -459,6 +460,41 @@ run = data_contract.test() run.result ``` +### Dataframe (programmatic) + +Works with Spark DataFrames. +DataFrames need to be created as named temporary views. +Multiple temporary views are suppored if your data contract contains multiple models. + +Testing DataFrames is useful to test your datasets in a pipeline before writing them to a data source. + +#### Example + +datacontract.yaml +```yaml +servers: + production: + type: dataframe +models: + my_table: # corresponds to a temporary view + type: table + fields: ... +``` + +Example code +```python +from datacontract.data_contract import DataContract + +df.createOrReplaceTempView("my_table") + +data_contract = DataContract( + data_contract_file="datacontract.yaml", + spark=spark, +) +run = data_contract.test() +assert run.result == "passed" +``` + ### Snowflake diff --git a/datacontract/data_contract.py b/datacontract/data_contract.py index 2a019384..fa38fe4f 100644 --- a/datacontract/data_contract.py +++ b/datacontract/data_contract.py @@ -218,10 +218,13 @@ def test(self) -> Run: run.outputPortId = server.outputPortId run.server = server_name - # 5. check server is supported type - # 6. check server credentials are complete + # TODO check server is supported type for nicer error messages + + # TODO check server credentials are complete for nicer error messages + if server.format == "json" and server.type != "kafka": check_jsonschema(run, data_contract, server) + check_soda_execute(run, data_contract, server, self._spark, tmp_dir) except DataContractException as e: diff --git a/datacontract/engines/soda/check_soda_execute.py b/datacontract/engines/soda/check_soda_execute.py index 1d2e4a8a..0e5bce30 100644 --- a/datacontract/engines/soda/check_soda_execute.py +++ b/datacontract/engines/soda/check_soda_execute.py @@ -64,6 +64,15 @@ def check_soda_execute( soda_configuration_str = to_databricks_soda_configuration(server) scan.add_configuration_yaml_str(soda_configuration_str) scan.set_data_source_name(server.type) + elif server.type == "dataframe": + if spark is None: + run.log_warn("Server type dataframe only works with the Python library and requires a Spark session, " + "please provide one with the DataContract class") + return + else: + logging.info("Use Spark to connect to data source") + scan.add_spark_session(spark, data_source_name="datacontract-cli") + scan.set_data_source_name("datacontract-cli") elif server.type == "kafka": if spark is None: spark = create_spark_session(tmp_dir) diff --git a/tests/fixtures/dataframe/datacontract.yaml b/tests/fixtures/dataframe/datacontract.yaml new file mode 100644 index 00000000..041d867c --- /dev/null +++ b/tests/fixtures/dataframe/datacontract.yaml @@ -0,0 +1,23 @@ +dataContractSpecification: 0.9.3 +id: dataframetest +info: + title: dataframetest + version: 0.0.1 + owner: my-domain-team +servers: + unittest: + type: dataframe +models: + my_table: + type: table + fields: + field_one: + type: varchar + required: true + unique: true + pattern: "[A-Za-z]{2}-\\d{3}-[A-Za-z]{2}$" + field_two: + type: int + minimum: 10 + field_three: + type: timestamp diff --git a/tests/test_test_dataframe.py b/tests/test_test_dataframe.py new file mode 100644 index 00000000..3bd837a4 --- /dev/null +++ b/tests/test_test_dataframe.py @@ -0,0 +1,62 @@ +import logging +from pathlib import Path + +from dotenv import load_dotenv +from pyspark.sql import SparkSession, Row +from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType +from datetime import datetime + +from datacontract.data_contract import DataContract + +# logging.basicConfig(level=logging.INFO, force=True) + +datacontract = "fixtures/dataframe/datacontract.yaml" + +load_dotenv(override=True) + + +def test_test_dataframe(tmp_path: Path): + spark = _create_spark_session(tmp_dir=str(tmp_path)) + _prepare_dataframe(spark) + data_contract = DataContract( + data_contract_file=datacontract, + spark=spark, + ) + + run = data_contract.test() + + print(run.pretty()) + assert run.result == "passed" + assert all(check.result == "passed" for check in run.checks) + + +def _prepare_dataframe(spark): + schema = StructType([ + StructField("field_one", StringType(), nullable=False), + StructField("field_two", IntegerType(), nullable=True), + StructField("field_three", TimestampType(), nullable=True) + ]) + data = [ + Row(field_one="AB-123-CD", field_two=15, + field_three=datetime.strptime("2024-01-01 12:00:00", "%Y-%m-%d %H:%M:%S")), + Row(field_one="XY-456-ZZ", field_two=20, + field_three=datetime.strptime("2024-02-01 12:00:00", "%Y-%m-%d %H:%M:%S")) + ] + # Create DataFrame + df = spark.createDataFrame(data, schema=schema) + # Create temporary view + # Name must match the model name in the data contract + df.createOrReplaceTempView("my_table") + + +def _create_spark_session(tmp_dir: str) -> SparkSession: + """Create and configure a Spark session.""" + spark = ( + SparkSession.builder.appName("datacontract") + .config("spark.sql.warehouse.dir", f"{tmp_dir}/spark-warehouse") + .config("spark.streaming.stopGracefullyOnShutdown", "true") + .getOrCreate() + ) + spark.sparkContext.setLogLevel("WARN") + print(f"Using PySpark version {spark.version}") + return spark