Skip to content

Commit

Permalink
Implementation of pydantic model export. (datacontract#117)
Browse files Browse the repository at this point in the history
Implements Pydantic model export (datacontract#109). 
Also replaced uses of print() in `cli.py` with `console.print` to not interpret markup on export.
  • Loading branch information
jeeger authored Mar 28, 2024
1 parent 57fd02d commit 54ee2ee
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 18 deletions.
37 changes: 20 additions & 17 deletions datacontract/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typer
from click import Context
from rich import box
from rich import print
from rich.console import Console
from rich.table import Table
from typer.core import TyperGroup
from typing_extensions import Annotated
Expand All @@ -14,6 +14,7 @@
from datacontract.init.download_datacontract_file import \
download_datacontract_file, FileExistsException

console = Console()

class OrderedCommands(TyperGroup):
def list_commands(self, ctx: Context) -> Iterable[str]:
Expand All @@ -29,7 +30,7 @@ def list_commands(self, ctx: Context) -> Iterable[str]:

def version_callback(value: bool):
if value:
print(metadata.version("datacontract-cli"))
console.print(metadata.version("datacontract-cli"))
raise typer.Exit()


Expand Down Expand Up @@ -66,10 +67,10 @@ def init(
try:
download_datacontract_file(location, template, overwrite)
except FileExistsException:
print("File already exists, use --overwrite to overwrite")
console.print("File already exists, use --overwrite to overwrite")
raise typer.Exit(code=1)
else:
print("📄 data contract written to " + location)
console.print("📄 data contract written to " + location)


@app.command()
Expand Down Expand Up @@ -120,7 +121,7 @@ def test(
"""
Run schema and quality tests on configured servers.
"""
print(f"Testing {location}")
console.print(f"Testing {location}")
if server == "all":
server = None
run = DataContract(
Expand All @@ -138,6 +139,7 @@ def test(

class ExportFormat(str, Enum):
jsonschema = "jsonschema"
pydantic_model = "pydantic-model"
sodacl = "sodacl"
dbt = "dbt"
dbt_sources = "dbt-sources"
Expand Down Expand Up @@ -181,7 +183,7 @@ def export(
] = "datacontract.yaml",
):
"""
Convert data contract to a specific format. Prints to stdout.
Convert data contract to a specific format. console.prints to stdout.
"""
# TODO exception handling
result = DataContract(data_contract_file=location, server=server).export(
Expand All @@ -190,7 +192,8 @@ def export(
rdf_base=rdf_base,
sql_server_type=sql_server_type,
)
print(result)
# Don't interpret console markup in output.
console.print(result, markup=False)


class ImportFormat(str, Enum):
Expand All @@ -207,7 +210,7 @@ def import_(
Create a data contract from the given source file. Prints to stdout.
"""
result = DataContract().import_from_source(format, source)
print(result.to_yaml())
console.print(result.to_yaml())


@app.command()
Expand All @@ -224,7 +227,7 @@ def breaking(
DataContract(data_contract_file=location_new, inline_definitions=True)
)

print(result.breaking_str())
console.print(result.breaking_str())

if not result.passed_checks():
raise typer.Exit(code=1)
Expand All @@ -244,7 +247,7 @@ def changelog(
DataContract(data_contract_file=location_new, inline_definitions=True)
)

print(result.changelog_str())
console.print(result.changelog_str())


@app.command()
Expand All @@ -261,21 +264,21 @@ def diff(
DataContract(data_contract_file=location_new, inline_definitions=True)
)

print(result.changelog_str())
console.print(result.changelog_str())


def _handle_result(run):
_print_table(run)
if run.result == "passed":
print(
console.print(
f"🟢 data contract is valid. Run {len(run.checks)} checks. Took {(run.timestampEnd - run.timestampStart).total_seconds()} seconds."
)
else:
print("🔴 data contract is invalid, found the following errors:")
console.print("🔴 data contract is invalid, found the following errors:")
i = 1
for check in run.checks:
if check.result != "passed":
print(str(++i) + ") " + check.reason)
console.print(str(++i) + ") " + check.reason)
raise typer.Exit(code=1)


Expand All @@ -287,7 +290,7 @@ def _print_table(run):
table.add_column("Details", max_width=50)
for check in run.checks:
table.add_row(with_markup(check.result), check.name, to_field(run, check), check.reason)
print(table)
console.print(table)


def to_field(run, check):
Expand All @@ -301,9 +304,9 @@ def to_field(run, check):


def _print_logs(run):
print("\nLogs:")
console.print("\nLogs:")
for log in run.logs:
print(log.timestamp.strftime("%y-%m-%d %H:%M:%S"), log.level.ljust(5), log.message)
console.print(log.timestamp.strftime("%y-%m-%d %H:%M:%S"), log.level.ljust(5), log.message)


def with_markup(result):
Expand Down
4 changes: 3 additions & 1 deletion datacontract/data_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from datacontract.export.jsonschema_converter import to_jsonschema_json
from datacontract.export.odcs_converter import to_odcs_yaml
from datacontract.export.protobuf_converter import to_protobuf
from datacontract.export.pydantic_converter import to_pydantic_model_str
from datacontract.export.rdf_converter import to_rdf_n3
from datacontract.export.sodacl_converter import to_sodacl_yaml
from datacontract.export.sql_converter import to_sql_ddl, to_sql_query
Expand Down Expand Up @@ -436,7 +437,8 @@ def export(self, export_format, model: str = "all", rdf_base: str = None, sql_se
)

return to_great_expectations(data_contract, model_name)

if export_format == "pydantic-model":
return to_pydantic_model_str(data_contract)
else:
print(f"Export format {export_format} not supported.")
return ""
Expand Down
140 changes: 140 additions & 0 deletions datacontract/export/pydantic_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import datacontract.model.data_contract_specification as spec
import typing
import ast

def to_pydantic_model_str(contract: spec.DataContractSpecification) -> str:
classdefs = [generate_model_class(model_name, model) for (model_name, model) in contract.models.items()]
documentation = [ast.Expr(ast.Constant(contract.info.description))] if (
contract.info and contract.info.description) else []
result = ast.Module(body=[
ast.Import(
names=[ast.Name("datetime", ctx=ast.Load()),
ast.Name("typing", ctx=ast.Load()),
ast.Name("pydantic", ctx=ast.Load())]),
*documentation,
*classdefs],
type_ignores=[])
return ast.unparse(result)

def optional_of(node) -> ast.Subscript:
return ast.Subscript(
value=ast.Attribute(
ast.Name(id="typing", ctx=ast.Load()),
attr="Optional",
ctx=ast.Load()),
slice=node)

def list_of(node) -> ast.Subscript:
return ast.Subscript(
value=ast.Name(id="list", ctx=ast.Load()),
slice=node)

def product_of(nodes: list[typing.Any]) -> ast.Subscript:
return ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="typing", ctx=ast.Load()),
attr="Product",
ctx=ast.Load()),
slice=ast.Tuple(nodes, ctx=ast.Load())
)


type_annotation_type = typing.Union[ast.Name, ast.Attribute, ast.Constant, ast.Subscript]

def constant_field_annotation(field_name: str, field: spec.Field)\
-> tuple[type_annotation_type,
typing.Optional[ast.ClassDef]]:
match field.type:
case "string"|"text"|"varchar":
return (ast.Name("str", ctx=ast.Load()), None)
case "number", "decimal", "numeric":
# Either integer or float in specification,
# so we use float.
return (ast.Name("float", ctx=ast.Load()), None)
case "int" | "integer" | "long" | "bigint":
return (ast.Name("int", ctx=ast.Load()), None)
case "float" | "double":
return (ast.Name("float", ctx=ast.Load()), None)
case "boolean":
return (ast.Name("bool", ctx=ast.Load()), None)
case "timestamp" | "timestamp_tz" | "timestamp_ntz":
return (ast.Attribute(
value=ast.Name(id="datetime", ctx=ast.Load()),
attr="datetime"), None)
case "date":
return (ast.Attribute(
value=ast.Name(id="datetime", ctx=ast.Load()),
attr="date"), None)
case "bytes":
return (ast.Name("bytes", ctx=ast.Load()), None)
case "null":
return (ast.Constant("None"), None)
case "array":
(annotated_type, new_class) = type_annotation(field_name, field.items)
return (list_of(annotated_type), new_class)
case "object" | "record" | "struct":
classdef = generate_field_class(field_name.capitalize(), field)
return (ast.Name(field_name.capitalize(), ctx=ast.Load()), classdef)
case _:
raise RuntimeError(f"Unsupported field type {field.type}.")


def type_annotation(field_name: str, field: spec.Field) -> tuple[type_annotation_type, typing.Optional[ast.ClassDef]]:
if field.required:
return constant_field_annotation(field_name, field)
else:
(annotated_type, new_classes) = constant_field_annotation(field_name, field)
return (optional_of(annotated_type), new_classes)

def is_simple_field(field: spec.Field) -> bool:
return field.type not in set(["object", "record", "struct"])

def field_definitions(fields: dict[str, spec.Field]) ->\
tuple[list[ast.Expr],
list[ast.ClassDef]]:
annotations = []
classes = []
for (field_name, field) in fields.items():
(ann, new_class) = type_annotation(field_name, field)
annotations.append(
ast.AnnAssign(
target=ast.Name(id=field_name, ctx=ast.Store()),
annotation=ann,
simple=1))
if field.description and is_simple_field(field):
annotations.append(
ast.Expr(ast.Constant(field.description)))
if new_class:
classes.append(new_class)
return (annotations, classes)

def generate_field_class(field_name: str, field: spec.Field) -> ast.ClassDef:
assert(field.type in set(["object", "record", "struct"]))
(annotated_type, new_classes) = field_definitions(field.fields)
documentation = [ast.Expr(ast.Constant(field.description))] if field.description else []
return ast.ClassDef(
name=field_name,
bases=[ast.Attribute(value=ast.Name(id="pydantic", ctx=ast.Load()),
attr="BaseModel",
ctx=ast.Load())],
body=[
*documentation,
*new_classes,
*annotated_type
],
keywords=[],
decorator_list=[])


def generate_model_class(name: str, model_definition: spec.Model) -> ast.ClassDef:
(field_assignments, nested_classes) = field_definitions(model_definition.fields)
documentation = [ast.Expr(ast.Constant(model_definition.description))] if model_definition.description else []
result = ast.ClassDef(
name=name.capitalize(),
bases=[ast.Attribute(value=ast.Name(id="pydantic", ctx=ast.Load()),
attr="BaseModel",
ctx=ast.Load())],
body=[*documentation, *nested_classes, *field_assignments],
keywords=[],
decorator_list=[])
return result
Loading

0 comments on commit 54ee2ee

Please sign in to comment.