Skip to content

Commit

Permalink
fix: cross account catalog_id glue client function calls (#370)
Browse files Browse the repository at this point in the history
Co-authored-by: nicor88 <[email protected]>
  • Loading branch information
brunofaustino and nicor88 authored Sep 7, 2023
1 parent 23478ef commit 905746f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 10 deletions.
57 changes: 47 additions & 10 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,15 @@ def get_glue_table(self, relation: AthenaRelation) -> Optional[GetTableResponseT
"""
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

try:
table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier)
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier)
except ClientError as e:
if e.response["Error"]["Code"] == "EntityNotFoundException":
LOGGER.debug(f"Table {relation.render()} does not exists - Ignoring")
Expand Down Expand Up @@ -596,16 +600,25 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(src_relation.database)
src_catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

src_table = glue_client.get_table(DatabaseName=src_relation.schema, Name=src_relation.identifier).get("Table")
src_table = glue_client.get_table(
CatalogId=src_catalog_id, DatabaseName=src_relation.schema, Name=src_relation.identifier
).get("Table")

src_table_partitions = glue_client.get_partitions(
DatabaseName=src_relation.schema, TableName=src_relation.identifier
CatalogId=src_catalog_id, DatabaseName=src_relation.schema, TableName=src_relation.identifier
).get("Partitions")

data_catalog = self._get_data_catalog(src_relation.database)
target_catalog_id = get_catalog_id(data_catalog)

target_table_partitions = glue_client.get_partitions(
DatabaseName=target_relation.schema, TableName=target_relation.identifier
CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableName=target_relation.identifier
).get("Partitions")

target_table_version = {
Expand All @@ -618,7 +631,9 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
}

# perform a table swap
glue_client.update_table(DatabaseName=target_relation.schema, TableInput=target_table_version)
glue_client.update_table(
CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableInput=target_table_version
)
LOGGER.debug(f"Table {target_relation.render()} swapped with the content of {src_relation.render()}")

# we delete the target table partitions in any case
Expand All @@ -627,6 +642,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
if target_table_partitions:
for partition_batch in get_chunks(target_table_partitions, AthenaAdapter.BATCH_DELETE_PARTITION_API_LIMIT):
glue_client.batch_delete_partition(
CatalogId=target_catalog_id,
DatabaseName=target_relation.schema,
TableName=target_relation.identifier,
PartitionsToDelete=[{"Values": partition["Values"]} for partition in partition_batch],
Expand All @@ -635,6 +651,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
if src_table_partitions:
for partition_batch in get_chunks(src_table_partitions, AthenaAdapter.BATCH_CREATE_PARTITION_API_LIMIT):
glue_client.batch_create_partition(
CatalogId=target_catalog_id,
DatabaseName=target_relation.schema,
TableName=target_relation.identifier,
PartitionInputList=[
Expand Down Expand Up @@ -676,6 +693,9 @@ def expire_glue_table_versions(
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

Expand All @@ -688,7 +708,10 @@ def expire_glue_table_versions(
location = v["Table"]["StorageDescriptor"]["Location"]
try:
glue_client.delete_table_version(
DatabaseName=relation.schema, TableName=relation.identifier, VersionId=str(version)
CatalogId=catalog_id,
DatabaseName=relation.schema,
TableName=relation.identifier,
VersionId=str(version),
)
deleted_versions.append(version)
LOGGER.debug(f"Deleted version {version} of table {relation.render()} ")
Expand Down Expand Up @@ -720,13 +743,16 @@ def persist_docs_to_glue(
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

# By default, there is no need to update Glue Table
need_udpate_table = False
# Get Table from Glue
table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.name)["Table"]
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.name)["Table"]
# Prepare new version of Glue Table picking up significant fields
updated_table = self._get_table_input(table)
# Update table description
Expand Down Expand Up @@ -766,7 +792,10 @@ def persist_docs_to_glue(
# It prevents redundant schema version creating after incremental runs.
if need_udpate_table:
glue_client.update_table(
DatabaseName=relation.schema, TableInput=updated_table, SkipArchive=skip_archive_table_version
CatalogId=catalog_id,
DatabaseName=relation.schema,
TableInput=updated_table,
SkipArchive=skip_archive_table_version,
)

@available
Expand Down Expand Up @@ -797,11 +826,16 @@ def get_columns_in_relation(self, relation: AthenaRelation) -> List[AthenaColumn
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

try:
table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier)["Table"]
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier)[
"Table"
]
except ClientError as e:
if e.response["Error"]["Code"] == "EntityNotFoundException":
LOGGER.debug("table not exist, catching the error")
Expand Down Expand Up @@ -829,11 +863,14 @@ def delete_from_glue_catalog(self, relation: AthenaRelation) -> None:
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

try:
glue_client.delete_table(DatabaseName=schema_name, Name=table_name)
glue_client.delete_table(CatalogId=catalog_id, DatabaseName=schema_name, Name=table_name)
LOGGER.debug(f"Deleted table from glue catalog: {relation.render()}")
except ClientError as e:
if e.response["Error"]["Code"] == "EntityNotFoundException":
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def test_generate_s3_location(
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test_get_table_location(self, dbt_debug_caplog, mock_aws_service):
table_name = "test_table"
self.adapter.acquire_connection("dummy")
Expand All @@ -417,6 +418,7 @@ def test_get_table_location(self, dbt_debug_caplog, mock_aws_service):
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog, mock_aws_service):
table_name = "test_table"
self.adapter.acquire_connection("dummy")
Expand All @@ -438,6 +440,7 @@ def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog,
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service):
view_name = "view"
self.adapter.acquire_connection("dummy")
Expand All @@ -452,6 +455,7 @@ def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service):
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test_get_table_location_with_failure(self, dbt_debug_caplog, mock_aws_service):
table_name = "test_table"
self.adapter.acquire_connection("dummy")
Expand Down Expand Up @@ -500,6 +504,7 @@ def test_clean_up_partitions_will_work(self, dbt_debug_caplog, mock_aws_service)

@mock_glue
@mock_athena
@mock_sts
def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -517,6 +522,7 @@ def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_se

@mock_glue
@mock_athena
@mock_sts
def test_clean_up_table_view(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -534,6 +540,7 @@ def test_clean_up_table_view(self, dbt_debug_caplog, mock_aws_service):
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test_clean_up_table_delete_table(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -844,6 +851,7 @@ def test_parse_s3_path(self, s3_path, expected):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_swap_table_with_partitions(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -870,6 +878,7 @@ def test_swap_table_with_partitions(self, mock_aws_service):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_swap_table_without_partitions(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -894,6 +903,7 @@ def test_swap_table_without_partitions(self, mock_aws_service):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_swap_table_with_partitions_to_one_without(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -931,6 +941,7 @@ def test_swap_table_with_partitions_to_one_without(self, mock_aws_service):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_swap_table_with_no_partitions_to_one_with(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -990,6 +1001,7 @@ def test__get_glue_table_versions_to_expire(self, mock_aws_service, dbt_debug_ca
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_expire_glue_table_versions(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -1101,6 +1113,7 @@ def test_get_work_group_output_location_not_enforced(self, mock_aws_service):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_persist_docs_to_glue_no_comment(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -1142,6 +1155,7 @@ def test_persist_docs_to_glue_no_comment(self, mock_aws_service):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_persist_docs_to_glue_comment(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -1194,6 +1208,7 @@ def test_list_schemas(self, mock_aws_service):

@mock_athena
@mock_glue
@mock_sts
def test_get_columns_in_relation(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1214,6 +1229,7 @@ def test_get_columns_in_relation(self, mock_aws_service):

@mock_athena
@mock_glue
@mock_sts
def test_get_columns_in_relation_not_found_table(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1229,6 +1245,7 @@ def test_get_columns_in_relation_not_found_table(self, mock_aws_service):

@mock_athena
@mock_glue
@mock_sts
def test_delete_from_glue_catalog(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1242,6 +1259,7 @@ def test_delete_from_glue_catalog(self, mock_aws_service):

@mock_athena
@mock_glue
@mock_sts
def test_delete_from_glue_catalog_not_found_table(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1258,6 +1276,7 @@ def test_delete_from_glue_catalog_not_found_table(self, dbt_debug_caplog, mock_a
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test__get_relation_type_table(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1272,6 +1291,7 @@ def test__get_relation_type_table(self, dbt_debug_caplog, mock_aws_service):
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test__get_relation_type_with_no_type(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1286,6 +1306,7 @@ def test__get_relation_type_with_no_type(self, dbt_debug_caplog, mock_aws_servic
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test__get_relation_type_view(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1300,6 +1321,7 @@ def test__get_relation_type_view(self, dbt_debug_caplog, mock_aws_service):
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test__get_relation_type_iceberg(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down

0 comments on commit 905746f

Please sign in to comment.