From 7834b8ee7022f2a7ea66a17ad155c48d9ab969c1 Mon Sep 17 00:00:00 2001 From: Erni Durdevic Date: Tue, 26 Sep 2023 11:47:38 +0200 Subject: [PATCH] Fixed some of the unit tests --- discoverx/msql.py | 1 + discoverx/scanner.py | 1 + tests/unit/conftest.py | 13 ++++++++++++- tests/unit/data/table_tags.csv | 4 ++++ tests/unit/explorer_test.py | 4 ++-- tests/unit/msql_test.py | 2 +- tests/unit/scanner_test.py | 31 ++++++++++++++----------------- 7 files changed, 35 insertions(+), 21 deletions(-) create mode 100644 tests/unit/data/table_tags.csv diff --git a/discoverx/msql.py b/discoverx/msql.py index d844b49..0fc20e6 100644 --- a/discoverx/msql.py +++ b/discoverx/msql.py @@ -97,6 +97,7 @@ def build(self, classified_result_pdf) -> list[SQLRow]: row[1], row[2], [ColumnInfo(col[0], "", None, col[1]) for col in row[3]], # col name # TODO # TODO # Classes + [], ) for _, row in df.iterrows() if fnmatch(row[0], self.catalogs) and fnmatch(row[1], self.schemas) and fnmatch(row[2], self.tables) diff --git a/discoverx/scanner.py b/discoverx/scanner.py index 325bdf8..5742550 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -199,6 +199,7 @@ def _get_list_of_tables(self) -> List[TableInfo]: ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], []) for col in row["table_columns"] ], + [], ) for row in rows ] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index ca89f94..7da4c4e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -158,6 +158,17 @@ def sample_datasets(spark: SparkSession, request): f"CREATE TABLE IF NOT EXISTS default.columns USING delta LOCATION '{warehouse_dir}/columns' AS SELECT * FROM view_columns_mock" ) + # table_tags + test_file_path = module_path.parent / "data/table_tags.csv" + ( + spark.read.option("header", True) + .schema("catalog_name string,schema_name string,table_name string,tag_name string,tag_value string") + .csv(str(test_file_path.resolve())) + ).createOrReplaceTempView("table_tags_mock") + spark.sql( + f"CREATE TABLE IF NOT EXISTS default.table_tags USING delta LOCATION '{warehouse_dir}/table_tags' AS SELECT * FROM table_tags_mock" + ) + logging.info("Sample datasets created") yield None @@ -213,7 +224,7 @@ def monkeymodule(): @pytest.fixture(autouse=True, scope="module") def mock_uc_functionality(monkeymodule): # apply the monkeypatch for the columns_table_name - monkeymodule.setattr(DX, "COLUMNS_TABLE_NAME", "default.columns_mock") + monkeymodule.setattr(DX, "COLUMNS_TABLE_NAME", "default.columns") monkeymodule.setattr(DX, "INFORMATION_SCHEMA", "default") # mock classifier method _get_classification_table_from_delta as we don't diff --git a/tests/unit/data/table_tags.csv b/tests/unit/data/table_tags.csv new file mode 100644 index 0000000..501db03 --- /dev/null +++ b/tests/unit/data/table_tags.csv @@ -0,0 +1,4 @@ +catalog_name,schema_name,table_name,tag_name,tag_value +hive_metastore,default,tb_all_types,int_col,my_int_tag +,default,tb_1,pk, +,default,tb_1,pii,true \ No newline at end of file diff --git a/tests/unit/explorer_test.py b/tests/unit/explorer_test.py index 23f54f5..9776a35 100644 --- a/tests/unit/explorer_test.py +++ b/tests/unit/explorer_test.py @@ -8,12 +8,12 @@ # sample_table_info = TableInfo("catalog1", "schema1", "table1", []) @pytest.fixture() def info_fetcher(spark): - return InfoFetcher(spark=spark, columns_table_name="default.columns_mock") + return InfoFetcher(spark=spark, information_schema="default") @pytest.fixture() def sample_table_info(): - return TableInfo("catalog1", "schema1", "table1", []) + return TableInfo("catalog1", "schema1", "table1", [], []) def test_validate_from_components(): diff --git a/tests/unit/msql_test.py b/tests/unit/msql_test.py index 60604d8..5d45fc6 100644 --- a/tests/unit/msql_test.py +++ b/tests/unit/msql_test.py @@ -39,7 +39,7 @@ def classification_df(spark) -> pd.DataFrame: ColumnInfo("email_2", "string", None, ["dx_email"]), ColumnInfo("date", "string", 1, ["dx_date_partition"]), ] -table_info = TableInfo("catalog", "prod_db1", "tb1", columns) +table_info = TableInfo("catalog", "prod_db1", "tb1", columns, []) def test_msql_extracts_command(): diff --git a/tests/unit/scanner_test.py b/tests/unit/scanner_test.py index 681051d..2371524 100644 --- a/tests/unit/scanner_test.py +++ b/tests/unit/scanner_test.py @@ -36,6 +36,7 @@ def test_get_table_list(spark): ColumnInfo("interval_col", "INTERVAL", None, []), ColumnInfo("str_part_col", "STRING", 1, []), ], + [], ) ] @@ -48,7 +49,7 @@ def test_get_table_list(spark): tables="*_all_types", rule_filter="*", sample_size=100, - columns_table_name="default.columns_mock", + columns_table_name="default.columns", ) actual = scanner._get_list_of_tables() @@ -124,13 +125,11 @@ def test_get_table_list(spark): ) def test_generate_sql(spark, rules_input, expected): columns = [ColumnInfo("id", "number", False, []), ColumnInfo("name", "string", False, [])] - table_info = TableInfo("meta", "db", "tb", columns) + table_info = TableInfo("meta", "db", "tb", columns, []) rules = rules_input rules = Rules(custom_rules=rules) - scanner = Scanner( - spark, rules=rules, rule_filter="any_*", sample_size=100, columns_table_name="default.columns_mock" - ) + scanner = Scanner(spark, rules=rules, rule_filter="any_*", sample_size=100, columns_table_name="default.columns") actual = scanner._rule_matching_sql(table_info) logging.info("Generated SQL is: \n%s", actual) @@ -144,16 +143,14 @@ def test_sql_runs(spark): ColumnInfo("ip", "string", None, []), ColumnInfo("description", "string", None, []), ] - table_info = TableInfo(None, "default", "tb_1", columns) + table_info = TableInfo(None, "default", "tb_1", columns, []) rules = [ RegexRule(name="any_word", description="Any word", definition=r"\w+"), RegexRule(name="any_number", description="Any number", definition=r"\d+"), ] rules = Rules(custom_rules=rules) - scanner = Scanner( - spark, rules=rules, rule_filter="any_*", sample_size=100, columns_table_name="default.columns_mock" - ) + scanner = Scanner(spark, rules=rules, rule_filter="any_*", sample_size=100, columns_table_name="default.columns") actual = scanner._rule_matching_sql(table_info) logging.info("Generated SQL is: \n%s", actual) @@ -181,7 +178,7 @@ def test_scan_custom_rules(spark: SparkSession): ColumnInfo("ip", "string", False, []), ColumnInfo("description", "string", False, []), ] - table_list = [TableInfo(None, "default", "tb_1", columns)] + table_list = [TableInfo(None, "default", "tb_1", columns, [])] rules = [ RegexRule(name="any_word", description="Any word", definition=r"^\w*$"), RegexRule(name="any_number", description="Any number", definition=r"^\d*$"), @@ -194,7 +191,7 @@ def test_scan_custom_rules(spark: SparkSession): tables="tb_1", rule_filter="any_*", sample_size=100, - columns_table_name="default.columns_mock", + columns_table_name="default.columns", ) scanner.scan() @@ -217,7 +214,7 @@ def test_scan(spark: SparkSession): ) rules = Rules() - scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns_mock") + scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns") scanner.scan() assert scanner.scan_result.df.equals(expected) @@ -226,7 +223,7 @@ def test_scan(spark: SparkSession): def test_save_scan(spark: SparkSession): # save scan result rules = Rules() - scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns_mock") + scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns") scanner.scan() scan_table_name = "_discoverx.scan_result_test" scanner.scan_result.save(scan_table_name=scan_table_name) @@ -261,8 +258,8 @@ def test_save_scan(spark: SparkSession): def test_scan_non_existing_table_returns_none(spark: SparkSession): rules = Rules() - scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns_mock") - result = scanner.scan_table(TableInfo("", "", "tb_non_existing", [])) + scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns") + result = scanner.scan_table(TableInfo("", "", "tb_non_existing", [], [])) assert result is None @@ -270,9 +267,9 @@ def test_scan_non_existing_table_returns_none(spark: SparkSession): def test_scan_whatif_returns_none(spark: SparkSession): rules = Rules() scanner = Scanner( - spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns_mock", what_if=True + spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns", what_if=True ) - result = scanner.scan_table(TableInfo(None, "default", "tb_1", [])) + result = scanner.scan_table(TableInfo(None, "default", "tb_1", [], [])) assert result is None