From 6ebe98226e3345c1cb089eea4ffbdb365f15f70e Mon Sep 17 00:00:00 2001 From: Lorenzo Rubio Date: Wed, 3 Jan 2024 10:47:19 +0100 Subject: [PATCH 1/3] map_chunked initial implementation --- discoverx/explorer.py | 38 ++++++++++++++++++++++++++++++++++++-- setup.py | 1 + 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/discoverx/explorer.py b/discoverx/explorer.py index 528c2e0..0455659 100644 --- a/discoverx/explorer.py +++ b/discoverx/explorer.py @@ -1,7 +1,8 @@ import concurrent.futures import copy import re -from typing import Optional, List +import more_itertools +from typing import Optional, List, Callable from discoverx import logging from discoverx.common import helper from discoverx.discovery import Discovery @@ -165,7 +166,7 @@ def scan( discover.scan(rules=rules, sample_size=sample_size, what_if=what_if) return discover - def map(self, f) -> list[any]: + def map(self, f: Callable) -> list[any]: """Runs a function for each table in the data explorer Args: @@ -197,6 +198,39 @@ def map(self, f) -> list[any]: return res + def map_chunked(self, f: Callable, tables_per_chunk: int, **kwargs) -> list[any]: + """Runs a function for each table in the data explorer + + Args: + f (function): The function to run. The function should accept either a list of TableInfo objects as input and return a list of any object as output. + + Returns: + list[any]: A list of the results of running the function for each table + """ + res = [] + table_list = self._info_fetcher.get_tables_info( + self._catalogs, + self._schemas, + self._tables, + self._having_columns, + self._with_tags, + ) + with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_concurrency) as executor: + # Submit tasks to the thread pool + futures = [ + executor.submit(f, table_chunk, **kwargs) for table_chunk in more_itertools.chunked(table_list, tables_per_chunk) + ] + + # Process completed tasks + for future in concurrent.futures.as_completed(futures): + result = future.result() + if result is not None: + res.extend(result) + + logger.debug("Finished lakehouse map_chunked task") + + return res + class DataExplorerActions: def __init__( diff --git a/setup.py b/setup.py index 9233b4d..0e4d73f 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ "delta-spark>=2.2.0", "pandas<2.0.0", # From 2.0.0 onwards, pandas does not support iteritems() anymore, spark.createDataFrame will fail "numpy<1.24", # From 1.24 onwards, module 'numpy' has no attribute 'bool'. + "more_itertools", ] TEST_REQUIREMENTS = [ From 91355c2b8f22669335e807222106c0ac5d651b8e Mon Sep 17 00:00:00 2001 From: Lorenzo Rubio Date: Wed, 3 Jan 2024 11:34:08 +0100 Subject: [PATCH 2/3] added tests for map_chunked --- tests/unit/explorer_test.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/unit/explorer_test.py b/tests/unit/explorer_test.py index ac6d518..5776fbc 100644 --- a/tests/unit/explorer_test.py +++ b/tests/unit/explorer_test.py @@ -75,6 +75,38 @@ def test_map(spark, info_fetcher): assert result[0].tags == None +def test_map_chunked_1(spark, info_fetcher): + data_explorer = DataExplorer("*.default.tb_1", spark, info_fetcher) + result = data_explorer.map_chunked(lambda table_info: table_info, 10) + assert len(result) == 1 + assert result[0].table == "tb_1" + assert result[0].schema == "default" + assert result[0].catalog == None + assert result[0].tags == None + + +def test_map_chunked_2(spark, info_fetcher): + data_explorer = DataExplorer("*.default.*", spark, info_fetcher) + result = data_explorer.map_chunked(lambda table_info: table_info, 10) + assert len(result) == 3 + for res in result: + assert res.table in ["tb_1", "tb_2", "tb_all_types"] + if res.table == "tb_1": + assert res.schema == "default" + assert res.catalog == None + assert res.tags == None + elif res.table == "tb_2": + assert res.schema == "default" + assert res.catalog == None + assert res.tags == None + else: + assert res.schema == "default" + assert res.catalog == "hive_metastore" + assert res.tags == None + result2 = data_explorer.map_chunked(lambda table_info: table_info, 2) + assert result2 == result + + def test_map_with_tags(spark, info_fetcher): data_explorer = DataExplorer("*.default.tb_1", spark, info_fetcher).with_tags() result = data_explorer.map(lambda table_info: table_info) From b45c4edf18d57d29304c683910de43f10011ed28 Mon Sep 17 00:00:00 2001 From: Lorenzo Rubio Date: Sat, 3 Feb 2024 21:52:47 +0100 Subject: [PATCH 3/3] do not use more_itertools + improve function hints --- discoverx/explorer.py | 19 +++++++++++-------- setup.py | 1 - tests/unit/explorer_test.py | 34 +++++++++++++++++++--------------- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/discoverx/explorer.py b/discoverx/explorer.py index 0455659..76c58a2 100644 --- a/discoverx/explorer.py +++ b/discoverx/explorer.py @@ -1,8 +1,8 @@ import concurrent.futures import copy import re -import more_itertools -from typing import Optional, List, Callable +import itertools as it +from typing import Optional, List, Callable, Any from discoverx import logging from discoverx.common import helper from discoverx.discovery import Discovery @@ -166,7 +166,7 @@ def scan( discover.scan(rules=rules, sample_size=sample_size, what_if=what_if) return discover - def map(self, f: Callable) -> list[any]: + def map(self, f: Callable[[TableInfo], Any]) -> list[Any]: """Runs a function for each table in the data explorer Args: @@ -198,7 +198,7 @@ def map(self, f: Callable) -> list[any]: return res - def map_chunked(self, f: Callable, tables_per_chunk: int, **kwargs) -> list[any]: + def map_chunked(self, f: Callable[[TableInfo, int], Any], tables_per_chunk: int, **kwargs) -> list[Any]: """Runs a function for each table in the data explorer Args: @@ -217,15 +217,18 @@ def map_chunked(self, f: Callable, tables_per_chunk: int, **kwargs) -> list[any] ) with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_concurrency) as executor: # Submit tasks to the thread pool - futures = [ - executor.submit(f, table_chunk, **kwargs) for table_chunk in more_itertools.chunked(table_list, tables_per_chunk) - ] + table_list = iter(table_list) + futures = [] + while item := [ + executor.submit(f, table_chunk, **kwargs) for table_chunk in it.islice(table_list, tables_per_chunk) + ]: + futures.extend(item) # Process completed tasks for future in concurrent.futures.as_completed(futures): result = future.result() if result is not None: - res.extend(result) + res.append(result) logger.debug("Finished lakehouse map_chunked task") diff --git a/setup.py b/setup.py index 0e4d73f..9233b4d 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,6 @@ "delta-spark>=2.2.0", "pandas<2.0.0", # From 2.0.0 onwards, pandas does not support iteritems() anymore, spark.createDataFrame will fail "numpy<1.24", # From 1.24 onwards, module 'numpy' has no attribute 'bool'. - "more_itertools", ] TEST_REQUIREMENTS = [ diff --git a/tests/unit/explorer_test.py b/tests/unit/explorer_test.py index 5776fbc..1f943cd 100644 --- a/tests/unit/explorer_test.py +++ b/tests/unit/explorer_test.py @@ -86,25 +86,29 @@ def test_map_chunked_1(spark, info_fetcher): def test_map_chunked_2(spark, info_fetcher): + def check_result(result): + for res in result: + assert res.table in ["tb_1", "tb_2", "tb_all_types"] + if res.table == "tb_1": + assert res.schema == "default" + assert res.catalog == None + assert res.tags == None + elif res.table == "tb_2": + assert res.schema == "default" + assert res.catalog == None + assert res.tags == None + else: + assert res.schema == "default" + assert res.catalog == "hive_metastore" + assert res.tags == None + data_explorer = DataExplorer("*.default.*", spark, info_fetcher) result = data_explorer.map_chunked(lambda table_info: table_info, 10) assert len(result) == 3 - for res in result: - assert res.table in ["tb_1", "tb_2", "tb_all_types"] - if res.table == "tb_1": - assert res.schema == "default" - assert res.catalog == None - assert res.tags == None - elif res.table == "tb_2": - assert res.schema == "default" - assert res.catalog == None - assert res.tags == None - else: - assert res.schema == "default" - assert res.catalog == "hive_metastore" - assert res.tags == None + check_result(result) result2 = data_explorer.map_chunked(lambda table_info: table_info, 2) - assert result2 == result + assert len(result2) == 3 + check_result(result2) def test_map_with_tags(spark, info_fetcher):