diff --git a/discoverx/explorer.py b/discoverx/explorer.py index c31464b..861f719 100644 --- a/discoverx/explorer.py +++ b/discoverx/explorer.py @@ -148,29 +148,14 @@ def scan( return discover def map(self, f) -> list[any]: - res = [] - table_list = self._info_fetcher.get_tables_info( - self._catalogs, - self._schemas, - self._tables, - self._having_columns, - self._with_tags, - ) - with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_concurrency) as executor: - # Submit tasks to the thread pool - futures = [executor.submit(f, table_info) for table_info in table_list] + """Runs a function for each table in the data explorer - # Process completed tasks - for future in concurrent.futures.as_completed(futures): - result = future.result() - if result is not None: - res.append(result) - - logger.debug("Finished lakehouse map task") - - return res + Args: + f (function): The function to run. The function should accept a TableInfo object as input and return any object as output. - def map(self, f) -> list[any]: + Returns: + list[any]: A list of the results of running the function for each table + """ res = [] table_list = self._info_fetcher.get_tables_info( self._catalogs, diff --git a/discoverx/msql.py b/discoverx/msql.py index 43a093c..50cee25 100644 --- a/discoverx/msql.py +++ b/discoverx/msql.py @@ -97,7 +97,7 @@ def build(self, classified_result_pdf) -> list[SQLRow]: row[1], row[2], [ColumnInfo(col[0], "", None, col[1]) for col in row[3]], # col name # TODO # TODO # Classes - [], + None, ) for _, row in df.iterrows() if fnmatch(row[0], self.catalogs) and fnmatch(row[1], self.schemas) and fnmatch(row[2], self.tables) diff --git a/discoverx/scanner.py b/discoverx/scanner.py index 659e2e3..f13c528 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -170,7 +170,7 @@ def _get_list_of_tables(self) -> List[TableInfo]: ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], []) for col in row["table_columns"] ], - [], + None, ) for row in rows ] diff --git a/tests/unit/explorer_test.py b/tests/unit/explorer_test.py index f22d4ae..8475599 100644 --- a/tests/unit/explorer_test.py +++ b/tests/unit/explorer_test.py @@ -11,13 +11,17 @@ def info_fetcher(spark): @pytest.fixture() def sample_table_info(): - return TableInfo("catalog1", "schema1", "table1", [], []) + return TableInfo("catalog1", "schema1", "table1", [], None) def test_validate_from_components(): with pytest.raises(ValueError): DataExplorer.validate_from_components("invalid_format") - assert DataExplorer.validate_from_components("catalog1.schema1.table1") == ("catalog1", "schema1", "table1") + assert DataExplorer.validate_from_components("catalog1.schema1.table1") == ( + "catalog1", + "schema1", + "table1", + ) def test_build_sql(sample_table_info): diff --git a/tests/unit/msql_test.py b/tests/unit/msql_test.py index 76e814d..c0b64c1 100644 --- a/tests/unit/msql_test.py +++ b/tests/unit/msql_test.py @@ -24,12 +24,24 @@ def classification_df(spark) -> pd.DataFrame: ["c", "db", "tb2", "email_3", "dx_email"], ["c", "db", "tb2", "date", "dx_date_partition"], ["c", "db2", "tb3", "email_4", "dx_email"], - ["c", "db", "tb1", "description", "any_number"], # any_number not in the class list + [ + "c", + "db", + "tb1", + "description", + "any_number", + ], # any_number not in the class list ["m_c", "db", "tb1", "email_3", "dx_email"], # catalog does not match ["c", "m_db", "tb1", "email_4", "dx_email"], # schema does not match ["c", "db", "m_tb1", "email_5", "dx_email"], # table does not match ], - columns=["table_catalog", "table_schema", "table_name", "column_name", "class_name"], + columns=[ + "table_catalog", + "table_schema", + "table_name", + "column_name", + "class_name", + ], ) @@ -39,7 +51,7 @@ def classification_df(spark) -> pd.DataFrame: ColumnInfo("email_2", "string", None, ["dx_email"]), ColumnInfo("date", "string", 1, ["dx_date_partition"]), ] -table_info = TableInfo("catalog", "prod_db1", "tb1", columns, []) +table_info = TableInfo("catalog", "prod_db1", "tb1", columns, None) def test_msql_extracts_command(): @@ -69,7 +81,12 @@ def test_msql_validates_command(): def test_msql_replace_from_clausole(): msql = "SELECT [dx_pii] AS dx_pii FROM *.*.*" - expected = SQLRow("catalog", "prod_db1", "tb1", "SELECT email_1 AS dx_pii FROM catalog.prod_db1.tb1") + expected = SQLRow( + "catalog", + "prod_db1", + "tb1", + "SELECT email_1 AS dx_pii FROM catalog.prod_db1.tb1", + ) actual = Msql(msql).compile_msql(table_info) assert len(actual) == 1 @@ -91,8 +108,18 @@ def test_msql_select_repeated_class(): actual = Msql(msql).compile_msql(table_info) assert len(actual) == 2 - assert actual[0] == SQLRow("catalog", "prod_db1", "tb1", "SELECT email_1 AS email FROM catalog.prod_db1.tb1") - assert actual[1] == SQLRow("catalog", "prod_db1", "tb1", "SELECT email_2 AS email FROM catalog.prod_db1.tb1") + assert actual[0] == SQLRow( + "catalog", + "prod_db1", + "tb1", + "SELECT email_1 AS email FROM catalog.prod_db1.tb1", + ) + assert actual[1] == SQLRow( + "catalog", + "prod_db1", + "tb1", + "SELECT email_2 AS email FROM catalog.prod_db1.tb1", + ) def test_msql_select_multi_class(): @@ -197,8 +224,18 @@ def test_msql_delete_command(): actual = Msql(msql).compile_msql(table_info) assert len(actual) == 2 - assert actual[0] == SQLRow("catalog", "prod_db1", "tb1", "DELETE FROM catalog.prod_db1.tb1 WHERE email_1 = 'a@b.c'") - assert actual[1] == SQLRow("catalog", "prod_db1", "tb1", "DELETE FROM catalog.prod_db1.tb1 WHERE email_2 = 'a@b.c'") + assert actual[0] == SQLRow( + "catalog", + "prod_db1", + "tb1", + "DELETE FROM catalog.prod_db1.tb1 WHERE email_1 = 'a@b.c'", + ) + assert actual[1] == SQLRow( + "catalog", + "prod_db1", + "tb1", + "DELETE FROM catalog.prod_db1.tb1 WHERE email_2 = 'a@b.c'", + ) def test_execute_sql_rows(spark): @@ -214,7 +251,12 @@ def test_execute_sql_rows_should_not_fail(spark): msql = Msql("SELECT description FROM *.*.* ") sql_rows = [ SQLRow(None, "default", "tb_1", "SELECT description FROM default.tb_1"), - SQLRow(None, "default", "non_existent_table", "SELECT description FROM default.non_existent_table"), + SQLRow( + None, + "default", + "non_existent_table", + "SELECT description FROM default.non_existent_table", + ), ] df = msql.execute_sql_rows(sqls=sql_rows, spark=spark) assert df.count() == 2 @@ -234,7 +276,12 @@ def test_execute_sql_should_fail_for_no_successful_queries(spark): msql = Msql("SELECT description FROM *.*.* ") sql_rows = [ SQLRow(None, "default", "tb_1", "SELECT non_existent_column FROM default.tb_1"), # Column does not exist - SQLRow(None, "default", "non_existent_table_2", "SELECT description FROM default.non_existent_table_2"), + SQLRow( + None, + "default", + "non_existent_table_2", + "SELECT description FROM default.non_existent_table_2", + ), ] with pytest.raises(ValueError): df = msql.execute_sql_rows(sqls=sql_rows, spark=spark) diff --git a/tests/unit/scanner_test.py b/tests/unit/scanner_test.py index d7e26b2..f469d45 100644 --- a/tests/unit/scanner_test.py +++ b/tests/unit/scanner_test.py @@ -37,7 +37,7 @@ def test_get_table_list(spark): ColumnInfo("interval_col", "INTERVAL", None, []), ColumnInfo("str_part_col", "STRING", 1, []), ], - [], + None, ) ] @@ -132,7 +132,7 @@ def test_generate_sql(spark, rules_input, expected): ColumnInfo("id", "number", False, []), ColumnInfo("name", "string", False, []), ] - table_info = TableInfo("meta", "db", "tb", columns, []) + table_info = TableInfo("meta", "db", "tb", columns, None) rules = rules_input rules = Rules(custom_rules=rules) @@ -156,7 +156,7 @@ def test_sql_runs(spark): ColumnInfo("ip", "string", None, []), ColumnInfo("description", "string", None, []), ] - table_info = TableInfo(None, "default", "tb_1", columns, []) + table_info = TableInfo(None, "default", "tb_1", columns, None) rules = [ RegexRule(name="any_word", description="Any word", definition=r"\w+"), RegexRule(name="any_number", description="Any number", definition=r"\d+"), @@ -204,7 +204,7 @@ def test_scan_custom_rules(spark: SparkSession): ColumnInfo("ip", "string", False, []), ColumnInfo("description", "string", False, []), ] - table_list = [TableInfo(None, "default", "tb_1", columns, [])] + table_list = [TableInfo(None, "default", "tb_1", columns, None)] rules = [ RegexRule(name="any_word", description="Any word", definition=r"^\w*$"), RegexRule(name="any_number", description="Any number", definition=r"^\d*$"), @@ -317,7 +317,7 @@ def test_scan_non_existing_table_returns_none(spark: SparkSession): rule_filter="ip_*", information_schema="default", ) - result = scanner.scan_table(TableInfo("", "", "tb_non_existing", [], [])) + result = scanner.scan_table(TableInfo("", "", "tb_non_existing", [], None)) assert result is None @@ -332,7 +332,7 @@ def test_scan_whatif_returns_none(spark: SparkSession): information_schema="default", what_if=True, ) - result = scanner.scan_table(TableInfo(None, "default", "tb_1", [], [])) + result = scanner.scan_table(TableInfo(None, "default", "tb_1", [], None)) assert result is None