From 54ee2ee5469df0df34b3900d8bc9f57ade1a0c06 Mon Sep 17 00:00:00 2001 From: Jan Seeger Date: Thu, 28 Mar 2024 09:35:11 +0100 Subject: [PATCH] Implementation of pydantic model export. (#117) Implements Pydantic model export (#109). Also replaced uses of print() in `cli.py` with `console.print` to not interpret markup on export. --- datacontract/cli.py | 37 +++--- datacontract/data_contract.py | 4 +- datacontract/export/pydantic_converter.py | 140 ++++++++++++++++++++++ tests/test_export_pydantic.py | 115 ++++++++++++++++++ 4 files changed, 278 insertions(+), 18 deletions(-) create mode 100644 datacontract/export/pydantic_converter.py create mode 100644 tests/test_export_pydantic.py diff --git a/datacontract/cli.py b/datacontract/cli.py index 44d2d349..c32d680a 100644 --- a/datacontract/cli.py +++ b/datacontract/cli.py @@ -5,7 +5,7 @@ import typer from click import Context from rich import box -from rich import print +from rich.console import Console from rich.table import Table from typer.core import TyperGroup from typing_extensions import Annotated @@ -14,6 +14,7 @@ from datacontract.init.download_datacontract_file import \ download_datacontract_file, FileExistsException +console = Console() class OrderedCommands(TyperGroup): def list_commands(self, ctx: Context) -> Iterable[str]: @@ -29,7 +30,7 @@ def list_commands(self, ctx: Context) -> Iterable[str]: def version_callback(value: bool): if value: - print(metadata.version("datacontract-cli")) + console.print(metadata.version("datacontract-cli")) raise typer.Exit() @@ -66,10 +67,10 @@ def init( try: download_datacontract_file(location, template, overwrite) except FileExistsException: - print("File already exists, use --overwrite to overwrite") + console.print("File already exists, use --overwrite to overwrite") raise typer.Exit(code=1) else: - print("📄 data contract written to " + location) + console.print("📄 data contract written to " + location) @app.command() @@ -120,7 +121,7 @@ def test( """ Run schema and quality tests on configured servers. """ - print(f"Testing {location}") + console.print(f"Testing {location}") if server == "all": server = None run = DataContract( @@ -138,6 +139,7 @@ def test( class ExportFormat(str, Enum): jsonschema = "jsonschema" + pydantic_model = "pydantic-model" sodacl = "sodacl" dbt = "dbt" dbt_sources = "dbt-sources" @@ -181,7 +183,7 @@ def export( ] = "datacontract.yaml", ): """ - Convert data contract to a specific format. Prints to stdout. + Convert data contract to a specific format. console.prints to stdout. """ # TODO exception handling result = DataContract(data_contract_file=location, server=server).export( @@ -190,7 +192,8 @@ def export( rdf_base=rdf_base, sql_server_type=sql_server_type, ) - print(result) + # Don't interpret console markup in output. + console.print(result, markup=False) class ImportFormat(str, Enum): @@ -207,7 +210,7 @@ def import_( Create a data contract from the given source file. Prints to stdout. """ result = DataContract().import_from_source(format, source) - print(result.to_yaml()) + console.print(result.to_yaml()) @app.command() @@ -224,7 +227,7 @@ def breaking( DataContract(data_contract_file=location_new, inline_definitions=True) ) - print(result.breaking_str()) + console.print(result.breaking_str()) if not result.passed_checks(): raise typer.Exit(code=1) @@ -244,7 +247,7 @@ def changelog( DataContract(data_contract_file=location_new, inline_definitions=True) ) - print(result.changelog_str()) + console.print(result.changelog_str()) @app.command() @@ -261,21 +264,21 @@ def diff( DataContract(data_contract_file=location_new, inline_definitions=True) ) - print(result.changelog_str()) + console.print(result.changelog_str()) def _handle_result(run): _print_table(run) if run.result == "passed": - print( + console.print( f"🟢 data contract is valid. Run {len(run.checks)} checks. Took {(run.timestampEnd - run.timestampStart).total_seconds()} seconds." ) else: - print("🔴 data contract is invalid, found the following errors:") + console.print("🔴 data contract is invalid, found the following errors:") i = 1 for check in run.checks: if check.result != "passed": - print(str(++i) + ") " + check.reason) + console.print(str(++i) + ") " + check.reason) raise typer.Exit(code=1) @@ -287,7 +290,7 @@ def _print_table(run): table.add_column("Details", max_width=50) for check in run.checks: table.add_row(with_markup(check.result), check.name, to_field(run, check), check.reason) - print(table) + console.print(table) def to_field(run, check): @@ -301,9 +304,9 @@ def to_field(run, check): def _print_logs(run): - print("\nLogs:") + console.print("\nLogs:") for log in run.logs: - print(log.timestamp.strftime("%y-%m-%d %H:%M:%S"), log.level.ljust(5), log.message) + console.print(log.timestamp.strftime("%y-%m-%d %H:%M:%S"), log.level.ljust(5), log.message) def with_markup(result): diff --git a/datacontract/data_contract.py b/datacontract/data_contract.py index e9719315..3ef90cf0 100644 --- a/datacontract/data_contract.py +++ b/datacontract/data_contract.py @@ -22,6 +22,7 @@ from datacontract.export.jsonschema_converter import to_jsonschema_json from datacontract.export.odcs_converter import to_odcs_yaml from datacontract.export.protobuf_converter import to_protobuf +from datacontract.export.pydantic_converter import to_pydantic_model_str from datacontract.export.rdf_converter import to_rdf_n3 from datacontract.export.sodacl_converter import to_sodacl_yaml from datacontract.export.sql_converter import to_sql_ddl, to_sql_query @@ -436,7 +437,8 @@ def export(self, export_format, model: str = "all", rdf_base: str = None, sql_se ) return to_great_expectations(data_contract, model_name) - + if export_format == "pydantic-model": + return to_pydantic_model_str(data_contract) else: print(f"Export format {export_format} not supported.") return "" diff --git a/datacontract/export/pydantic_converter.py b/datacontract/export/pydantic_converter.py new file mode 100644 index 00000000..75d7dd2e --- /dev/null +++ b/datacontract/export/pydantic_converter.py @@ -0,0 +1,140 @@ +import datacontract.model.data_contract_specification as spec +import typing +import ast + +def to_pydantic_model_str(contract: spec.DataContractSpecification) -> str: + classdefs = [generate_model_class(model_name, model) for (model_name, model) in contract.models.items()] + documentation = [ast.Expr(ast.Constant(contract.info.description))] if ( + contract.info and contract.info.description) else [] + result = ast.Module(body=[ + ast.Import( + names=[ast.Name("datetime", ctx=ast.Load()), + ast.Name("typing", ctx=ast.Load()), + ast.Name("pydantic", ctx=ast.Load())]), + *documentation, + *classdefs], + type_ignores=[]) + return ast.unparse(result) + +def optional_of(node) -> ast.Subscript: + return ast.Subscript( + value=ast.Attribute( + ast.Name(id="typing", ctx=ast.Load()), + attr="Optional", + ctx=ast.Load()), + slice=node) + +def list_of(node) -> ast.Subscript: + return ast.Subscript( + value=ast.Name(id="list", ctx=ast.Load()), + slice=node) + +def product_of(nodes: list[typing.Any]) -> ast.Subscript: + return ast.Subscript( + value=ast.Attribute( + value=ast.Name(id="typing", ctx=ast.Load()), + attr="Product", + ctx=ast.Load()), + slice=ast.Tuple(nodes, ctx=ast.Load()) + ) + + +type_annotation_type = typing.Union[ast.Name, ast.Attribute, ast.Constant, ast.Subscript] + +def constant_field_annotation(field_name: str, field: spec.Field)\ + -> tuple[type_annotation_type, + typing.Optional[ast.ClassDef]]: + match field.type: + case "string"|"text"|"varchar": + return (ast.Name("str", ctx=ast.Load()), None) + case "number", "decimal", "numeric": + # Either integer or float in specification, + # so we use float. + return (ast.Name("float", ctx=ast.Load()), None) + case "int" | "integer" | "long" | "bigint": + return (ast.Name("int", ctx=ast.Load()), None) + case "float" | "double": + return (ast.Name("float", ctx=ast.Load()), None) + case "boolean": + return (ast.Name("bool", ctx=ast.Load()), None) + case "timestamp" | "timestamp_tz" | "timestamp_ntz": + return (ast.Attribute( + value=ast.Name(id="datetime", ctx=ast.Load()), + attr="datetime"), None) + case "date": + return (ast.Attribute( + value=ast.Name(id="datetime", ctx=ast.Load()), + attr="date"), None) + case "bytes": + return (ast.Name("bytes", ctx=ast.Load()), None) + case "null": + return (ast.Constant("None"), None) + case "array": + (annotated_type, new_class) = type_annotation(field_name, field.items) + return (list_of(annotated_type), new_class) + case "object" | "record" | "struct": + classdef = generate_field_class(field_name.capitalize(), field) + return (ast.Name(field_name.capitalize(), ctx=ast.Load()), classdef) + case _: + raise RuntimeError(f"Unsupported field type {field.type}.") + + +def type_annotation(field_name: str, field: spec.Field) -> tuple[type_annotation_type, typing.Optional[ast.ClassDef]]: + if field.required: + return constant_field_annotation(field_name, field) + else: + (annotated_type, new_classes) = constant_field_annotation(field_name, field) + return (optional_of(annotated_type), new_classes) + +def is_simple_field(field: spec.Field) -> bool: + return field.type not in set(["object", "record", "struct"]) + +def field_definitions(fields: dict[str, spec.Field]) ->\ + tuple[list[ast.Expr], + list[ast.ClassDef]]: + annotations = [] + classes = [] + for (field_name, field) in fields.items(): + (ann, new_class) = type_annotation(field_name, field) + annotations.append( + ast.AnnAssign( + target=ast.Name(id=field_name, ctx=ast.Store()), + annotation=ann, + simple=1)) + if field.description and is_simple_field(field): + annotations.append( + ast.Expr(ast.Constant(field.description))) + if new_class: + classes.append(new_class) + return (annotations, classes) + +def generate_field_class(field_name: str, field: spec.Field) -> ast.ClassDef: + assert(field.type in set(["object", "record", "struct"])) + (annotated_type, new_classes) = field_definitions(field.fields) + documentation = [ast.Expr(ast.Constant(field.description))] if field.description else [] + return ast.ClassDef( + name=field_name, + bases=[ast.Attribute(value=ast.Name(id="pydantic", ctx=ast.Load()), + attr="BaseModel", + ctx=ast.Load())], + body=[ + *documentation, + *new_classes, + *annotated_type + ], + keywords=[], + decorator_list=[]) + + +def generate_model_class(name: str, model_definition: spec.Model) -> ast.ClassDef: + (field_assignments, nested_classes) = field_definitions(model_definition.fields) + documentation = [ast.Expr(ast.Constant(model_definition.description))] if model_definition.description else [] + result = ast.ClassDef( + name=name.capitalize(), + bases=[ast.Attribute(value=ast.Name(id="pydantic", ctx=ast.Load()), + attr="BaseModel", + ctx=ast.Load())], + body=[*documentation, *nested_classes, *field_assignments], + keywords=[], + decorator_list=[]) + return result diff --git a/tests/test_export_pydantic.py b/tests/test_export_pydantic.py new file mode 100644 index 00000000..445c0fd8 --- /dev/null +++ b/tests/test_export_pydantic.py @@ -0,0 +1,115 @@ +import datacontract.model.data_contract_specification as spec +import datacontract.export.pydantic_converter as conv +from textwrap import dedent +import ast + + +# These tests would be easier if AST nodes were comparable. +# Current string comparisons are very brittle. +def test_simple_model_export(): + m = spec.Model( + fields={ + "f": spec.Field( + type="string")}) + ast_class = conv.generate_model_class("Test", m) + assert ast.unparse(ast_class) == dedent( + """ + class Test(pydantic.BaseModel): + f: typing.Optional[str] + """).strip() + +def test_array_model_export(): + m = spec.Model( + fields={ + "f": spec.Field( + type="array", + items=spec.Field( + type="string", + required=True))}) + ast_class = conv.generate_model_class("Test", m) + assert ast.unparse(ast_class) == dedent( + """ + class Test(pydantic.BaseModel): + f: typing.Optional[list[str]] + """).strip() + +def test_object_model_export(): + m = spec.Model( + fields={ + "f": spec.Field( + type="object", + fields={ + "f1": spec.Field( + type="string", + required=True)})}) + ast_class = conv.generate_model_class("Test", m) + assert ast.unparse(ast_class) == dedent( + """ + class Test(pydantic.BaseModel): + + class F(pydantic.BaseModel): + f1: str + f: typing.Optional[F] + """).strip() + +def test_model_documentation_export(): + m = spec.Model( + description="A test model", + fields={ + "f": spec.Field( + type="object", + description="A test field", + fields={ + "f1": spec.Field( + type="string", + required=True)})}) + ast_class = conv.generate_model_class("Test", m) + assert ast.unparse(ast_class) == dedent( + """ + class Test(pydantic.BaseModel): + \"""A test model\""" + + class F(pydantic.BaseModel): + \"""A test field\""" + f1: str + f: typing.Optional[F] + """).strip() + +def test_model_field_description_export(): + m = spec.Model( + fields={ + "f": spec.Field( + type="object", + fields={ + "f1": spec.Field( + type="string", + description="A test field", + required=True)})}) + ast_class = conv.generate_model_class("Test", m) + assert ast.unparse(ast_class) == dedent( + """ + class Test(pydantic.BaseModel): + + class F(pydantic.BaseModel): + f1: str + 'A test field' + f: typing.Optional[F] + """).strip() + +def test_model_description_export(): + m = spec.DataContractSpecification( + info=spec.Info(description="Contract description"), + models={"test_model": + spec.Model( + fields={ + "f": spec.Field( + type="string")})}) + result = conv.to_pydantic_model_str(m) + assert result == dedent( + """ + import datetime, typing, pydantic + 'Contract description' + + class Test_model(pydantic.BaseModel): + f: typing.Optional[str] + """).strip()