diff --git a/datacontract/cli.py b/datacontract/cli.py index d152b96d..461cf755 100644 --- a/datacontract/cli.py +++ b/datacontract/cli.py @@ -133,6 +133,9 @@ class ExportFormat(str, Enum): def export( format: Annotated[ExportFormat, typer.Option(help="The export format.")], server: Annotated[str, typer.Option(help="The server name to export.")] = None, + model: Annotated[str, typer.Option(help="Use the key of the model in the data contract yaml file " + "to refer to a model, e.g., `orders`, or `all` for all " + "models (default).")] = "all", rdf_base: Annotated[Optional[str], typer.Option(help="[rdf] The base URI used to generate the RDF graph.", rich_help_panel="RDF Options")] = None, location: Annotated[ str, typer.Argument(help="The location (url or path) of the data contract yaml.")] = "datacontract.yaml", @@ -141,7 +144,7 @@ def export( Convert data contract to a specific format. Prints to stdout. """ # TODO exception handling - result = DataContract(data_contract_file=location, server=server).export(format, rdf_base) + result = DataContract(data_contract_file=location, server=server).export(export_format=format, model=model,rdf_base=rdf_base) print(result) diff --git a/datacontract/data_contract.py b/datacontract/data_contract.py index b59647b1..8dc7f337 100644 --- a/datacontract/data_contract.py +++ b/datacontract/data_contract.py @@ -17,7 +17,7 @@ from datacontract.export.jsonschema_converter import to_jsonschema, to_jsonschema_json from datacontract.export.odcs_converter import to_odcs_yaml from datacontract.export.protobuf_converter import to_protobuf -from datacontract.export.rdf_converter import to_rdf, to_rdf_n3 +from datacontract.export.rdf_converter import to_rdf_n3 from datacontract.export.sodacl_converter import to_sodacl_yaml from datacontract.export.terraform_converter import to_terraform from datacontract.imports.sql_importer import import_sql @@ -246,15 +246,28 @@ def get_data_contract_specification(self) -> DataContractSpecification: inline_definitions=self._inline_definitions, ) - def export(self, export_format, rdf_base: str = None) -> str: + def export(self, export_format, model: str = "all", rdf_base: str = None) -> str: data_contract = resolve.resolve_data_contract(self._data_contract_file, self._data_contract_str, self._data_contract) if export_format == "jsonschema": - if data_contract.models is None or len(data_contract.models.items()) != 1: - print(f"Export to {export_format} currently only works with exactly one model in the data contract.") - return "" - model_name, model = next(iter(data_contract.models.items())) - return to_jsonschema_json(model_name, model) + if data_contract.models is None: + raise RuntimeError( f"Export to {export_format} requires models in the data contract.") + + model_names = list(data_contract.models.keys()) + + if model == "all": + if len(data_contract.models.items()) != 1: + raise RuntimeError( f"Export to {export_format} is model specific. Specify the model via --model $MODEL_NAME. Available models: {model_names}") + + model_name, model_value = next(iter(data_contract.models.items())) + return to_jsonschema_json(model_name, model_value) + else: + model_name = model + model_value = data_contract.models.get(model_name) + if model_value is None: + raise RuntimeError( f"Model {model_name} not found in the data contract. Available models: {model_names}") + + return to_jsonschema_json(model_name, model_value) if export_format == "sodacl": return to_sodacl_yaml(data_contract) if export_format == "dbt": @@ -262,7 +275,24 @@ def export(self, export_format, rdf_base: str = None) -> str: if export_format == "dbt-sources": return to_dbt_sources_yaml(data_contract, self._server) if export_format == "dbt-staging-sql": - return to_dbt_staging_sql(data_contract) + if data_contract.models is None: + raise RuntimeError(f"Export to {export_format} requires models in the data contract.") + + model_names = list(data_contract.models.keys()) + + if model == "all": + if len(data_contract.models.items()) != 1: + raise RuntimeError(f"Export to {export_format} is model specific. Specify the model via --model $MODEL_NAME. Available models: {model_names}") + + model_name, model_value = next(iter(data_contract.models.items())) + return to_dbt_staging_sql(data_contract, model_name, model_value) + else: + model_name = model + model_value = data_contract.models.get(model_name) + if model_value is None: + raise RuntimeError(f"Model {model_name} not found in the data contract. Available models: {model_names}") + + return to_dbt_staging_sql(data_contract, model_name, model_value) if export_format == "odcs": return to_odcs_yaml(data_contract) if export_format == "rdf": @@ -270,11 +300,24 @@ def export(self, export_format, rdf_base: str = None) -> str: if export_format == "protobuf": return to_protobuf(data_contract) if export_format == "avro": - if data_contract.models is None or len(data_contract.models.items()) != 1: - print(f"Export to {export_format} currently only works with exactly one model in the data contract.") - return "" - model_name, model = next(iter(data_contract.models.items())) - return to_avro_schema_json(model_name, model) + if data_contract.models is None: + raise RuntimeError(f"Export to {export_format} requires models in the data contract.") + + model_names = list(data_contract.models.keys()) + + if model == "all": + if len(data_contract.models.items()) != 1: + raise RuntimeError(f"Export to {export_format} is model specific. Specify the model via --model $MODEL_NAME. Available models: {model_names}") + + model_name, model_value = next(iter(data_contract.models.items())) + return to_avro_schema_json(model_name, model_value) + else: + model_name = model + model_value = data_contract.models.get(model_name) + if model_value is None: + raise RuntimeError(f"Model {model_name} not found in the data contract. Available models: {model_names}") + + return to_avro_schema_json(model_name, model_value) if export_format == "terraform": return to_terraform(data_contract) else: diff --git a/datacontract/export/dbt_converter.py b/datacontract/export/dbt_converter.py index d0ab99a3..c50a400e 100644 --- a/datacontract/export/dbt_converter.py +++ b/datacontract/export/dbt_converter.py @@ -21,15 +21,14 @@ def to_dbt_models_yaml(data_contract_spec: DataContractSpecification): return yaml.dump(dbt, indent=2, sort_keys=False, allow_unicode=True) -def to_dbt_staging_sql(data_contract_spec: DataContractSpecification): +def to_dbt_staging_sql(data_contract_spec: DataContractSpecification, model_name: str, model_value: Model) -> str: if data_contract_spec.models is None or len(data_contract_spec.models.items()) != 1: print(f"Export to dbt-staging-sql currently only works with exactly one model in the data contract.") return "" id = data_contract_spec.id - model_name, model = next(iter(data_contract_spec.models.items())) columns = [] - for field_name, field in model.fields.items(): + for field_name, field in model_value.fields.items(): # TODO escape SQL reserved key words, probably dependent on server type columns.append(field_name) return f""" diff --git a/tests/test_export_dbt_staging_sql.py b/tests/test_export_dbt_staging_sql.py index 589f5d57..f419c4fe 100644 --- a/tests/test_export_dbt_staging_sql.py +++ b/tests/test_export_dbt_staging_sql.py @@ -32,7 +32,7 @@ def test_to_dbt_staging(): from {{ source('orders-unit-test', 'orders') }} """ - result = to_dbt_staging_sql(data_contract) + result = to_dbt_staging_sql(data_contract, "orders", data_contract.models.get("orders")) assert yaml.safe_load(result) == yaml.safe_load(expected) diff --git a/tests/test_schema.py b/tests/test_schema.py index 653ca734..bc34814c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -15,40 +15,38 @@ def test_schema(): info: title: Orders Latest version: 1.0.0 -schema: - type: json-schema - specification: - orders: - description: One record per order. Includes cancelled and deleted orders. - type: object - properties: - order_id: - type: string - description: Primary key of the orders table - order_timestamp: - type: string - format: date-time - description: The business timestamp in UTC when the order was successfully registered in the source system and the payment was successful. - order_total: - type: integer - description: Total amount of the order in the smallest monetary unit (e.g., cents). - line_items: - type: object - properties: - lines_item_id: - type: string - description: Primary key of the lines_item_id table - order_id: - type: string - description: Foreign key to the orders table - sku: - type: string - description: The purchased article number""") +models: + orders: + description: One record per order. Includes cancelled and deleted orders. + type: object + fields: + order_id: + type: string + description: Primary key of the orders table + order_timestamp: + type: string + format: date-time + description: The business timestamp in UTC when the order was successfully registered in the source system and the payment was successful. + order_total: + type: integer + description: Total amount of the order in the smallest monetary unit (e.g., cents). + line_items: + type: object + fields: + lines_item_id: + type: string + description: Primary key of the lines_item_id table + order_id: + type: string + description: Foreign key to the orders table + sku: + type: string + description: The purchased article number""") data_contract.lint() data_contract.test() data_contract.export(export_format="odcs") data_contract.export(export_format="dbt-model") data_contract.export(export_format="dbt-source") - data_contract.export(export_format="dbt-staging-sql") - data_contract.export(export_format="jsonschema") + data_contract.export(export_format="dbt-staging-sql", model="orders") + data_contract.export(export_format="jsonschema", model="orders")