Skip to content

Commit

Permalink
feat(ingest/snowflake): support email_as_user_identifier for queries v2
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate committed Dec 24, 2024
1 parent 73dce9e commit a93fcf4
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,20 @@ class SnowflakeIdentifierConfig(
description="Whether to convert dataset urns to lowercase.",
)


class SnowflakeUsageConfig(BaseUsageConfig):
email_domain: Optional[str] = pydantic.Field(
default=None,
description="Email domain of your organization so users can be displayed on UI appropriately.",
)

email_as_user_identifier: bool = Field(
default=True,
description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is "
"provided, generates email addresses for snowflake users with unset emails, based on their "
"username.",
)


class SnowflakeUsageConfig(BaseUsageConfig):
apply_view_usage_to_tables: bool = pydantic.Field(
default=False,
description="Whether to apply view's usage to its base tables. If set to True, usage is applied to base tables only.",
Expand Down Expand Up @@ -285,13 +293,6 @@ class SnowflakeV2Config(
" Map of share name -> details of share.",
)

email_as_user_identifier: bool = Field(
default=True,
description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is "
"provided, generates email addresses for snowflake users with unset emails, based on their "
"username.",
)

include_assertion_results: bool = Field(
default=False,
description="Whether to ingest assertion run results for assertions created using Datahub"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ class SnowflakeQueriesSourceConfig(
class SnowflakeQueriesExtractorReport(Report):
copy_history_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
query_log_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
users_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)

audit_log_load_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
sql_aggregator: Optional[SqlAggregatorReport] = None

num_ddl_queries_dropped: int = 0
num_users: int = 0


@dataclass
Expand Down Expand Up @@ -225,6 +227,9 @@ def is_allowed_table(self, name: str) -> bool:
def get_workunits_internal(
self,
) -> Iterable[MetadataWorkUnit]:
with self.report.users_fetch_timer:
users = self.fetch_users()

Check warning on line 231 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py#L230-L231

Added lines #L230 - L231 were not covered by tests

# TODO: Add some logic to check if the cached audit log is stale or not.
audit_log_file = self.local_temp_path / "audit_log.sqlite"
use_cached_audit_log = audit_log_file.exists()
Expand All @@ -248,7 +253,7 @@ def get_workunits_internal(
queries.append(entry)

with self.report.query_log_fetch_timer:
for entry in self.fetch_query_log():
for entry in self.fetch_query_log(users):

Check warning on line 256 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py#L256

Added line #L256 was not covered by tests
queries.append(entry)

with self.report.audit_log_load_timer:
Expand All @@ -263,6 +268,25 @@ def get_workunits_internal(
shared_connection.close()
audit_log_file.unlink(missing_ok=True)

def fetch_users(self) -> Dict[str, str]:
users: Dict[str, str] = dict()
with self.structured_reporter.report_exc("Error fetching users from Snowflake"):
logger.info("Fetching users from Snowflake")
query = SnowflakeQuery.get_all_users()
resp = self.connection.query(query)

Check warning on line 276 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py#L272-L276

Added lines #L272 - L276 were not covered by tests

for row in resp:
try:
users[row["NAME"]] = row["EMAIL"]
self.report.num_users += 1
except Exception as e:
self.structured_reporter.warning(

Check warning on line 283 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py#L278-L283

Added lines #L278 - L283 were not covered by tests
"Error parsing user row",
context=f"{row}",
exc=e,
)
return users

Check warning on line 288 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py#L288

Added line #L288 was not covered by tests

def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:
# Derived from _populate_external_lineage_from_copy_history.

Expand Down Expand Up @@ -298,7 +322,7 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:
yield result

def fetch_query_log(
self,
self, users: Dict[str, str]
) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap]]:
query_log_query = _build_enriched_query_log_query(
start_time=self.config.window.start_time,
Expand All @@ -319,7 +343,7 @@ def fetch_query_log(

assert isinstance(row, dict)
try:
entry = self._parse_audit_log_row(row)
entry = self._parse_audit_log_row(row, users)

Check warning on line 346 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py#L346

Added line #L346 was not covered by tests
except Exception as e:
self.structured_reporter.warning(
"Error parsing query log row",
Expand All @@ -331,7 +355,7 @@ def fetch_query_log(
yield entry

def _parse_audit_log_row(
self, row: Dict[str, Any]
self, row: Dict[str, Any], users: Dict[str, str]
) -> Optional[Union[TableRename, TableSwap, PreparsedQuery]]:
json_fields = {
"DIRECT_OBJECTS_ACCESSED",
Expand Down Expand Up @@ -430,9 +454,11 @@ def _parse_audit_log_row(
)
)

# TODO: Fetch email addresses from Snowflake to map user -> email
# TODO: Support email_domain fallback for generating user urns.
user = CorpUserUrn(self.identifiers.snowflake_identifier(res["user_name"]))
user = CorpUserUrn(

Check warning on line 457 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py#L457

Added line #L457 was not covered by tests
self.identifiers.get_user_identifier(
res["user_name"], users.get(res["user_name"])
)
)

timestamp: datetime = res["query_start_time"]
timestamp = timestamp.astimezone(timezone.utc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -956,4 +956,8 @@ def dmf_assertion_results(start_time_millis: int, end_time_millis: int) -> str:
AND METRIC_NAME ilike '{pattern}' escape '{escape_pattern}'
ORDER BY MEASUREMENT_TIME ASC;
"""
"""

@staticmethod
def get_all_users() -> str:
return """SELECT name as "NAME", email as "EMAIL" FROM SNOWFLAKE.ACCOUNT_USAGE.USERS"""

Check warning on line 963 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py#L963

Added line #L963 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,9 @@ def _map_user_counts(
filtered_user_counts.append(
DatasetUserUsageCounts(
user=make_user_urn(
self.get_user_identifier(
self.identifiers.get_user_identifier(
user_count["user_name"],
user_email,
self.config.email_as_user_identifier,
)
),
count=user_count["total"],
Expand Down Expand Up @@ -453,9 +452,7 @@ def _get_operation_aspect_work_unit(
reported_time: int = int(time.time() * 1000)
last_updated_timestamp: int = int(start_time.timestamp() * 1000)
user_urn = make_user_urn(
self.get_user_identifier(
user_name, user_email, self.config.email_as_user_identifier
)
self.identifiers.get_user_identifier(user_name, user_email)
)

# NOTE: In earlier `snowflake-usage` connector this was base_objects_accessed, which is incorrect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,28 @@ def get_quoted_identifier_for_schema(db_name, schema_name):
def get_quoted_identifier_for_table(db_name, schema_name, table_name):
return f'"{db_name}"."{schema_name}"."{table_name}"'

# Note - decide how to construct user urns.
# Historically urns were created using part before @ from user's email.
# Users without email were skipped from both user entries as well as aggregates.
# However email is not mandatory field in snowflake user, user_name is always present.
def get_user_identifier(
self,
user_name: str,
user_email: Optional[str],
) -> str:
if user_email:
return self.snowflake_identifier(
user_email
if self.identifier_config.email_as_user_identifier is True
else user_email.split("@")[0]
)
return self.snowflake_identifier(
f"{user_name}@{self.identifier_config.email_domain}"
if self.identifier_config.email_as_user_identifier is True
and self.identifier_config.email_domain is not None
else user_name
)


class SnowflakeCommonMixin(SnowflakeStructuredReportMixin):
platform = "snowflake"
Expand All @@ -315,24 +337,6 @@ def structured_reporter(self) -> SourceReport:
def identifiers(self) -> SnowflakeIdentifierBuilder:
return SnowflakeIdentifierBuilder(self.config, self.report)

# Note - decide how to construct user urns.
# Historically urns were created using part before @ from user's email.
# Users without email were skipped from both user entries as well as aggregates.
# However email is not mandatory field in snowflake user, user_name is always present.
def get_user_identifier(
self,
user_name: str,
user_email: Optional[str],
email_as_user_identifier: bool,
) -> str:
if user_email:
return self.identifiers.snowflake_identifier(
user_email
if email_as_user_identifier is True
else user_email.split("@")[0]
)
return self.identifiers.snowflake_identifier(user_name)

# TODO: Revisit this after stateful ingestion can commit checkpoint
# for failures that do not affect the checkpoint
# TODO: Add additional parameters to match the signature of the .warning and .failure methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,49 @@ def test_source_close_cleans_tmp(snowflake_connect, tmp_path):
# This closes QueriesExtractor which in turn closes SqlParsingAggregator
source.close()
assert len(os.listdir(tmp_path)) == 0


@patch("snowflake.connector.connect")
def test_user_identifiers_email_as_identifier(snowflake_connect, tmp_path):
with patch("tempfile.tempdir", str(tmp_path)):
source = SnowflakeQueriesSource.create(
{
"connection": {
"account_id": "ABC12345.ap-south-1.aws",
"username": "TST_USR",
"password": "TST_PWD",
},
"email_as_user_identifier": True,
"email_domain": "example.com",
},
PipelineContext("run-id"),
)
assert (
source.identifiers.get_user_identifier("username", "[email protected]")
== "[email protected]"
)
assert (
source.identifiers.get_user_identifier("username", None)
== "[email protected]"
)


@patch("snowflake.connector.connect")
def test_user_identifiers_username_as_identifier(snowflake_connect, tmp_path):
with patch("tempfile.tempdir", str(tmp_path)):
source = SnowflakeQueriesSource.create(
{
"connection": {
"account_id": "ABC12345.ap-south-1.aws",
"username": "TST_USR",
"password": "TST_PWD",
},
"email_as_user_identifier": False,
},
PipelineContext("run-id"),
)
assert (
source.identifiers.get_user_identifier("username", "[email protected]")
== "username"
)
assert source.identifiers.get_user_identifier("username", None) == "username"

0 comments on commit a93fcf4

Please sign in to comment.