Skip to content

Commit

Permalink
Import from Glue DataCatalog feature (datacontract#166)
Browse files Browse the repository at this point in the history
* Import from Glue Datacatalog feature

* Updates for import command

* Move moto test to fixtures

* Adding Server object in the glue import function

* Set mock environment with monkeypatch

* remove unused import os
  • Loading branch information
jverhoeks authored May 2, 2024
1 parent 0610d1a commit e9231f3
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 12 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -647,15 +647,15 @@ data products, find the true domain owner of a field attribute)
### import

```
Usage: datacontract import [OPTIONS]
Create a data contract from the given source file. Prints to stdout.
╭─ Options ───────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ * --format [sql|avro] The format of the source file. [default: None] [required] │
│ * --source TEXT The path to the file that should be imported. [default: None] [required]
│ --help Show this message and exit. │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
Usage: datacontract import [OPTIONS]
Create a data contract from the given source location. Prints to stdout.
╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
│ * --format [sql|avro|glue] The format of the source file. [default: None] [required]
│ * --source TEXT The path to the file or Glue Database that should be imported. [default: None] [required] │
│ --help Show this message and exit.
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
```

Example:
Expand All @@ -670,6 +670,7 @@ Available import options:
|--------------------|------------------------------------------------|---------|
| `sql` | Import from SQL DDL ||
| `avro` | Import from AVRO schemas ||
| `glue` | Import from AWS Glue DataCatalog ||
| `protobuf` | Import from Protobuf schemas | TBD |
| `jsonschema` | Import from JSON Schemas | TBD |
| `bigquery` | Import from BigQuery Schemas | TBD |
Expand Down
5 changes: 3 additions & 2 deletions datacontract/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,16 @@ def export(
class ImportFormat(str, Enum):
sql = "sql"
avro = "avro"
glue = "glue"


@app.command(name="import")
def import_(
format: Annotated[ImportFormat, typer.Option(help="The format of the source file.")],
source: Annotated[str, typer.Option(help="The path to the file that should be imported.")],
source: Annotated[str, typer.Option(help="The path to the file or Glue Database that should be imported.")],
):
"""
Create a data contract from the given source file. Prints to stdout.
Create a data contract from the given source location. Prints to stdout.
"""
result = DataContract().import_from_source(format, source)
console.print(result.to_yaml())
Expand Down
3 changes: 3 additions & 0 deletions datacontract/data_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from datacontract.export.terraform_converter import to_terraform
from datacontract.imports.avro_importer import import_avro
from datacontract.imports.sql_importer import import_sql
from datacontract.imports.glue_importer import import_glue
from datacontract.integration.publish_datamesh_manager import publish_datamesh_manager
from datacontract.integration.publish_opentelemetry import publish_opentelemetry
from datacontract.lint import resolve
Expand Down Expand Up @@ -476,6 +477,8 @@ def import_from_source(self, format: str, source: str) -> DataContractSpecificat
data_contract_specification = import_sql(data_contract_specification, format, source)
elif format == "avro":
data_contract_specification = import_avro(data_contract_specification, source)
elif format == "glue":
data_contract_specification = import_glue(data_contract_specification, source)
else:
print(f"Import format {format} not supported.")

Expand Down
183 changes: 183 additions & 0 deletions datacontract/imports/glue_importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import boto3
from typing import List

from datacontract.model.data_contract_specification import (
DataContractSpecification,
Model,
Field,
Server,
)


def get_glue_database(datebase_name: str):
"""Get the details Glue database.
Args:
database_name (str): glue database to request.
Returns:
set: catalogid and locationUri
"""

glue = boto3.client("glue")
try:
response = glue.get_database(Name=datebase_name)
except glue.exceptions.EntityNotFoundException:
print(f"Database not found {datebase_name}.")
return (None, None)
except Exception as e:
# todo catch all
print(f"Error: {e}")
return (None, None)

return (response["Database"]["CatalogId"], response["Database"].get("LocationUri", "None"))


def get_glue_tables(database_name: str) -> List[str]:
"""Get the list of tables in a Glue database.
Args:
database_name (str): glue database to request.
Returns:
List[string]: List of table names
"""

glue = boto3.client("glue")

# Set the paginator
paginator = glue.get_paginator("get_tables")

# Initialize an empty list to store the table names
table_names = []
try:
# Paginate through the tables
for page in paginator.paginate(DatabaseName=database_name, PaginationConfig={"PageSize": 100}):
# Add the tables from the current page to the list
table_names.extend([table["Name"] for table in page["TableList"] if "Name" in table])
except glue.exceptions.EntityNotFoundException:
print(f"Database {database_name} not found.")
return []
except Exception as e:
# todo catch all
print(f"Error: {e}")
return []

return table_names


def get_glue_table_schema(database_name: str, table_name: str):
"""Get the schema of a Glue table.
Args:
database_name (str): Glue database name.
table_name (str): Glue table name.
Returns:
dict: Table schema
"""

glue = boto3.client("glue")

# Get the table schema
try:
response = glue.get_table(DatabaseName=database_name, Name=table_name)
except glue.exceptions.EntityNotFoundException:
print(f"Table {table_name} not found in database {database_name}.")
return {}
except Exception as e:
# todo catch all
print(f"Error: {e}")
return {}

table_schema = response["Table"]["StorageDescriptor"]["Columns"]

# when using hive partition keys, the schema is stored in the PartitionKeys field
if response["Table"].get("PartitionKeys") is not None:
for pk in response["Table"]["PartitionKeys"]:
table_schema.append(
{
"Name": pk["Name"],
"Type": pk["Type"],
"Hive": True,
"Comment": "Partition Key",
}
)

return table_schema


def import_glue(data_contract_specification: DataContractSpecification, source: str):
"""Import the schema of a Glue database."""

catalogid, location_uri = get_glue_database(source)

# something went wrong
if catalogid is None:
return data_contract_specification

tables = get_glue_tables(source)

data_contract_specification.servers = {
"production": Server(type="glue", account=catalogid, database=source, location=location_uri),
}

for table_name in tables:
if data_contract_specification.models is None:
data_contract_specification.models = {}

table_schema = get_glue_table_schema(source, table_name)

fields = {}
for column in table_schema:
field = Field()
field.type = map_type_from_sql(column["Type"])

# hive partitons are required, but are not primary keys
if column.get("Hive"):
field.required = True

field.description = column.get("Comment")

fields[column["Name"]] = field

data_contract_specification.models[table_name] = Model(
type="table",
fields=fields,
)

return data_contract_specification


def map_type_from_sql(sql_type: str):
if sql_type is None:
return None

if sql_type.lower().startswith("varchar"):
return "varchar"
if sql_type.lower().startswith("string"):
return "string"
if sql_type.lower().startswith("text"):
return "text"
elif sql_type.lower().startswith("byte"):
return "byte"
elif sql_type.lower().startswith("short"):
return "short"
elif sql_type.lower().startswith("integer"):
return "integer"
elif sql_type.lower().startswith("long"):
return "long"
elif sql_type.lower().startswith("bigint"):
return "long"
elif sql_type.lower().startswith("float"):
return "float"
elif sql_type.lower().startswith("double"):
return "double"
elif sql_type.lower().startswith("boolean"):
return "boolean"
elif sql_type.lower().startswith("timestamp"):
return "timestamp"
elif sql_type.lower().startswith("date"):
return "date"
else:
return "variant"
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ dependencies = [
"avro==1.11.3",
"opentelemetry-exporter-otlp-proto-grpc~=1.16.0",
"opentelemetry-exporter-otlp-proto-http~=1.16.0",
"deltalake~=0.17.0"
"deltalake~=0.17.0",
"boto3<1.34.70,>=1.34.41",
"botocore<1.34.70,>=1.34.41"
]

[project.optional-dependencies]
Expand All @@ -47,6 +49,7 @@ dev = [
"ruff",
"pytest",
"pytest-xdist",
"moto",
# testcontainers 4.x have issues with Kafka on arm
# https://github.com/testcontainers/testcontainers-python/issues/450
"testcontainers<4.0.0",
Expand Down
25 changes: 25 additions & 0 deletions tests/fixtures/glue/datacontract.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
dataContractSpecification: 0.9.3
id: my-data-contract-id
info:
title: My Data Contract
version: 0.0.1
servers:
production:
account: '123456789012'
database: test_database
location: s3://test_bucket/testdb
type: glue
models:
test_table:
type: table
fields:
field_one:
type: string
field_two:
type: integer
field_three:
type: timestamp
part_one:
description: Partition Key
required: True
type: string
98 changes: 98 additions & 0 deletions tests/test_import_glue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import boto3
from typer.testing import CliRunner
import logging
import yaml
from moto import mock_aws
import pytest

from datacontract.cli import app
from datacontract.data_contract import DataContract

logging.basicConfig(level=logging.INFO, force=True)

db_name = "test_database"
table_name = "test_table"


@pytest.fixture(scope="function")
def aws_credentials(monkeypatch):
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing")
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing")
monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1")


@pytest.fixture(scope="function")
def setup_mock_glue(aws_credentials):
with mock_aws():
client = boto3.client("glue")

client.create_database(
DatabaseInput={
"Name": db_name,
"LocationUri": "s3://test_bucket/testdb",
},
)

client.create_table(
DatabaseName=db_name,
TableInput={
"Name": table_name,
"StorageDescriptor": {
"Columns": [
{
"Name": "field_one",
"Type": "string",
},
{
"Name": "field_two",
"Type": "integer",
},
{
"Name": "field_three",
"Type": "timestamp",
},
]
},
"PartitionKeys": [
{
"Name": "part_one",
"Type": "string",
},
],
},
)
# everything after the yield will run after the fixture is used
yield client


@mock_aws
def test_cli(setup_mock_glue):
runner = CliRunner()
result = runner.invoke(
app,
[
"import",
"--format",
"glue",
"--source",
"test_database",
],
)
assert result.exit_code == 0


@mock_aws
def test_import_glue_schema(setup_mock_glue):
result = DataContract().import_from_source("glue", "test_database")

with open("fixtures/glue/datacontract.yaml") as file:
expected = file.read()

print("Result", result.to_yaml())
assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected)
# Disable linters so we don't get "missing description" warnings
assert DataContract(data_contract_str=expected).lint(enabled_linters=set()).has_passed()

0 comments on commit e9231f3

Please sign in to comment.