From e96fe2eee59385c4ae5eefe320f5c7ca760945aa Mon Sep 17 00:00:00 2001 From: souravg-db Date: Thu, 28 Dec 2023 11:30:49 +0000 Subject: [PATCH] Added changes to filter data on data source formats --- discoverx/dx.py | 2 ++ discoverx/explorer.py | 16 +++++++++++++++- discoverx/scanner.py | 6 +++++- discoverx/table_info.py | 11 +++++++++-- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/discoverx/dx.py b/discoverx/dx.py index 01d02cc..2eaa97c 100644 --- a/discoverx/dx.py +++ b/discoverx/dx.py @@ -117,6 +117,7 @@ def scan( rules="*", sample_size=10000, what_if: bool = False, + data_source_formats: list[str] = ["DELTA"], ): """Scans the lakehouse for columns matching the given rules @@ -139,6 +140,7 @@ def scan( what_if=what_if, information_schema=self.INFORMATION_SCHEMA, max_workers=self.MAX_WORKERS, + data_source_formats=data_source_formats, ) self._scan_result = self.scanner.scan() diff --git a/discoverx/explorer.py b/discoverx/explorer.py index 706ce5d..b4e4bc9 100644 --- a/discoverx/explorer.py +++ b/discoverx/explorer.py @@ -32,6 +32,7 @@ def __init__(self, from_tables, spark: SparkSession, info_fetcher: InfoFetcher) self._sql_query_template = None self._max_concurrency = 10 self._with_tags = False + self._data_source_formats = ["DELTA"] @staticmethod def validate_from_components(from_tables: str): @@ -54,6 +55,7 @@ def __deepcopy__(self, memo): new_obj._sql_query_template = copy.deepcopy(self._sql_query_template) new_obj._max_concurrency = copy.deepcopy(self._max_concurrency) new_obj._with_tags = copy.deepcopy(self._with_tags) + new_obj._data_source_formats = copy.deepcopy(self._data_source_formats) new_obj._spark = self._spark new_obj._info_fetcher = self._info_fetcher @@ -70,6 +72,12 @@ def having_columns(self, *columns) -> "DataExplorer": new_obj._having_columns.extend(columns) return new_obj + def with_data_source_formats(self, data_source_formats: list[str] = ["DELTA"]) -> "DataExplorer": + """Sets the maximum number of concurrent queries to run""" + new_obj = copy.deepcopy(self) + new_obj._data_source_formats = data_source_formats + return new_obj + def with_concurrency(self, max_concurrency) -> "DataExplorer": """Sets the maximum number of concurrent queries to run""" new_obj = copy.deepcopy(self) @@ -140,7 +148,13 @@ def scan( self._catalogs, self._schemas, self._tables, - self._info_fetcher.get_tables_info(self._catalogs, self._schemas, self._tables, self._having_columns), + self._info_fetcher.get_tables_info( + catalogs=self._catalogs, + schemas=self._schemas, + tables=self._tables, + columns=self._having_columns, + data_source_formats=self._data_source_formats, + ), custom_rules=custom_rules, locale=locale, ) diff --git a/discoverx/scanner.py b/discoverx/scanner.py index 89f858d..59815a7 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -140,6 +140,7 @@ def __init__( what_if: bool = False, information_schema: str = "", max_workers: int = 10, + data_source_formats: list[str] = ["DELTA"], ): self.spark = spark self.rules = rules @@ -152,6 +153,7 @@ def __init__( self.what_if = what_if self.information_schema = information_schema self.max_workers = max_workers + self.data_source_formats = data_source_formats self.content: ScanContent = self._resolve_scan_content() self.rule_list = self.rules.get_rules(rule_filter=self.rules_filter) @@ -211,7 +213,9 @@ def _resolve_scan_content(self) -> ScanContent: table_list = self.table_list else: info_fetcher = InfoFetcher(self.spark, information_schema=self.information_schema) - table_list = info_fetcher.get_tables_info(self.catalogs, self.schemas, self.tables) + table_list = info_fetcher.get_tables_info( + self.catalogs, self.schemas, self.tables, self.data_source_formats + ) catalogs = set(map(lambda x: x.catalog, table_list)) schemas = set(map(lambda x: f"{x.catalog}.{x.schema}", table_list)) diff --git a/discoverx/table_info.py b/discoverx/table_info.py index 5afe31a..90f6d35 100644 --- a/discoverx/table_info.py +++ b/discoverx/table_info.py @@ -111,9 +111,10 @@ def get_tables_info( tables: str, columns: list[str] = [], with_tags=False, + data_source_formats: list[str] = ["DELTA"], ) -> list[TableInfo]: # Filter tables by matching filter - table_list_sql = self._get_table_list_sql(catalogs, schemas, tables, columns, with_tags) + table_list_sql = self._get_table_list_sql(catalogs, schemas, tables, columns, with_tags, data_source_formats) filtered_tables = self.spark.sql(table_list_sql).collect() @@ -129,6 +130,7 @@ def _get_table_list_sql( tables: str, columns: list[str] = [], with_tags=False, + data_source_formats: list[str] = ["DELTA"], ) -> str: """ Returns a SQL expression which returns a list of columns matching @@ -160,6 +162,7 @@ def _get_table_list_sql( if columns: match_any_col = "|".join([f'({c.replace("*", ".*")})' for c in columns]) columns_sql = f"""AND regexp_like(column_name, "^{match_any_col}$")""" + data_source_formats_values = ",".join("'{0}'".format(f) for f in data_source_formats) with_column_info_sql = f""" WITH all_user_tbl_list AS ( @@ -184,7 +187,11 @@ def _get_table_list_sql( FROM {self.information_schema}.tables WHERE table_schema != "information_schema" - and table_type != "VIEW" + {catalog_sql if catalogs != "*" else ""} + {schema_sql if schemas != "*" else ""} + {table_sql if tables != "*" else ""} + and table_type in ('MANAGED','EXTERNAL','MANAGED_SHALLOW_CLONE','EXTERNAL_SHALLOW_CLONE') + and data_source_format in ({data_source_formats_values}) ), filtered_tbl_list AS (