diff --git a/README.md b/README.md index 327e376a..b8556e04 100644 --- a/README.md +++ b/README.md @@ -647,15 +647,15 @@ data products, find the true domain owner of a field attribute) ### import ``` - Usage: datacontract import [OPTIONS] - - Create a data contract from the given source file. Prints to stdout. - -╭─ Options ───────────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ * --format [sql|avro] The format of the source file. [default: None] [required] │ -│ * --source TEXT The path to the file that should be imported. [default: None] [required] │ -│ --help Show this message and exit. │ -╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ + Usage: datacontract import [OPTIONS] + + Create a data contract from the given source location. Prints to stdout. + +╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ * --format [sql|avro|glue] The format of the source file. [default: None] [required] │ +│ * --source TEXT The path to the file or Glue Database that should be imported. [default: None] [required] │ +│ --help Show this message and exit. │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` Example: @@ -670,6 +670,7 @@ Available import options: |--------------------|------------------------------------------------|---------| | `sql` | Import from SQL DDL | ✅ | | `avro` | Import from AVRO schemas | ✅ | +| `glue` | Import from AWS Glue DataCatalog | ✅ | | `protobuf` | Import from Protobuf schemas | TBD | | `jsonschema` | Import from JSON Schemas | TBD | | `bigquery` | Import from BigQuery Schemas | TBD | diff --git a/datacontract/cli.py b/datacontract/cli.py index 0c8937d2..8fc06bc7 100644 --- a/datacontract/cli.py +++ b/datacontract/cli.py @@ -204,15 +204,16 @@ def export( class ImportFormat(str, Enum): sql = "sql" avro = "avro" + glue = "glue" @app.command(name="import") def import_( format: Annotated[ImportFormat, typer.Option(help="The format of the source file.")], - source: Annotated[str, typer.Option(help="The path to the file that should be imported.")], + source: Annotated[str, typer.Option(help="The path to the file or Glue Database that should be imported.")], ): """ - Create a data contract from the given source file. Prints to stdout. + Create a data contract from the given source location. Prints to stdout. """ result = DataContract().import_from_source(format, source) console.print(result.to_yaml()) diff --git a/datacontract/data_contract.py b/datacontract/data_contract.py index 25d059dc..dbddbbc0 100644 --- a/datacontract/data_contract.py +++ b/datacontract/data_contract.py @@ -27,6 +27,7 @@ from datacontract.export.terraform_converter import to_terraform from datacontract.imports.avro_importer import import_avro from datacontract.imports.sql_importer import import_sql +from datacontract.imports.glue_importer import import_glue from datacontract.integration.publish_datamesh_manager import publish_datamesh_manager from datacontract.integration.publish_opentelemetry import publish_opentelemetry from datacontract.lint import resolve @@ -476,6 +477,8 @@ def import_from_source(self, format: str, source: str) -> DataContractSpecificat data_contract_specification = import_sql(data_contract_specification, format, source) elif format == "avro": data_contract_specification = import_avro(data_contract_specification, source) + elif format == "glue": + data_contract_specification = import_glue(data_contract_specification, source) else: print(f"Import format {format} not supported.") diff --git a/datacontract/imports/glue_importer.py b/datacontract/imports/glue_importer.py new file mode 100644 index 00000000..df2438d7 --- /dev/null +++ b/datacontract/imports/glue_importer.py @@ -0,0 +1,183 @@ +import boto3 +from typing import List + +from datacontract.model.data_contract_specification import ( + DataContractSpecification, + Model, + Field, + Server, +) + + +def get_glue_database(datebase_name: str): + """Get the details Glue database. + + Args: + database_name (str): glue database to request. + + Returns: + set: catalogid and locationUri + """ + + glue = boto3.client("glue") + try: + response = glue.get_database(Name=datebase_name) + except glue.exceptions.EntityNotFoundException: + print(f"Database not found {datebase_name}.") + return (None, None) + except Exception as e: + # todo catch all + print(f"Error: {e}") + return (None, None) + + return (response["Database"]["CatalogId"], response["Database"].get("LocationUri", "None")) + + +def get_glue_tables(database_name: str) -> List[str]: + """Get the list of tables in a Glue database. + + Args: + database_name (str): glue database to request. + + Returns: + List[string]: List of table names + """ + + glue = boto3.client("glue") + + # Set the paginator + paginator = glue.get_paginator("get_tables") + + # Initialize an empty list to store the table names + table_names = [] + try: + # Paginate through the tables + for page in paginator.paginate(DatabaseName=database_name, PaginationConfig={"PageSize": 100}): + # Add the tables from the current page to the list + table_names.extend([table["Name"] for table in page["TableList"] if "Name" in table]) + except glue.exceptions.EntityNotFoundException: + print(f"Database {database_name} not found.") + return [] + except Exception as e: + # todo catch all + print(f"Error: {e}") + return [] + + return table_names + + +def get_glue_table_schema(database_name: str, table_name: str): + """Get the schema of a Glue table. + + Args: + database_name (str): Glue database name. + table_name (str): Glue table name. + + Returns: + dict: Table schema + """ + + glue = boto3.client("glue") + + # Get the table schema + try: + response = glue.get_table(DatabaseName=database_name, Name=table_name) + except glue.exceptions.EntityNotFoundException: + print(f"Table {table_name} not found in database {database_name}.") + return {} + except Exception as e: + # todo catch all + print(f"Error: {e}") + return {} + + table_schema = response["Table"]["StorageDescriptor"]["Columns"] + + # when using hive partition keys, the schema is stored in the PartitionKeys field + if response["Table"].get("PartitionKeys") is not None: + for pk in response["Table"]["PartitionKeys"]: + table_schema.append( + { + "Name": pk["Name"], + "Type": pk["Type"], + "Hive": True, + "Comment": "Partition Key", + } + ) + + return table_schema + + +def import_glue(data_contract_specification: DataContractSpecification, source: str): + """Import the schema of a Glue database.""" + + catalogid, location_uri = get_glue_database(source) + + # something went wrong + if catalogid is None: + return data_contract_specification + + tables = get_glue_tables(source) + + data_contract_specification.servers = { + "production": Server(type="glue", account=catalogid, database=source, location=location_uri), + } + + for table_name in tables: + if data_contract_specification.models is None: + data_contract_specification.models = {} + + table_schema = get_glue_table_schema(source, table_name) + + fields = {} + for column in table_schema: + field = Field() + field.type = map_type_from_sql(column["Type"]) + + # hive partitons are required, but are not primary keys + if column.get("Hive"): + field.required = True + + field.description = column.get("Comment") + + fields[column["Name"]] = field + + data_contract_specification.models[table_name] = Model( + type="table", + fields=fields, + ) + + return data_contract_specification + + +def map_type_from_sql(sql_type: str): + if sql_type is None: + return None + + if sql_type.lower().startswith("varchar"): + return "varchar" + if sql_type.lower().startswith("string"): + return "string" + if sql_type.lower().startswith("text"): + return "text" + elif sql_type.lower().startswith("byte"): + return "byte" + elif sql_type.lower().startswith("short"): + return "short" + elif sql_type.lower().startswith("integer"): + return "integer" + elif sql_type.lower().startswith("long"): + return "long" + elif sql_type.lower().startswith("bigint"): + return "long" + elif sql_type.lower().startswith("float"): + return "float" + elif sql_type.lower().startswith("double"): + return "double" + elif sql_type.lower().startswith("boolean"): + return "boolean" + elif sql_type.lower().startswith("timestamp"): + return "timestamp" + elif sql_type.lower().startswith("date"): + return "date" + else: + return "variant" diff --git a/pyproject.toml b/pyproject.toml index abd5e2e8..c809173b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,9 @@ dependencies = [ "avro==1.11.3", "opentelemetry-exporter-otlp-proto-grpc~=1.16.0", "opentelemetry-exporter-otlp-proto-http~=1.16.0", - "deltalake~=0.17.0" + "deltalake~=0.17.0", + "boto3<1.34.70,>=1.34.41", + "botocore<1.34.70,>=1.34.41" ] [project.optional-dependencies] @@ -47,6 +49,7 @@ dev = [ "ruff", "pytest", "pytest-xdist", + "moto", # testcontainers 4.x have issues with Kafka on arm # https://github.com/testcontainers/testcontainers-python/issues/450 "testcontainers<4.0.0", diff --git a/tests/fixtures/glue/datacontract.yaml b/tests/fixtures/glue/datacontract.yaml new file mode 100644 index 00000000..5696fd49 --- /dev/null +++ b/tests/fixtures/glue/datacontract.yaml @@ -0,0 +1,25 @@ +dataContractSpecification: 0.9.3 +id: my-data-contract-id +info: + title: My Data Contract + version: 0.0.1 +servers: + production: + account: '123456789012' + database: test_database + location: s3://test_bucket/testdb + type: glue +models: + test_table: + type: table + fields: + field_one: + type: string + field_two: + type: integer + field_three: + type: timestamp + part_one: + description: Partition Key + required: True + type: string diff --git a/tests/test_import_glue.py b/tests/test_import_glue.py new file mode 100644 index 00000000..0420a904 --- /dev/null +++ b/tests/test_import_glue.py @@ -0,0 +1,98 @@ +import boto3 +from typer.testing import CliRunner +import logging +import yaml +from moto import mock_aws +import pytest + +from datacontract.cli import app +from datacontract.data_contract import DataContract + +logging.basicConfig(level=logging.INFO, force=True) + +db_name = "test_database" +table_name = "test_table" + + +@pytest.fixture(scope="function") +def aws_credentials(monkeypatch): + """Mocked AWS Credentials for moto.""" + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing") + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing") + monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing") + monkeypatch.setenv("AWS_SESSION_TOKEN", "testing") + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + + +@pytest.fixture(scope="function") +def setup_mock_glue(aws_credentials): + with mock_aws(): + client = boto3.client("glue") + + client.create_database( + DatabaseInput={ + "Name": db_name, + "LocationUri": "s3://test_bucket/testdb", + }, + ) + + client.create_table( + DatabaseName=db_name, + TableInput={ + "Name": table_name, + "StorageDescriptor": { + "Columns": [ + { + "Name": "field_one", + "Type": "string", + }, + { + "Name": "field_two", + "Type": "integer", + }, + { + "Name": "field_three", + "Type": "timestamp", + }, + ] + }, + "PartitionKeys": [ + { + "Name": "part_one", + "Type": "string", + }, + ], + }, + ) + # everything after the yield will run after the fixture is used + yield client + + +@mock_aws +def test_cli(setup_mock_glue): + runner = CliRunner() + result = runner.invoke( + app, + [ + "import", + "--format", + "glue", + "--source", + "test_database", + ], + ) + assert result.exit_code == 0 + + +@mock_aws +def test_import_glue_schema(setup_mock_glue): + result = DataContract().import_from_source("glue", "test_database") + + with open("fixtures/glue/datacontract.yaml") as file: + expected = file.read() + + print("Result", result.to_yaml()) + assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) + # Disable linters so we don't get "missing description" warnings + assert DataContract(data_contract_str=expected).lint(enabled_linters=set()).has_passed()