diff --git a/discoverx/explorer.py b/discoverx/explorer.py index 3845fde..8e79074 100644 --- a/discoverx/explorer.py +++ b/discoverx/explorer.py @@ -32,6 +32,7 @@ def __init__(self, from_tables, spark: SparkSession, info_fetcher: InfoFetcher) self._sql_query_template = None self._max_concurrency = 10 self._with_tags = False + self._having_tags = [] @staticmethod def validate_from_components(from_tables: str): @@ -70,6 +71,19 @@ def having_columns(self, *columns) -> "DataExplorer": new_obj._having_columns.extend(columns) return new_obj + def having_tag(self, tag_name: str, tag_value: str = None) -> "DataExplorer": + """Will select tables tagged with the provided tag name and optionally value + either at table, schema, or catalog level. + + Args: + tag_name (str): Tag name + tag_value (str, optional): Tag value. Defaults to None. + """ + new_obj = copy.deepcopy(self) + new_obj._having_tags.extend(TagInfo(tag_name, tag_value)) + new_obj._with_tags = True + return new_obj + def with_concurrency(self, max_concurrency) -> "DataExplorer": """Sets the maximum number of concurrent queries to run""" new_obj = copy.deepcopy(self) @@ -140,7 +154,9 @@ def scan( self._catalogs, self._schemas, self._tables, - self._info_fetcher.get_tables_info(self._catalogs, self._schemas, self._tables, self._having_columns), + self._info_fetcher.get_tables_info( + self._catalogs, self._schemas, self._tables, self._having_columns, self._having_tags + ), custom_rules=custom_rules, locale=locale, ) @@ -163,6 +179,7 @@ def map(self, f) -> list[any]: self._tables, self._having_columns, self._with_tags, + self._having_tags, ) with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_concurrency) as executor: # Submit tasks to the thread pool @@ -244,6 +261,7 @@ def _get_sql_commands(self, data_explorer: DataExplorer) -> list[tuple[str, Tabl data_explorer._tables, data_explorer._having_columns, data_explorer._with_tags, + data_explorer._having_tags, ) sql_commands = [ ( diff --git a/discoverx/table_info.py b/discoverx/table_info.py index 86e7f8f..837d982 100644 --- a/discoverx/table_info.py +++ b/discoverx/table_info.py @@ -111,6 +111,7 @@ def get_tables_info( tables: str, columns: list[str] = [], with_tags=False, + having_tags=[], ) -> list[TableInfo]: # Filter tables by matching filter table_list_sql = self._get_table_list_sql(catalogs, schemas, tables, columns, with_tags) @@ -120,7 +121,28 @@ def get_tables_info( if len(filtered_tables) == 0: raise ValueError(f"No tables found matching filter: {catalogs}.{schemas}.{tables}") - return self._to_info_list(filtered_tables) + info_list = self._to_info_list(filtered_tables) + return [info for info in info_list if InfoFetcher._contains_all_tags(info.tags, having_tags)] + + @staticmethod + def _contains_all_tags(tags_info: TagsInfo, tags: list[TagInfo]) -> bool: + if not tags_info: + return False + if not tags: + return True + + all_tags = [] + + if tags_info.catalog_tags: + all_tags.extend(tags_info.catalog_tags) + + if tags_info.schema_tags: + all_tags.extend(tags_info.schema_tags) + + if tags_info.table_tags: + all_tags.extend(tags_info.table_tags) + + return all([tag in all_tags for tag in tags]) def _get_table_list_sql( self, diff --git a/tests/unit/table_info_test.py b/tests/unit/table_info_test.py new file mode 100644 index 0000000..0bf2328 --- /dev/null +++ b/tests/unit/table_info_test.py @@ -0,0 +1,23 @@ +import pytest +from discoverx.explorer import InfoFetcher, TagsInfo, TagInfo + + +def test_validate_from_components(): + info_table = TagsInfo([], [TagInfo("a", "v1")], [], []) + info_schema = TagsInfo([], [], [TagInfo("a", "v1")], []) + info_catalog = TagsInfo([], [], [], [TagInfo("a", "v1")]) + info_no_tags = TagsInfo([], [], [], []) + + assert InfoFetcher._contains_all_tags(info_table, [TagInfo("a", "v1")]) + assert not InfoFetcher._contains_all_tags(info_table, [TagInfo("a", "v2")]) + assert not InfoFetcher._contains_all_tags(info_table, [TagInfo("b", "v1")]) + assert not InfoFetcher._contains_all_tags(info_table, [TagInfo("a", None)]) + # If no tags to check, then it should be true + assert InfoFetcher._contains_all_tags(info_table, []) + + assert InfoFetcher._contains_all_tags(info_schema, [TagInfo("a", "v1")]) + + assert InfoFetcher._contains_all_tags(info_catalog, [TagInfo("a", "v1")]) + + assert InfoFetcher._contains_all_tags(info_no_tags, []) + assert not InfoFetcher._contains_all_tags(info_no_tags, [TagInfo("a", "v1")])