Skip to content

Commit

Permalink
Merge branch 'main' into fix/cross-account-delete-table
Browse files Browse the repository at this point in the history
  • Loading branch information
nicor88 authored Sep 6, 2023
2 parents ac131b8 + a9a8696 commit df389db
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 24 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ A dbt profile can be configured to run against AWS Athena using the following co
| schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` |
| database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` |
| poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` |
| debug_query_state | Flag if debug message with Athena query state is needed | Optional | `false` |
| aws_access_key_id | Access key ID of the user performing requests. | Optional | `AKIAIOSFODNN7EXAMPLE` |
| aws_secret_access_key | Secret access key of the user performing requests | Optional | `wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY` |
| aws_profile_name | Profile to use from your AWS shared credentials file. | Optional | `my-profile` |
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
| num_retries | Number of times to retry a failing query | Optional | `3` |
| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` |

**Example profiles.yml entry:**
```yaml
Expand Down
11 changes: 9 additions & 2 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ class AthenaCredentials(Credentials):
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
poll_interval: float = 1.0
debug_query_state: bool = False
_ALIASES = {"catalog": "database"}
num_retries: Optional[int] = 5
s3_data_dir: Optional[str] = None
s3_data_naming: Optional[str] = "schema_table_unique"
# Unfortunately we can not just use dict, must by Dict because we'll get the following error:
# Credentials in profile "athena", target "athena" invalid: Unable to create schema for 'dict'
lf_tags_database: Optional[Dict[str, str]] = None

@property
def type(self) -> str:
Expand All @@ -81,7 +85,8 @@ def _connection_keys(self) -> Tuple[str, ...]:
"endpoint_url",
"s3_data_dir",
"s3_data_naming",
"lf_tags",
"debug_query_state",
"lf_tags_database",
)


Expand Down Expand Up @@ -122,7 +127,8 @@ def __poll(self, query_id: str) -> AthenaQueryExecution:
]:
return query_execution
else:
logger.debug(f"Query state is: {query_execution.state}. Sleeping for {self._poll_interval}...")
if self.connection.cursor_kwargs.get("debug_query_state", False):
logger.debug(f"Query state is: {query_execution.state}. Sleeping for {self._poll_interval}...")
time.sleep(self._poll_interval)

def execute( # type: ignore
Expand Down Expand Up @@ -215,6 +221,7 @@ def open(cls, connection: Connection) -> Connection:
schema_name=creds.schema,
work_group=creds.work_group,
cursor_class=AthenaCursor,
cursor_kwargs={"debug_query_state": creds.debug_query_state},
formatter=AthenaParameterFormatter(),
poll_interval=creds.poll_interval,
session=get_boto3_session(connection),
Expand Down
14 changes: 14 additions & 0 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "timestamp"

@available
def add_lf_tags_to_database(self, relation: AthenaRelation) -> None:
conn = self.connections.get_thread_connection()
client = conn.handle
if lf_tags := conn.credentials.lf_tags_database:
config = LfTagsConfig(enabled=True, tags=lf_tags)
with boto3_client_lock:
lf_client = client.session.client("lakeformation", client.region_name, config=get_boto3_config())
manager = LfTagsManager(lf_client, relation, config)
manager.process_lf_tags_database()
else:
LOGGER.debug(f"Lakeformation is disabled for {relation}")

@available
def add_lf_tags(self, relation: AthenaRelation, lf_tags_config: Dict[str, Any]) -> None:
config = LfTagsConfig(**lf_tags_config)
Expand Down Expand Up @@ -528,6 +541,7 @@ def _get_data_catalog(self, database: str) -> Optional[DataCatalogTypeDef]:
return athena.get_data_catalog(Name=database)["DataCatalog"]
return None

@available
def list_relations_without_caching(self, schema_relation: AthenaRelation) -> List[BaseRelation]:
data_catalog = self._get_data_catalog(schema_relation.database)
catalog_id = get_catalog_id(data_catalog)
Expand Down
31 changes: 20 additions & 11 deletions dbt/adapters/athena/lakeformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def __init__(self, lf_client: LakeFormationClient, relation: AthenaRelation, lf_
self.lf_tags = lf_tags_config.tags
self.lf_tags_columns = lf_tags_config.tags_columns

def process_lf_tags_database(self) -> None:
if self.lf_tags:
database_resource = {"Database": {"Name": self.database}}
response = self.lf_client.add_lf_tags_to_resource(
Resource=database_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()]
)
self._parse_and_log_lf_response(response, None, self.lf_tags)

def process_lf_tags(self) -> None:
table_resource = {"Table": {"DatabaseName": self.database, "Name": self.table}}
existing_lf_tags = self.lf_client.get_resource_lf_tags(Resource=table_resource)
Expand Down Expand Up @@ -65,7 +73,7 @@ def _remove_lf_tags_columns(self, existing_lf_tags: GetResourceLFTagsResponseTyp
response = self.lf_client.remove_lf_tags_from_resource(
Resource=resource, LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}]
)
logger.debug(self._parse_lf_response(response, columns, {tag_key: tag_value}, "remove"))
self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}, "remove")

def _apply_lf_tags_table(
self, table_resource: ResourceTypeDef, existing_lf_tags: GetResourceLFTagsResponseTypeDef
Expand All @@ -84,13 +92,13 @@ def _apply_lf_tags_table(
response = self.lf_client.remove_lf_tags_from_resource(
Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": v} for k, v in to_remove.items()]
)
logger.debug(self._parse_lf_response(response, None, self.lf_tags, "remove"))
self._parse_and_log_lf_response(response, None, self.lf_tags, "remove")

if self.lf_tags:
response = self.lf_client.add_lf_tags_to_resource(
Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()]
)
logger.debug(self._parse_lf_response(response, None, self.lf_tags))
self._parse_and_log_lf_response(response, None, self.lf_tags)

def _apply_lf_tags_columns(self) -> None:
if self.lf_tags_columns:
Expand All @@ -103,25 +111,26 @@ def _apply_lf_tags_columns(self) -> None:
Resource=resource,
LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}],
)
logger.debug(self._parse_lf_response(response, columns, {tag_key: tag_value}))
self._parse_and_log_lf_response(response, columns, {tag_key: tag_value})

def _parse_lf_response(
def _parse_and_log_lf_response(
self,
response: Union[AddLFTagsToResourceResponseTypeDef, RemoveLFTagsFromResourceResponseTypeDef],
columns: Optional[List[str]] = None,
lf_tags: Optional[Dict[str, str]] = None,
verb: str = "add",
) -> str:
failures = response.get("Failures", [])
) -> None:
table_appendix = f".{self.table}" if self.table else ""
columns_appendix = f" for columns {columns}" if columns else ""
if failures:
base_msg = f"Failed to {verb} LF tags: {lf_tags} to {self.database}.{self.table}" + columns_appendix
resource_msg = self.database + table_appendix + columns_appendix
if failures := response.get("Failures", []):
base_msg = f"Failed to {verb} LF tags: {lf_tags} to " + resource_msg
for failure in failures:
tag = failure.get("LFTag", {}).get("TagKey")
error = failure.get("Error", {}).get("ErrorMessage")
logger.error(f"Failed to {verb} {tag} for {self.database}.{self.table}" + f" - {error}")
logger.error(f"Failed to {verb} {tag} for " + resource_msg + f" - {error}")
raise DbtRuntimeError(base_msg)
return f"Success: {verb} LF tags: {lf_tags} to {self.database}.{self.table}" + columns_appendix
logger.debug(f"Success: {verb} LF tags {lf_tags} to " + resource_msg)


class FilterConfig(BaseModel):
Expand Down
3 changes: 3 additions & 0 deletions dbt/include/athena/macros/adapters/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
{%- call statement('create_schema') -%}
create schema if not exists {{ relation.without_identifier().render_hive() }}
{% endcall %}

{{ adapter.add_lf_tags_to_database(relation) }}

{% endmacro %}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
{%- set single_partition = [] -%}
{%- for col in row -%}


{%- set column_type = adapter.convert_type(table, loop.index0) -%}
{%- if column_type == 'integer' -%}
{%- set comp_func = '=' -%}
{%- if col is none -%}
{%- set value = 'null' -%}
{%- set comp_func = ' is ' -%}
{%- elif column_type == 'integer' -%}
{%- set value = col | string -%}
{%- elif column_type == 'string' -%}
{%- set value = "'" + col + "'" -%}
Expand All @@ -31,7 +36,7 @@
{%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%}
{%- endif -%}
{%- set partition_key = adapter.format_one_partition_key(partitioned_by[loop.index0]) -%}
{%- do single_partition.append(partition_key + '=' + value) -%}
{%- do single_partition.append(partition_key + comp_func + value) -%}
{%- endfor -%}

{%- set single_partition_expression = single_partition | join(' and ') -%}
Expand Down
4 changes: 2 additions & 2 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ dbt-tests-adapter~=1.6.1
flake8~=6.1
Flake8-pyproject~=1.2
isort~=5.11
moto~=4.2.0
pre-commit~=2.21
moto~=4.2.2
pre-commit~=3.4
pyparsing~=3.1.1
pytest~=7.4
pytest-cov~=4.1
Expand Down
139 changes: 139 additions & 0 deletions tests/functional/adapter/test_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,67 @@
cross join unnest(date_array) as t2(date_column)
"""

test_single_nullable_partition_model_sql = """
with data as (
select
random() as col_1,
row_number() over() as id
from
unnest(sequence(1, 200))
)
select
col_1, id
from data
union all
select random() as col_1, NULL as id
union all
select random() as col_1, NULL as id
"""

test_nullable_partitions_model_sql = """
{{ config(
materialized='table',
format='parquet',
s3_data_naming='table',
partitioned_by=['id', 'date_column']
) }}
with data as (
select
random() as rnd,
row_number() over() as id,
cast(date_column as date) as date_column
from (
values (
sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day)
)
) as t1(date_array)
cross join unnest(date_array) as t2(date_column)
)
select
rnd,
case when id <= 50 then null else id end as id,
date_column
from data
union all
select
random() as rnd,
NULL as id,
NULL as date_column
union all
select
random() as rnd,
NULL as id,
cast('2023-09-02' as date) as date_column
union all
select
random() as rnd,
40 as id,
NULL as date_column
"""


class TestHiveTablePartitions:
@pytest.fixture(scope="class")
Expand Down Expand Up @@ -125,3 +186,81 @@ def test__check_incremental_run_with_partitions(self, project):
incremental_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert incremental_records_count == 212


class TestHiveNullValuedPartitions:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+table_type": "hive",
"+materialized": "table",
"+partitioned_by": ["id", "date_column"],
}
}

@pytest.fixture(scope="class")
def models(self):
return {
"test_nullable_partitions_model.sql": test_nullable_partitions_model_sql,
}

def test__check_run_with_partitions(self, project):
relation_name = "test_nullable_partitions_model"
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"
model_run_result_null_id_count_query = (
f"select count(*) as records from {project.test_schema}.{relation_name} where id is null"
)
model_run_result_null_date_count_query = (
f"select count(*) as records from {project.test_schema}.{relation_name} where date_column is null"
)

first_model_run = run_dbt(["run", "--select", relation_name])
first_model_run_result = first_model_run.results[0]

# check that the model run successfully
assert first_model_run_result.status == RunStatus.Success

records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert records_count_first_run == 215

null_id_count_first_run = project.run_sql(model_run_result_null_id_count_query, fetch="all")[0][0]

assert null_id_count_first_run == 52

null_date_count_first_run = project.run_sql(model_run_result_null_date_count_query, fetch="all")[0][0]

assert null_date_count_first_run == 2


class TestHiveSingleNullValuedPartition:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+table_type": "hive",
"+materialized": "table",
"+partitioned_by": ["id"],
}
}

@pytest.fixture(scope="class")
def models(self):
return {
"test_single_nullable_partition_model.sql": test_single_nullable_partition_model_sql,
}

def test__check_run_with_partitions(self, project):
relation_name = "test_single_nullable_partition_model"
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"

first_model_run = run_dbt(["run", "--select", relation_name])
first_model_run_result = first_model_run.results[0]

# check that the model run successfully
assert first_model_run_result.status == RunStatus.Success

records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert records_count_first_run == 202
Loading

0 comments on commit df389db

Please sign in to comment.