diff --git a/datacontract/data_contract.py b/datacontract/data_contract.py index 3ef90cf0..806810db 100644 --- a/datacontract/data_contract.py +++ b/datacontract/data_contract.py @@ -4,6 +4,7 @@ import typing import yaml +from pyspark.sql import SparkSession from datacontract.breaking.breaking import models_breaking_changes, \ quality_breaking_changes @@ -53,25 +54,6 @@ from datacontract.model.run import Run, Check -def _determine_sql_server_type(data_contract, sql_server_type): - if sql_server_type == "auto": - if data_contract.servers is None or len(data_contract.servers) == 0: - raise RuntimeError("Export with server_type='auto' requires servers in the data contract.") - - server_types = set([server.type for server in data_contract.servers.values()]) - if "snowflake" in server_types: - return "snowflake" - elif "postgres" in server_types: - return "postgres" - elif "databricks" in server_types: - return "databricks" - else: - # default to snowflake dialect - return "snowflake" - else: - return sql_server_type - - class DataContract: def __init__( self, @@ -83,7 +65,7 @@ def __init__( examples: bool = False, publish_url: str = None, publish_to_opentelemetry: bool = False, - spark: str = None, + spark: SparkSession = None, inline_definitions: bool = False, ): self._data_contract_file = data_contract_file @@ -385,13 +367,13 @@ def export(self, export_format, model: str = "all", rdf_base: str = None, sql_se if export_format == "terraform": return to_terraform(data_contract) if export_format == "sql": - server_type = _determine_sql_server_type(data_contract, sql_server_type) + server_type = self._determine_sql_server_type(data_contract, sql_server_type) return to_sql_ddl(data_contract, server_type=server_type) if export_format == "sql-query": if data_contract.models is None: raise RuntimeError(f"Export to {export_format} requires models in the data contract.") - server_type = _determine_sql_server_type(data_contract, sql_server_type) + server_type = self._determine_sql_server_type(data_contract, sql_server_type) model_names = list(data_contract.models.keys()) @@ -443,6 +425,24 @@ def export(self, export_format, model: str = "all", rdf_base: str = None, sql_se print(f"Export format {export_format} not supported.") return "" + def _determine_sql_server_type(data_contract, sql_server_type): + if sql_server_type == "auto": + if data_contract.servers is None or len(data_contract.servers) == 0: + raise RuntimeError("Export with server_type='auto' requires servers in the data contract.") + + server_types = set([server.type for server in data_contract.servers.values()]) + if "snowflake" in server_types: + return "snowflake" + elif "postgres" in server_types: + return "postgres" + elif "databricks" in server_types: + return "databricks" + else: + # default to snowflake dialect + return "snowflake" + else: + return sql_server_type + def _get_examples_server(self, data_contract, run, tmp_dir): run.log_info(f"Copying examples to files in temporary directory {tmp_dir}") format = "json" diff --git a/datacontract/engines/soda/check_soda_execute.py b/datacontract/engines/soda/check_soda_execute.py index 16299988..23651a9c 100644 --- a/datacontract/engines/soda/check_soda_execute.py +++ b/datacontract/engines/soda/check_soda_execute.py @@ -1,5 +1,6 @@ import logging +from pyspark.sql import SparkSession from soda.scan import Scan from datacontract.engines.soda.connections.bigquery import \ @@ -19,7 +20,7 @@ from datacontract.model.run import Run, Check, Log -def check_soda_execute(run: Run, data_contract: DataContractSpecification, server: Server, spark, tmp_dir): +def check_soda_execute(run: Run, data_contract: DataContractSpecification, server: Server, spark: SparkSession, tmp_dir): if data_contract is None: run.log_warn("Cannot run engine soda-core, as data contract is invalid") return