Skip to content

Commit

Permalink
Added changes to filter data on data source formats
Browse files Browse the repository at this point in the history
  • Loading branch information
souravg-db2 committed Dec 28, 2023
1 parent 38e7f3e commit e96fe2e
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 4 deletions.
2 changes: 2 additions & 0 deletions discoverx/dx.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def scan(
rules="*",
sample_size=10000,
what_if: bool = False,
data_source_formats: list[str] = ["DELTA"],
):
"""Scans the lakehouse for columns matching the given rules
Expand All @@ -139,6 +140,7 @@ def scan(
what_if=what_if,
information_schema=self.INFORMATION_SCHEMA,
max_workers=self.MAX_WORKERS,
data_source_formats=data_source_formats,
)

self._scan_result = self.scanner.scan()
Expand Down
16 changes: 15 additions & 1 deletion discoverx/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, from_tables, spark: SparkSession, info_fetcher: InfoFetcher)
self._sql_query_template = None
self._max_concurrency = 10
self._with_tags = False
self._data_source_formats = ["DELTA"]

@staticmethod
def validate_from_components(from_tables: str):
Expand All @@ -54,6 +55,7 @@ def __deepcopy__(self, memo):
new_obj._sql_query_template = copy.deepcopy(self._sql_query_template)
new_obj._max_concurrency = copy.deepcopy(self._max_concurrency)
new_obj._with_tags = copy.deepcopy(self._with_tags)
new_obj._data_source_formats = copy.deepcopy(self._data_source_formats)

new_obj._spark = self._spark
new_obj._info_fetcher = self._info_fetcher
Expand All @@ -70,6 +72,12 @@ def having_columns(self, *columns) -> "DataExplorer":
new_obj._having_columns.extend(columns)
return new_obj

def with_data_source_formats(self, data_source_formats: list[str] = ["DELTA"]) -> "DataExplorer":
"""Sets the maximum number of concurrent queries to run"""
new_obj = copy.deepcopy(self)
new_obj._data_source_formats = data_source_formats
return new_obj

def with_concurrency(self, max_concurrency) -> "DataExplorer":
"""Sets the maximum number of concurrent queries to run"""
new_obj = copy.deepcopy(self)
Expand Down Expand Up @@ -140,7 +148,13 @@ def scan(
self._catalogs,
self._schemas,
self._tables,
self._info_fetcher.get_tables_info(self._catalogs, self._schemas, self._tables, self._having_columns),
self._info_fetcher.get_tables_info(
catalogs=self._catalogs,
schemas=self._schemas,
tables=self._tables,
columns=self._having_columns,
data_source_formats=self._data_source_formats,
),
custom_rules=custom_rules,
locale=locale,
)
Expand Down
6 changes: 5 additions & 1 deletion discoverx/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
what_if: bool = False,
information_schema: str = "",
max_workers: int = 10,
data_source_formats: list[str] = ["DELTA"],
):
self.spark = spark
self.rules = rules
Expand All @@ -152,6 +153,7 @@ def __init__(
self.what_if = what_if
self.information_schema = information_schema
self.max_workers = max_workers
self.data_source_formats = data_source_formats

self.content: ScanContent = self._resolve_scan_content()
self.rule_list = self.rules.get_rules(rule_filter=self.rules_filter)
Expand Down Expand Up @@ -211,7 +213,9 @@ def _resolve_scan_content(self) -> ScanContent:
table_list = self.table_list
else:
info_fetcher = InfoFetcher(self.spark, information_schema=self.information_schema)
table_list = info_fetcher.get_tables_info(self.catalogs, self.schemas, self.tables)
table_list = info_fetcher.get_tables_info(
self.catalogs, self.schemas, self.tables, self.data_source_formats
)
catalogs = set(map(lambda x: x.catalog, table_list))
schemas = set(map(lambda x: f"{x.catalog}.{x.schema}", table_list))

Expand Down
11 changes: 9 additions & 2 deletions discoverx/table_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ def get_tables_info(
tables: str,
columns: list[str] = [],
with_tags=False,
data_source_formats: list[str] = ["DELTA"],
) -> list[TableInfo]:
# Filter tables by matching filter
table_list_sql = self._get_table_list_sql(catalogs, schemas, tables, columns, with_tags)
table_list_sql = self._get_table_list_sql(catalogs, schemas, tables, columns, with_tags, data_source_formats)

filtered_tables = self.spark.sql(table_list_sql).collect()

Expand All @@ -129,6 +130,7 @@ def _get_table_list_sql(
tables: str,
columns: list[str] = [],
with_tags=False,
data_source_formats: list[str] = ["DELTA"],
) -> str:
"""
Returns a SQL expression which returns a list of columns matching
Expand Down Expand Up @@ -160,6 +162,7 @@ def _get_table_list_sql(
if columns:
match_any_col = "|".join([f'({c.replace("*", ".*")})' for c in columns])
columns_sql = f"""AND regexp_like(column_name, "^{match_any_col}$")"""
data_source_formats_values = ",".join("'{0}'".format(f) for f in data_source_formats)

with_column_info_sql = f"""
WITH all_user_tbl_list AS (
Expand All @@ -184,7 +187,11 @@ def _get_table_list_sql(
FROM {self.information_schema}.tables
WHERE
table_schema != "information_schema"
and table_type != "VIEW"
{catalog_sql if catalogs != "*" else ""}
{schema_sql if schemas != "*" else ""}
{table_sql if tables != "*" else ""}
and table_type in ('MANAGED','EXTERNAL','MANAGED_SHALLOW_CLONE','EXTERNAL_SHALLOW_CLONE')
and data_source_format in ({data_source_formats_values})
),
filtered_tbl_list AS (
Expand Down

0 comments on commit e96fe2e

Please sign in to comment.