forked from datacontract/datacontract-cli
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Import from Glue DataCatalog feature (datacontract#166)
* Import from Glue Datacatalog feature * Updates for import command * Move moto test to fixtures * Adding Server object in the glue import function * Set mock environment with monkeypatch * remove unused import os
- Loading branch information
Showing
7 changed files
with
326 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import boto3 | ||
from typing import List | ||
|
||
from datacontract.model.data_contract_specification import ( | ||
DataContractSpecification, | ||
Model, | ||
Field, | ||
Server, | ||
) | ||
|
||
|
||
def get_glue_database(datebase_name: str): | ||
"""Get the details Glue database. | ||
Args: | ||
database_name (str): glue database to request. | ||
Returns: | ||
set: catalogid and locationUri | ||
""" | ||
|
||
glue = boto3.client("glue") | ||
try: | ||
response = glue.get_database(Name=datebase_name) | ||
except glue.exceptions.EntityNotFoundException: | ||
print(f"Database not found {datebase_name}.") | ||
return (None, None) | ||
except Exception as e: | ||
# todo catch all | ||
print(f"Error: {e}") | ||
return (None, None) | ||
|
||
return (response["Database"]["CatalogId"], response["Database"].get("LocationUri", "None")) | ||
|
||
|
||
def get_glue_tables(database_name: str) -> List[str]: | ||
"""Get the list of tables in a Glue database. | ||
Args: | ||
database_name (str): glue database to request. | ||
Returns: | ||
List[string]: List of table names | ||
""" | ||
|
||
glue = boto3.client("glue") | ||
|
||
# Set the paginator | ||
paginator = glue.get_paginator("get_tables") | ||
|
||
# Initialize an empty list to store the table names | ||
table_names = [] | ||
try: | ||
# Paginate through the tables | ||
for page in paginator.paginate(DatabaseName=database_name, PaginationConfig={"PageSize": 100}): | ||
# Add the tables from the current page to the list | ||
table_names.extend([table["Name"] for table in page["TableList"] if "Name" in table]) | ||
except glue.exceptions.EntityNotFoundException: | ||
print(f"Database {database_name} not found.") | ||
return [] | ||
except Exception as e: | ||
# todo catch all | ||
print(f"Error: {e}") | ||
return [] | ||
|
||
return table_names | ||
|
||
|
||
def get_glue_table_schema(database_name: str, table_name: str): | ||
"""Get the schema of a Glue table. | ||
Args: | ||
database_name (str): Glue database name. | ||
table_name (str): Glue table name. | ||
Returns: | ||
dict: Table schema | ||
""" | ||
|
||
glue = boto3.client("glue") | ||
|
||
# Get the table schema | ||
try: | ||
response = glue.get_table(DatabaseName=database_name, Name=table_name) | ||
except glue.exceptions.EntityNotFoundException: | ||
print(f"Table {table_name} not found in database {database_name}.") | ||
return {} | ||
except Exception as e: | ||
# todo catch all | ||
print(f"Error: {e}") | ||
return {} | ||
|
||
table_schema = response["Table"]["StorageDescriptor"]["Columns"] | ||
|
||
# when using hive partition keys, the schema is stored in the PartitionKeys field | ||
if response["Table"].get("PartitionKeys") is not None: | ||
for pk in response["Table"]["PartitionKeys"]: | ||
table_schema.append( | ||
{ | ||
"Name": pk["Name"], | ||
"Type": pk["Type"], | ||
"Hive": True, | ||
"Comment": "Partition Key", | ||
} | ||
) | ||
|
||
return table_schema | ||
|
||
|
||
def import_glue(data_contract_specification: DataContractSpecification, source: str): | ||
"""Import the schema of a Glue database.""" | ||
|
||
catalogid, location_uri = get_glue_database(source) | ||
|
||
# something went wrong | ||
if catalogid is None: | ||
return data_contract_specification | ||
|
||
tables = get_glue_tables(source) | ||
|
||
data_contract_specification.servers = { | ||
"production": Server(type="glue", account=catalogid, database=source, location=location_uri), | ||
} | ||
|
||
for table_name in tables: | ||
if data_contract_specification.models is None: | ||
data_contract_specification.models = {} | ||
|
||
table_schema = get_glue_table_schema(source, table_name) | ||
|
||
fields = {} | ||
for column in table_schema: | ||
field = Field() | ||
field.type = map_type_from_sql(column["Type"]) | ||
|
||
# hive partitons are required, but are not primary keys | ||
if column.get("Hive"): | ||
field.required = True | ||
|
||
field.description = column.get("Comment") | ||
|
||
fields[column["Name"]] = field | ||
|
||
data_contract_specification.models[table_name] = Model( | ||
type="table", | ||
fields=fields, | ||
) | ||
|
||
return data_contract_specification | ||
|
||
|
||
def map_type_from_sql(sql_type: str): | ||
if sql_type is None: | ||
return None | ||
|
||
if sql_type.lower().startswith("varchar"): | ||
return "varchar" | ||
if sql_type.lower().startswith("string"): | ||
return "string" | ||
if sql_type.lower().startswith("text"): | ||
return "text" | ||
elif sql_type.lower().startswith("byte"): | ||
return "byte" | ||
elif sql_type.lower().startswith("short"): | ||
return "short" | ||
elif sql_type.lower().startswith("integer"): | ||
return "integer" | ||
elif sql_type.lower().startswith("long"): | ||
return "long" | ||
elif sql_type.lower().startswith("bigint"): | ||
return "long" | ||
elif sql_type.lower().startswith("float"): | ||
return "float" | ||
elif sql_type.lower().startswith("double"): | ||
return "double" | ||
elif sql_type.lower().startswith("boolean"): | ||
return "boolean" | ||
elif sql_type.lower().startswith("timestamp"): | ||
return "timestamp" | ||
elif sql_type.lower().startswith("date"): | ||
return "date" | ||
else: | ||
return "variant" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
dataContractSpecification: 0.9.3 | ||
id: my-data-contract-id | ||
info: | ||
title: My Data Contract | ||
version: 0.0.1 | ||
servers: | ||
production: | ||
account: '123456789012' | ||
database: test_database | ||
location: s3://test_bucket/testdb | ||
type: glue | ||
models: | ||
test_table: | ||
type: table | ||
fields: | ||
field_one: | ||
type: string | ||
field_two: | ||
type: integer | ||
field_three: | ||
type: timestamp | ||
part_one: | ||
description: Partition Key | ||
required: True | ||
type: string |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import boto3 | ||
from typer.testing import CliRunner | ||
import logging | ||
import yaml | ||
from moto import mock_aws | ||
import pytest | ||
|
||
from datacontract.cli import app | ||
from datacontract.data_contract import DataContract | ||
|
||
logging.basicConfig(level=logging.INFO, force=True) | ||
|
||
db_name = "test_database" | ||
table_name = "test_table" | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def aws_credentials(monkeypatch): | ||
"""Mocked AWS Credentials for moto.""" | ||
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing") | ||
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing") | ||
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing") | ||
monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing") | ||
monkeypatch.setenv("AWS_SESSION_TOKEN", "testing") | ||
monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def setup_mock_glue(aws_credentials): | ||
with mock_aws(): | ||
client = boto3.client("glue") | ||
|
||
client.create_database( | ||
DatabaseInput={ | ||
"Name": db_name, | ||
"LocationUri": "s3://test_bucket/testdb", | ||
}, | ||
) | ||
|
||
client.create_table( | ||
DatabaseName=db_name, | ||
TableInput={ | ||
"Name": table_name, | ||
"StorageDescriptor": { | ||
"Columns": [ | ||
{ | ||
"Name": "field_one", | ||
"Type": "string", | ||
}, | ||
{ | ||
"Name": "field_two", | ||
"Type": "integer", | ||
}, | ||
{ | ||
"Name": "field_three", | ||
"Type": "timestamp", | ||
}, | ||
] | ||
}, | ||
"PartitionKeys": [ | ||
{ | ||
"Name": "part_one", | ||
"Type": "string", | ||
}, | ||
], | ||
}, | ||
) | ||
# everything after the yield will run after the fixture is used | ||
yield client | ||
|
||
|
||
@mock_aws | ||
def test_cli(setup_mock_glue): | ||
runner = CliRunner() | ||
result = runner.invoke( | ||
app, | ||
[ | ||
"import", | ||
"--format", | ||
"glue", | ||
"--source", | ||
"test_database", | ||
], | ||
) | ||
assert result.exit_code == 0 | ||
|
||
|
||
@mock_aws | ||
def test_import_glue_schema(setup_mock_glue): | ||
result = DataContract().import_from_source("glue", "test_database") | ||
|
||
with open("fixtures/glue/datacontract.yaml") as file: | ||
expected = file.read() | ||
|
||
print("Result", result.to_yaml()) | ||
assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) | ||
# Disable linters so we don't get "missing description" warnings | ||
assert DataContract(data_contract_str=expected).lint(enabled_linters=set()).has_passed() |