-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dx 32 nested types #17
base: master
Are you sure you want to change the base?
Changes from all commits
c6d2ebe
5d3deb1
dd0dafe
264210c
7739396
524af43
4485038
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ def above_threshold(self): | |
"database": "table_schema", | ||
"table": "table_name", | ||
"column": "column_name", | ||
"type": "data_type", | ||
"rule_name": "tag_name", | ||
} | ||
) | ||
|
@@ -77,7 +78,7 @@ def aggregate_updates(pdf): | |
|
||
return pd.DataFrame(output) | ||
|
||
self.classification_result = pd.concat([classification_result, current_tags]).groupby(["table_catalog", "table_schema", "table_name", "column_name"], dropna=False, group_keys=True).apply(aggregate_updates).reset_index().drop(columns=["level_4"]) | ||
self.classification_result = pd.concat([classification_result, current_tags]).groupby(["table_catalog", "table_schema", "table_name", "column_name", "data_type"], dropna=False, group_keys=True).apply(aggregate_updates).reset_index().drop(columns=["level_5"]) | ||
# when testing we don't have a 3-level namespace but we need | ||
# to make sure we get None instead of NaN | ||
self.classification_result.table_catalog = self.classification_result.table_catalog.astype(object) | ||
|
@@ -94,7 +95,7 @@ def _get_classification_table_from_delta(self): | |
self.spark.sql(f"CREATE DATABASE IF NOT EXISTS {catalog + '.' + schema}") | ||
self.spark.sql( | ||
f""" | ||
CREATE TABLE IF NOT EXISTS {self.classification_table_name} (table_catalog string, table_schema string, table_name string, column_name string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp) | ||
CREATE TABLE IF NOT EXISTS {self.classification_table_name} (table_catalog string, table_schema string, table_name string, column_name string, data_type string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp) | ||
""" | ||
) | ||
logger.friendly(f"The classification table {self.classification_table_name} has been created.") | ||
|
@@ -154,8 +155,7 @@ def _stage_updates(self, input_classification_pdf: pd.DataFrame): | |
classification_pdf["to_be_set"] = classification_pdf.apply(lambda x: list(set(x["Tags to be published"]) - set(x["Current Tags"])), axis=1) | ||
classification_pdf["to_be_kept"] = classification_pdf.apply(lambda x: list(set(x["Tags to be published"]) & set(x["Current Tags"])), axis=1) | ||
|
||
self.staged_updates = pd.melt(classification_pdf, id_vars=["table_catalog", "table_schema", "table_name", "column_name"], value_vars=["to_be_unset", "to_be_set", "to_be_kept"], var_name="action", value_name="tag_name").explode("tag_name").dropna(subset=["tag_name"]).reset_index(drop=True) | ||
|
||
self.staged_updates = pd.melt(classification_pdf, id_vars=["table_catalog", "table_schema", "table_name", "column_name", "data_type"], value_vars=["to_be_unset", "to_be_set", "to_be_kept"], var_name="action", value_name="tag_name").explode("tag_name").dropna(subset=["tag_name"]).reset_index(drop=True) | ||
|
||
def inspect(self): | ||
self.inspection_tool = InspectionTool(self.classification_result, self.publish) | ||
|
@@ -169,14 +169,14 @@ def publish(self, publish_uc_tags: bool): | |
|
||
staged_updates_df = self.spark.createDataFrame( | ||
self.staged_updates, | ||
"table_catalog: string, table_schema: string, table_name: string, column_name: string, action: string, tag_name: string", | ||
"table_catalog: string, table_schema: string, table_name: string, column_name: string, data_type: string, action: string, tag_name: string", | ||
).withColumn("effective_timestamp", func.current_timestamp()) | ||
# merge using scd-typ2 | ||
logger.friendly(f"Update classification table {self.classification_table_name}") | ||
|
||
self.classification_table.alias("target").merge( | ||
staged_updates_df.alias("source"), | ||
"target.table_catalog <=> source.table_catalog AND target.table_schema = source.table_schema AND target.table_name = source.table_name AND target.column_name = source.column_name AND target.tag_name = source.tag_name AND target.current = true", | ||
"target.table_catalog <=> source.table_catalog AND target.table_schema = source.table_schema AND target.table_name = source.table_name AND target.column_name = source.column_name AND target.data_type = source.data_type AND target.tag_name = source.tag_name AND target.current = true", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to match on data_type? |
||
).whenMatchedUpdate( | ||
condition = "source.action = 'to_be_unset'", | ||
set={"current": "false", "end_timestamp": "source.effective_timestamp"} | ||
|
@@ -186,6 +186,7 @@ def publish(self, publish_uc_tags: bool): | |
"table_schema": "source.table_schema", | ||
"table_name": "source.table_name", | ||
"column_name": "source.column_name", | ||
"data_type": "source.data_type", | ||
"tag_name": "source.tag_name", | ||
"effective_timestamp": "source.effective_timestamp", | ||
"current": "true", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,6 +66,15 @@ def compile_msql(self, table_info: TableInfo) -> list[SQLRow]: | |
temp_sql = msql | ||
for tagged_col in tagged_cols: | ||
temp_sql = temp_sql.replace(f"[{tagged_col.tag}]", tagged_col.name) | ||
# TODO: Can we avoid "replacing strings" for the different types in the future? This is due to the generation of MSQL. Maybe we should rather generate SQL directly from the search method... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we should do that instead. |
||
if tagged_col.data_type == "array": | ||
# return a string of the array as value to be able to union later | ||
temp_sql = re.sub("(.*\'value\', )([^)]+)(\).*)", f"\g<1> array_join({tagged_col.name}, ', ') \g<3>", temp_sql) | ||
# modify the WHERE condition to work with arrays | ||
split_cond_sql = temp_sql.split("WHERE") | ||
if len(split_cond_sql) > 1: | ||
temp_sql = split_cond_sql[0] + "WHERE " + f"array_contains({tagged_col.name},{split_cond_sql[1].split('=')[1]})" | ||
|
||
sql_statements.append(SQLRow(table_info.catalog, table_info.database, table_info.table, temp_sql)) | ||
|
||
return sql_statements | ||
|
@@ -77,9 +86,9 @@ def build(self, classified_result_pdf) -> list[SQLRow]: | |
|
||
classified_cols = classified_result_pdf.copy() | ||
classified_cols = classified_cols[classified_cols['tag_name'].isin(self.tags)] | ||
classified_cols = classified_cols.groupby(['catalog', 'database', 'table', 'column']).aggregate(lambda x: list(x))[['tag_name']].reset_index() | ||
classified_cols = classified_cols.groupby(['catalog', 'database', 'table', 'column', 'data_type']).aggregate(lambda x: list(x))[['tag_name']].reset_index() | ||
|
||
classified_cols['col_tags'] = classified_cols[['column', 'tag_name']].apply(tuple, axis=1) | ||
classified_cols['col_tags'] = classified_cols[['column', 'data_type', 'tag_name']].apply(tuple, axis=1) | ||
df = classified_cols.groupby(['catalog', 'database', 'table']).aggregate(lambda x: list(x))[['col_tags']].reset_index() | ||
|
||
# Filter tables by matching filter | ||
|
@@ -91,9 +100,9 @@ def build(self, classified_result_pdf) -> list[SQLRow]: | |
[ | ||
ColumnInfo( | ||
col[0], # col name | ||
"", # TODO | ||
col[1], # data type | ||
None, # TODO | ||
col[1] # Tags | ||
col[2] # Tags | ||
) for col in row[3] | ||
] | ||
) for _, row in df.iterrows() if fnmatch(row[0], self.catalogs) and fnmatch(row[1], self.databases) and fnmatch(row[2], self.tables)] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
from dataclasses import dataclass | ||
import pandas as pd | ||
from pyspark.sql import SparkSession | ||
from pyspark.sql.types import * | ||
from pyspark.sql.types import _parse_datatype_string | ||
from pyspark.sql.utils import ParseException | ||
from typing import Optional, List, Set | ||
|
||
from discoverx.common.helper import strip_margin, format_regex | ||
|
@@ -26,12 +29,13 @@ class TableInfo: | |
columns: list[ColumnInfo] | ||
|
||
def get_columns_by_tag(self, tag: str): | ||
return [TaggedColumn(col.name, tag) for col in self.columns if tag in col.tags] | ||
return [TaggedColumn(col.name, col.data_type, tag) for col in self.columns if tag in col.tags] | ||
|
||
|
||
@dataclass | ||
class TaggedColumn: | ||
name: str | ||
data_type: str | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also need the |
||
tag: str | ||
|
||
|
||
|
@@ -60,14 +64,13 @@ class ScanResult: | |
|
||
@property | ||
def n_scanned_columns(self) -> int: | ||
return len( | ||
self.df[["catalog", "database", "table", "column"]].drop_duplicates() | ||
) | ||
return len(self.df[["catalog", "database", "table", "column"]].drop_duplicates()) | ||
|
||
|
||
class Scanner: | ||
|
||
COLUMNS_TABLE_NAME = "system.information_schema.columns" | ||
COMPLEX_TYPES = {StructType, ArrayType} | ||
|
||
def __init__( | ||
self, | ||
|
@@ -103,9 +106,7 @@ def _get_list_of_tables(self) -> List[TableInfo]: | |
row["table_schema"], | ||
row["table_name"], | ||
[ | ||
ColumnInfo( | ||
col["column_name"], col["data_type"], col["partition_index"], [] | ||
) | ||
ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], []) | ||
for col in row["table_columns"] | ||
], | ||
) | ||
|
@@ -124,9 +125,7 @@ def _get_table_list_sql(self): | |
|
||
catalog_sql = f"""AND regexp_like(table_catalog, "^{self.catalogs.replace("*", ".*")}$")""" | ||
database_sql = f"""AND regexp_like(table_schema, "^{self.databases.replace("*", ".*")}$")""" | ||
table_sql = ( | ||
f"""AND regexp_like(table_name, "^{self.tables.replace("*", ".*")}$")""" | ||
) | ||
table_sql = f"""AND regexp_like(table_name, "^{self.tables.replace("*", ".*")}$")""" | ||
|
||
sql = f""" | ||
SELECT | ||
|
@@ -154,9 +153,7 @@ def _resolve_scan_content(self) -> ScanContent: | |
|
||
def scan(self): | ||
|
||
logger.friendly( | ||
"""Ok, I'm going to scan your lakehouse for data that matches your rules.""" | ||
) | ||
logger.friendly("""Ok, I'm going to scan your lakehouse for data that matches your rules.""") | ||
text = f""" | ||
This is what you asked for: | ||
|
||
|
@@ -196,18 +193,50 @@ def scan(self): | |
# Execute SQL and append result | ||
dfs.append(self.spark.sql(sql).toPandas()) | ||
except Exception as e: | ||
logger.error( | ||
f"Error while scanning table '{table.catalog}.{table.database}.{table.table}': {e}" | ||
) | ||
logger.error(f"Error while scanning table '{table.catalog}.{table.database}.{table.table}': {e}") | ||
continue | ||
|
||
logger.debug("Finished lakehouse scanning task") | ||
|
||
if dfs: | ||
self.scan_result = ScanResult(df=pd.concat(dfs)) | ||
self.scan_result = ScanResult(df=pd.concat(dfs).reset_index(drop=True)) | ||
else: | ||
self.scan_result = ScanResult(df=pd.DataFrame()) | ||
|
||
@staticmethod | ||
def backtick_col_name(col_name: str) -> str: | ||
col_name_splitted = col_name.split(".") | ||
return ".".join(["`" + col + "`" for col in col_name_splitted]) | ||
|
||
def recursive_flatten_complex_type(self, col_name, schema, column_list): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, wow! This was more complex than I expected. |
||
if type(schema) in self.COMPLEX_TYPES: | ||
iterable = schema | ||
elif type(schema) is StructField: | ||
iterable = schema.dataType | ||
elif schema == StringType(): | ||
column_list.append({"col_name": col_name, "type": "string"}) | ||
return column_list | ||
else: | ||
return column_list | ||
|
||
if type(iterable) is StructType: | ||
for field in iterable: | ||
if type(field.dataType) == StringType: | ||
column_list.append( | ||
{"col_name": self.backtick_col_name(col_name + "." + field.name), "type": "string"} | ||
) | ||
elif type(field.dataType) in self.COMPLEX_TYPES: | ||
column_list = self.recursive_flatten_complex_type(col_name + "." + field.name, field, column_list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should be appending to column_list instead of replacing column_list. Otherwise you overwrite previously appended string types |
||
elif type(iterable) is MapType: | ||
if type(iterable.valueType) not in self.COMPLEX_TYPES: | ||
column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_values"}) | ||
if type(iterable.keyType) not in self.COMPLEX_TYPES: | ||
column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_keys"}) | ||
elif type(iterable) is ArrayType: | ||
column_list.append({"col_name": self.backtick_col_name(col_name), "type": "array"}) | ||
|
||
return column_list | ||
|
||
def _rule_matching_sql(self, table_info: TableInfo): | ||
""" | ||
Given a table and a set of rules this method will return a | ||
|
@@ -223,53 +252,116 @@ def _rule_matching_sql(self, table_info: TableInfo): | |
""" | ||
|
||
expressions = [r for r in self.rule_list if r.type == RuleTypes.REGEX] | ||
cols = [c for c in table_info.columns if c.data_type.lower() == "string"] | ||
expr_pdf = pd.DataFrame([{"rule_name": r.name, "rule_definition": r.definition, "key": 0} for r in expressions]) | ||
column_list = [] | ||
for col in table_info.columns: | ||
try: | ||
data_type = _parse_datatype_string(col.data_type) | ||
except ParseException: | ||
data_type = None | ||
|
||
if not cols: | ||
raise Exception( | ||
f"There are no columns of type string to be scanned in {table_info.table}" | ||
) | ||
if data_type: | ||
self.recursive_flatten_complex_type(col.name, data_type, column_list) | ||
|
||
if len(column_list) == 0: | ||
raise Exception(f"There are no columns with supported types to be scanned in {table_info.table}") | ||
|
||
if not expressions: | ||
raise Exception(f"There are no rules to scan for.") | ||
|
||
string_cols = [col for col in column_list if col["type"] == "string"] | ||
|
||
sql_list = [] | ||
if len(string_cols) > 0: | ||
sql_list.append(self.string_col_sql(string_cols, expressions, table_info)) | ||
|
||
array_cols = [col for col in column_list if col["type"] == "array"] | ||
if len(array_cols) > 0: | ||
sql_list.append(self.array_col_sql(array_cols, expressions, table_info)) | ||
|
||
all_sql = "\nUNION ALL \n".join(sql_list) | ||
return all_sql | ||
|
||
def string_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> str: | ||
catalog_str = f"{table_info.catalog}." if table_info.catalog else "" | ||
matching_columns = [ | ||
f"INT(regexp_like(value, '{format_regex(r.definition)}')) AS `{r.name}`" | ||
for r in expressions | ||
f"INT(regexp_like(value, '{format_regex(r.definition)}')) AS `{r.name}`" for r in expressions | ||
] | ||
matching_string = ",\n ".join(matching_columns) | ||
|
||
unpivot_expressions = ", ".join( | ||
[f"'{r.name}', `{r.name}`" for r in expressions] | ||
) | ||
unpivot_columns = ", ".join([f"'{c.name}', `{c.name}`" for c in cols]) | ||
unpivot_expressions = ", ".join([f"'{r.name}', `{r.name}`" for r in expressions]) | ||
unpivot_columns = ", ".join([f"'{c['col_name']}', '{c['type']}', {c['col_name']}" for c in cols]) | ||
|
||
sql = f""" | ||
SELECT | ||
'{table_info.catalog}' as catalog, | ||
'{table_info.database}' as database, | ||
'{table_info.table}' as table, | ||
column, | ||
rule_name, | ||
(sum(value) / count(value)) as frequency | ||
FROM | ||
( | ||
SELECT column, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value) | ||
FROM | ||
( | ||
SELECT | ||
column, | ||
{matching_string} | ||
FROM ( | ||
SELECT | ||
stack({len(cols)}, {unpivot_columns}) AS (column, value) | ||
FROM {catalog_str}{table_info.database}.{table_info.table} | ||
TABLESAMPLE ({self.sample_size} ROWS) | ||
SELECT | ||
'{table_info.catalog}' as catalog, | ||
'{table_info.database}' as database, | ||
'{table_info.table}' as table, | ||
column, | ||
type, | ||
rule_name, | ||
(sum(value) / count(value)) as frequency | ||
FROM | ||
( | ||
SELECT column, type, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value) | ||
FROM | ||
( | ||
SELECT | ||
column, | ||
type, | ||
{matching_string} | ||
FROM ( | ||
SELECT | ||
stack({len(cols)}, {unpivot_columns}) AS (column, type, value) | ||
FROM {catalog_str}{table_info.database}.{table_info.table} | ||
TABLESAMPLE ({self.sample_size} ROWS) | ||
) | ||
) | ||
) | ||
) | ||
) | ||
GROUP BY catalog, database, table, column, rule_name | ||
""" | ||
GROUP BY catalog, database, table, column, type, rule_name | ||
""" | ||
return strip_margin(sql) | ||
|
||
def array_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> str: | ||
catalog_str = f"{table_info.catalog}." if table_info.catalog else "" | ||
matching_columns_sum = [ | ||
f"size(filter(value, x -> x rlike '{r.definition}')) AS `{r.name}_sum`" for r in expressions | ||
] | ||
matching_columns_count = [ | ||
f"size(value) AS `{r.name}_count`" for r in expressions | ||
] | ||
matching_columns = matching_columns_sum + matching_columns_count | ||
matching_string = ",\n ".join(matching_columns) | ||
|
||
unpivot_expressions = ", ".join([f"'{r.name}', `{r.name}_sum`, `{r.name}_count`" for r in expressions]) | ||
unpivot_columns = ", ".join([f"'{c['col_name']}', '{c['type']}', {c['col_name']}" for c in cols]) | ||
|
||
sql = f""" | ||
SELECT | ||
'{table_info.catalog}' as catalog, | ||
'{table_info.database}' as database, | ||
'{table_info.table}' as table, | ||
column, | ||
type, | ||
rule_name, | ||
(sum(value_sum) / sum(value_count)) as frequency | ||
FROM | ||
( | ||
SELECT column, type, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value_sum, value_count) | ||
FROM | ||
( | ||
SELECT | ||
column, | ||
type, | ||
{matching_string} | ||
FROM ( | ||
SELECT | ||
stack({len(cols)}, {unpivot_columns}) AS (column, type, value) | ||
FROM {catalog_str}{table_info.database}.{table_info.table} | ||
TABLESAMPLE ({self.sample_size} ROWS) | ||
) | ||
) | ||
) | ||
GROUP BY catalog, database, table, column, type, rule_name | ||
""" | ||
return strip_margin(sql) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed the columns from the source SQL query, so you don't need this rename any more