diff --git a/discoverx/classification.py b/discoverx/classification.py index 5c32632..f80628b 100644 --- a/discoverx/classification.py +++ b/discoverx/classification.py @@ -43,6 +43,7 @@ def above_threshold(self): "database": "table_schema", "table": "table_name", "column": "column_name", + "type": "data_type", "rule_name": "tag_name", } ) @@ -77,7 +78,7 @@ def aggregate_updates(pdf): return pd.DataFrame(output) - self.classification_result = pd.concat([classification_result, current_tags]).groupby(["table_catalog", "table_schema", "table_name", "column_name"], dropna=False, group_keys=True).apply(aggregate_updates).reset_index().drop(columns=["level_4"]) + self.classification_result = pd.concat([classification_result, current_tags]).groupby(["table_catalog", "table_schema", "table_name", "column_name", "data_type"], dropna=False, group_keys=True).apply(aggregate_updates).reset_index().drop(columns=["level_5"]) # when testing we don't have a 3-level namespace but we need # to make sure we get None instead of NaN self.classification_result.table_catalog = self.classification_result.table_catalog.astype(object) @@ -94,7 +95,7 @@ def _get_classification_table_from_delta(self): self.spark.sql(f"CREATE DATABASE IF NOT EXISTS {catalog + '.' + schema}") self.spark.sql( f""" - CREATE TABLE IF NOT EXISTS {self.classification_table_name} (table_catalog string, table_schema string, table_name string, column_name string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp) + CREATE TABLE IF NOT EXISTS {self.classification_table_name} (table_catalog string, table_schema string, table_name string, column_name string, data_type string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp) """ ) logger.friendly(f"The classification table {self.classification_table_name} has been created.") @@ -154,8 +155,7 @@ def _stage_updates(self, input_classification_pdf: pd.DataFrame): classification_pdf["to_be_set"] = classification_pdf.apply(lambda x: list(set(x["Tags to be published"]) - set(x["Current Tags"])), axis=1) classification_pdf["to_be_kept"] = classification_pdf.apply(lambda x: list(set(x["Tags to be published"]) & set(x["Current Tags"])), axis=1) - self.staged_updates = pd.melt(classification_pdf, id_vars=["table_catalog", "table_schema", "table_name", "column_name"], value_vars=["to_be_unset", "to_be_set", "to_be_kept"], var_name="action", value_name="tag_name").explode("tag_name").dropna(subset=["tag_name"]).reset_index(drop=True) - + self.staged_updates = pd.melt(classification_pdf, id_vars=["table_catalog", "table_schema", "table_name", "column_name", "data_type"], value_vars=["to_be_unset", "to_be_set", "to_be_kept"], var_name="action", value_name="tag_name").explode("tag_name").dropna(subset=["tag_name"]).reset_index(drop=True) def inspect(self): self.inspection_tool = InspectionTool(self.classification_result, self.publish) @@ -169,14 +169,14 @@ def publish(self, publish_uc_tags: bool): staged_updates_df = self.spark.createDataFrame( self.staged_updates, - "table_catalog: string, table_schema: string, table_name: string, column_name: string, action: string, tag_name: string", + "table_catalog: string, table_schema: string, table_name: string, column_name: string, data_type: string, action: string, tag_name: string", ).withColumn("effective_timestamp", func.current_timestamp()) # merge using scd-typ2 logger.friendly(f"Update classification table {self.classification_table_name}") self.classification_table.alias("target").merge( staged_updates_df.alias("source"), - "target.table_catalog <=> source.table_catalog AND target.table_schema = source.table_schema AND target.table_name = source.table_name AND target.column_name = source.column_name AND target.tag_name = source.tag_name AND target.current = true", + "target.table_catalog <=> source.table_catalog AND target.table_schema = source.table_schema AND target.table_name = source.table_name AND target.column_name = source.column_name AND target.data_type = source.data_type AND target.tag_name = source.tag_name AND target.current = true", ).whenMatchedUpdate( condition = "source.action = 'to_be_unset'", set={"current": "false", "end_timestamp": "source.effective_timestamp"} @@ -186,6 +186,7 @@ def publish(self, publish_uc_tags: bool): "table_schema": "source.table_schema", "table_name": "source.table_name", "column_name": "source.column_name", + "data_type": "source.data_type", "tag_name": "source.tag_name", "effective_timestamp": "source.effective_timestamp", "current": "true", diff --git a/discoverx/dx.py b/discoverx/dx.py index d23eee6..d9e5ef6 100644 --- a/discoverx/dx.py +++ b/discoverx/dx.py @@ -323,6 +323,7 @@ def _msql(self, msql: str, what_if: bool = False): func.col("table_schema").alias("database"), func.col("table_name").alias("table"), func.col("column_name").alias("column"), + "data_type", "tag_name", ).toPandas() ) diff --git a/discoverx/msql.py b/discoverx/msql.py index 9342c98..193b680 100644 --- a/discoverx/msql.py +++ b/discoverx/msql.py @@ -66,6 +66,15 @@ def compile_msql(self, table_info: TableInfo) -> list[SQLRow]: temp_sql = msql for tagged_col in tagged_cols: temp_sql = temp_sql.replace(f"[{tagged_col.tag}]", tagged_col.name) + # TODO: Can we avoid "replacing strings" for the different types in the future? This is due to the generation of MSQL. Maybe we should rather generate SQL directly from the search method... + if tagged_col.data_type == "array": + # return a string of the array as value to be able to union later + temp_sql = re.sub("(.*\'value\', )([^)]+)(\).*)", f"\g<1> array_join({tagged_col.name}, ', ') \g<3>", temp_sql) + # modify the WHERE condition to work with arrays + split_cond_sql = temp_sql.split("WHERE") + if len(split_cond_sql) > 1: + temp_sql = split_cond_sql[0] + "WHERE " + f"array_contains({tagged_col.name},{split_cond_sql[1].split('=')[1]})" + sql_statements.append(SQLRow(table_info.catalog, table_info.database, table_info.table, temp_sql)) return sql_statements @@ -77,9 +86,9 @@ def build(self, classified_result_pdf) -> list[SQLRow]: classified_cols = classified_result_pdf.copy() classified_cols = classified_cols[classified_cols['tag_name'].isin(self.tags)] - classified_cols = classified_cols.groupby(['catalog', 'database', 'table', 'column']).aggregate(lambda x: list(x))[['tag_name']].reset_index() + classified_cols = classified_cols.groupby(['catalog', 'database', 'table', 'column', 'data_type']).aggregate(lambda x: list(x))[['tag_name']].reset_index() - classified_cols['col_tags'] = classified_cols[['column', 'tag_name']].apply(tuple, axis=1) + classified_cols['col_tags'] = classified_cols[['column', 'data_type', 'tag_name']].apply(tuple, axis=1) df = classified_cols.groupby(['catalog', 'database', 'table']).aggregate(lambda x: list(x))[['col_tags']].reset_index() # Filter tables by matching filter @@ -91,9 +100,9 @@ def build(self, classified_result_pdf) -> list[SQLRow]: [ ColumnInfo( col[0], # col name - "", # TODO + col[1], # data type None, # TODO - col[1] # Tags + col[2] # Tags ) for col in row[3] ] ) for _, row in df.iterrows() if fnmatch(row[0], self.catalogs) and fnmatch(row[1], self.databases) and fnmatch(row[2], self.tables)] diff --git a/discoverx/scanner.py b/discoverx/scanner.py index caf177e..8c60796 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -1,6 +1,9 @@ from dataclasses import dataclass import pandas as pd from pyspark.sql import SparkSession +from pyspark.sql.types import * +from pyspark.sql.types import _parse_datatype_string +from pyspark.sql.utils import ParseException from typing import Optional, List, Set from discoverx.common.helper import strip_margin, format_regex @@ -26,12 +29,13 @@ class TableInfo: columns: list[ColumnInfo] def get_columns_by_tag(self, tag: str): - return [TaggedColumn(col.name, tag) for col in self.columns if tag in col.tags] + return [TaggedColumn(col.name, col.data_type, tag) for col in self.columns if tag in col.tags] @dataclass class TaggedColumn: name: str + data_type: str tag: str @@ -60,14 +64,13 @@ class ScanResult: @property def n_scanned_columns(self) -> int: - return len( - self.df[["catalog", "database", "table", "column"]].drop_duplicates() - ) + return len(self.df[["catalog", "database", "table", "column"]].drop_duplicates()) class Scanner: COLUMNS_TABLE_NAME = "system.information_schema.columns" + COMPLEX_TYPES = {StructType, ArrayType} def __init__( self, @@ -103,9 +106,7 @@ def _get_list_of_tables(self) -> List[TableInfo]: row["table_schema"], row["table_name"], [ - ColumnInfo( - col["column_name"], col["data_type"], col["partition_index"], [] - ) + ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], []) for col in row["table_columns"] ], ) @@ -124,9 +125,7 @@ def _get_table_list_sql(self): catalog_sql = f"""AND regexp_like(table_catalog, "^{self.catalogs.replace("*", ".*")}$")""" database_sql = f"""AND regexp_like(table_schema, "^{self.databases.replace("*", ".*")}$")""" - table_sql = ( - f"""AND regexp_like(table_name, "^{self.tables.replace("*", ".*")}$")""" - ) + table_sql = f"""AND regexp_like(table_name, "^{self.tables.replace("*", ".*")}$")""" sql = f""" SELECT @@ -154,9 +153,7 @@ def _resolve_scan_content(self) -> ScanContent: def scan(self): - logger.friendly( - """Ok, I'm going to scan your lakehouse for data that matches your rules.""" - ) + logger.friendly("""Ok, I'm going to scan your lakehouse for data that matches your rules.""") text = f""" This is what you asked for: @@ -196,18 +193,50 @@ def scan(self): # Execute SQL and append result dfs.append(self.spark.sql(sql).toPandas()) except Exception as e: - logger.error( - f"Error while scanning table '{table.catalog}.{table.database}.{table.table}': {e}" - ) + logger.error(f"Error while scanning table '{table.catalog}.{table.database}.{table.table}': {e}") continue logger.debug("Finished lakehouse scanning task") if dfs: - self.scan_result = ScanResult(df=pd.concat(dfs)) + self.scan_result = ScanResult(df=pd.concat(dfs).reset_index(drop=True)) else: self.scan_result = ScanResult(df=pd.DataFrame()) + @staticmethod + def backtick_col_name(col_name: str) -> str: + col_name_splitted = col_name.split(".") + return ".".join(["`" + col + "`" for col in col_name_splitted]) + + def recursive_flatten_complex_type(self, col_name, schema, column_list): + if type(schema) in self.COMPLEX_TYPES: + iterable = schema + elif type(schema) is StructField: + iterable = schema.dataType + elif schema == StringType(): + column_list.append({"col_name": col_name, "type": "string"}) + return column_list + else: + return column_list + + if type(iterable) is StructType: + for field in iterable: + if type(field.dataType) == StringType: + column_list.append( + {"col_name": self.backtick_col_name(col_name + "." + field.name), "type": "string"} + ) + elif type(field.dataType) in self.COMPLEX_TYPES: + column_list = self.recursive_flatten_complex_type(col_name + "." + field.name, field, column_list) + elif type(iterable) is MapType: + if type(iterable.valueType) not in self.COMPLEX_TYPES: + column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_values"}) + if type(iterable.keyType) not in self.COMPLEX_TYPES: + column_list.append({"col_name": self.backtick_col_name(col_name), "type": "map_keys"}) + elif type(iterable) is ArrayType: + column_list.append({"col_name": self.backtick_col_name(col_name), "type": "array"}) + + return column_list + def _rule_matching_sql(self, table_info: TableInfo): """ Given a table and a set of rules this method will return a @@ -223,53 +252,116 @@ def _rule_matching_sql(self, table_info: TableInfo): """ expressions = [r for r in self.rule_list if r.type == RuleTypes.REGEX] - cols = [c for c in table_info.columns if c.data_type.lower() == "string"] + expr_pdf = pd.DataFrame([{"rule_name": r.name, "rule_definition": r.definition, "key": 0} for r in expressions]) + column_list = [] + for col in table_info.columns: + try: + data_type = _parse_datatype_string(col.data_type) + except ParseException: + data_type = None - if not cols: - raise Exception( - f"There are no columns of type string to be scanned in {table_info.table}" - ) + if data_type: + self.recursive_flatten_complex_type(col.name, data_type, column_list) + + if len(column_list) == 0: + raise Exception(f"There are no columns with supported types to be scanned in {table_info.table}") if not expressions: raise Exception(f"There are no rules to scan for.") + string_cols = [col for col in column_list if col["type"] == "string"] + + sql_list = [] + if len(string_cols) > 0: + sql_list.append(self.string_col_sql(string_cols, expressions, table_info)) + + array_cols = [col for col in column_list if col["type"] == "array"] + if len(array_cols) > 0: + sql_list.append(self.array_col_sql(array_cols, expressions, table_info)) + + all_sql = "\nUNION ALL \n".join(sql_list) + return all_sql + + def string_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> str: catalog_str = f"{table_info.catalog}." if table_info.catalog else "" matching_columns = [ - f"INT(regexp_like(value, '{format_regex(r.definition)}')) AS `{r.name}`" - for r in expressions + f"INT(regexp_like(value, '{format_regex(r.definition)}')) AS `{r.name}`" for r in expressions ] matching_string = ",\n ".join(matching_columns) - unpivot_expressions = ", ".join( - [f"'{r.name}', `{r.name}`" for r in expressions] - ) - unpivot_columns = ", ".join([f"'{c.name}', `{c.name}`" for c in cols]) + unpivot_expressions = ", ".join([f"'{r.name}', `{r.name}`" for r in expressions]) + unpivot_columns = ", ".join([f"'{c['col_name']}', '{c['type']}', {c['col_name']}" for c in cols]) sql = f""" - SELECT - '{table_info.catalog}' as catalog, - '{table_info.database}' as database, - '{table_info.table}' as table, - column, - rule_name, - (sum(value) / count(value)) as frequency - FROM - ( - SELECT column, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value) - FROM - ( - SELECT - column, - {matching_string} - FROM ( - SELECT - stack({len(cols)}, {unpivot_columns}) AS (column, value) - FROM {catalog_str}{table_info.database}.{table_info.table} - TABLESAMPLE ({self.sample_size} ROWS) + SELECT + '{table_info.catalog}' as catalog, + '{table_info.database}' as database, + '{table_info.table}' as table, + column, + type, + rule_name, + (sum(value) / count(value)) as frequency + FROM + ( + SELECT column, type, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value) + FROM + ( + SELECT + column, + type, + {matching_string} + FROM ( + SELECT + stack({len(cols)}, {unpivot_columns}) AS (column, type, value) + FROM {catalog_str}{table_info.database}.{table_info.table} + TABLESAMPLE ({self.sample_size} ROWS) + ) + ) ) - ) - ) - GROUP BY catalog, database, table, column, rule_name - """ + GROUP BY catalog, database, table, column, type, rule_name + """ + return strip_margin(sql) + + def array_col_sql(self, cols: List, expressions: List, table_info: TableInfo) -> str: + catalog_str = f"{table_info.catalog}." if table_info.catalog else "" + matching_columns_sum = [ + f"size(filter(value, x -> x rlike '{r.definition}')) AS `{r.name}_sum`" for r in expressions + ] + matching_columns_count = [ + f"size(value) AS `{r.name}_count`" for r in expressions + ] + matching_columns = matching_columns_sum + matching_columns_count + matching_string = ",\n ".join(matching_columns) + + unpivot_expressions = ", ".join([f"'{r.name}', `{r.name}_sum`, `{r.name}_count`" for r in expressions]) + unpivot_columns = ", ".join([f"'{c['col_name']}', '{c['type']}', {c['col_name']}" for c in cols]) + sql = f""" + SELECT + '{table_info.catalog}' as catalog, + '{table_info.database}' as database, + '{table_info.table}' as table, + column, + type, + rule_name, + (sum(value_sum) / sum(value_count)) as frequency + FROM + ( + SELECT column, type, stack({len(expressions)}, {unpivot_expressions}) as (rule_name, value_sum, value_count) + FROM + ( + SELECT + column, + type, + {matching_string} + FROM ( + SELECT + stack({len(cols)}, {unpivot_columns}) AS (column, type, value) + FROM {catalog_str}{table_info.database}.{table_info.table} + TABLESAMPLE ({self.sample_size} ROWS) + ) + ) + ) + GROUP BY catalog, database, table, column, type, rule_name + """ return strip_margin(sql) diff --git a/tests/unit/classification_test.py b/tests/unit/classification_test.py index d4aee39..25dfd25 100644 --- a/tests/unit/classification_test.py +++ b/tests/unit/classification_test.py @@ -1,7 +1,6 @@ import pandas as pd from pandas.testing import assert_frame_equal import pytest -import numpy as np from discoverx.dx import DX from discoverx.dx import Scanner @@ -32,6 +31,7 @@ def test_classifier(spark): ], "table": ["tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1"], "column": ["ip", "ip", "ip", "mac", "mac", "mac", "description", "description", "description"], + "type": ["string", "string", "string", "string", "string", "string", "string", "string", "string"], "rule_name": ["ip_v4", "ip_v6", "mac", "ip_v4", "ip_v6", "mac", "ip_v4", "ip_v6", "mac"], "frequency": [1.0, 0.0, 0.0, 0.0, 0.0, 0.97, 0.0, 0.0, 0.0], } @@ -48,6 +48,7 @@ def test_classifier(spark): "table_schema": ["default", "default"], "table_name": ["tb_1", "tb_1"], "column_name": ["ip", "mac"], + "data_type": ["string", "string"], "Current Tags": [[], []], "Detected Tags": [["ip_v4"], ["mac"]], "Tags to be published": [["ip_v4"], ["mac"]], @@ -74,6 +75,7 @@ def test_merging_scan_results(spark, mock_current_time): "database": ["default", "default", "default", "default", "default", "default"], "table": ["tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1"], "column": ["ip", "ip", "mac", "mac", "description", "description"], + "type": ["string", "string", "string", "string", "string", "string"], "rule_name": ["ip_v4", "ip_v6", "ip_v4", "ip_v6", "ip_v4", "ip_v6"], "frequency": [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], } @@ -90,6 +92,7 @@ def test_merging_scan_results(spark, mock_current_time): "table_schema": ["default"], "table_name": ["tb_1"], "column_name": ["ip"], + "data_type": ["string"], "tag_name": ["ip_v4"], "effective_timestamp": [pd.Timestamp(2023, 1, 1, 0)], "current": [True], @@ -112,6 +115,7 @@ def test_merging_scan_results(spark, mock_current_time): "table_schema": ["default"], "table_name": ["tb_1"], "column_name": ["ip"], + "data_type": ["string"], "tag_name": ["ip_v4"], "effective_timestamp": [pd.Timestamp(2023, 1, 1, 0)], "current": [True], @@ -129,6 +133,7 @@ def test_merging_scan_results(spark, mock_current_time): "database": ["default", "default", "default", "default", "default", "default", "default", "default"], "table": ["tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1"], "column": ["ip", "ip", "ip6", "ip6", "mac", "mac", "description", "description"], + "type": ["string", "string", "string", "string", "string", "string", "string", "string"], "rule_name": ["ip_v4", "ip_v6", "ip_v4", "ip_v6", "ip_v4", "ip_v6", "ip_v4", "ip_v6"], "frequency": [1.0, 0.0, 0.0, 0.97, 0.0, 0.0, 0.0, 0.0], } @@ -147,6 +152,7 @@ def test_merging_scan_results(spark, mock_current_time): "table_schema": ["default", "default"], "table_name": ["tb_1", "tb_1"], "column_name": ["ip", "ip6"], + "data_type": ["string", "string"], "tag_name": ["ip_v4", "ip_v6"], "effective_timestamp": [current_time, current_time], "current": [True, True], @@ -169,6 +175,7 @@ def test_merging_scan_results(spark, mock_current_time): "database": ["default", "default", "default", "default", "default", "default", "default", "default"], "table": ["tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1", "tb_1"], "column": ["ip", "ip", "ip6", "ip6", "mac", "mac", "description", "description"], + "type": ["string", "string", "string", "string", "string", "string", "string", "string"], "rule_name": ["ip_v4", "ip_v6", "ip_v4", "ip_v6", "ip_v4", "ip_v6", "ip_v4", "ip_v6"], "frequency": [0.7, 0.0, 0.0, 0.97, 0.0, 0.0, 0.0, 0.0], } @@ -186,6 +193,7 @@ def test_merging_scan_results(spark, mock_current_time): "table_schema": ["default", "default"], "table_name": ["tb_1", "tb_1"], "column_name": ["ip", "ip6"], + "data_type": ["string", "string"], "tag_name": ["ip_v4", "ip_v6"], "effective_timestamp": [current_time, current_time], "current": [True, True], @@ -235,6 +243,20 @@ def test_merging_scan_results(spark, mock_current_time): "description", "description", ], + "type": [ + "string", + "string", + "string", + "string", + "string", + "string", + "string", + "string", + "string", + "string", + "string", + "string", + ], "rule_name": [ "ip_v4", "ip_v6", @@ -272,6 +294,7 @@ def test_merging_scan_results(spark, mock_current_time): "table_schema": ["default", "default", "default", "default"], "table_name": ["tb_1", "tb_1", "tb_1", "tb_2"], "column_name": ["ip", "ip6", "ip6", "mac"], + "data_type": ["string", "string", "string", "string"], "tag_name": ["ip_v4", "ip_v6", "pii", "mac"], "effective_timestamp": [current_time, current_time, current_time, current_time], "current": [False, True, True, True], diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 86000dd..27c678f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -8,13 +8,12 @@ import tempfile from dataclasses import dataclass from pathlib import Path -from typing import Iterator -from unittest.mock import patch import mlflow import pytest from delta import configure_spark_with_delta_pip from pyspark.sql import SparkSession +from pyspark.sql.types import * from discoverx.classification import DeltaTable from discoverx.dx import Classifier from discoverx.classification import func @@ -139,6 +138,40 @@ def sample_datasets(spark: SparkSession, request): ).createOrReplaceTempView("view_tb_1") spark.sql(f"CREATE TABLE IF NOT EXISTS default.tb_1 USING delta LOCATION '{warehouse_dir}/tb_1' AS SELECT * FROM view_tb_1 ") + # tb_2 + test_file_tb2_path = module_path.parent / "data/tb_2.json" + schema_json_example = ( + StructType() + .add( + "customer", + StructType() + .add("name", StringType(), True) + .add("id", IntegerType(), True) + .add( + "contact", + StructType() + .add( + "address", + StructType() + .add("street", StringType(), True) + .add("town", StringType(), True) + .add("postal_number", StringType(), True) + .add("country", StringType(), True) + .add("ips_used", ArrayType(StringType()), True), + True, + ) + .add("email", StringType()), + ) + .add("products_owned", ArrayType(StringType()), True) + .add("interactions", MapType(StringType(), StringType())), + True, + ) + .add("active", BooleanType(), True) + .add("categories", MapType(StringType(), StringType())) +) + spark.read.schema(schema_json_example).json(str(test_file_tb2_path.resolve())).createOrReplaceTempView("view_tb_2") + spark.sql( + f"CREATE TABLE IF NOT EXISTS default.tb_2 USING delta LOCATION '{warehouse_dir}/tb_2' AS SELECT * FROM view_tb_2 ") # columns_mock test_file_path = module_path.parent / "data/columns_mock.csv" (spark @@ -157,6 +190,7 @@ def sample_datasets(spark: SparkSession, request): logging.info("Test session finished, removing sample datasets") spark.sql("DROP TABLE IF EXISTS default.tb_1") + spark.sql("DROP TABLE IF EXISTS default.tb_2") spark.sql("DROP TABLE IF EXISTS default.columns_mock") if Path(warehouse_dir).exists(): shutil.rmtree(warehouse_dir) @@ -220,7 +254,7 @@ def get_classification_table_mock(self): self.spark.sql(f"CREATE DATABASE IF NOT EXISTS {schema}") self.spark.sql( f""" - CREATE TABLE IF NOT EXISTS {schema + '.' + table} (table_catalog string, table_schema string, table_name string, column_name string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp) USING DELTA + CREATE TABLE IF NOT EXISTS {schema + '.' + table} (table_catalog string, table_schema string, table_name string, column_name string, data_type string, tag_name string, effective_timestamp timestamp, current boolean, end_timestamp timestamp) USING DELTA """ ) return DeltaTable.forName(self.spark, self.classification_table_name) diff --git a/tests/unit/data/columns_mock.csv b/tests/unit/data/columns_mock.csv index b4d5dbe..6980d41 100644 --- a/tests/unit/data/columns_mock.csv +++ b/tests/unit/data/columns_mock.csv @@ -23,3 +23,6 @@ hive_metastore,default,tb_all_types,str_part_col,STRING,1 ,default,tb_1,ip,STRING, ,default,tb_1,mac,STRING, ,default,tb_1,description,STRING, +,default,tb_2,active,BOOLEAN, +,default,tb_2,categories,"map", +,default,tb_2,customer,"struct>,email:string>,products_owned:array>", diff --git a/tests/unit/data/tb_2.json b/tests/unit/data/tb_2.json new file mode 100644 index 0000000..384b73f --- /dev/null +++ b/tests/unit/data/tb_2.json @@ -0,0 +1,3 @@ +{"customer": {"name": "AAA BBBB", "id": 1, "contact": {"address": {"street": "AAA street 11", "town": "AAA town", "postal_number": "111333", "country": "AAA country", "ips_used": []}, "email": "aaa.bbb@aaa.com"}, "products_owned": ["product1", "product2", "product10"], "interactions": {"service": "test aaa", "shop": "test shop aaa"}}, "active": true, "categories": {"cat1": "D"}} +{"customer": {"name": "BBB CCCC", "id": 2, "contact": {"address": {"street": "BBB street 12", "town": "BBB town", "postal_number": "111233", "country": "BBB country", "ips_used": ["102.2.1.1", "103.3.1.1"]}, "email": "bbb.ccc@bbb.com"}, "products_owned": ["product1", "product10"], "interactions": {"service": "test bbb", "request": "test r bbb"}}, "active": false, "categories": {"cat1": "A", "cat2": "B", "cat3": "C"}} +{"customer": {"name": "CCC DDDD", "id": 3, "contact": {"address": {"street": "CCC street 13", "town": "CCC town", "postal_number": "111244", "country": "CCC country", "ips_used": ["102.1.1.1", "1.2.3.4", "104.1.1.1"]}, "email": "ccc.ddd@ccc.com"}, "products_owned": ["product11"], "interactions": {}}, "active": true, "categories": {"cat1": "A", "cat2": "A"}} \ No newline at end of file diff --git a/tests/unit/dx_test.py b/tests/unit/dx_test.py index c3b17f1..ef47221 100644 --- a/tests/unit/dx_test.py +++ b/tests/unit/dx_test.py @@ -7,9 +7,9 @@ @pytest.fixture(scope="module", name="dx_ip") -def scan_ip_in_tb1(spark, mock_uc_functionality): +def scan_ip_in_tb(spark, mock_uc_functionality): dx = DX(spark=spark, classification_table_name="_discoverx.tags") - dx.scan(from_tables="*.*.tb_1", rules="ip_*") + dx.scan(from_tables="*.*.tb_*", rules="ip_*") dx.publish() yield dx @@ -50,87 +50,108 @@ def test_scan_and_msql(spark, dx_ip): except Exception as e: pytest.fail(f"Test failed with exception {e}") + def test_search(spark, dx_ip: DX): # search a specific term and auto-detect matching tags/rules result = dx_ip.search("1.2.3.4").collect() - assert result[0].table == 'tb_1' - assert result[0].search_result.ip_v4.column == 'ip' + assert result[0].table == "tb_1" + assert result[0].search_result.ip_v4.column == "ip" + assert result[1].table == "tb_2" + assert result[1].search_result.ip_v4.column == "`customer`.`contact`.`address`.`ips_used`" # search all records for specific tag - result_tags_only = dx_ip.search(by_tags='ip_v4') - assert {row.search_result.ip_v4.value for row in result_tags_only.collect()} == {"1.2.3.4", "3.4.5.60"} + result_tags_only = dx_ip.search(by_tags="ip_v4") + assert {row.search_result.ip_v4.value for row in result_tags_only.collect()} == { + "", + "1.2.3.4", + "102.1.1.1, 1.2.3.4, 104.1.1.1", + "102.2.1.1, 103.3.1.1", + "3.4.5.60", + } # specify catalog, database and table - result_tags_namespace = dx_ip.search(by_tags='ip_v4', from_tables="*.default.tb_*") + result_tags_namespace = dx_ip.search(by_tags="ip_v4", from_tables="*.default.tb_1") assert {row.search_result.ip_v4.value for row in result_tags_namespace.collect()} == {"1.2.3.4", "3.4.5.60"} # search specific term for list of specified tags - result_term_tag = dx_ip.search(search_term="3.4.5.60", by_tags=['ip_v4']).collect() - assert result_term_tag[0].table == 'tb_1' + result_term_tag = dx_ip.search(search_term="3.4.5.60", by_tags=["ip_v4"]).collect() + assert result_term_tag[0].table == "tb_1" assert result_term_tag[0].search_result.ip_v4.value == "3.4.5.60" with pytest.raises(ValueError) as no_tags_no_terms_error: dx_ip.search() - assert no_tags_no_terms_error.value.args[0] == "Neither search_term nor by_tags have been provided. At least one of them need to be specified." + assert ( + no_tags_no_terms_error.value.args[0] + == "Neither search_term nor by_tags have been provided. At least one of them need to be specified." + ) with pytest.raises(ValueError) as list_with_ints: - dx_ip.search(by_tags=[1, 3, 'ip']) - assert list_with_ints.value.args[0] == "The provided by_tags [1, 3, 'ip'] have the wrong type. Please provide either a str or List[str]." + dx_ip.search(by_tags=[1, 3, "ip"]) + assert ( + list_with_ints.value.args[0] + == "The provided by_tags [1, 3, 'ip'] have the wrong type. Please provide either a str or List[str]." + ) with pytest.raises(ValueError) as single_bool: dx_ip.search(by_tags=True) - assert single_bool.value.args[0] == "The provided by_tags True have the wrong type. Please provide either a str or List[str]." + assert ( + single_bool.value.args[0] + == "The provided by_tags True have the wrong type. Please provide either a str or List[str]." + ) def test_select_by_tag(spark, dx_ip): # search a specific term and auto-detect matching tags/rules result = dx_ip.select_by_tags(from_tables="*.default.tb_*", by_tags="ip_v4").collect() - assert result[0].table == 'tb_1' - assert result[0].tagged_columns.ip_v4.column == 'ip' + assert result[0].table == "tb_1" + assert result[0].tagged_columns.ip_v4.column == "ip" result = dx_ip.select_by_tags(from_tables="*.default.tb_*", by_tags=["ip_v4"]).collect() - assert result[0].table == 'tb_1' - assert result[0].tagged_columns.ip_v4.column == 'ip' + assert result[0].table == "tb_1" + assert result[0].tagged_columns.ip_v4.column == "ip" with pytest.raises(ValueError): dx_ip.select_by_tags(from_tables="*.default.tb_*") - + with pytest.raises(ValueError): - dx_ip.select_by_tags(from_tables="*.default.tb_*", by_tags=[1, 3, 'ip']) - + dx_ip.select_by_tags(from_tables="*.default.tb_*", by_tags=[1, 3, "ip"]) + with pytest.raises(ValueError): dx_ip.select_by_tags(from_tables="*.default.tb_*", by_tags=True) with pytest.raises(ValueError): dx_ip.select_by_tags(from_tables="invalid from", by_tags="email") - + + # @pytest.mark.skip(reason="Delete is only working with v2 tables. Needs investigation") def test_delete_by_tag(spark, dx_ip): # search a specific term and auto-detect matching tags/rules result = dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag="ip_v4", values="9.9.9.9") - assert result is None # Nothing should be executed + assert result is None # Nothing should be executed - result = dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag="ip_v4", values="9.9.9.9", yes_i_am_sure=True).collect() - assert result[0].table == 'tb_1' + result = dx_ip.delete_by_tag( + from_tables="*.default.tb_*", by_tag="ip_v4", values="9.9.9.9", yes_i_am_sure=True + ).collect() + assert result[0].table == "tb_1" with pytest.raises(ValueError): dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag="x") with pytest.raises(ValueError): dx_ip.delete_by_tag(from_tables="*.default.tb_*", values="x") - + with pytest.raises(ValueError): - dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag=['ip'], values="x") - + dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag=["ip"], values="x") + with pytest.raises(ValueError): dx_ip.delete_by_tag(from_tables="*.default.tb_*", by_tag=True, values="x") with pytest.raises(ValueError): dx_ip.delete_by_tag(from_tables="invalid from", by_tag="email", values="x") - + # test multiple tags def test_search_multiple(spark, mock_uc_functionality): @@ -140,8 +161,8 @@ def test_search_multiple(spark, mock_uc_functionality): # search a specific term and auto-detect matching tags/rules result = dx.search(by_tags=["ip_v4", "mac"]) - assert result.collect()[0].table == 'tb_1' - assert result.collect()[0].search_result.ip_v4.column == 'ip' - assert result.collect()[0].search_result.mac.column == 'mac' + assert result.collect()[0].table == "tb_1" + assert result.collect()[0].search_result.ip_v4.column == "ip" + assert result.collect()[0].search_result.mac.column == "mac" spark.sql("DROP TABLE IF EXISTS _discoverx.tags") diff --git a/tests/unit/scanner_test.py b/tests/unit/scanner_test.py index 73d660b..8f3fb86 100644 --- a/tests/unit/scanner_test.py +++ b/tests/unit/scanner_test.py @@ -194,20 +194,36 @@ def test_scan_custom_rules(spark: SparkSession): def test_scan(spark: SparkSession): expected = pd.DataFrame( [ - ["None", "default", "tb_1", "ip", "ip_v4", 1.0], - ["None", "default", "tb_1", "ip", "ip_v6", 0.0], - ["None", "default", "tb_1", "mac", "ip_v4", 0.0], - ["None", "default", "tb_1", "mac", "ip_v6", 0.0], - ["None", "default", "tb_1", "description", "ip_v4", 0.0], - ["None", "default", "tb_1", "description", "ip_v6", 0.0], + ["None", "default", "tb_1", "ip", "string", "ip_v4", 1.0], + ["None", "default", "tb_1", "ip", "string", "ip_v6", 0.0], + ["None", "default", "tb_1", "mac", "string", "ip_v4", 0.0], + ["None", "default", "tb_1", "mac", "string", "ip_v6", 0.0], + ["None", "default", "tb_1", "description", "string", "ip_v4", 0.0], + ["None", "default", "tb_1", "description", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`name`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`name`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`street`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`street`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`town`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`town`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`postal_number`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`postal_number`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`country`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`country`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`email`", "string", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`email`", "string", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`ips_used`", "array", "ip_v4", 1.0], + ["None", "default", "tb_2", "`customer`.`contact`.`address`.`ips_used`", "array", "ip_v6", 0.0], + ["None", "default", "tb_2", "`customer`.`products_owned`", "array", "ip_v4", 0.0], + ["None", "default", "tb_2", "`customer`.`products_owned`", "array", "ip_v6", 0.0], ], - columns=["catalog", "database", "table", "column", "rule_name", "frequency"], + columns=["catalog", "database", "table", "column", "type", "rule_name", "frequency"], ) rules = Rules() MockedScanner = Scanner MockedScanner.COLUMNS_TABLE_NAME = "default.columns_mock" - scanner = MockedScanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*") + scanner = MockedScanner(spark, rules=rules, tables="tb_*", rule_filter="ip_*") scanner.scan() assert scanner.scan_result.df.equals(expected)