Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/fix/cross-account-delete-table' …
Browse files Browse the repository at this point in the history
…into fix/cross-account-delete-table
  • Loading branch information
brunofaustino committed Sep 5, 2023
2 parents c609ad2 + 128ff07 commit ac131b8
Show file tree
Hide file tree
Showing 24 changed files with 703 additions and 115 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko @thenaturalist
* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
python-version: ['3.8', '3.9', '3.10', '3.11']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
<p align="center">
<img src="https://raw.githubusercontent.com/dbt-athena/dbt-athena/main/static/images/dbt-athena-long.png" />
<a href="https://pypi.org/project/dbt-athena-community/"><img src="https://badge.fury.io/py/dbt-athena-community.svg" /></a>
<a target="_blank" href="https://pypi.org/project/dlt/" style="background:none">
<img src="https://img.shields.io/pypi/pyversions/dbt-athena-community">
</a>
<a href="https://pycqa.github.io/isort/"><img src="https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336" /></a>
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" /></a>
<a href="https://github.com/python/mypy"><img src="https://www.mypy-lang.org/static/mypy_badge.svg" /></a>
<a href="https://pepy.tech/project/dbt-athena-community"><img src="https://pepy.tech/badge/dbt-athena-community/month" /></a>
<a href="https://pepy.tech/project/dbt-athena-community"><img src="https://static.pepy.tech/badge/dbt-athena-community/month" /></a>
</p>

## Features

* Supports dbt version `1.5.*`
* Supports dbt version `1.6.*`
* Supports from Python
* Supports [seeds][seeds]
* Correctly detects views and their columns
* Supports [table materialization][table]
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/athena/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.5.1"
version = "1.6.0"
40 changes: 35 additions & 5 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import hashlib
import json
import re
import time
from concurrent.futures.thread import ThreadPoolExecutor
from contextlib import contextmanager
Expand Down Expand Up @@ -132,6 +134,7 @@ def execute( # type: ignore
endpoint_url: Optional[str] = None,
cache_size: int = 0,
cache_expiration_time: int = 0,
catch_partitions_limit: bool = False,
**kwargs,
):
def inner() -> AthenaCursor:
Expand All @@ -158,7 +161,12 @@ def inner() -> AthenaCursor:
return self

retry = tenacity.Retrying(
retry=retry_if_exception(lambda _: True),
# No need to retry if TOO_MANY_OPEN_PARTITIONS occurs.
# Otherwise, Athena throws ICEBERG_FILESYSTEM_ERROR after retry,
# because not all files are removed immediately after first try to create table
retry=retry_if_exception(
lambda e: False if catch_partitions_limit and "TOO_MANY_OPEN_PARTITIONS" in str(e) else True
),
stop=stop_after_attempt(self._retry_config.attempt),
wait=wait_exponential(
multiplier=self._retry_config.attempt,
Expand Down Expand Up @@ -231,15 +239,37 @@ def open(cls, connection: Connection) -> Connection:
@classmethod
def get_response(cls, cursor: AthenaCursor) -> AthenaAdapterResponse:
code = "OK" if cursor.state == AthenaQueryExecution.STATE_SUCCEEDED else "ERROR"
rowcount, data_scanned_in_bytes = cls.process_query_stats(cursor)
return AthenaAdapterResponse(
_message=f"{code} {cursor.rowcount}",
rows_affected=cursor.rowcount,
_message=f"{code} {rowcount}",
rows_affected=rowcount,
code=code,
data_scanned_in_bytes=cursor.data_scanned_in_bytes,
data_scanned_in_bytes=data_scanned_in_bytes,
)

@staticmethod
def process_query_stats(cursor: AthenaCursor) -> Tuple[int, int]:
"""
Helper function to parse query statistics from SELECT statements.
The function looks for all statements that contains rowcount or data_scanned_in_bytes,
then strip the SELECT statements, and pick the value between curly brackets.
"""
if all(map(cursor.query.__contains__, ["rowcount", "data_scanned_in_bytes"])):
try:
query_split = cursor.query.lower().split("select")[-1]
# query statistics are in the format {"rowcount":1, "data_scanned_in_bytes": 3}
# the following statement extract the content between { and }
query_stats = re.search("{(.*)}", query_split)
if query_stats:
stats = json.loads("{" + query_stats.group(1) + "}")
return stats.get("rowcount", -1), stats.get("data_scanned_in_bytes", 0)
except Exception as err:
logger.debug(f"There was an error parsing query stats {err}")
return -1, 0
return cursor.rowcount, cursor.data_scanned_in_bytes

def cancel(self, connection: Connection) -> None:
connection.handle.cancel()
pass

def add_begin_query(self) -> None:
pass
Expand Down
120 changes: 100 additions & 20 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import os
import posixpath as path
import re
import tempfile
from itertools import chain
from textwrap import dedent
Expand All @@ -19,6 +20,7 @@
TableTypeDef,
TableVersionTypeDef,
)
from pyathena.error import OperationalError

from dbt.adapters.athena import AthenaConnectionManager
from dbt.adapters.athena.column import AthenaColumn
Expand All @@ -42,7 +44,13 @@
get_table_type,
)
from dbt.adapters.athena.s3 import S3DataNaming
from dbt.adapters.athena.utils import clean_sql_comment, get_catalog_id, get_chunks
from dbt.adapters.athena.utils import (
AthenaCatalogType,
clean_sql_comment,
get_catalog_id,
get_catalog_type,
get_chunks,
)
from dbt.adapters.base import ConstraintSupport, available
from dbt.adapters.base.relation import BaseRelation, InformationSchema
from dbt.adapters.sql import SQLAdapter
Expand Down Expand Up @@ -413,36 +421,85 @@ def _get_one_table_for_catalog(self, table: TableTypeDef, database: str) -> List
for idx, col in enumerate(table["StorageDescriptor"]["Columns"] + table.get("PartitionKeys", []))
]

def _get_one_table_for_non_glue_catalog(
self, table: TableTypeDef, schema: str, database: str
) -> List[Dict[str, Any]]:
table_catalog = {
"table_database": database,
"table_schema": schema,
"table_name": table["Name"],
"table_type": RELATION_TYPE_MAP[table.get("TableType", "EXTERNAL_TABLE")].value,
"table_comment": table.get("Parameters", {}).get("comment", ""),
}
return [
{
**table_catalog,
**{
"column_name": col["Name"],
"column_index": idx,
"column_type": col["Type"],
"column_comment": col.get("Comment", ""),
},
}
for idx, col in enumerate(table["Columns"] + table.get("PartitionKeys", []))
]

def _get_one_catalog(
self,
information_schema: InformationSchema,
schemas: Dict[str, Optional[Set[str]]],
manifest: Manifest,
) -> agate.Table:
data_catalog = self._get_data_catalog(information_schema.path.database)
catalog_id = get_catalog_id(data_catalog)
data_catalog_type = get_catalog_type(data_catalog)

conn = self.connections.get_thread_connection()
client = conn.handle
with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

catalog = []
paginator = glue_client.get_paginator("get_tables")
for schema, relations in schemas.items():
kwargs = {
"DatabaseName": schema,
"MaxResults": 100,
}
# If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3 infers it from the account Id.
if catalog_id:
kwargs["CatalogId"] = catalog_id
if data_catalog_type == AthenaCatalogType.GLUE:
with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

catalog = []
paginator = glue_client.get_paginator("get_tables")
for schema, relations in schemas.items():
kwargs = {
"DatabaseName": schema,
"MaxResults": 100,
}
# If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3
# infers it from the account Id.
catalog_id = get_catalog_id(data_catalog)
if catalog_id:
kwargs["CatalogId"] = catalog_id

for page in paginator.paginate(**kwargs):
for table in page["TableList"]:
if relations and table["Name"] in relations:
catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database))
table = agate.Table.from_object(catalog)
else:
with boto3_client_lock:
athena_client = client.session.client(
"athena", region_name=client.region_name, config=get_boto3_config()
)

for page in paginator.paginate(**kwargs):
for table in page["TableList"]:
if relations and table["Name"] in relations:
catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database))
catalog = []
paginator = athena_client.get_paginator("list_table_metadata")
for schema, relations in schemas.items():
for page in paginator.paginate(
CatalogName=information_schema.path.database,
DatabaseName=schema,
MaxResults=50, # Limit supported by this operation
):
for table in page["TableMetadataList"]:
if relations and table["Name"] in relations:
catalog.extend(
self._get_one_table_for_non_glue_catalog(
table, schema, information_schema.path.database
)
)
table = agate.Table.from_object(catalog)

table = agate.Table.from_object(catalog)
filtered_table = self._catalog_filter_table(table, manifest)
return self._join_catalog_table_owners(filtered_table, manifest)

Expand Down Expand Up @@ -912,3 +969,26 @@ def _get_table_input(table: TableTypeDef) -> TableInputTypeDef:
returned by get_table() method.
"""
return {k: v for k, v in table.items() if k in TableInputTypeDef.__annotations__}

@available
def run_query_with_partitions_limit_catching(self, sql: str) -> str:
conn = self.connections.get_thread_connection()
cursor = conn.handle.cursor()
try:
cursor.execute(sql, catch_partitions_limit=True)
except OperationalError as e:
LOGGER.debug(f"CAUGHT EXCEPTION: {e}")
if "TOO_MANY_OPEN_PARTITIONS" in str(e):
return "TOO_MANY_OPEN_PARTITIONS"
raise e
return f'{{"rowcount":{cursor.rowcount},"data_scanned_in_bytes":{cursor.data_scanned_in_bytes}}}'

@available
def format_partition_keys(self, partition_keys: List[str]) -> str:
return ", ".join([self.format_one_partition_key(k) for k in partition_keys])

@available
def format_one_partition_key(self, partition_key: str) -> str:
"""Check if partition key uses Iceberg hidden partitioning"""
hidden = re.search(r"^(hour|day|month|year)\((.+)\)", partition_key.lower())
return f"date_trunc('{hidden.group(1)}', {hidden.group(2)})" if hidden else partition_key.lower()
1 change: 1 addition & 0 deletions dbt/adapters/athena/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def add(self, relation: AthenaRelation) -> None:

RELATION_TYPE_MAP = {
"EXTERNAL_TABLE": TableType.TABLE,
"EXTERNAL": TableType.TABLE, # type returned by federated query tables
"MANAGED_TABLE": TableType.TABLE,
"VIRTUAL_VIEW": TableType.VIEW,
"table": TableType.TABLE,
Expand Down
13 changes: 12 additions & 1 deletion dbt/adapters/athena/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Generator, List, Optional, TypeVar

from mypy_boto3_athena.type_defs import DataCatalogTypeDef
Expand All @@ -9,7 +10,17 @@ def clean_sql_comment(comment: str) -> str:


def get_catalog_id(catalog: Optional[DataCatalogTypeDef]) -> Optional[str]:
return catalog["Parameters"]["catalog-id"] if catalog else None
return catalog["Parameters"]["catalog-id"] if catalog and catalog["Type"] == AthenaCatalogType.GLUE.value else None


class AthenaCatalogType(Enum):
GLUE = "GLUE"
LAMBDA = "LAMBDA"
HIVE = "HIVE"


def get_catalog_type(catalog: Optional[DataCatalogTypeDef]) -> Optional[AthenaCatalogType]:
return AthenaCatalogType(catalog["Type"]) if catalog else None


T = TypeVar("T")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{% macro get_partition_batches(sql) -%}
{%- set partitioned_by = config.get('partitioned_by') -%}
{%- set athena_partitions_limit = config.get('partitions_limit', 100) | int -%}
{%- set partitioned_keys = adapter.format_partition_keys(partitioned_by) -%}
{% do log('PARTITIONED KEYS: ' ~ partitioned_keys) %}

{% call statement('get_partitions', fetch_result=True) %}
select distinct {{ partitioned_keys }} from ({{ sql }}) order by {{ partitioned_keys }};
{% endcall %}

{%- set table = load_result('get_partitions').table -%}
{%- set rows = table.rows -%}
{%- set partitions = {} -%}
{% do log('TOTAL PARTITIONS TO PROCESS: ' ~ rows | length) %}
{%- set partitions_batches = [] -%}

{%- for row in rows -%}
{%- set single_partition = [] -%}
{%- for col in row -%}

{%- set column_type = adapter.convert_type(table, loop.index0) -%}
{%- if column_type == 'integer' -%}
{%- set value = col | string -%}
{%- elif column_type == 'string' -%}
{%- set value = "'" + col + "'" -%}
{%- elif column_type == 'date' -%}
{%- set value = "DATE'" + col | string + "'" -%}
{%- elif column_type == 'timestamp' -%}
{%- set value = "TIMESTAMP'" + col | string + "'" -%}
{%- else -%}
{%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%}
{%- endif -%}
{%- set partition_key = adapter.format_one_partition_key(partitioned_by[loop.index0]) -%}
{%- do single_partition.append(partition_key + '=' + value) -%}
{%- endfor -%}

{%- set single_partition_expression = single_partition | join(' and ') -%}

{%- set batch_number = (loop.index0 / athena_partitions_limit) | int -%}
{% if not batch_number in partitions %}
{% do partitions.update({batch_number: []}) %}
{% endif %}

{%- do partitions[batch_number].append('(' + single_partition_expression + ')') -%}
{%- if partitions[batch_number] | length == athena_partitions_limit or loop.last -%}
{%- do partitions_batches.append(partitions[batch_number] | join(' or ')) -%}
{%- endif -%}
{%- endfor -%}

{{ return(partitions_batches) }}

{%- endmacro %}
Loading

0 comments on commit ac131b8

Please sign in to comment.