diff --git a/discoverx/discovery.py b/discoverx/discovery.py new file mode 100644 index 0000000..767c14b --- /dev/null +++ b/discoverx/discovery.py @@ -0,0 +1,301 @@ +from typing import Optional, List, Union + +from discoverx import logging +from discoverx.msql import Msql +from discoverx.scanner import TableInfo +from discoverx.scanner import Scanner, ScanResult +from discoverx.rules import Rules, Rule +from pyspark.sql import DataFrame, SparkSession + +logger = logging.Logging() + + +class Discovery: + """ """ + + COLUMNS_TABLE_NAME = "system.information_schema.columns" + MAX_WORKERS = 10 + + def __init__( + self, + spark: SparkSession, + catalogs: str, + schemas: str, + tables: str, + table_info_list: list[TableInfo], + custom_rules: Optional[List[Rule]] = None, + locale: str = None, + ): + self.spark = spark + self._catalogs = catalogs + self._schemas = schemas + self._tables = tables + self._table_info_list = table_info_list + + self.scanner: Optional[Scanner] = None + self._scan_result: Optional[ScanResult] = None + self.rules: Optional[Rules] = Rules(custom_rules=custom_rules, locale=locale) + + def _msql(self, msql: str, what_if: bool = False, min_score: Optional[float] = None): + logger.debug(f"Executing sql template: {msql}") + + msql_builder = Msql(msql) + + # check if classification is available + # Check for more specific exception + classification_result_pdf = self._scan_result.get_classes(min_score) + sql_rows = msql_builder.build(classification_result_pdf) + + if what_if: + logger.friendly(f"SQL that would be executed:") + + for sql_row in sql_rows: + logger.friendly(sql_row.sql) + + return None + else: + logger.debug(f"Executing SQL:\n{sql_rows}") + return msql_builder.execute_sql_rows(sql_rows, self.spark) + + def scan( + self, + rules="*", + sample_size=10000, + what_if: bool = False, + ): + + self.scanner = Scanner( + self.spark, + self.rules, + catalogs=self._catalogs, + schemas=self._schemas, + tables=self._tables, + table_list=self._table_info_list, + rule_filter=rules, + sample_size=sample_size, + what_if=what_if, + columns_table_name=self.COLUMNS_TABLE_NAME, + max_workers=self.MAX_WORKERS, + ) + + self._scan_result = self.scanner.scan() + logger.friendlyHTML(self.scanner.summary_html) + + def _check_scan_result(self): + if self._scan_result is None: + raise Exception("You first need to scan your lakehouse using Scanner.scan()") + + @property + def scan_result(self): + """Returns the scan results as a pandas DataFrame + + Raises: + Exception: If the scan has not been run + """ + self._check_scan_result() + + return self._scan_result.df + + def search( + self, + search_term: str, + from_tables: str = "*.*.*", + by_class: Optional[str] = None, + min_score: Optional[float] = None, + ): + """Searches your lakehouse for columns matching the given search term + + Args: + search_term (str): The search term to be used to search for columns. + from_tables (str, optional): The tables to be searched in format + "catalog.schema.table", use "*" as a wildcard. Defaults to "*.*.*". + by_class (str, optional): The class to be used to search for columns. + Defaults to None. + min_score (float, optional): Defines the classification score or frequency + threshold for columns to be considered during the scan. Defaults to None + which means that all columns where at least one record matched the + respective rule during the scan will be included. Has to be either None + or between 0 and 1. + + Raises: + ValueError: If search_term is not provided + ValueError: If the search_term type is not valid + ValueError: If the by_class type is not valid + + Returns: + DataFrame: A dataframe containing the results of the search + """ + + Msql.validate_from_components(from_tables) + + if search_term is None: + raise ValueError("search_term has not been provided.") + + if not isinstance(search_term, str): + raise ValueError(f"The search_term type {type(search_term)} is not valid. Please use a string type.") + + if by_class is None: + # Trying to infer the class by the search term + logger.friendly( + "You did not provide any class to be searched." + "We will try to auto-detect matching rules for the given search term" + ) + search_matching_rules = self.rules.match_search_term(search_term) + if len(search_matching_rules) == 0: + raise ValueError( + f"Could not infer any class for the given search term. Please specify the by_class parameter." + ) + elif len(search_matching_rules) > 1: + raise ValueError( + f"Multiple classes {search_matching_rules} match the given search term ({search_term}). Please specify the class to search in with the by_class parameter." + ) + else: + by_class = search_matching_rules[0] + logger.friendly(f"Discoverx will search your lakehouse using the class {by_class}") + elif isinstance(by_class, str): + search_matching_rules = [by_class] + else: + raise ValueError(f"The provided by_class {by_class} must be of string type.") + + sql_filter = f"`[{search_matching_rules[0]}]` = '{search_term}'" + select_statement = ( + "named_struct(" + + ", ".join( + [ + f"'{rule_name}', named_struct('column_name', '[{rule_name}]', 'value', `[{rule_name}]`)" + for rule_name in search_matching_rules + ] + ) + + ") AS search_result" + ) + + where_statement = f"WHERE {sql_filter}" + + return self._msql( + f"SELECT {select_statement}, to_json(struct(*)) AS row_content FROM {from_tables} {where_statement}", + min_score=min_score, + ) + + def select_by_classes( + self, + from_tables: str = "*.*.*", + by_classes: Optional[Union[List[str], str]] = None, + min_score: Optional[float] = None, + ): + """Selects all columns in the lakehouse from tables that match ALL the given classes + + Args: + from_tables (str, optional): The tables to be selected in format + "catalog.schema.table", use "*" as a wildcard. Defaults to "*.*.*". + by_classes (Union[List[str], str], optional): The classes to be used to + search for columns. Defaults to None. + min_score (float, optional): Defines the classification score or frequency + threshold for columns to be considered during the scan. Defaults to None + which means that all columns where at least one record matched the + respective rule during the scan will be included. Has to be either None + or between 0 and 1. + + Raises: + ValueError: If the by_classes type is not valid + + Returns: + DataFrame: A dataframe containing the UNION ALL results of the select""" + + Msql.validate_from_components(from_tables) + + if isinstance(by_classes, str): + by_classes = [by_classes] + elif isinstance(by_classes, list) and all(isinstance(elem, str) for elem in by_classes): + by_classes = by_classes + else: + raise ValueError( + f"The provided by_classes {by_classes} have the wrong type. Please provide" + f" either a str or List[str]." + ) + + from_statement = ( + "named_struct(" + + ", ".join( + [ + f"'{class_name}', named_struct('column_name', '[{class_name}]', 'value', `[{class_name}]`)" + for class_name in by_classes + ] + ) + + ") AS classified_columns" + ) + + return self._msql( + f"SELECT {from_statement}, to_json(struct(*)) AS row_content FROM {from_tables}", min_score=min_score + ) + + def delete_by_class( + self, + from_tables="*.*.*", + by_class: str = None, + values: Optional[Union[List[str], str]] = None, + yes_i_am_sure: bool = False, + min_score: Optional[float] = None, + ): + """Deletes all rows in the lakehouse that match any of the provided values in a column classified with the given class + + Args: + from_tables (str, optional): The tables to delete from in format + "catalog.schema.table", use "*" as a wildcard. Defaults to "*.*.*". + by_class (str, optional): The class to be used to search for columns. + Defaults to None. + values (Union[List[str], str], optional): The values to be deleted. + Defaults to None. + yes_i_am_sure (bool, optional): Whether you are sure that you want to delete + the data. If False prints the SQL statements instead of executing them. Defaults to False. + min_score (float, optional): Defines the classification score or frequency + threshold for columns to be considered during the scan. Defaults to None + which means that all columns where at least one record matched the + respective rule during the scan will be included. Has to be either None + or between 0 and 1. + + Raises: + ValueError: If the from_tables is not valid + ValueError: If the by_class is not valid + ValueError: If the values is not valid + """ + + Msql.validate_from_components(from_tables) + + if (by_class is None) or (not isinstance(by_class, str)): + raise ValueError(f"Please provide a class to identify the columns to be matched on the provided values.") + + if values is None: + raise ValueError( + f"Please specify the values to be deleted. You can either provide a list of values or a single value." + ) + elif isinstance(values, str): + value_string = f"'{values}'" + elif isinstance(values, list) and all(isinstance(elem, str) for elem in values): + value_string = "'" + "', '".join(values) + "'" + else: + raise ValueError( + f"The provided values {values} have the wrong type. Please provide" f" either a str or List[str]." + ) + + if not yes_i_am_sure: + logger.friendly( + f"Please confirm that you want to delete the following values from the table {from_tables} using the class {by_class}: {values}" + ) + logger.friendly( + f"If you are sure, please run the same command again but set the parameter yes_i_am_sure to True." + ) + + delete_result = self._msql( + f"DELETE FROM {from_tables} WHERE `[{by_class}]` IN ({value_string})", + what_if=(not yes_i_am_sure), + min_score=min_score, + ) + + if delete_result is not None: + delete_result = delete_result.toPandas() + logger.friendlyHTML(f"
The affected tables are
{delete_result.to_html()}") + + def save_scan(self): + """Method to save scan result""" + # TODO: + pass diff --git a/discoverx/explorer.py b/discoverx/explorer.py index 67417df..98bedcf 100644 --- a/discoverx/explorer.py +++ b/discoverx/explorer.py @@ -1,9 +1,12 @@ import concurrent.futures import copy import re -import pandas as pd +from typing import Optional, List + from discoverx import logging from discoverx.common import helper +from discoverx.discovery import Discovery +from discoverx.rules import Rule from discoverx.scanner import ColumnInfo, TableInfo from functools import reduce from pyspark.sql import DataFrame, SparkSession @@ -178,6 +181,31 @@ def unpivot_string_columns(self, sample_size=None) -> "DataExplorerActions": return self.apply_sql(sql_query_template) + def scan( + self, + rules="*", + sample_size=10000, + what_if: bool = False, + custom_rules: Optional[List[Rule]] = None, + locale: str = None, + ): + discover = Discovery( + self._spark, + self._catalogs, + self._schemas, + self._tables, + self._info_fetcher.get_tables_info(self._catalogs, self._schemas, self._tables, self._having_columns), + custom_rules=custom_rules, + locale=locale, + ) + discover.scan(rules=rules, sample_size=sample_size, what_if=what_if) + return discover + + def from_scan(self): + """Method to load from saved scan result and return discover object""" + # TODO + pass + class DataExplorerActions: def __init__( diff --git a/discoverx/scanner.py b/discoverx/scanner.py index 06e8a3a..fb65a98 100644 --- a/discoverx/scanner.py +++ b/discoverx/scanner.py @@ -158,6 +158,7 @@ def __init__( catalogs: str = "*", schemas: str = "*", tables: str = "*", + table_list: Optional[List[TableInfo]] = None, rule_filter: str = "*", sample_size: int = 1000, what_if: bool = False, @@ -169,6 +170,7 @@ def __init__( self.catalogs = catalogs self.schemas = schemas self.tables = tables + self.table_list = table_list self.rules_filter = rule_filter self.sample_size = sample_size self.what_if = what_if @@ -197,38 +199,11 @@ def _get_list_of_tables(self) -> List[TableInfo]: ] return filtered_tables - def _get_table_list_sql(self): - """ - Returns a SQL expression which returns a list of columns matching - the specified filters - - Returns: - string: The SQL expression - """ - - catalog_sql = f"""AND regexp_like(table_catalog, "^{self.catalogs.replace("*", ".*")}$")""" - schema_sql = f"""AND regexp_like(table_schema, "^{self.schemas.replace("*", ".*")}$")""" - table_sql = f"""AND regexp_like(table_name, "^{self.tables.replace("*", ".*")}$")""" - - sql = f""" - SELECT - table_catalog, - table_schema, - table_name, - collect_list(struct(column_name, data_type, partition_index)) as table_columns - FROM {self.columns_table_name} - WHERE - table_schema != "information_schema" - {catalog_sql if self.catalogs != "*" else ""} - {schema_sql if self.schemas != "*" else ""} - {table_sql if self.tables != "*" else ""} - GROUP BY table_catalog, table_schema, table_name - """ - - return strip_margin(sql) - def _resolve_scan_content(self) -> ScanContent: - table_list = self._get_list_of_tables() + if self.table_list: + table_list = self.table_list + else: + table_list = self._get_list_of_tables() catalogs = set(map(lambda x: x.catalog, table_list)) schemas = set(map(lambda x: f"{x.catalog}.{x.schema}", table_list)) diff --git a/tests/unit/discovery_test.py b/tests/unit/discovery_test.py new file mode 100644 index 0000000..1e7f349 --- /dev/null +++ b/tests/unit/discovery_test.py @@ -0,0 +1,102 @@ +import pytest +from discoverx.explorer import DataExplorer, InfoFetcher + + +@pytest.fixture() +def info_fetcher(spark): + return InfoFetcher(spark=spark, columns_table_name="default.columns_mock") + + +@pytest.fixture(name="discover_ip") +def scan_ip_in_tb1(spark, info_fetcher): + data_explorer = DataExplorer("*.*.tb_1", spark, info_fetcher) + discover = data_explorer.scan(rules="ip_*") + return discover + + +def test_discover_scan_msql(discover_ip): + result = discover_ip._msql("SELECT [ip_v4] as ip FROM *.*.*").collect() + assert {row.ip for row in result} == {"1.2.3.4", "3.4.5.60"} + + # test what-if + try: + _ = discover_ip._msql("SELECT [ip_v4] as ip FROM *.*.*", what_if=True) + except Exception as e: + pytest.fail(f"Test failed with exception {e}") + + +def test_discover_search(discover_ip): + # search a specific term and auto-detect matching classes/rules + result = discover_ip.search("1.2.3.4").collect() + assert result[0].table_name == "tb_1" + assert result[0].search_result.ip_v4.column_name == "ip" + + # specify catalog, schema and table + result_classes_namespace = discover_ip.search("1.2.3.4", by_class="ip_v4", from_tables="*.default.tb_*") + assert {row.search_result.ip_v4.value for row in result_classes_namespace.collect()} == {"1.2.3.4"} + + with pytest.raises(ValueError) as no_search_term_error: + discover_ip.search(None) + assert no_search_term_error.value.args[0] == "search_term has not been provided." + + with pytest.raises(ValueError) as no_inferred_class_error: + discover_ip.search("###") + assert ( + no_inferred_class_error.value.args[0] + == "Could not infer any class for the given search term. Please specify the by_class parameter." + ) + + with pytest.raises(ValueError) as single_bool: + discover_ip.search("", by_class=True) + assert single_bool.value.args[0] == "The provided by_class True must be of string type." + + +def test_discover_select_by_class(discover_ip): + # search a specific term and auto-detect matching classes/rules + result = discover_ip.select_by_classes(from_tables="*.default.tb_*", by_classes="ip_v4").collect() + assert result[0].table_name == "tb_1" + assert result[0].classified_columns.ip_v4.column_name == "ip" + + result = discover_ip.select_by_classes(from_tables="*.default.tb_*", by_classes=["ip_v4"]).collect() + assert result[0].table_name == "tb_1" + assert result[0].classified_columns.ip_v4.column_name == "ip" + + with pytest.raises(ValueError): + discover_ip.select_by_classes(from_tables="*.default.tb_*") + + with pytest.raises(ValueError): + discover_ip.select_by_classes(from_tables="*.default.tb_*", by_classes=[1, 3, "ip"]) + + with pytest.raises(ValueError): + discover_ip.select_by_classes(from_tables="*.default.tb_*", by_classes=True) + + with pytest.raises(ValueError): + discover_ip.select_by_classes(from_tables="invalid from", by_classes="email") + + +def test_discover_delete_by_class(spark, discover_ip): + # search a specific term and auto-detect matching classes/rules + discover_ip.delete_by_class(from_tables="*.default.tb_*", by_class="ip_v4", values="9.9.9.9") + assert {row.ip for row in spark.sql("select * from tb_1").collect()} == {"1.2.3.4", "3.4.5.60"} + + discover_ip.delete_by_class(from_tables="*.default.tb_*", by_class="ip_v4", values="1.2.3.4", yes_i_am_sure=True) + assert {row.ip for row in spark.sql("select * from tb_1").collect()} == {"3.4.5.60"} + + with pytest.raises(ValueError): + discover_ip.delete_by_class(from_tables="*.default.tb_*", by_class="x") + + with pytest.raises(ValueError): + discover_ip.delete_by_class(from_tables="*.default.tb_*", values="x") + + with pytest.raises(ValueError): + discover_ip.delete_by_class(from_tables="*.default.tb_*", by_class=["ip"], values="x") + + with pytest.raises(ValueError): + discover_ip.delete_by_class(from_tables="*.default.tb_*", by_class=True, values="x") + + with pytest.raises(ValueError): + discover_ip.delete_by_class(from_tables="invalid from", by_class="email", values="x") + + +def test_discover_scan_result(discover_ip): + assert not discover_ip.scan_result.empty diff --git a/tests/unit/explorer_test.py b/tests/unit/explorer_test.py index 23f54f5..9c41741 100644 --- a/tests/unit/explorer_test.py +++ b/tests/unit/explorer_test.py @@ -1,6 +1,4 @@ import pytest -from unittest.mock import Mock, patch -from pyspark.sql import SparkSession from discoverx.explorer import DataExplorer, DataExplorerActions, InfoFetcher, TableInfo