Skip to content

Commit

Permalink
De-duplicated table info fetching code
Browse files Browse the repository at this point in the history
  • Loading branch information
edurdevic committed Sep 29, 2023
1 parent bbfe878 commit 5f42ee1
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 219 deletions.
4 changes: 2 additions & 2 deletions discoverx/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from discoverx import logging
from discoverx.msql import Msql
from discoverx.scanner import TableInfo
from discoverx.table_info import TableInfo
from discoverx.scanner import Scanner, ScanResult
from discoverx.rules import Rules, Rule
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import SparkSession

logger = logging.Logging()

Expand Down
94 changes: 2 additions & 92 deletions discoverx/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,108 +2,18 @@
import copy
import re
from typing import Optional, List

from discoverx import logging
from discoverx.common import helper
from discoverx.discovery import Discovery
from discoverx.rules import Rule
from discoverx.scanner import ColumnInfo, TableInfo
from functools import reduce
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import lit
from pyspark.sql.types import Row

logger = logging.Logging()


class InfoFetcher:
def __init__(self, spark, columns_table_name="system.information_schema.columns") -> None:
self.columns_table_name = columns_table_name
self.spark = spark

def _to_info_list(self, info_rows: list[Row]) -> list[TableInfo]:
filtered_tables = [
TableInfo(
row["table_catalog"],
row["table_schema"],
row["table_name"],
[
ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], [])
for col in row["table_columns"]
],
)
for row in info_rows
]
return filtered_tables

def get_tables_info(self, catalogs: str, schemas: str, tables: str, columns: list[str] = []) -> list[TableInfo]:
# Filter tables by matching filter
table_list_sql = self._get_table_list_sql(catalogs, schemas, tables, columns)

filtered_tables = self.spark.sql(table_list_sql).collect()

if len(filtered_tables) == 0:
raise ValueError(f"No tables found matching filter: {catalogs}.{schemas}.{tables}")

return self._to_info_list(filtered_tables)

def _get_table_list_sql(self, catalogs: str, schemas: str, tables: str, columns: list[str] = []) -> str:
"""
Returns a SQL expression which returns a list of columns matching
the specified filters
Returns:
string: The SQL expression
"""

if "*" in catalogs:
catalog_sql = f"""AND regexp_like(table_catalog, "^{catalogs.replace("*", ".*")}$")"""
else:
catalog_sql = f"""AND table_catalog = "{catalogs}" """

if "*" in schemas:
schema_sql = f"""AND regexp_like(table_schema, "^{schemas.replace("*", ".*")}$")"""
else:
schema_sql = f"""AND table_schema = "{schemas}" """

if "*" in tables:
table_sql = f"""AND regexp_like(table_name, "^{tables.replace("*", ".*")}$")"""
else:
table_sql = f"""AND table_name = "{tables}" """

if columns:
match_any_col = "|".join([f'({c.replace("*", ".*")})' for c in columns])
columns_sql = f"""AND regexp_like(column_name, "^{match_any_col}$")"""

sql = f"""
WITH tb_list AS (
SELECT DISTINCT
table_catalog,
table_schema,
table_name
FROM {self.columns_table_name}
WHERE
table_schema != "information_schema"
{catalog_sql if catalogs != "*" else ""}
{schema_sql if schemas != "*" else ""}
{table_sql if tables != "*" else ""}
{columns_sql if columns else ""}
)
from discoverx.table_info import InfoFetcher, TableInfo

SELECT
info_schema.table_catalog,
info_schema.table_schema,
info_schema.table_name,
collect_list(struct(column_name, data_type, partition_index)) as table_columns
FROM {self.columns_table_name} info_schema
INNER JOIN tb_list ON (
info_schema.table_catalog <=> tb_list.table_catalog AND
info_schema.table_schema = tb_list.table_schema AND
info_schema.table_name = tb_list.table_name)
GROUP BY info_schema.table_catalog, info_schema.table_schema, info_schema.table_name
"""

return helper.strip_margin(sql)
logger = logging.Logging()


class DataExplorer:
Expand Down
2 changes: 1 addition & 1 deletion discoverx/msql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from functools import reduce
from discoverx import logging
from discoverx.scanner import ColumnInfo, TableInfo
from discoverx.table_info import ColumnInfo, TableInfo
from discoverx.common.helper import strip_margin
from fnmatch import fnmatch
from pyspark.sql.functions import lit
Expand Down
77 changes: 3 additions & 74 deletions discoverx/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,12 @@

from discoverx.common.helper import strip_margin, format_regex
from discoverx import logging
from discoverx.table_info import InfoFetcher, TableInfo
from discoverx.rules import Rules, RuleTypes

logger = logging.Logging()


@dataclass
class ColumnInfo:
name: str
data_type: str
partition_index: int
classes: list[str]


@dataclass
class TableInfo:
catalog: Optional[str]
schema: str
table: str
columns: list[ColumnInfo]

def get_columns_by_class(self, class_name: str):
return [ClassifiedColumn(col.name, class_name) for col in self.columns if class_name in col.classes]


@dataclass
class ClassifiedColumn:
name: str
class_name: str


@dataclass
class ScanContent:
table_list: List[TableInfo]
Expand Down Expand Up @@ -181,59 +157,12 @@ def __init__(
self.rule_list = self.rules.get_rules(rule_filter=self.rules_filter)
self.scan_result: Optional[ScanResult] = None

def _get_list_of_tables(self) -> List[TableInfo]:
table_list_sql = self._get_table_list_sql()

rows = self.spark.sql(table_list_sql).collect()
filtered_tables = [
TableInfo(
row["table_catalog"],
row["table_schema"],
row["table_name"],
[
ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], [])
for col in row["table_columns"]
],
)
for row in rows
]
return filtered_tables

def _get_table_list_sql(self):
"""
Returns a SQL expression which returns a list of columns matching
the specified filters
Returns:
string: The SQL expression
"""

catalog_sql = f"""AND regexp_like(table_catalog, "^{self.catalogs.replace("*", ".*")}$")"""
schema_sql = f"""AND regexp_like(table_schema, "^{self.schemas.replace("*", ".*")}$")"""
table_sql = f"""AND regexp_like(table_name, "^{self.tables.replace("*", ".*")}$")"""

sql = f"""
SELECT
table_catalog,
table_schema,
table_name,
collect_list(struct(column_name, data_type, partition_index)) as table_columns
FROM {self.columns_table_name}
WHERE
table_schema != "information_schema"
{catalog_sql if self.catalogs != "*" else ""}
{schema_sql if self.schemas != "*" else ""}
{table_sql if self.tables != "*" else ""}
GROUP BY table_catalog, table_schema, table_name
"""

return strip_margin(sql)

def _resolve_scan_content(self) -> ScanContent:
if self.table_list:
table_list = self.table_list
else:
table_list = self._get_list_of_tables()
info_fetcher = InfoFetcher(self.spark, columns_table_name=self.columns_table_name)
table_list = info_fetcher.get_tables_info(self.catalogs, self.schemas, self.tables)
catalogs = set(map(lambda x: x.catalog, table_list))
schemas = set(map(lambda x: f"{x.catalog}.{x.schema}", table_list))

Expand Down
119 changes: 119 additions & 0 deletions discoverx/table_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Optional
from discoverx.common import helper
from pyspark.sql.types import Row
from dataclasses import dataclass


@dataclass
class ColumnInfo:
name: str
data_type: str
partition_index: int
classes: list[str]


@dataclass
class TableInfo:
catalog: Optional[str]
schema: str
table: str
columns: list[ColumnInfo]

def get_columns_by_class(self, class_name: str):
return [ClassifiedColumn(col.name, class_name) for col in self.columns if class_name in col.classes]


@dataclass
class ClassifiedColumn:
name: str
class_name: str


class InfoFetcher:
def __init__(self, spark, columns_table_name="system.information_schema.columns") -> None:
self.columns_table_name = columns_table_name
self.spark = spark

def _to_info_list(self, info_rows: list[Row]) -> list[TableInfo]:
filtered_tables = [
TableInfo(
row["table_catalog"],
row["table_schema"],
row["table_name"],
[
ColumnInfo(col["column_name"], col["data_type"], col["partition_index"], [])
for col in row["table_columns"]
],
)
for row in info_rows
]
return filtered_tables

def get_tables_info(self, catalogs: str, schemas: str, tables: str, columns: list[str] = []) -> list[TableInfo]:
# Filter tables by matching filter
table_list_sql = self._get_table_list_sql(catalogs, schemas, tables, columns)

filtered_tables = self.spark.sql(table_list_sql).collect()

if len(filtered_tables) == 0:
raise ValueError(f"No tables found matching filter: {catalogs}.{schemas}.{tables}")

Check warning on line 59 in discoverx/table_info.py

View check run for this annotation

Codecov / codecov/patch

discoverx/table_info.py#L59

Added line #L59 was not covered by tests

return self._to_info_list(filtered_tables)

def _get_table_list_sql(self, catalogs: str, schemas: str, tables: str, columns: list[str] = []) -> str:
"""
Returns a SQL expression which returns a list of columns matching
the specified filters
Returns:
string: The SQL expression
"""

if "*" in catalogs:
catalog_sql = f"""AND regexp_like(table_catalog, "^{catalogs.replace("*", ".*")}$")"""
else:
catalog_sql = f"""AND table_catalog = "{catalogs}" """

Check warning on line 75 in discoverx/table_info.py

View check run for this annotation

Codecov / codecov/patch

discoverx/table_info.py#L75

Added line #L75 was not covered by tests

if "*" in schemas:
schema_sql = f"""AND regexp_like(table_schema, "^{schemas.replace("*", ".*")}$")"""
else:
schema_sql = f"""AND table_schema = "{schemas}" """

Check warning on line 80 in discoverx/table_info.py

View check run for this annotation

Codecov / codecov/patch

discoverx/table_info.py#L80

Added line #L80 was not covered by tests

if "*" in tables:
table_sql = f"""AND regexp_like(table_name, "^{tables.replace("*", ".*")}$")"""
else:
table_sql = f"""AND table_name = "{tables}" """

if columns:
match_any_col = "|".join([f'({c.replace("*", ".*")})' for c in columns])
columns_sql = f"""AND regexp_like(column_name, "^{match_any_col}$")"""

sql = f"""
WITH tb_list AS (
SELECT DISTINCT
table_catalog,
table_schema,
table_name
FROM {self.columns_table_name}
WHERE
table_schema != "information_schema"
{catalog_sql if catalogs != "*" else ""}
{schema_sql if schemas != "*" else ""}
{table_sql if tables != "*" else ""}
{columns_sql if columns else ""}
)
SELECT
info_schema.table_catalog,
info_schema.table_schema,
info_schema.table_name,
collect_list(struct(column_name, data_type, partition_index)) as table_columns
FROM {self.columns_table_name} info_schema
INNER JOIN tb_list ON (
info_schema.table_catalog <=> tb_list.table_catalog AND
info_schema.table_schema = tb_list.table_schema AND
info_schema.table_name = tb_list.table_name)
GROUP BY info_schema.table_catalog, info_schema.table_schema, info_schema.table_name
"""

return helper.strip_margin(sql)
2 changes: 1 addition & 1 deletion tests/unit/msql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
import pytest
from discoverx.common.helper import strip_margin
from discoverx.scanner import ColumnInfo, TableInfo
from discoverx.table_info import ColumnInfo, TableInfo
from discoverx.msql import Msql, SQLRow
from discoverx.scanner import ScanResult

Expand Down
Loading

0 comments on commit 5f42ee1

Please sign in to comment.