Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding save and load methods to the new classification api #85

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions discoverx/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ def _msql(self, msql: str, what_if: bool = False, min_score: Optional[float] = N
logger.debug(f"Executing SQL:\n{sql_rows}")
return msql_builder.execute_sql_rows(sql_rows, self.spark)

def save(self, full_table_name: str):
"""Saves the scan results to the lakehouse

Args:
full_table_name (str): The full table name to be
used to save the scan results.
Raises:
Exception: If the scan has not been run

"""
self._check_scan_result()
# save classes
self._scan_result.save(full_table_name)

def scan(
self,
rules="*",
Expand Down
25 changes: 25 additions & 0 deletions discoverx/dx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pyspark.sql import SparkSession
from typing import List, Optional, Union
from discoverx import logging
from discoverx.discovery import Discovery
from discoverx.explorer import DataExplorer, InfoFetcher
from discoverx.msql import Msql
from discoverx.rules import Rules, Rule
Expand Down Expand Up @@ -185,6 +186,30 @@ def load(self, full_table_name: str):
self._scan_result = ScanResult(df=pd.DataFrame(), spark=self.spark)
self._scan_result.load(full_table_name)

def load_classification(self, full_table_name: str):
"""Loads previously saved classification results from a table

Args:
full_table_name (str, optional): The full table name to be
used to load the classification results.
Raises:
Exception: If the table to be loaded does not exist
"""
scan_result = ScanResult(df=pd.DataFrame(), spark=self.spark)
scan_result.load_classification(full_table_name)
discover = Discovery(
self.spark,
scan_result.catalogs,
scan_result.schemas,
scan_result.tables,
InfoFetcher(self.spark, self.COLUMNS_TABLE_NAME).get_tables_info(
scan_result.catalogs, scan_result.schemas, scan_result.tables, columns=[]
), # TODO: Need to add column support, i.e. include having_columns functionality
)
discover._scan_result = scan_result # TODO: Move the loading of the scan result into Discover or ScanResult class
return discover


def search(
self,
search_term: str,
Expand Down
43 changes: 43 additions & 0 deletions discoverx/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def n_tables(self) -> int:
class ScanResult:
df: pd.DataFrame
spark: SparkSession
catalogs: Optional[str] = None
schemas: Optional[str] = None
tables: Optional[str] = None

@property
def is_empty(self) -> bool:
Expand Down Expand Up @@ -93,16 +96,39 @@ def _create_databes_if_not_exists(self, scan_table_name: str):
)
logger.friendly(f"The scan result table {scan_table_name} has been created.")

if (self.catalogs is not None) and (self.schemas is not None) and (self.tables is not None):
self.spark.sql(
f"ALTER TABLE {scan_table_name} SET TBLPROPERTIES (catalogs='{self.catalogs}', schemas='{self.schemas}', tables='{self.tables}')"
)

def _get_or_create_result_table_from_delta(self, scan_table_name: str) -> DeltaTable:
try:
return DeltaTable.forName(self.spark, scan_table_name)
except Exception:
self._create_databes_if_not_exists(scan_table_name)
return DeltaTable.forName(self.spark, scan_table_name)

def save_discovery(self, scan_table_name: str):
scan_delta_table = self._get_or_create_result_table_from_delta(scan_table_name)
table_props = (scan_delta_table.detail().select("properties").collect())[0].properties
try:
delta_scanned_tables = ".".join([table_props.catalogs, table_props.schemas, table_props.tables])
except Exception: # TODO: Use specific exception instead of general one
raise Exception("The scan result table has been created without catalogs, schemas and tables properties using the old dx.save() method. Use dx.load() instead.")
scanned_tables = ".".join([self.catalogs, self.schemas, self.tables])
if delta_scanned_tables != scanned_tables:
raise Exception(
f"The scan result table has been created for tables {delta_scanned_tables}. The current scan has been performed for tables {scanned_tables}. Please delete the scan result table and rerun the scan."
)

self._merge_scan_results(scan_table_name, scan_delta_table)

def save(self, scan_table_name: str):
scan_delta_table = self._get_or_create_result_table_from_delta(scan_table_name)

self._merge_scan_results(scan_table_name, scan_delta_table)

def _merge_scan_results(self, scan_table_name: str, scan_delta_table: DeltaTable):
scan_result_df = self.spark.createDataFrame(
self.df,
"table_catalog: string, table_schema: string, table_name: string, column_name: string, class_name: string, score: double",
Expand All @@ -125,6 +151,23 @@ def load(self, scan_table_name: str):
logger.error(f"Error while reading the scan result table {scan_table_name}: {e}")
raise e

def load_classification(self, scan_table_name: str):
try:
scan_delta_table = DeltaTable.forName(self.spark, scan_table_name)
self.df = scan_delta_table.toDF().drop("effective_timestamp").toPandas()
except Exception as e:
logger.error(f"Error while reading the scan result table {scan_table_name}: {e}")
raise e

table_props = (scan_delta_table.detail().select("properties").collect())[0].properties
try:
self.catalogs = table_props.catalogs
self.schemas = table_props.schemas
self.tables = table_props.tables
except Exception: # TODO: Use specific exception instead of general one
raise Exception(
"The scan result table has been created without catalogs, schemas and tables properties using the old dx.save() method. Use dx.load() instead.")


class Scanner:
def __init__(
Expand Down
Loading