Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dx 32 nested types #17

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions discoverx/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def above_threshold(self):
"database": "table_schema",
"table": "table_name",
"column": "column_name",
"type": "data_type",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the columns from the source SQL query, so you don't need this rename any more

"rule_name": "tag_name",
}
)
Expand Down Expand Up @@ -77,7 +78,7 @@ def aggregate_updates(pdf):

return pd.DataFrame(output)

self.classification_result = pd.concat([classification_result, current_tags]).groupby(["table_catalog", "table_schema", "table_name", "column_name"], dropna=False, group_keys=True).apply(aggregate_updates).reset_index().drop(columns=["level_4"])
self.classification_result = pd.concat([classification_result, current_tags]).groupby(["table_catalog", "table_schema", "table_name", "column_name", "data_type"], dropna=False, group_keys=True).apply(aggregate_updates).reset_index().drop(columns=["level_5"])
# when testing we don't have a 3-level namespace but we need
# to make sure we get None instead of NaN
self.classification_result.table_catalog = self.classification_result.table_catalog.astype(object)
Expand All @@ -94,7 +95,7 @@ def _get_classification_table_from_delta(self):
self.spark.sql(f"CREATE DATABASE IF NOT EXISTS {catalog + '.' + schema}")
self.spark.sql(
f"""
CREATE TABLE IF NOT EXISTS {self.classification_table_name} (table_catalog string, table_schema string, table_name string, column_name string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp)
CREATE TABLE IF NOT EXISTS {self.classification_table_name} (table_catalog string, table_schema string, table_name string, column_name string, data_type string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp)
"""
)
logger.friendly(f"The classification table {self.classification_table_name} has been created.")
Expand Down Expand Up @@ -154,8 +155,7 @@ def _stage_updates(self, input_classification_pdf: pd.DataFrame):
classification_pdf["to_be_set"] = classification_pdf.apply(lambda x: list(set(x["Tags to be published"]) - set(x["Current Tags"])), axis=1)
classification_pdf["to_be_kept"] = classification_pdf.apply(lambda x: list(set(x["Tags to be published"]) & set(x["Current Tags"])), axis=1)

self.staged_updates = pd.melt(classification_pdf, id_vars=["table_catalog", "table_schema", "table_name", "column_name"], value_vars=["to_be_unset", "to_be_set", "to_be_kept"], var_name="action", value_name="tag_name").explode("tag_name").dropna(subset=["tag_name"]).reset_index(drop=True)

self.staged_updates = pd.melt(classification_pdf, id_vars=["table_catalog", "table_schema", "table_name", "column_name", "data_type"], value_vars=["to_be_unset", "to_be_set", "to_be_kept"], var_name="action", value_name="tag_name").explode("tag_name").dropna(subset=["tag_name"]).reset_index(drop=True)

def inspect(self):
self.inspection_tool = InspectionTool(self.classification_result, self.publish)
Expand All @@ -169,14 +169,14 @@ def publish(self, publish_uc_tags: bool):

staged_updates_df = self.spark.createDataFrame(
self.staged_updates,
"table_catalog: string, table_schema: string, table_name: string, column_name: string, action: string, tag_name: string",
"table_catalog: string, table_schema: string, table_name: string, column_name: string, data_type: string, action: string, tag_name: string",
).withColumn("effective_timestamp", func.current_timestamp())
# merge using scd-typ2
logger.friendly(f"Update classification table {self.classification_table_name}")

self.classification_table.alias("target").merge(
staged_updates_df.alias("source"),
"target.table_catalog <=> source.table_catalog AND target.table_schema = source.table_schema AND target.table_name = source.table_name AND target.column_name = source.column_name AND target.tag_name = source.tag_name AND target.current = true",
"target.table_catalog <=> source.table_catalog AND target.table_schema = source.table_schema AND target.table_name = source.table_name AND target.column_name = source.column_name AND target.data_type = source.data_type AND target.tag_name = source.tag_name AND target.current = true",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to match on data_type?

).whenMatchedUpdate(
condition = "source.action = 'to_be_unset'",
set={"current": "false", "end_timestamp": "source.effective_timestamp"}
Expand All @@ -186,6 +186,7 @@ def publish(self, publish_uc_tags: bool):
"table_schema": "source.table_schema",
"table_name": "source.table_name",
"column_name": "source.column_name",
"data_type": "source.data_type",
"tag_name": "source.tag_name",
"effective_timestamp": "source.effective_timestamp",
"current": "true",
Expand Down
1 change: 1 addition & 0 deletions discoverx/dx.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def _msql(self, msql: str, what_if: bool = False):
func.col("table_schema").alias("database"),
func.col("table_name").alias("table"),
func.col("column_name").alias("column"),
"data_type",
"tag_name",
).toPandas()
)
Expand Down
17 changes: 13 additions & 4 deletions discoverx/msql.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def compile_msql(self, table_info: TableInfo) -> list[SQLRow]:
temp_sql = msql
for tagged_col in tagged_cols:
temp_sql = temp_sql.replace(f"[{tagged_col.tag}]", tagged_col.name)
# TODO: Can we avoid "replacing strings" for the different types in the future? This is due to the generation of MSQL. Maybe we should rather generate SQL directly from the search method...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should do that instead.

if tagged_col.data_type == "array":
# return a string of the array as value to be able to union later
temp_sql = re.sub("(.*\'value\', )([^)]+)(\).*)", f"\g<1> array_join({tagged_col.name}, ', ') \g<3>", temp_sql)
# modify the WHERE condition to work with arrays
split_cond_sql = temp_sql.split("WHERE")
if len(split_cond_sql) > 1:
temp_sql = split_cond_sql[0] + "WHERE " + f"array_contains({tagged_col.name},{split_cond_sql[1].split('=')[1]})"

sql_statements.append(SQLRow(table_info.catalog, table_info.database, table_info.table, temp_sql))

return sql_statements
Expand All @@ -77,9 +86,9 @@ def build(self, classified_result_pdf) -> list[SQLRow]:

classified_cols = classified_result_pdf.copy()
classified_cols = classified_cols[classified_cols['tag_name'].isin(self.tags)]
classified_cols = classified_cols.groupby(['catalog', 'database', 'table', 'column']).aggregate(lambda x: list(x))[['tag_name']].reset_index()
classified_cols = classified_cols.groupby(['catalog', 'database', 'table', 'column', 'data_type']).aggregate(lambda x: list(x))[['tag_name']].reset_index()

classified_cols['col_tags'] = classified_cols[['column', 'tag_name']].apply(tuple, axis=1)
classified_cols['col_tags'] = classified_cols[['column', 'data_type', 'tag_name']].apply(tuple, axis=1)
df = classified_cols.groupby(['catalog', 'database', 'table']).aggregate(lambda x: list(x))[['col_tags']].reset_index()

# Filter tables by matching filter
Expand All @@ -91,9 +100,9 @@ def build(self, classified_result_pdf) -> list[SQLRow]:
[
ColumnInfo(
col[0], # col name
"", # TODO
col[1], # data type
None, # TODO
col[1] # Tags
col[2] # Tags
) for col in row[3]
]
) for _, row in df.iterrows() if fnmatch(row[0], self.catalogs) and fnmatch(row[1], self.databases) and fnmatch(row[2], self.tables)]
Expand Down
196 changes: 144 additions & 52 deletions discoverx/scanner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.types import _parse_datatype_string
from pyspark.sql.utils import ParseException
from typing import Optional, List, Set

from discoverx.common.helper import strip_margin, format_regex
Expand All @@ -26,12 +29,13 @@ class TableInfo:
columns: list[ColumnInfo]

def get_columns_by_tag(self, tag: str):
return [TaggedColumn(col.name, tag) for col in self.columns if tag in col.tags]
return [TaggedColumn(col.name, col.data_type, tag) for col in self.columns if tag in col.tags]


@dataclass
class TaggedColumn:
name: str
data_type: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need the full_data_type, which contains the full definition of the composed columns

tag: str


Expand Down Expand Up @@ -60,14 +64,13 @@ class ScanResult:

@property
def n_scanned_columns(self) -> int:
return len(
self.df[["catalog", "database", "table", "column"]].drop_duplicates()
)
return len(self.df[["catalog", "database", "table", "column"]].drop_duplicates())


class Scanner:

COLUMNS_TABLE_NAME = "system.information_schema.columns"
COMPLEX_TYPES = {StructType, ArrayType}

def __init__(
self,
Expand Down Expand Up @@ -103,9 +106,7 @@ def _get_list_of_tables(self) -> List[TableInfo]:
row["table_schema"],
row["table_name"],
[
ColumnInfo(
col["column_name"], col["data_type"], col["partition_index"], []
)
ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], [])
for col in row["table_columns"]
],
)
Expand All @@ -124,9 +125,7 @@ def _get_table_list_sql(self):

catalog_sql = f"""AND regexp_like(table_catalog, "^{self.catalogs.replace("*", ".*")}$")"""
database_sql = f"""AND regexp_like(table_schema, "^{self.databases.replace("*", ".*")}$")"""
table_sql = (
f"""AND regexp_like(table_name, "^{self.tables.replace("*", ".*")}$")"""
)
table_sql = f"""AND regexp_like(table_name, "^{self.tables.replace("*", ".*")}$")"""

sql = f"""
SELECT
Expand Down Expand Up @@ -154,9 +153,7 @@ def _resolve_scan_content(self) -> ScanContent:

def scan(self):

logger.friendly(
"""Ok, I'm going to scan your lakehouse for data that matches your rules."""
)
logger.friendly("""Ok, I'm going to scan your lakehouse for data that matches your rules.""")
text = f"""
This is what you asked for:

Expand Down Expand Up @@ -196,18 +193,50 @@ def scan(self):
# Execute SQL and append result
dfs.append(self.spark.sql(sql).toPandas())
except Exception as e:
logger.error(
f"Error while scanning table '{table.catalog}.{table.database}.{table.table}': {e}"
)
logger.error(f"Error while scanning table '{table.catalog}.{table.database}.{table.table}': {e}")
continue

logger.debug("Finished lakehouse scanning task")

if dfs:
self.scan_result = ScanResult(df=pd.concat(dfs))
self.scan_result = ScanResult(df=pd.concat(dfs).reset_index(drop=True))
else:
self.scan_result = ScanResult(df=pd.DataFrame())

@staticmethod
def backtick_col_name(col_name: str) -> str:
col_name_splitted = col_name.split(".")
return ".".join(["`" + col + "`" for col in col_name_splitted])

def recursive_flatten_complex_type(self, col_name, schema, column_list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, wow! This was more complex than I expected.

if type(schema) in self.COMPLEX_TYPES:
iterable = schema
elif type(schema) is StructField:
iterable = schema.dataType
elif schema == StringType():
column_list.append({"col_name": col_name, "type": "string"})
return column_list
else:
return column_list

if type(iterable) is StructType:
for field in iterable:
if type(field.dataType) == StringType:
column_list.append(
{"col_name": self.backtick_col_name(col_name + "." + field.name), "type": "string"}
)
elif type(field.dataType) in self.COMPLEX_TYPES:
column_list = self.recursive_flatten_complex_type(col_name + "." + field.name, field, column_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should be appending to column_list instead of replacing column_list. Otherwise you overwrite previously appended string types

elif type(iterable) is MapType:
if type(iterable.valueType) not in self.COMPLEX_TYPES:
column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_values"})
if type(iterable.keyType) not in self.COMPLEX_TYPES:
column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_keys"})
elif type(iterable) is ArrayType:
column_list.append({"col_name": self.backtick_col_name(col_name), "type": "array"})

return column_list

def _rule_matching_sql(self, table_info: TableInfo):
"""
Given a table and a set of rules this method will return a
Expand All @@ -223,53 +252,116 @@ def _rule_matching_sql(self, table_info: TableInfo):
"""

expressions = [r for r in self.rule_list if r.type == RuleTypes.REGEX]
cols = [c for c in table_info.columns if c.data_type.lower() == "string"]
expr_pdf = pd.DataFrame([{"rule_name": r.name, "rule_definition": r.definition, "key": 0} for r in expressions])
column_list = []
for col in table_info.columns:
try:
data_type = _parse_datatype_string(col.data_type)
except ParseException:
data_type = None

if not cols:
raise Exception(
f"There are no columns of type string to be scanned in {table_info.table}"
)
if data_type:
self.recursive_flatten_complex_type(col.name, data_type, column_list)

if len(column_list) == 0:
raise Exception(f"There are no columns with supported types to be scanned in {table_info.table}")

if not expressions:
raise Exception(f"There are no rules to scan for.")

string_cols = [col for col in column_list if col["type"] == "string"]

sql_list = []
if len(string_cols) > 0:
sql_list.append(self.string_col_sql(string_cols, expressions, table_info))

array_cols = [col for col in column_list if col["type"] == "array"]
if len(array_cols) > 0:
sql_list.append(self.array_col_sql(array_cols, expressions, table_info))

all_sql = "\nUNION ALL \n".join(sql_list)
return all_sql

def string_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> str:
catalog_str = f"{table_info.catalog}." if table_info.catalog else ""
matching_columns = [
f"INT(regexp_like(value, '{format_regex(r.definition)}')) AS `{r.name}`"
for r in expressions
f"INT(regexp_like(value, '{format_regex(r.definition)}')) AS `{r.name}`" for r in expressions
]
matching_string = ",\n ".join(matching_columns)

unpivot_expressions = ", ".join(
[f"'{r.name}', `{r.name}`" for r in expressions]
)
unpivot_columns = ", ".join([f"'{c.name}', `{c.name}`" for c in cols])
unpivot_expressions = ", ".join([f"'{r.name}', `{r.name}`" for r in expressions])
unpivot_columns = ", ".join([f"'{c['col_name']}', '{c['type']}', {c['col_name']}" for c in cols])

sql = f"""
SELECT
'{table_info.catalog}' as catalog,
'{table_info.database}' as database,
'{table_info.table}' as table,
column,
rule_name,
(sum(value) / count(value)) as frequency
FROM
(
SELECT column, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value)
FROM
(
SELECT
column,
{matching_string}
FROM (
SELECT
stack({len(cols)}, {unpivot_columns}) AS (column, value)
FROM {catalog_str}{table_info.database}.{table_info.table}
TABLESAMPLE ({self.sample_size} ROWS)
SELECT
'{table_info.catalog}' as catalog,
'{table_info.database}' as database,
'{table_info.table}' as table,
column,
type,
rule_name,
(sum(value) / count(value)) as frequency
FROM
(
SELECT column, type, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value)
FROM
(
SELECT
column,
type,
{matching_string}
FROM (
SELECT
stack({len(cols)}, {unpivot_columns}) AS (column, type, value)
FROM {catalog_str}{table_info.database}.{table_info.table}
TABLESAMPLE ({self.sample_size} ROWS)
)
)
)
)
)
GROUP BY catalog, database, table, column, rule_name
"""
GROUP BY catalog, database, table, column, type, rule_name
"""
return strip_margin(sql)

def array_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> str:
catalog_str = f"{table_info.catalog}." if table_info.catalog else ""
matching_columns_sum = [
f"size(filter(value, x -> x rlike '{r.definition}')) AS `{r.name}_sum`" for r in expressions
]
matching_columns_count = [
f"size(value) AS `{r.name}_count`" for r in expressions
]
matching_columns = matching_columns_sum + matching_columns_count
matching_string = ",\n ".join(matching_columns)

unpivot_expressions = ", ".join([f"'{r.name}', `{r.name}_sum`, `{r.name}_count`" for r in expressions])
unpivot_columns = ", ".join([f"'{c['col_name']}', '{c['type']}', {c['col_name']}" for c in cols])

sql = f"""
SELECT
'{table_info.catalog}' as catalog,
'{table_info.database}' as database,
'{table_info.table}' as table,
column,
type,
rule_name,
(sum(value_sum) / sum(value_count)) as frequency
FROM
(
SELECT column, type, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value_sum, value_count)
FROM
(
SELECT
column,
type,
{matching_string}
FROM (
SELECT
stack({len(cols)}, {unpivot_columns}) AS (column, type, value)
FROM {catalog_str}{table_info.database}.{table_info.table}
TABLESAMPLE ({self.sample_size} ROWS)
)
)
)
GROUP BY catalog, database, table, column, type, rule_name
"""
return strip_margin(sql)
Loading