Skip to content

Commit

Permalink
feature: Add a progress bar. Support deep scans in Athena
Browse files Browse the repository at this point in the history
Support a progress bar using tqdm only for deep scans.
Add support for deep scans for AWS Athena.
In Postgres and Redshift a percentage of rows were processed instead of
a limited number of rows. Add a LIMIT clause for these databases.

Fix #159
  • Loading branch information
vrajat committed Nov 29, 2021
1 parent b3a8fa3 commit 9154a6c
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 39 deletions.
2 changes: 1 addition & 1 deletion piicatcher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# flake8: noqa
__version__ = "0.17.5"
__version__ = "0.18.0"
14 changes: 13 additions & 1 deletion piicatcher/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def scan_database(
if incremental:
last_task = catalog.get_latest_task("piicatcher.{}".format(source.name))
last_run = last_task.updated_at if last_task is not None else None
LOGGER.debug("Last Run at {}", last_run)
if last_run is not None:
LOGGER.debug("Last Run at {}", last_run)
else:
LOGGER.debug("No last run found")

try:
scanner = DbScanner(
Expand Down Expand Up @@ -97,6 +100,15 @@ def scan_database(
else:
deep_scan(
catalog=catalog,
work_generator=column_generator(
catalog=catalog,
source=source,
last_run=last_run,
exclude_schema_regex_str=exclude_schema_regex,
include_schema_regex_str=include_schema_regex,
exclude_table_regex_str=exclude_table_regex,
include_table_regex_str=include_table_regex,
),
generator=data_generator(
catalog=catalog,
source=source,
Expand Down
27 changes: 18 additions & 9 deletions piicatcher/dbinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_select_query(

@classmethod
@abstractmethod
def get_sample_query(cls, schema_name, table_name, column_list) -> str:
def get_sample_query(cls, schema_name, table_name, column_list, num_rows) -> str:
pass


Expand All @@ -46,57 +46,64 @@ def get_select_query(
)

@classmethod
def get_sample_query(cls, schema_name, table_name, column_list) -> str:
def get_sample_query(cls, schema_name, table_name, column_list, num_rows) -> str:
raise NotImplementedError


class MySQL(DbInfo):
_sample_query_template = (
"select {column_list} from {schema_name}.{table_name} limit 10"
"select {column_list} from {schema_name}.{table_name} limit {num_rows}"
)
_column_escape = "`"

@classmethod
def get_sample_query(cls, schema_name, table_name, column_list) -> str:
def get_sample_query(cls, schema_name, table_name, column_list, num_rows) -> str:
return cls._sample_query_template.format(
column_list="`{0}`".format("`,`".join(col for col in column_list)),
schema_name=schema_name,
table_name=table_name,
num_rows=num_rows,
)


class Postgres(DbInfo):
_sample_query_template = "select {column_list} from {schema_name}.{table_name} TABLESAMPLE BERNOULLI (10)"
_sample_query_template = "SELECT {column_list} FROM {schema_name}.{table_name} TABLESAMPLE BERNOULLI (10) LIMIT {num_rows}"

@classmethod
def get_sample_query(
cls, schema_name: str, table_name: str, column_list: List[str]
cls, schema_name: str, table_name: str, column_list: List[str], num_rows
) -> str:
return cls._sample_query_template.format(
column_list='"{0}"'.format('","'.join(col for col in column_list)),
schema_name=schema_name,
table_name=table_name,
num_rows=num_rows,
)


class Redshift(Postgres):
_sample_query_template = "SELECT {column_list} FROM {schema_name}.{table_name} ORDER BY random() LIMIT 10"
_sample_query_template = "SELECT {column_list} FROM {schema_name}.{table_name} TABLESAMPLE BERNOULLI (10) LIMIT {num_rows}"


class Snowflake(DbInfo):
_sample_query_template = "select {column_list} from {schema_name}.{table_name} TABLESAMPLE BERNOULLI (10 ROWS)"
_sample_query_template = "SELECT {column_list} FROM {schema_name}.{table_name} TABLESAMPLE BERNOULLI ({num_rows} ROWS)"

@classmethod
def get_sample_query(
cls, schema_name: str, table_name: str, column_list: List[str]
cls, schema_name: str, table_name: str, column_list: List[str], num_rows
) -> str:
return cls._sample_query_template.format(
column_list=",".join(column_list),
schema_name=schema_name,
table_name=table_name,
num_rows=num_rows,
)


class Athena(Postgres):
pass


def get_dbinfo(source_type: str) -> Type[DbInfo]:
if source_type == "sqlite":
return Sqlite
Expand All @@ -108,4 +115,6 @@ def get_dbinfo(source_type: str) -> Type[DbInfo]:
return Redshift
elif source_type == "snowflake":
return Snowflake
elif source_type == "athena":
return Athena
raise AttributeError
4 changes: 3 additions & 1 deletion piicatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def _get_query(

if count > sample_boundary:
try:
query = dbinfo.get_sample_query(schema.name, table.name, column_name_list)
query = dbinfo.get_sample_query(
schema.name, table.name, column_name_list, sample_boundary
)
LOGGER.debug("Choosing a SAMPLE query as table size is big")
except NotImplementedError:
LOGGER.warning(
Expand Down
49 changes: 28 additions & 21 deletions piicatcher/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from commonregex import CommonRegex
from dbcat.catalog import Catalog
from dbcat.catalog.models import CatColumn, CatSchema, CatTable, PiiTypes
from tqdm import tqdm

from piicatcher.generators import SMALL_TABLE_MAX, _filter_text_columns

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -146,37 +149,41 @@ def shallow_scan(

def deep_scan(
catalog: Catalog,
work_generator: Generator[Tuple[CatSchema, CatTable, CatColumn], None, None],
generator: Generator[Tuple[CatSchema, CatTable, CatColumn, str], None, None],
):
scanners = [RegexScanner(), NERScanner()]

total_columns = _filter_text_columns([c for s, t, c in work_generator])
total_work = len(total_columns) * SMALL_TABLE_MAX

counter = 0
set_number = 0

for schema, table, column, val in generator:
for schema, table, column, val in tqdm(
generator, total=total_work, desc="datum", unit="datum"
):
counter += 1
LOGGER.debug("Scanning column name %s", column.fqdn)
if val is not None:
pii_types = set()
for scanner in scanners:
for pii in scanner.scan(val):
pii_types.add(pii)

if len(pii_types) > 0:
set_number += 1

top = pii_types.pop()
catalog.set_column_pii_type(
column=column, pii_type=top, pii_plugin=scanner.name()
)
LOGGER.debug("{} has {}".format(column.fqdn, top))

scan_logger.info(
"deep_scan", extra={"column": column.fqdn, "pii_types": top}
)
data_logger.info(
"deep_scan",
extra={"column": column.fqdn, "data": val, "pii_types": top},
)

set_number += 1

catalog.set_column_pii_type(
column=column, pii_type=pii, pii_plugin=scanner.name()
)
LOGGER.debug("{} has {}".format(column.fqdn, pii))

scan_logger.info(
"deep_scan", extra={"column": column.fqdn, "pii_types": pii}
)
data_logger.info(
"deep_scan",
extra={"column": column.fqdn, "data": val, "pii_types": pii},
)
break
else:
continue
break
LOGGER.info("Columns Scanned: %d, Columns Labeled: %d", counter, set_number)
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "piicatcher"
version = "0.17.5"
version = "0.18.0"
description = "Find PII data in databases"
authors = ["Tokern <[email protected]>"]
license = "Apache 2.0"
Expand Down Expand Up @@ -35,6 +35,7 @@ tabulate = "^0.8.9"
dataclasses = {version = ">=0.6", markers="python_version >= '3.6' and python_version < '3.7'"}
great-expectations = {version = "^0.13.42", optional = true}
acryl-datahub = {version = "^0.8.16", optional = true}
tqdm = "^4.62.3"

[tool.poetry.extras]
datahub = ["acryl-datahub", "great-expectations"]
Expand Down
39 changes: 36 additions & 3 deletions tests/test_generators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Generator, Tuple

import pytest
from dbcat.catalog import Catalog, CatSource
from dbcat.catalog import Catalog, CatColumn, CatSchema, CatSource, CatTable
from sqlalchemy import create_engine

from piicatcher.dbinfo import get_dbinfo
Expand Down Expand Up @@ -150,16 +150,49 @@ def test_get_sample_query(sqlalchemy_engine):
)

if source.source_type == "mysql":
assert query == """select `name`,`state` from piidb.full_pii limit 10"""
assert query == """select `name`,`state` from piidb.full_pii limit 1"""
elif source.source_type == "postgresql":
assert (
query
== """select "name","state" from public.full_pii TABLESAMPLE BERNOULLI (10)"""
== """SELECT "name","state" FROM public.full_pii TABLESAMPLE BERNOULLI (10) LIMIT 1"""
)
elif source.source_type == "sqlite":
assert query == """select "name","state" from full_pii"""


@pytest.mark.parametrize(
("source_type", "expected_query"),
[
(
"redshift",
'SELECT "column" FROM public.table TABLESAMPLE BERNOULLI (10) LIMIT 1',
),
("snowflake", "SELECT column FROM public.table TABLESAMPLE BERNOULLI (1 ROWS)"),
(
"athena",
'SELECT "column" FROM public.table TABLESAMPLE BERNOULLI (10) LIMIT 1',
),
],
)
def test_get_sample_query_redshift(mocker, source_type, expected_query):
source = CatSource(name="src", source_type=source_type)
schema = CatSchema(source=source, name="public")
table = CatTable(schema=schema, name="table")
column = CatColumn(table=table, name="column")

mocker.patch("piicatcher.generators._get_table_count", return_value=100)
query = _get_query(
schema=schema,
table=table,
column_list=[column],
dbinfo=get_dbinfo(source_type=source.source_type),
connection=None,
sample_boundary=1,
)

assert query == expected_query


def test_row_generator(sqlalchemy_engine):
catalog, source, conn = sqlalchemy_engine
schemata = catalog.search_schema(source_like=source.name, schema_like="%")
Expand Down
4 changes: 3 additions & 1 deletion tests/test_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def test_deep_scan(load_data_and_pull):
with catalog.managed_session:
source = catalog.get_source_by_id(source_id)
deep_scan(
catalog=catalog, generator=data_generator(catalog=catalog, source=source)
catalog=catalog,
work_generator=column_generator(catalog=catalog, source=source),
generator=data_generator(catalog=catalog, source=source),
)

schemata = catalog.search_schema(source_like=source.name, schema_like="%")
Expand Down

0 comments on commit 9154a6c

Please sign in to comment.