Skip to content

Commit

Permalink
Fixed some of the unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edurdevic committed Sep 26, 2023
1 parent 73b62b0 commit 7834b8e
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 21 deletions.
1 change: 1 addition & 0 deletions discoverx/msql.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def build(self, classified_result_pdf) -> list[SQLRow]:
row[1],
row[2],
[ColumnInfo(col[0], "", None, col[1]) for col in row[3]], # col name # TODO # TODO # Classes
[],
)
for _, row in df.iterrows()
if fnmatch(row[0], self.catalogs) and fnmatch(row[1], self.schemas) and fnmatch(row[2], self.tables)
Expand Down
1 change: 1 addition & 0 deletions discoverx/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def _get_list_of_tables(self) -> List[TableInfo]:
ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], [])
for col in row["table_columns"]
],
[],
)
for row in rows
]
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ def sample_datasets(spark: SparkSession, request):
f"CREATE TABLE IF NOT EXISTS default.columns USING delta LOCATION '{warehouse_dir}/columns' AS SELECT * FROM view_columns_mock"
)

# table_tags
test_file_path = module_path.parent / "data/table_tags.csv"
(
spark.read.option("header", True)
.schema("catalog_name string,schema_name string,table_name string,tag_name string,tag_value string")
.csv(str(test_file_path.resolve()))
).createOrReplaceTempView("table_tags_mock")
spark.sql(
f"CREATE TABLE IF NOT EXISTS default.table_tags USING delta LOCATION '{warehouse_dir}/table_tags' AS SELECT * FROM table_tags_mock"
)

logging.info("Sample datasets created")

yield None
Expand Down Expand Up @@ -213,7 +224,7 @@ def monkeymodule():
@pytest.fixture(autouse=True, scope="module")
def mock_uc_functionality(monkeymodule):
# apply the monkeypatch for the columns_table_name
monkeymodule.setattr(DX, "COLUMNS_TABLE_NAME", "default.columns_mock")
monkeymodule.setattr(DX, "COLUMNS_TABLE_NAME", "default.columns")
monkeymodule.setattr(DX, "INFORMATION_SCHEMA", "default")

# mock classifier method _get_classification_table_from_delta as we don't
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/data/table_tags.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
catalog_name,schema_name,table_name,tag_name,tag_value
hive_metastore,default,tb_all_types,int_col,my_int_tag
,default,tb_1,pk,
,default,tb_1,pii,true
4 changes: 2 additions & 2 deletions tests/unit/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
# sample_table_info = TableInfo("catalog1", "schema1", "table1", [])
@pytest.fixture()
def info_fetcher(spark):
return InfoFetcher(spark=spark, columns_table_name="default.columns_mock")
return InfoFetcher(spark=spark, information_schema="default")


@pytest.fixture()
def sample_table_info():
return TableInfo("catalog1", "schema1", "table1", [])
return TableInfo("catalog1", "schema1", "table1", [], [])


def test_validate_from_components():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/msql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def classification_df(spark) -> pd.DataFrame:
ColumnInfo("email_2", "string", None, ["dx_email"]),
ColumnInfo("date", "string", 1, ["dx_date_partition"]),
]
table_info = TableInfo("catalog", "prod_db1", "tb1", columns)
table_info = TableInfo("catalog", "prod_db1", "tb1", columns, [])


def test_msql_extracts_command():
Expand Down
31 changes: 14 additions & 17 deletions tests/unit/scanner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_get_table_list(spark):
ColumnInfo("interval_col", "INTERVAL", None, []),
ColumnInfo("str_part_col", "STRING", 1, []),
],
[],
)
]

Expand All @@ -48,7 +49,7 @@ def test_get_table_list(spark):
tables="*_all_types",
rule_filter="*",
sample_size=100,
columns_table_name="default.columns_mock",
columns_table_name="default.columns",
)
actual = scanner._get_list_of_tables()

Expand Down Expand Up @@ -124,13 +125,11 @@ def test_get_table_list(spark):
)
def test_generate_sql(spark, rules_input, expected):
columns = [ColumnInfo("id", "number", False, []), ColumnInfo("name", "string", False, [])]
table_info = TableInfo("meta", "db", "tb", columns)
table_info = TableInfo("meta", "db", "tb", columns, [])
rules = rules_input

rules = Rules(custom_rules=rules)
scanner = Scanner(
spark, rules=rules, rule_filter="any_*", sample_size=100, columns_table_name="default.columns_mock"
)
scanner = Scanner(spark, rules=rules, rule_filter="any_*", sample_size=100, columns_table_name="default.columns")

actual = scanner._rule_matching_sql(table_info)
logging.info("Generated SQL is: \n%s", actual)
Expand All @@ -144,16 +143,14 @@ def test_sql_runs(spark):
ColumnInfo("ip", "string", None, []),
ColumnInfo("description", "string", None, []),
]
table_info = TableInfo(None, "default", "tb_1", columns)
table_info = TableInfo(None, "default", "tb_1", columns, [])
rules = [
RegexRule(name="any_word", description="Any word", definition=r"\w+"),
RegexRule(name="any_number", description="Any number", definition=r"\d+"),
]

rules = Rules(custom_rules=rules)
scanner = Scanner(
spark, rules=rules, rule_filter="any_*", sample_size=100, columns_table_name="default.columns_mock"
)
scanner = Scanner(spark, rules=rules, rule_filter="any_*", sample_size=100, columns_table_name="default.columns")
actual = scanner._rule_matching_sql(table_info)

logging.info("Generated SQL is: \n%s", actual)
Expand Down Expand Up @@ -181,7 +178,7 @@ def test_scan_custom_rules(spark: SparkSession):
ColumnInfo("ip", "string", False, []),
ColumnInfo("description", "string", False, []),
]
table_list = [TableInfo(None, "default", "tb_1", columns)]
table_list = [TableInfo(None, "default", "tb_1", columns, [])]
rules = [
RegexRule(name="any_word", description="Any word", definition=r"^\w*$"),
RegexRule(name="any_number", description="Any number", definition=r"^\d*$"),
Expand All @@ -194,7 +191,7 @@ def test_scan_custom_rules(spark: SparkSession):
tables="tb_1",
rule_filter="any_*",
sample_size=100,
columns_table_name="default.columns_mock",
columns_table_name="default.columns",
)
scanner.scan()

Expand All @@ -217,7 +214,7 @@ def test_scan(spark: SparkSession):
)

rules = Rules()
scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns_mock")
scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns")
scanner.scan()

assert scanner.scan_result.df.equals(expected)
Expand All @@ -226,7 +223,7 @@ def test_scan(spark: SparkSession):
def test_save_scan(spark: SparkSession):
# save scan result
rules = Rules()
scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns_mock")
scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns")
scanner.scan()
scan_table_name = "_discoverx.scan_result_test"
scanner.scan_result.save(scan_table_name=scan_table_name)
Expand Down Expand Up @@ -261,18 +258,18 @@ def test_save_scan(spark: SparkSession):
def test_scan_non_existing_table_returns_none(spark: SparkSession):
rules = Rules()

scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns_mock")
result = scanner.scan_table(TableInfo("", "", "tb_non_existing", []))
scanner = Scanner(spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns")
result = scanner.scan_table(TableInfo("", "", "tb_non_existing", [], []))

assert result is None


def test_scan_whatif_returns_none(spark: SparkSession):
rules = Rules()
scanner = Scanner(
spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns_mock", what_if=True
spark, rules=rules, tables="tb_1", rule_filter="ip_*", columns_table_name="default.columns", what_if=True
)
result = scanner.scan_table(TableInfo(None, "default", "tb_1", []))
result = scanner.scan_table(TableInfo(None, "default", "tb_1", [], []))

assert result is None

Expand Down

0 comments on commit 7834b8e

Please sign in to comment.