Skip to content

Commit

Permalink
Added docs for map function and refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
edurdevic committed Oct 15, 2023
1 parent 19a8e8e commit 12a311c
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 41 deletions.
27 changes: 6 additions & 21 deletions discoverx/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,29 +148,14 @@ def scan(
return discover

def map(self, f) -> list[any]:
res = []
table_list = self._info_fetcher.get_tables_info(
self._catalogs,
self._schemas,
self._tables,
self._having_columns,
self._with_tags,
)
with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_concurrency) as executor:
# Submit tasks to the thread pool
futures = [executor.submit(f, table_info) for table_info in table_list]
"""Runs a function for each table in the data explorer
# Process completed tasks
for future in concurrent.futures.as_completed(futures):
result = future.result()
if result is not None:
res.append(result)

logger.debug("Finished lakehouse map task")

return res
Args:
f (function): The function to run. The function should accept a TableInfo object as input and return any object as output.
def map(self, f) -> list[any]:
Returns:
list[any]: A list of the results of running the function for each table
"""
res = []
table_list = self._info_fetcher.get_tables_info(
self._catalogs,
Expand Down
2 changes: 1 addition & 1 deletion discoverx/msql.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def build(self, classified_result_pdf) -> list[SQLRow]:
row[1],
row[2],
[ColumnInfo(col[0], "", None, col[1]) for col in row[3]], # col name # TODO # TODO # Classes
[],
None,
)
for _, row in df.iterrows()
if fnmatch(row[0], self.catalogs) and fnmatch(row[1], self.schemas) and fnmatch(row[2], self.tables)
Expand Down
2 changes: 1 addition & 1 deletion discoverx/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _get_list_of_tables(self) -> List[TableInfo]:
ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], [])
for col in row["table_columns"]
],
[],
None,
)
for row in rows
]
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ def info_fetcher(spark):

@pytest.fixture()
def sample_table_info():
return TableInfo("catalog1", "schema1", "table1", [], [])
return TableInfo("catalog1", "schema1", "table1", [], None)


def test_validate_from_components():
with pytest.raises(ValueError):
DataExplorer.validate_from_components("invalid_format")
assert DataExplorer.validate_from_components("catalog1.schema1.table1") == ("catalog1", "schema1", "table1")
assert DataExplorer.validate_from_components("catalog1.schema1.table1") == (
"catalog1",
"schema1",
"table1",
)


def test_build_sql(sample_table_info):
Expand Down
67 changes: 57 additions & 10 deletions tests/unit/msql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,24 @@ def classification_df(spark) -> pd.DataFrame:
["c", "db", "tb2", "email_3", "dx_email"],
["c", "db", "tb2", "date", "dx_date_partition"],
["c", "db2", "tb3", "email_4", "dx_email"],
["c", "db", "tb1", "description", "any_number"], # any_number not in the class list
[
"c",
"db",
"tb1",
"description",
"any_number",
], # any_number not in the class list
["m_c", "db", "tb1", "email_3", "dx_email"], # catalog does not match
["c", "m_db", "tb1", "email_4", "dx_email"], # schema does not match
["c", "db", "m_tb1", "email_5", "dx_email"], # table does not match
],
columns=["table_catalog", "table_schema", "table_name", "column_name", "class_name"],
columns=[
"table_catalog",
"table_schema",
"table_name",
"column_name",
"class_name",
],
)


Expand All @@ -39,7 +51,7 @@ def classification_df(spark) -> pd.DataFrame:
ColumnInfo("email_2", "string", None, ["dx_email"]),
ColumnInfo("date", "string", 1, ["dx_date_partition"]),
]
table_info = TableInfo("catalog", "prod_db1", "tb1", columns, [])
table_info = TableInfo("catalog", "prod_db1", "tb1", columns, None)


def test_msql_extracts_command():
Expand Down Expand Up @@ -69,7 +81,12 @@ def test_msql_validates_command():
def test_msql_replace_from_clausole():
msql = "SELECT [dx_pii] AS dx_pii FROM *.*.*"

expected = SQLRow("catalog", "prod_db1", "tb1", "SELECT email_1 AS dx_pii FROM catalog.prod_db1.tb1")
expected = SQLRow(
"catalog",
"prod_db1",
"tb1",
"SELECT email_1 AS dx_pii FROM catalog.prod_db1.tb1",
)

actual = Msql(msql).compile_msql(table_info)
assert len(actual) == 1
Expand All @@ -91,8 +108,18 @@ def test_msql_select_repeated_class():

actual = Msql(msql).compile_msql(table_info)
assert len(actual) == 2
assert actual[0] == SQLRow("catalog", "prod_db1", "tb1", "SELECT email_1 AS email FROM catalog.prod_db1.tb1")
assert actual[1] == SQLRow("catalog", "prod_db1", "tb1", "SELECT email_2 AS email FROM catalog.prod_db1.tb1")
assert actual[0] == SQLRow(
"catalog",
"prod_db1",
"tb1",
"SELECT email_1 AS email FROM catalog.prod_db1.tb1",
)
assert actual[1] == SQLRow(
"catalog",
"prod_db1",
"tb1",
"SELECT email_2 AS email FROM catalog.prod_db1.tb1",
)


def test_msql_select_multi_class():
Expand Down Expand Up @@ -197,8 +224,18 @@ def test_msql_delete_command():

actual = Msql(msql).compile_msql(table_info)
assert len(actual) == 2
assert actual[0] == SQLRow("catalog", "prod_db1", "tb1", "DELETE FROM catalog.prod_db1.tb1 WHERE email_1 = '[email protected]'")
assert actual[1] == SQLRow("catalog", "prod_db1", "tb1", "DELETE FROM catalog.prod_db1.tb1 WHERE email_2 = '[email protected]'")
assert actual[0] == SQLRow(
"catalog",
"prod_db1",
"tb1",
"DELETE FROM catalog.prod_db1.tb1 WHERE email_1 = '[email protected]'",
)
assert actual[1] == SQLRow(
"catalog",
"prod_db1",
"tb1",
"DELETE FROM catalog.prod_db1.tb1 WHERE email_2 = '[email protected]'",
)


def test_execute_sql_rows(spark):
Expand All @@ -214,7 +251,12 @@ def test_execute_sql_rows_should_not_fail(spark):
msql = Msql("SELECT description FROM *.*.* ")
sql_rows = [
SQLRow(None, "default", "tb_1", "SELECT description FROM default.tb_1"),
SQLRow(None, "default", "non_existent_table", "SELECT description FROM default.non_existent_table"),
SQLRow(
None,
"default",
"non_existent_table",
"SELECT description FROM default.non_existent_table",
),
]
df = msql.execute_sql_rows(sqls=sql_rows, spark=spark)
assert df.count() == 2
Expand All @@ -234,7 +276,12 @@ def test_execute_sql_should_fail_for_no_successful_queries(spark):
msql = Msql("SELECT description FROM *.*.* ")
sql_rows = [
SQLRow(None, "default", "tb_1", "SELECT non_existent_column FROM default.tb_1"), # Column does not exist
SQLRow(None, "default", "non_existent_table_2", "SELECT description FROM default.non_existent_table_2"),
SQLRow(
None,
"default",
"non_existent_table_2",
"SELECT description FROM default.non_existent_table_2",
),
]
with pytest.raises(ValueError):
df = msql.execute_sql_rows(sqls=sql_rows, spark=spark)
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/scanner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_get_table_list(spark):
ColumnInfo("interval_col", "INTERVAL", None, []),
ColumnInfo("str_part_col", "STRING", 1, []),
],
[],
None,
)
]

Expand Down Expand Up @@ -132,7 +132,7 @@ def test_generate_sql(spark, rules_input, expected):
ColumnInfo("id", "number", False, []),
ColumnInfo("name", "string", False, []),
]
table_info = TableInfo("meta", "db", "tb", columns, [])
table_info = TableInfo("meta", "db", "tb", columns, None)
rules = rules_input

rules = Rules(custom_rules=rules)
Expand All @@ -156,7 +156,7 @@ def test_sql_runs(spark):
ColumnInfo("ip", "string", None, []),
ColumnInfo("description", "string", None, []),
]
table_info = TableInfo(None, "default", "tb_1", columns, [])
table_info = TableInfo(None, "default", "tb_1", columns, None)
rules = [
RegexRule(name="any_word", description="Any word", definition=r"\w+"),
RegexRule(name="any_number", description="Any number", definition=r"\d+"),
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_scan_custom_rules(spark: SparkSession):
ColumnInfo("ip", "string", False, []),
ColumnInfo("description", "string", False, []),
]
table_list = [TableInfo(None, "default", "tb_1", columns, [])]
table_list = [TableInfo(None, "default", "tb_1", columns, None)]
rules = [
RegexRule(name="any_word", description="Any word", definition=r"^\w*$"),
RegexRule(name="any_number", description="Any number", definition=r"^\d*$"),
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_scan_non_existing_table_returns_none(spark: SparkSession):
rule_filter="ip_*",
information_schema="default",
)
result = scanner.scan_table(TableInfo("", "", "tb_non_existing", [], []))
result = scanner.scan_table(TableInfo("", "", "tb_non_existing", [], None))

assert result is None

Expand All @@ -332,7 +332,7 @@ def test_scan_whatif_returns_none(spark: SparkSession):
information_schema="default",
what_if=True,
)
result = scanner.scan_table(TableInfo(None, "default", "tb_1", [], []))
result = scanner.scan_table(TableInfo(None, "default", "tb_1", [], None))

assert result is None

Expand Down

0 comments on commit 12a311c

Please sign in to comment.