diff --git a/discoverx/explorer.py b/discoverx/explorer.py index 528c2e0..76c58a2 100644 --- a/discoverx/explorer.py +++ b/discoverx/explorer.py @@ -1,7 +1,8 @@ import concurrent.futures import copy import re -from typing import Optional, List +import itertools as it +from typing import Optional, List, Callable, Any from discoverx import logging from discoverx.common import helper from discoverx.discovery import Discovery @@ -165,7 +166,7 @@ def scan( discover.scan(rules=rules, sample_size=sample_size, what_if=what_if) return discover - def map(self, f) -> list[any]: + def map(self, f: Callable[[TableInfo], Any]) -> list[Any]: """Runs a function for each table in the data explorer Args: @@ -197,6 +198,42 @@ def map(self, f) -> list[any]: return res + def map_chunked(self, f: Callable[[TableInfo, int], Any], tables_per_chunk: int, **kwargs) -> list[Any]: + """Runs a function for each table in the data explorer + + Args: + f (function): The function to run. The function should accept either a list of TableInfo objects as input and return a list of any object as output. + + Returns: + list[any]: A list of the results of running the function for each table + """ + res = [] + table_list = self._info_fetcher.get_tables_info( + self._catalogs, + self._schemas, + self._tables, + self._having_columns, + self._with_tags, + ) + with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_concurrency) as executor: + # Submit tasks to the thread pool + table_list = iter(table_list) + futures = [] + while item := [ + executor.submit(f, table_chunk, **kwargs) for table_chunk in it.islice(table_list, tables_per_chunk) + ]: + futures.extend(item) + + # Process completed tasks + for future in concurrent.futures.as_completed(futures): + result = future.result() + if result is not None: + res.append(result) + + logger.debug("Finished lakehouse map_chunked task") + + return res + class DataExplorerActions: def __init__( diff --git a/tests/unit/explorer_test.py b/tests/unit/explorer_test.py index ac6d518..1f943cd 100644 --- a/tests/unit/explorer_test.py +++ b/tests/unit/explorer_test.py @@ -75,6 +75,42 @@ def test_map(spark, info_fetcher): assert result[0].tags == None +def test_map_chunked_1(spark, info_fetcher): + data_explorer = DataExplorer("*.default.tb_1", spark, info_fetcher) + result = data_explorer.map_chunked(lambda table_info: table_info, 10) + assert len(result) == 1 + assert result[0].table == "tb_1" + assert result[0].schema == "default" + assert result[0].catalog == None + assert result[0].tags == None + + +def test_map_chunked_2(spark, info_fetcher): + def check_result(result): + for res in result: + assert res.table in ["tb_1", "tb_2", "tb_all_types"] + if res.table == "tb_1": + assert res.schema == "default" + assert res.catalog == None + assert res.tags == None + elif res.table == "tb_2": + assert res.schema == "default" + assert res.catalog == None + assert res.tags == None + else: + assert res.schema == "default" + assert res.catalog == "hive_metastore" + assert res.tags == None + + data_explorer = DataExplorer("*.default.*", spark, info_fetcher) + result = data_explorer.map_chunked(lambda table_info: table_info, 10) + assert len(result) == 3 + check_result(result) + result2 = data_explorer.map_chunked(lambda table_info: table_info, 2) + assert len(result2) == 3 + check_result(result2) + + def test_map_with_tags(spark, info_fetcher): data_explorer = DataExplorer("*.default.tb_1", spark, info_fetcher).with_tags() result = data_explorer.map(lambda table_info: table_info)