diff --git a/piicatcher/__init__.py b/piicatcher/__init__.py index fae4246..b5c6f9b 100644 --- a/piicatcher/__init__.py +++ b/piicatcher/__init__.py @@ -1,2 +1,2 @@ # flake8: noqa -__version__ = "0.17.5" +__version__ = "0.18.0" diff --git a/piicatcher/api.py b/piicatcher/api.py index c08664e..22c18af 100644 --- a/piicatcher/api.py +++ b/piicatcher/api.py @@ -65,7 +65,10 @@ def scan_database( if incremental: last_task = catalog.get_latest_task("piicatcher.{}".format(source.name)) last_run = last_task.updated_at if last_task is not None else None - LOGGER.debug("Last Run at {}", last_run) + if last_run is not None: + LOGGER.debug("Last Run at {}", last_run) + else: + LOGGER.debug("No last run found") try: scanner = DbScanner( @@ -97,6 +100,15 @@ def scan_database( else: deep_scan( catalog=catalog, + work_generator=column_generator( + catalog=catalog, + source=source, + last_run=last_run, + exclude_schema_regex_str=exclude_schema_regex, + include_schema_regex_str=include_schema_regex, + exclude_table_regex_str=exclude_table_regex, + include_table_regex_str=include_table_regex, + ), generator=data_generator( catalog=catalog, source=source, diff --git a/piicatcher/dbinfo.py b/piicatcher/dbinfo.py index 92c7f55..ce2dcaf 100644 --- a/piicatcher/dbinfo.py +++ b/piicatcher/dbinfo.py @@ -28,7 +28,7 @@ def get_select_query( @classmethod @abstractmethod - def get_sample_query(cls, schema_name, table_name, column_list) -> str: + def get_sample_query(cls, schema_name, table_name, column_list, num_rows) -> str: pass @@ -46,57 +46,64 @@ def get_select_query( ) @classmethod - def get_sample_query(cls, schema_name, table_name, column_list) -> str: + def get_sample_query(cls, schema_name, table_name, column_list, num_rows) -> str: raise NotImplementedError class MySQL(DbInfo): _sample_query_template = ( - "select {column_list} from {schema_name}.{table_name} limit 10" + "select {column_list} from {schema_name}.{table_name} limit {num_rows}" ) _column_escape = "`" @classmethod - def get_sample_query(cls, schema_name, table_name, column_list) -> str: + def get_sample_query(cls, schema_name, table_name, column_list, num_rows) -> str: return cls._sample_query_template.format( column_list="`{0}`".format("`,`".join(col for col in column_list)), schema_name=schema_name, table_name=table_name, + num_rows=num_rows, ) class Postgres(DbInfo): - _sample_query_template = "select {column_list} from {schema_name}.{table_name} TABLESAMPLE BERNOULLI (10)" + _sample_query_template = "SELECT {column_list} FROM {schema_name}.{table_name} TABLESAMPLE BERNOULLI (10) LIMIT {num_rows}" @classmethod def get_sample_query( - cls, schema_name: str, table_name: str, column_list: List[str] + cls, schema_name: str, table_name: str, column_list: List[str], num_rows ) -> str: return cls._sample_query_template.format( column_list='"{0}"'.format('","'.join(col for col in column_list)), schema_name=schema_name, table_name=table_name, + num_rows=num_rows, ) class Redshift(Postgres): - _sample_query_template = "SELECT {column_list} FROM {schema_name}.{table_name} ORDER BY random() LIMIT 10" + _sample_query_template = "SELECT {column_list} FROM {schema_name}.{table_name} TABLESAMPLE BERNOULLI (10) LIMIT {num_rows}" class Snowflake(DbInfo): - _sample_query_template = "select {column_list} from {schema_name}.{table_name} TABLESAMPLE BERNOULLI (10 ROWS)" + _sample_query_template = "SELECT {column_list} FROM {schema_name}.{table_name} TABLESAMPLE BERNOULLI ({num_rows} ROWS)" @classmethod def get_sample_query( - cls, schema_name: str, table_name: str, column_list: List[str] + cls, schema_name: str, table_name: str, column_list: List[str], num_rows ) -> str: return cls._sample_query_template.format( column_list=",".join(column_list), schema_name=schema_name, table_name=table_name, + num_rows=num_rows, ) +class Athena(Postgres): + pass + + def get_dbinfo(source_type: str) -> Type[DbInfo]: if source_type == "sqlite": return Sqlite @@ -108,4 +115,6 @@ def get_dbinfo(source_type: str) -> Type[DbInfo]: return Redshift elif source_type == "snowflake": return Snowflake + elif source_type == "athena": + return Athena raise AttributeError diff --git a/piicatcher/generators.py b/piicatcher/generators.py index 721a685..fa80df6 100644 --- a/piicatcher/generators.py +++ b/piicatcher/generators.py @@ -70,7 +70,9 @@ def _get_query( if count > sample_boundary: try: - query = dbinfo.get_sample_query(schema.name, table.name, column_name_list) + query = dbinfo.get_sample_query( + schema.name, table.name, column_name_list, sample_boundary + ) LOGGER.debug("Choosing a SAMPLE query as table size is big") except NotImplementedError: LOGGER.warning( diff --git a/piicatcher/scanner.py b/piicatcher/scanner.py index 012b525..fc73356 100644 --- a/piicatcher/scanner.py +++ b/piicatcher/scanner.py @@ -8,6 +8,9 @@ from commonregex import CommonRegex from dbcat.catalog import Catalog from dbcat.catalog.models import CatColumn, CatSchema, CatTable, PiiTypes +from tqdm import tqdm + +from piicatcher.generators import SMALL_TABLE_MAX, _filter_text_columns LOGGER = logging.getLogger(__name__) @@ -146,37 +149,41 @@ def shallow_scan( def deep_scan( catalog: Catalog, + work_generator: Generator[Tuple[CatSchema, CatTable, CatColumn], None, None], generator: Generator[Tuple[CatSchema, CatTable, CatColumn, str], None, None], ): scanners = [RegexScanner(), NERScanner()] + total_columns = _filter_text_columns([c for s, t, c in work_generator]) + total_work = len(total_columns) * SMALL_TABLE_MAX + counter = 0 set_number = 0 - for schema, table, column, val in generator: + for schema, table, column, val in tqdm( + generator, total=total_work, desc="datum", unit="datum" + ): counter += 1 LOGGER.debug("Scanning column name %s", column.fqdn) if val is not None: - pii_types = set() for scanner in scanners: for pii in scanner.scan(val): - pii_types.add(pii) - - if len(pii_types) > 0: - set_number += 1 - - top = pii_types.pop() - catalog.set_column_pii_type( - column=column, pii_type=top, pii_plugin=scanner.name() - ) - LOGGER.debug("{} has {}".format(column.fqdn, top)) - - scan_logger.info( - "deep_scan", extra={"column": column.fqdn, "pii_types": top} - ) - data_logger.info( - "deep_scan", - extra={"column": column.fqdn, "data": val, "pii_types": top}, - ) - + set_number += 1 + + catalog.set_column_pii_type( + column=column, pii_type=pii, pii_plugin=scanner.name() + ) + LOGGER.debug("{} has {}".format(column.fqdn, pii)) + + scan_logger.info( + "deep_scan", extra={"column": column.fqdn, "pii_types": pii} + ) + data_logger.info( + "deep_scan", + extra={"column": column.fqdn, "data": val, "pii_types": pii}, + ) + break + else: + continue + break LOGGER.info("Columns Scanned: %d, Columns Labeled: %d", counter, set_number) diff --git a/poetry.lock b/poetry.lock index c27d121..a23efb5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2936,7 +2936,7 @@ datahub = ["acryl-datahub", "great-expectations"] [metadata] lock-version = "1.1" python-versions = ">=3.6,<3.9" -content-hash = "0791e569f5b93da980a033f4f1152b5d2943dfd286c8c7f8532d546b2afce1d1" +content-hash = "386eacc4c2710fccabcd4900e20367a790a829b2e338ef8ba90ed21fc2387c79" [metadata.files] acryl-datahub = [ diff --git a/pyproject.toml b/pyproject.toml index e759476..428d4b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "piicatcher" -version = "0.17.5" +version = "0.18.0" description = "Find PII data in databases" authors = ["Tokern "] license = "Apache 2.0" @@ -35,6 +35,7 @@ tabulate = "^0.8.9" dataclasses = {version = ">=0.6", markers="python_version >= '3.6' and python_version < '3.7'"} great-expectations = {version = "^0.13.42", optional = true} acryl-datahub = {version = "^0.8.16", optional = true} +tqdm = "^4.62.3" [tool.poetry.extras] datahub = ["acryl-datahub", "great-expectations"] diff --git a/tests/test_generators.py b/tests/test_generators.py index 214358a..ef779eb 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -1,7 +1,7 @@ from typing import Any, Generator, Tuple import pytest -from dbcat.catalog import Catalog, CatSource +from dbcat.catalog import Catalog, CatColumn, CatSchema, CatSource, CatTable from sqlalchemy import create_engine from piicatcher.dbinfo import get_dbinfo @@ -150,16 +150,49 @@ def test_get_sample_query(sqlalchemy_engine): ) if source.source_type == "mysql": - assert query == """select `name`,`state` from piidb.full_pii limit 10""" + assert query == """select `name`,`state` from piidb.full_pii limit 1""" elif source.source_type == "postgresql": assert ( query - == """select "name","state" from public.full_pii TABLESAMPLE BERNOULLI (10)""" + == """SELECT "name","state" FROM public.full_pii TABLESAMPLE BERNOULLI (10) LIMIT 1""" ) elif source.source_type == "sqlite": assert query == """select "name","state" from full_pii""" +@pytest.mark.parametrize( + ("source_type", "expected_query"), + [ + ( + "redshift", + 'SELECT "column" FROM public.table TABLESAMPLE BERNOULLI (10) LIMIT 1', + ), + ("snowflake", "SELECT column FROM public.table TABLESAMPLE BERNOULLI (1 ROWS)"), + ( + "athena", + 'SELECT "column" FROM public.table TABLESAMPLE BERNOULLI (10) LIMIT 1', + ), + ], +) +def test_get_sample_query_redshift(mocker, source_type, expected_query): + source = CatSource(name="src", source_type=source_type) + schema = CatSchema(source=source, name="public") + table = CatTable(schema=schema, name="table") + column = CatColumn(table=table, name="column") + + mocker.patch("piicatcher.generators._get_table_count", return_value=100) + query = _get_query( + schema=schema, + table=table, + column_list=[column], + dbinfo=get_dbinfo(source_type=source.source_type), + connection=None, + sample_boundary=1, + ) + + assert query == expected_query + + def test_row_generator(sqlalchemy_engine): catalog, source, conn = sqlalchemy_engine schemata = catalog.search_schema(source_like=source.name, schema_like="%") diff --git a/tests/test_scanner.py b/tests/test_scanner.py index 6422f3a..30232f3 100644 --- a/tests/test_scanner.py +++ b/tests/test_scanner.py @@ -160,7 +160,9 @@ def test_deep_scan(load_data_and_pull): with catalog.managed_session: source = catalog.get_source_by_id(source_id) deep_scan( - catalog=catalog, generator=data_generator(catalog=catalog, source=source) + catalog=catalog, + work_generator=column_generator(catalog=catalog, source=source), + generator=data_generator(catalog=catalog, source=source), ) schemata = catalog.search_schema(source_like=source.name, schema_like="%")