Skip to content

Commit

Permalink
Test data contract against dataframes or views
Browse files Browse the repository at this point in the history
  • Loading branch information
jochenchrist committed May 29, 2024
1 parent cf41b9e commit 6e89feb
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Test data contract against dataframes / temporary views (#175)

## [0.10.6] - 2024-05-29

### Fixed
Expand Down
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ Supported server types:
- [azure](#azure)
- [databricks](#databricks)
- [databricks (programmatic)](#databricks-programmatic)
- [dataframr (programmatic)](#dataframe-programmatic)
- [snowflake](#snowflake)
- [kafka](#kafka)
- [postgres](#postgres)
Expand Down Expand Up @@ -459,6 +460,41 @@ run = data_contract.test()
run.result
```

### Dataframe (programmatic)

Works with Spark DataFrames.
DataFrames need to be created as named temporary views.
Multiple temporary views are suppored if your data contract contains multiple models.

Testing DataFrames is useful to test your datasets in a pipeline before writing them to a data source.

#### Example

datacontract.yaml
```yaml
servers:
production:
type: dataframe
models:
my_table: # corresponds to a temporary view
type: table
fields: ...
```

Example code
```python
from datacontract.data_contract import DataContract
df.createOrReplaceTempView("my_table")
data_contract = DataContract(
data_contract_file="datacontract.yaml",
spark=spark,
)
run = data_contract.test()
assert run.result == "passed"
```


### Snowflake

Expand Down
7 changes: 5 additions & 2 deletions datacontract/data_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,13 @@ def test(self) -> Run:
run.outputPortId = server.outputPortId
run.server = server_name

# 5. check server is supported type
# 6. check server credentials are complete
# TODO check server is supported type for nicer error messages

# TODO check server credentials are complete for nicer error messages

if server.format == "json" and server.type != "kafka":
check_jsonschema(run, data_contract, server)

check_soda_execute(run, data_contract, server, self._spark, tmp_dir)

except DataContractException as e:
Expand Down
9 changes: 9 additions & 0 deletions datacontract/engines/soda/check_soda_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def check_soda_execute(
soda_configuration_str = to_databricks_soda_configuration(server)
scan.add_configuration_yaml_str(soda_configuration_str)
scan.set_data_source_name(server.type)
elif server.type == "dataframe":
if spark is None:
run.log_warn("Server type dataframe only works with the Python library and requires a Spark session, "
"please provide one with the DataContract class")
return
else:
logging.info("Use Spark to connect to data source")
scan.add_spark_session(spark, data_source_name="datacontract-cli")
scan.set_data_source_name("datacontract-cli")
elif server.type == "kafka":
if spark is None:
spark = create_spark_session(tmp_dir)
Expand Down
23 changes: 23 additions & 0 deletions tests/fixtures/dataframe/datacontract.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
dataContractSpecification: 0.9.3
id: dataframetest
info:
title: dataframetest
version: 0.0.1
owner: my-domain-team
servers:
unittest:
type: dataframe
models:
my_table:
type: table
fields:
field_one:
type: varchar
required: true
unique: true
pattern: "[A-Za-z]{2}-\\d{3}-[A-Za-z]{2}$"
field_two:
type: int
minimum: 10
field_three:
type: timestamp
62 changes: 62 additions & 0 deletions tests/test_test_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import logging
from pathlib import Path

from dotenv import load_dotenv
from pyspark.sql import SparkSession, Row
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType
from datetime import datetime

from datacontract.data_contract import DataContract

# logging.basicConfig(level=logging.INFO, force=True)

datacontract = "fixtures/dataframe/datacontract.yaml"

load_dotenv(override=True)


def test_test_dataframe(tmp_path: Path):
spark = _create_spark_session(tmp_dir=str(tmp_path))
_prepare_dataframe(spark)
data_contract = DataContract(
data_contract_file=datacontract,
spark=spark,
)

run = data_contract.test()

print(run.pretty())
assert run.result == "passed"
assert all(check.result == "passed" for check in run.checks)


def _prepare_dataframe(spark):
schema = StructType([
StructField("field_one", StringType(), nullable=False),
StructField("field_two", IntegerType(), nullable=True),
StructField("field_three", TimestampType(), nullable=True)
])
data = [
Row(field_one="AB-123-CD", field_two=15,
field_three=datetime.strptime("2024-01-01 12:00:00", "%Y-%m-%d %H:%M:%S")),
Row(field_one="XY-456-ZZ", field_two=20,
field_three=datetime.strptime("2024-02-01 12:00:00", "%Y-%m-%d %H:%M:%S"))
]
# Create DataFrame
df = spark.createDataFrame(data, schema=schema)
# Create temporary view
# Name must match the model name in the data contract
df.createOrReplaceTempView("my_table")


def _create_spark_session(tmp_dir: str) -> SparkSession:
"""Create and configure a Spark session."""
spark = (
SparkSession.builder.appName("datacontract")
.config("spark.sql.warehouse.dir", f"{tmp_dir}/spark-warehouse")
.config("spark.streaming.stopGracefullyOnShutdown", "true")
.getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")
print(f"Using PySpark version {spark.version}")
return spark

0 comments on commit 6e89feb

Please sign in to comment.