diff --git a/piicatcher/db/aws.py b/piicatcher/db/aws.py index 9adfbd1..d65851f 100644 --- a/piicatcher/db/aws.py +++ b/piicatcher/db/aws.py @@ -5,7 +5,6 @@ from piicatcher.db.explorer import Explorer - class AthenaExplorer(Explorer): _catalog_query = """ SELECT @@ -33,6 +32,7 @@ def factory(cls, ns): logging.debug("AWS Dispatch entered") explorer = AthenaExplorer(ns.access_key, ns.secret_key, ns.staging_dir, ns.region) + return explorer @classmethod def parser(cls, sub_parsers): @@ -48,6 +48,7 @@ def parser(cls, sub_parsers): help="AWS Region") cls.scan_options(sub_parser) + sub_parser.set_defaults(func=AthenaExplorer.dispatch) def _open_connection(self): return pyathena.connect(aws_access_key_id=self._access_key, diff --git a/piicatcher/db/explorer.py b/piicatcher/db/explorer.py index c6b3b29..0c057c3 100644 --- a/piicatcher/db/explorer.py +++ b/piicatcher/db/explorer.py @@ -75,6 +75,7 @@ def parser(cls, sub_parsers): help="Type of database") cls.scan_options(sub_parser) + sub_parser.set_defaults(func=Explorer.dispatch) @classmethod def scan_options(cls, sub_parser): @@ -90,10 +91,10 @@ def scan_options(cls, sub_parser): help="Choose output format type") sub_parser.add_argument("--list-all", action="store_true", default=False, help="List all columns. By default only columns with PII information is listed") - sub_parser.set_defaults(func=Explorer.dispatch) @classmethod def dispatch(cls, ns): + logging.debug("Dispatch of %s" % cls.__name__) explorer = cls.factory(ns) if ns.scan_type is None or ns.scan_type == "deep": explorer.scan() diff --git a/piicatcher/scanner.py b/piicatcher/scanner.py index fafd29a..a689fb2 100644 --- a/piicatcher/scanner.py +++ b/piicatcher/scanner.py @@ -72,7 +72,7 @@ class ColumnNameScanner(Scanner): PiiTypes.GENDER: re.compile("^.*(gender).*$", re.IGNORECASE), PiiTypes.NATIONALITY: re.compile("^.*(nationality).*$", re.IGNORECASE), PiiTypes.ADDRESS: re.compile("^.*(address|city|state|county|country|" - "zipcode|postal).*$", re.IGNORECASE), + "zipcode|postal|zone|borough).*$", re.IGNORECASE), PiiTypes.USER_NAME: re.compile("^.*user(id|name|).*$", re.IGNORECASE), PiiTypes.PASSWORD: re.compile("^.*pass.*$", re.IGNORECASE), PiiTypes.SSN: re.compile("^.*(ssn|social).*$", re.IGNORECASE) diff --git a/tests/test_awsexplorer.py b/tests/test_awsexplorer.py new file mode 100644 index 0000000..fc2803c --- /dev/null +++ b/tests/test_awsexplorer.py @@ -0,0 +1,22 @@ +from unittest import TestCase, mock +from argparse import Namespace + +from piicatcher.db.aws import AthenaExplorer + + +class AwsExplorerTest(TestCase): + def test_aws_dispath(self): + with mock.patch('piicatcher.db.aws.AthenaExplorer.scan', autospec=True) as mock_scan_method: + with mock.patch('piicatcher.db.aws.AthenaExplorer.get_tabular', + autospec=True) as mock_tabular_method: + with mock.patch('piicatcher.db.explorer.tableprint', autospec=True) as MockTablePrint: + AthenaExplorer.dispatch(Namespace(access_key='ACCESS KEY', + secret_key='SECRET KEY', + staging_dir='s3://DIR', + region='us-east-1', + scan_type=None, + output_format="ascii_table", + list_all=False)) + mock_scan_method.assert_called_once() + mock_tabular_method.assert_called_once() + MockTablePrint.table.assert_called_once()