diff --git a/.circleci/config.yml b/.circleci/config.yml index 809890d..fb4e95d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -28,11 +28,30 @@ jobs: environment: # environment variables for primary container PIPENV_VENV_IN_PROJECT: true + # Specify service dependencies here if necessary + # CircleCI maintains a library of pre-built images + # documented at https://circleci.com/docs/2.0/circleci-images/ + - image: circleci/postgres:12.0-alpine-ram + environment: + POSTGRES_USER: piiuser + POSTGRES_PASSWORD: p11secret + POSTGRES_DB: piidb + working_directory: ~/repo steps: - checkout + - run: + name: install dockerize + command: wget https://github.com/jwilder/dockerize/releases/download/$DOCKERIZE_VERSION/dockerize-linux-amd64-$DOCKERIZE_VERSION.tar.gz && sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-$DOCKERIZE_VERSION.tar.gz && rm dockerize-linux-amd64-$DOCKERIZE_VERSION.tar.gz + environment: + DOCKERIZE_VERSION: v0.3.0 + + - run: + name: Wait for db + command: dockerize -wait tcp://localhost:5432 -timeout 1m + # Download and cache dependencies - restore_cache: key: deps9-{{ .Branch }}-{{ checksum "Pipfile.lock" }} diff --git a/piicatcher/command_line.py b/piicatcher/command_line.py index 7342b6c..f36fc96 100644 --- a/piicatcher/command_line.py +++ b/piicatcher/command_line.py @@ -1,13 +1,9 @@ import argparse -import yaml import logging -from piicatcher.db.aws import AthenaExplorer -from piicatcher.db.explorer import Explorer -from piicatcher.files.explorer import parser as files_parser - -from piicatcher.config import set_global_config -from piicatcher.orm.models import init +from piicatcher.explorer.aws import AthenaExplorer +from piicatcher.explorer.explorer import Explorer +from piicatcher.explorer.files import parser as files_parser def get_parser(parser_cls=argparse.ArgumentParser): @@ -25,10 +21,6 @@ def get_parser(parser_cls=argparse.ArgumentParser): def dispatch(ns): logging.basicConfig(level=getattr(logging, ns.log_level.upper())) - if ns.config_file is not None: - with open(ns.config_file, 'r') as stream: - set_global_config(yaml.load(stream)) - init() ns.func(ns) diff --git a/piicatcher/config.py b/piicatcher/config.py deleted file mode 100644 index fc36c9a..0000000 --- a/piicatcher/config.py +++ /dev/null @@ -1,6 +0,0 @@ -config = {} - - -def set_global_config(c): - global config # Needed to modify global copy of config - config = c diff --git a/piicatcher/db/__init__.py b/piicatcher/explorer/__init__.py similarity index 100% rename from piicatcher/db/__init__.py rename to piicatcher/explorer/__init__.py diff --git a/piicatcher/db/aws.py b/piicatcher/explorer/aws.py similarity index 72% rename from piicatcher/db/aws.py rename to piicatcher/explorer/aws.py index d65851f..e74ecbd 100644 --- a/piicatcher/db/aws.py +++ b/piicatcher/explorer/aws.py @@ -2,7 +2,8 @@ import pyathena -from piicatcher.db.explorer import Explorer +from piicatcher.explorer.explorer import Explorer +from piicatcher.store.glue import GlueStore class AthenaExplorer(Explorer): @@ -20,18 +21,14 @@ class AthenaExplorer(Explorer): _select_query_template = "select {column_list} from {schema_name}.{table_name} limit 10" _count_query = "select count(*) from {schema_name}.{table_name}" - def __init__(self, access_key, secret_key, staging_dir, region_name): - super(AthenaExplorer, self).__init__() - self._access_key = access_key - self._secret_key = secret_key - self._staging_dir = staging_dir - self._region_name = region_name + def __init__(self, ns): + super(AthenaExplorer, self).__init__(ns) + self.config = ns @classmethod def factory(cls, ns): logging.debug("AWS Dispatch entered") - explorer = AthenaExplorer(ns.access_key, ns.secret_key, - ns.staging_dir, ns.region) + explorer = AthenaExplorer(ns) return explorer @classmethod @@ -46,15 +43,25 @@ def parser(cls, sub_parsers): help="S3 Staging Directory for Athena results") sub_parser.add_argument("-r", "--region", required=True, help="AWS Region") + sub_parser.add_argument("-f", "--output-format", choices=["ascii_table", "json", "db", "glue"], + default="ascii_table", + help="Choose output format type") cls.scan_options(sub_parser) sub_parser.set_defaults(func=AthenaExplorer.dispatch) + @classmethod + def output(cls, ns, explorer): + if ns.output_format == "glue": + GlueStore.save_schemas(explorer) + else: + super(AthenaExplorer, cls).output(ns, explorer) + def _open_connection(self): - return pyathena.connect(aws_access_key_id=self._access_key, - aws_secret_access_key=self._secret_key, - s3_staging_dir=self._staging_dir, - region_name=self._region_name) + return pyathena.connect(aws_access_key_id=self.config.access_key, + aws_secret_access_key=self.config.secret_key, + s3_staging_dir=self.config.staging_dir, + region_name=self.config.region) def _get_catalog_query(self): return self._catalog_query diff --git a/piicatcher/explorer/databases.py b/piicatcher/explorer/databases.py new file mode 100644 index 0000000..3108117 --- /dev/null +++ b/piicatcher/explorer/databases.py @@ -0,0 +1,260 @@ +import sqlite3 +from abc import abstractmethod + +import pymysql +import psycopg2 +import pymssql +import cx_Oracle + +import logging + +from piicatcher.explorer.explorer import Explorer + + +class RelDbExplorer(Explorer): + def __init__(self, ns): + super(RelDbExplorer, self).__init__(ns) + self.host = ns.host + self.user = ns.user + self.password = ns.password + self.port = int(ns.port) if 'port' in vars(ns) and ns.port is not None else self.default_port + + @property + @abstractmethod + def default_port(self): + pass + + @classmethod + def factory(cls, ns): + logging.debug("Relational Db Factory entered") + explorer = None + if ns.connection_type == "sqlite": + explorer = SqliteExplorer(ns) + elif ns.connection_type == "mysql": + explorer = MySQLExplorer(ns) + elif ns.connection_type == "postgres" or ns.connection_type == "redshift": + explorer = PostgreSQLExplorer(ns) + elif ns.connection_type == "sqlserver": + explorer = MSSQLExplorer(ns) + elif ns.connection_type == "oracle": + explorer = OracleExplorer(ns) + assert (explorer is not None) + + return explorer + + +class SqliteExplorer(Explorer): + _catalog_query = """ + SELECT + "" as schema_name, + m.name as table_name, + p.name as column_name, + p.type as data_type + FROM + sqlite_master AS m + JOIN + pragma_table_info(m.name) AS p + WHERE + p.type like 'text' or p.type like 'varchar%' or p.type like 'char%' + ORDER BY + m.name, + p.name + """ + + _query_template = "select {column_list} from {table_name}" + + class CursorContextManager: + def __init__(self, connection): + self.connection = connection + + def __enter__(self): + self.cursor = self.connection.cursor() + return self.cursor + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def __init__(self, ns): + super(SqliteExplorer, self).__init__(ns) + self.host = ns.host + + @classmethod + def factory(cls, ns): + logging.debug("Sqlite Factory entered") + return SqliteExplorer(ns) + + def _open_connection(self): + logging.debug("Sqlite connection string '{}'".format(self.host)) + return sqlite3.connect(self.host) + + def _get_catalog_query(self): + return self._catalog_query + + def _get_context_manager(self): + return SqliteExplorer.CursorContextManager(self.get_connection()) + + @classmethod + def _get_select_query(cls, schema_name, table_name, column_list): + return cls._query_template.format( + column_list=",".join([col.get_name() for col in column_list]), + table_name=table_name.get_name() + ) + + +class MySQLExplorer(RelDbExplorer): + _catalog_query = """ + SELECT + TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA NOT IN ('information_schema', 'performance_schema', 'sys', 'mysql') + AND DATA_TYPE RLIKE 'char.*|varchar.*|text' + ORDER BY table_schema, table_name, column_name + """ + + def __init__(self, ns): + super(MySQLExplorer, self).__init__(ns) + + @property + def default_port(self): + return 3036 + + def _open_connection(self): + return pymysql.connect(host=self.host, + port=self.port, + user=self.user, + password=self.password) + + def _get_catalog_query(self): + return self._catalog_query + + +class PostgreSQLExplorer(RelDbExplorer): + _catalog_query = """ + SELECT + TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA NOT IN ('information_schema', 'pg_catalog') + AND DATA_TYPE SIMILAR TO '%char%|%text%' + ORDER BY table_schema, table_name, column_name + """ + + def __init__(self, ns): + super(PostgreSQLExplorer, self).__init__(ns) + self.database = self.database if ns.database is None else ns.database + + @property + def default_database(self): + return "public" + + @property + def default_port(self): + return 5432 + + def _open_connection(self): + return psycopg2.connect(host=self.host, + port=self.port, + user=self.user, + password=self.password, + database=self.database) + + def _get_catalog_query(self): + return self._catalog_query + + +class MSSQLExplorer(RelDbExplorer): + _catalog_query = """ + SELECT + TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + DATA_TYPE LIKE '%char%' + ORDER BY TABLE_SCHEMA, table_name, ordinal_position + """ + + _sample_query_template = "SELECT TOP 10 * FROM {schema_name}.{table_name} TABLESAMPLE (1000 ROWS)" + + def __init__(self, ns): + super(MSSQLExplorer, self).__init__(ns) + self.database = self.database if ns.database is None else ns.database + + @property + def default_database(self): + return "public" + + @property + def default_port(self): + return 1433 + + def _open_connection(self): + return pymssql.connect(host=self.host, + port=self.port, + user=self.user, + password=self.password, + database=self.database) + + def _get_catalog_query(self): + return self._catalog_query + + @classmethod + def _get_sample_query(cls, schema_name, table_name, column_list): + return cls._sample_query_template.format( + column_list=",".join([col.get_name() for col in column_list]), + schema_name=schema_name.get_name(), + table_name=table_name.get_name() + ) + + +class OracleExplorer(RelDbExplorer): + _catalog_query = """ + SELECT + '{db}', TABLE_NAME, COLUMN_NAME + FROM + USER_TAB_COLUMNS + WHERE UPPER(DATA_TYPE) LIKE '%CHAR%' + ORDER BY TABLE_NAME, COLUMN_ID + """ + + _sample_query_template = "select {column_list} from {table_name} sample(5)" + _select_query_template = "select {column_list} from {table_name}" + _count_query = "select count(*) from {table_name}" + + def __init__(self, ns): + super(OracleExplorer, self).__init__(ns) + self.database = ns.database + + @property + def default_port(self): + return 1521 + + def _open_connection(self): + return cx_Oracle.connect(self.user, + self.password, + "%s:%d/%s" % (self.host, self.port, self.database)) + + def _get_catalog_query(self): + return self._catalog_query.format(db=self.database) + + @classmethod + def _get_select_query(cls, schema_name, table_name, column_list): + return cls._select_query_template.format( + column_list=",".join([col.get_name() for col in column_list]), + table_name=table_name.get_name() + ) + + @classmethod + def _get_sample_query(cls, schema_name, table_name, column_list): + return cls._sample_query_template.format( + column_list=",".join([col.get_name() for col in column_list]), + table_name=table_name.get_name() + ) + + @classmethod + def _get_count_query(cls, schema_name, table_name): + return cls._count_query.format( + table_name=table_name.get_name() + ) diff --git a/piicatcher/db/explorer.py b/piicatcher/explorer/explorer.py similarity index 51% rename from piicatcher/db/explorer.py rename to piicatcher/explorer/explorer.py index 0c057c3..c8dfe58 100644 --- a/piicatcher/db/explorer.py +++ b/piicatcher/explorer/explorer.py @@ -1,28 +1,27 @@ +import json +import logging from abc import ABC, abstractmethod from datetime import datetime, timedelta -import sqlite3 -import pymysql -import psycopg2 -import pymssql -import cx_Oracle - -import logging -import json +import yaml import tableprint -from piicatcher.db.metadata import Schema, Table, Column -from piicatcher.orm.models import Store +from piicatcher.explorer.metadata import Schema, Table, Column +from piicatcher.store.db import DbStore class Explorer(ABC): query_template = "select {column_list} from {schema_name}.{table_name}" _count_query = "select count(*) from {schema_name}.{table_name}" - def __init__(self): + def __init__(self, ns): self._connection = None self._schemas = None self._cache_ts = None + self.config = None + + if ns.config_file is not None: + self.config = yaml.full_load(ns.configfile) def __enter__(self): return self @@ -40,21 +39,7 @@ def _get_catalog_query(self): @classmethod def factory(cls, ns): - logging.debug("Db Dispatch entered") - explorer = None - if ns.connection_type == "sqlite": - explorer = SqliteExplorer(ns.host) - elif ns.connection_type == "mysql": - explorer = MySQLExplorer(ns.host, ns.port, ns.user, ns.password) - elif ns.connection_type == "postgres" or ns.connection_type == "redshift": - explorer = PostgreSQLExplorer(ns.host, ns.port, ns.user, ns.password, ns.database) - elif ns.connection_type == "sqlserver": - explorer = MSSQLExplorer(ns.host, ns.port, ns.user, ns.password, ns.database) - elif ns.connection_type == "oracle": - explorer = OracleExplorer(ns.host, ns.port, ns.user, ns.password, ns.database) - assert (explorer is not None) - - return explorer + pass @classmethod def parser(cls, sub_parsers): @@ -73,6 +58,9 @@ def parser(cls, sub_parsers): sub_parser.add_argument("-t", "--connection-type", default="sqlite", choices=["sqlite", "mysql", "postgres", "redshift", "oracle", "sqlserver"], help="Type of database") + sub_parser.add_argument("-f", "--output-format", choices=["ascii_table", "json", "db"], + default="ascii_table", + help="Choose output format type") cls.scan_options(sub_parser) sub_parser.set_defaults(func=Explorer.dispatch) @@ -86,9 +74,6 @@ def scan_options(cls, sub_parser): sub_parser.add_argument("-o", "--output", default=None, help="File path for report. If not specified, " "then report is printed to sys.stdout") - sub_parser.add_argument("-f", "--output-format", choices=["ascii_table", "json", "orm"], - default="ascii_table", - help="Choose output format type") sub_parser.add_argument("--list-all", action="store_true", default=False, help="List all columns. By default only columns with PII information is listed") @@ -101,13 +86,17 @@ def dispatch(cls, ns): else: explorer.shallow_scan() + cls.output(ns, explorer) + + @classmethod + def output(cls, ns, explorer): if ns.output_format == "ascii_table": headers = ["schema", "table", "column", "has_pii"] tableprint.table(explorer.get_tabular(ns.list_all), headers) elif ns.output_format == "json": print(json.dumps(explorer.get_dict(), sort_keys=True, indent=2)) - elif ns.output_format == "orm": - Store.save_schemas(explorer) + elif ns.output_format == "db": + DbStore.save_schemas(explorer) def get_connection(self): if self._connection is None: @@ -255,215 +244,4 @@ def get_columns(self, schema_name, table_name): if t.get_name() == table_name: return t.get_columns() - raise ValueError("{} table not found".format(table_name)) - - -class SqliteExplorer(Explorer): - _catalog_query = """ - SELECT - "" as schema_name, - m.name as table_name, - p.name as column_name, - p.type as data_type - FROM - sqlite_master AS m - JOIN - pragma_table_info(m.name) AS p - WHERE - p.type like 'text' or p.type like 'varchar%' or p.type like 'char%' - ORDER BY - m.name, - p.name - """ - - _query_template = "select {column_list} from {table_name}" - - class CursorContextManager(): - def __init__(self, connection): - self.connection = connection - - def __enter__(self): - self.cursor = self.connection.cursor() - return self.cursor - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - def __init__(self, conn_string): - super(SqliteExplorer, self).__init__() - self.conn_string = conn_string - - def _open_connection(self): - logging.debug("Sqlite connection string '{}'".format(self.conn_string)) - return sqlite3.connect(self.conn_string) - - def _get_catalog_query(self): - return self._catalog_query - - def _get_context_manager(self): - return SqliteExplorer.CursorContextManager(self.get_connection()) - - @classmethod - def _get_select_query(cls, schema_name, table_name, column_list): - return self._query_template.format( - column_list=",".join([col.get_name() for col in column_list]), - table_name=table_name.get_name() - ) - - -class MySQLExplorer(Explorer): - _catalog_query = """ - SELECT - TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - TABLE_SCHEMA NOT IN ('information_schema', 'performance_schema', 'sys', 'mysql') - AND DATA_TYPE RLIKE 'char.*|varchar.*|text' - ORDER BY table_schema, table_name, column_name - """ - - default_port = 3036 - - def __init__(self, host, port, user, password): - super(MySQLExplorer, self).__init__() - self.host = host - self.user = user - self.password = password - self.port = self.default_port if port is None else int(port) - - def _open_connection(self): - return pymysql.connect(host=self.host, - port=self.port, - user=self.user, - password=self.password) - - def _get_catalog_query(self): - return self._catalog_query - - -class PostgreSQLExplorer(Explorer): - _catalog_query = """ - SELECT - TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - TABLE_SCHEMA NOT IN ('information_schema', 'pg_catalog') - AND DATA_TYPE SIMILAR TO '%char%|%text%' - ORDER BY table_schema, table_name, column_name - """ - - default_port = 5432 - - def __init__(self, host, port, user, password, database='public'): - super(PostgreSQLExplorer, self).__init__() - self.host = host - self.port = self.default_port if port is None else int(port) - self.user = user - self.password = password - self.database = database - - def _open_connection(self): - return psycopg2.connect(host=self.host, - port=self.port, - user=self.user, - password=self.password, - database=self.database) - - def _get_catalog_query(self): - return self._catalog_query - - -class MSSQLExplorer(Explorer): - _catalog_query = """ - SELECT - TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - DATA_TYPE LIKE '%char%' - ORDER BY TABLE_SCHEMA, table_name, ordinal_position - """ - - _sample_query_template = "SELECT TOP 10 * FROM {schema_name}.{table_name} TABLESAMPLE (1000 ROWS)" - default_port = 1433 - - def __init__(self, host, port, user, password, database='public'): - super(MSSQLExplorer, self).__init__() - self.host = host - self.port = self.default_port if port is None else int(port) - self.user = user - self.password = password - self.database = database - - def _open_connection(self): - return pymssql.connect(host=self.host, - port=self.port, - user=self.user, - password=self.password, - database=self.database) - - def _get_catalog_query(self): - return self._catalog_query - - @classmethod - def _get_sample_query(cls, schema_name, table_name, column_list): - return cls._sample_query_template.format( - column_list=",".join([col.get_name() for col in column_list]), - schema_name=schema_name.get_name(), - table_name=table_name.get_name() - ) - - -class OracleExplorer(Explorer): - _catalog_query = """ - SELECT - '{db}', TABLE_NAME, COLUMN_NAME - FROM - USER_TAB_COLUMNS - WHERE UPPER(DATA_TYPE) LIKE '%CHAR%' - ORDER BY TABLE_NAME, COLUMN_ID - """ - - _sample_query_template = "select {column_list} from {table_name} sample(5)" - _select_query_template = "select {column_list} from {table_name}" - _count_query = "select count(*) from {table_name}" - - default_port = 1521 - - def __init__(self, host, port, user, password, database): - super(OracleExplorer, self).__init__() - self.host = host - self.port = self.default_port if port is None else int(port) - self.user = user - self.password = password - self.database = database - - def _open_connection(self): - return cx_Oracle.connect(self.user, - self.password, - "%s:%d/%s" % (self.host, self.port, self.database)) - - def _get_catalog_query(self): - return self._catalog_query.format(db=self.database) - - @classmethod - def _get_select_query(cls, schema_name, table_name, column_list): - return cls._select_query_template.format( - column_list=",".join([col.get_name() for col in column_list]), - table_name=table_name.get_name() - ) - - @classmethod - def _get_sample_query(cls, schema_name, table_name, column_list): - return cls._sample_query_template.format( - column_list=",".join([col.get_name() for col in column_list]), - table_name=table_name.get_name() - ) - - @classmethod - def _get_count_query(cls, schema_name, table_name): - return cls._count_query.format( - table_name=table_name.get_name() - ) + raise ValueError("{} table not found".format(table_name)) \ No newline at end of file diff --git a/piicatcher/files/explorer.py b/piicatcher/explorer/files.py similarity index 98% rename from piicatcher/files/explorer.py rename to piicatcher/explorer/files.py index fb1ed9e..e42e5e1 100644 --- a/piicatcher/files/explorer.py +++ b/piicatcher/explorer/files.py @@ -5,7 +5,7 @@ import magic from piicatcher.tokenizer import Tokenizer -from piicatcher.db.metadata import NamedObject +from piicatcher.explorer.metadata import NamedObject from piicatcher.piitypes import PiiTypes, PiiTypeEncoder from piicatcher.scanner import NERScanner, RegexScanner diff --git a/piicatcher/db/metadata.py b/piicatcher/explorer/metadata.py similarity index 100% rename from piicatcher/db/metadata.py rename to piicatcher/explorer/metadata.py diff --git a/piicatcher/files/__init__.py b/piicatcher/files/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/piicatcher/orm/__init__.py b/piicatcher/orm/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/piicatcher/orm/PiiTypeField.py b/piicatcher/store/PiiTypeField.py similarity index 95% rename from piicatcher/orm/PiiTypeField.py rename to piicatcher/store/PiiTypeField.py index bec4d5a..554698a 100644 --- a/piicatcher/orm/PiiTypeField.py +++ b/piicatcher/store/PiiTypeField.py @@ -1,5 +1,4 @@ import json -import logging from peewee import Field from piicatcher.piitypes import PiiTypeEncoder, as_enum diff --git a/piicatcher/store/__init__.py b/piicatcher/store/__init__.py new file mode 100644 index 0000000..0638e0d --- /dev/null +++ b/piicatcher/store/__init__.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod + + +class Store(ABC): + @classmethod + @abstractmethod + def save_schemas(cls, explorer): + pass diff --git a/piicatcher/orm/models.py b/piicatcher/store/db.py similarity index 66% rename from piicatcher/orm/models.py rename to piicatcher/store/db.py index 3312f0a..cd2b148 100644 --- a/piicatcher/orm/models.py +++ b/piicatcher/store/db.py @@ -1,9 +1,7 @@ -import json - from peewee import * -from piicatcher.orm.PiiTypeField import PiiTypeField -from piicatcher.config import config +from piicatcher.store import Store +from piicatcher.store.PiiTypeField import PiiTypeField database_proxy = DatabaseProxy() @@ -38,19 +36,6 @@ class DbFile(BaseModel): pii_types = PiiTypeField() -def init(): - if 'orm' in config: - orm = config['orm'] - database = MySQLDatabase('tokern', - host=orm['host'], - port=int(orm['port']), - user=orm['user'], - password=orm['password']) - database_proxy.initialize(database) - database_proxy.connect() - database_proxy.create_tables([DbSchemas, DbTables, DbColumns, DbFile]) - - def init_test(path): database = SqliteDatabase(path) database_proxy.initialize(database) @@ -62,9 +47,23 @@ def model_db_close(): database_proxy.close() -class Store: +class DbStore(Store): + @classmethod + def setup_database(cls, config): + if config is not None and 'orm' in config: + orm = config['store'] + database = MySQLDatabase('tokern', + host=orm['host'], + port=int(orm['port']), + user=orm['user'], + password=orm['password']) + database_proxy.initialize(database) + database_proxy.connect() + database_proxy.create_tables([DbSchemas, DbTables, DbColumns, DbFile]) + @classmethod def save_schemas(cls, explorer): + cls.setup_database(explorer.config) with database_proxy.atomic(): schemas = explorer.get_schemas() for s in schemas: diff --git a/piicatcher/store/glue.py b/piicatcher/store/glue.py new file mode 100644 index 0000000..85031fc --- /dev/null +++ b/piicatcher/store/glue.py @@ -0,0 +1,79 @@ +import logging + +from piicatcher.store import Store + +import boto3 + + +class GlueStore(Store): + @staticmethod + def update_column_parameters(column_parameters, pii_table): + updated_columns = [] + is_table_updated = False + + for c in column_parameters: + if c['Name'] in pii_table: + if 'Parameters' not in c or c['Parameters'] is None: + c['Parameters'] = {} + + c['Parameters']['PII'] = pii_table[c['Name']][0] + is_table_updated = True + updated_columns.append(c) + + return updated_columns, is_table_updated + + @staticmethod + def get_pii_table(table): + logging.debug("Processing table %s" % table.get_name()) + field_value = {} + for c in table.get_columns(): + pii = c.get_pii_types() + if pii: + field_value[c.get_name()] = sorted([str(v) for v in pii]) + return field_value + + @staticmethod + def update_table_params(table_params, column_params): + updated_params = {} + for param in ['Name', 'Description', 'Owner', 'LastAccessTime', 'LastAnalyzedTime', 'Retention', + 'StorageDescriptor', 'PartitionKeys', 'ViewOriginalText', 'ViewExpandedText', + 'TableType', 'Parameters']: + if param in table_params: + updated_params[param] = table_params[param] + + updated_params['StorageDescriptor']['Columns'] = column_params + + logging.debug("Updated parameters are :") + logging.debug(updated_params) + return updated_params + + @classmethod + def save_schemas(cls, explorer): + schemas = explorer.get_schemas() + client = boto3.client("glue", + region_name=explorer.config.region, + aws_access_key_id=explorer.config.access_key, + aws_secret_access_key=explorer.config.secret_key) + + logging.debug(client) + for s in schemas: + logging.debug("Processing schema %s" % s.get_name()) + for t in s.get_tables(): + field_value = GlueStore.get_pii_table(t) + table_info = client.get_table( + DatabaseName=s.get_name(), + Name=t.get_name() + ) + + logging.debug(table_info) + + updated_columns, is_table_updated = GlueStore.update_column_parameters( + table_info['Table']['StorageDescriptor']['Columns'], field_value + ) + + if is_table_updated: + updated_params = GlueStore.update_table_params(table_info['Table'], updated_columns) + client.update_table( + DatabaseName=s.get_name(), + TableInput=updated_params + ) diff --git a/requirements.txt b/requirements.txt index eecf991..6c72773 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ peewee pyyaml python-magic pyathena[sqlalchemy] +boto3 == 1.10.34 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c9b8279 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +def pytest_configure(config): + config.addinivalue_line( + "markers", "dbtest: Tests that require a database to run" + ) diff --git a/tests/test_awsexplorer.py b/tests/test_awsexplorer.py index fc2803c..f668fab 100644 --- a/tests/test_awsexplorer.py +++ b/tests/test_awsexplorer.py @@ -1,21 +1,22 @@ from unittest import TestCase, mock from argparse import Namespace -from piicatcher.db.aws import AthenaExplorer +from piicatcher.explorer.aws import AthenaExplorer class AwsExplorerTest(TestCase): def test_aws_dispath(self): - with mock.patch('piicatcher.db.aws.AthenaExplorer.scan', autospec=True) as mock_scan_method: - with mock.patch('piicatcher.db.aws.AthenaExplorer.get_tabular', + with mock.patch('piicatcher.explorer.aws.AthenaExplorer.scan', autospec=True) as mock_scan_method: + with mock.patch('piicatcher.explorer.aws.AthenaExplorer.get_tabular', autospec=True) as mock_tabular_method: - with mock.patch('piicatcher.db.explorer.tableprint', autospec=True) as MockTablePrint: + with mock.patch('piicatcher.explorer.explorer.tableprint', autospec=True) as MockTablePrint: AthenaExplorer.dispatch(Namespace(access_key='ACCESS KEY', secret_key='SECRET KEY', staging_dir='s3://DIR', region='us-east-1', scan_type=None, output_format="ascii_table", + config_file=None, list_all=False)) mock_scan_method.assert_called_once() mock_tabular_method.assert_called_once() diff --git a/tests/test_config.py b/tests/test_config.py index f315476..1af99e8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,7 +2,7 @@ import yaml config_file = """ -orm: +store: host: a_host port: 3306 user: a_user @@ -12,7 +12,7 @@ class TestConfigFile(TestCase): def testOrmConfig(self): - orm = yaml.load(config_file)['orm'] + orm = yaml.full_load(config_file)['store'] self.assertEqual("a_host", orm['host']) self.assertEqual(3306, orm["port"]) self.assertEqual("a_user", orm['user']) diff --git a/tests/test_dbexplorer.py b/tests/test_databases.py similarity index 70% rename from tests/test_dbexplorer.py rename to tests/test_databases.py index d19cda9..8982574 100644 --- a/tests/test_dbexplorer.py +++ b/tests/test_databases.py @@ -9,9 +9,9 @@ import logging import pytest -from piicatcher.db.explorer import Explorer, SqliteExplorer, MySQLExplorer, PostgreSQLExplorer, OracleExplorer, \ +from piicatcher.explorer.databases import SqliteExplorer, MySQLExplorer, PostgreSQLExplorer, OracleExplorer, \ MSSQLExplorer -from piicatcher.db.metadata import Schema, Table, Column +from piicatcher.explorer.metadata import Schema, Table, Column from piicatcher.piitypes import PiiTypes logging.basicConfig(level=logging.DEBUG) @@ -19,7 +19,7 @@ class ExplorerTest(TestCase): def setUp(self): - self.explorer = SqliteExplorer("mock_connection") + self.explorer = SqliteExplorer(Namespace(host="mock_connection", config_file=None)) col1 = Column('c1') col2 = Column('c2') @@ -217,9 +217,10 @@ def drop_tables(): request.addfinalizer(drop_tables) def setUp(self): - self.explorer = MySQLExplorer(host="127.0.0.1", - user="pii_tester", - password="pii_secret") + self.explorer = MySQLExplorer(Namespace( + host="127.0.0.1", + user="pii_tester", + password="pii_secret")) def tearDown(self): self.explorer.get_connection().close() @@ -252,9 +253,10 @@ def execute_script(cursor, script): @pytest.fixture(scope="class") def create_tables(self, request): - self.conn = pymysql.connect(host="127.0.0.1", - user="pii_tester", - password="pii_secret") + self.conn = pymysql.connect(Namespace( + host="127.0.0.1", + user="pii_tester", + password="pii_secret")) with self.conn.cursor() as cursor: self.execute_script(cursor, self.char_db_query) @@ -272,9 +274,10 @@ def drop_tables(): request.addfinalizer(drop_tables) def setUp(self): - self.explorer = MySQLExplorer(host="127.0.0.1", - user="pii_tester", - password="pii_secret") + self.explorer = MySQLExplorer(Namespace( + host="127.0.0.1", + user="pii_tester", + password="pii_secret")) def tearDown(self): self.explorer.get_connection().close() @@ -284,7 +287,6 @@ def get_test_schema(self): @pytest.mark.usefixtures("create_tables") -@pytest.mark.dbtest class PostgresDataTypeTest(CommonDataTypeTestCases.CommonDataTypeTests): char_db_drop = """ DROP TABLE char_columns; @@ -301,8 +303,9 @@ def execute_script(cursor, script): @pytest.fixture(scope="class") def create_tables(self, request): self.conn = psycopg2.connect(host="127.0.0.1", - user="postgres", - password="pii_secret") + user="piiuser", + password="p11secret", + database="piidb") self.conn.autocommit = True @@ -320,10 +323,13 @@ def drop_tables(): request.addfinalizer(drop_tables) def setUp(self): - self.explorer = PostgreSQLExplorer(host="127.0.0.1", - user="postgres", - password="pii_secret", - database="postgres") + self.explorer = PostgreSQLExplorer(Namespace( + host="127.0.0.1", + user="piiuser", + password="p11secret", + database="piidb", + config_file=None + )) def tearDown(self): self.explorer.get_connection().close() @@ -333,7 +339,6 @@ def get_test_schema(self): @pytest.mark.usefixtures("create_tables") -@pytest.mark.dbtest class PostgresExplorerTest(CommonExplorerTestCases.CommonExplorerTests): pii_db_drop = """ DROP TABLE full_pii; @@ -350,8 +355,9 @@ def execute_script(cursor, script): @pytest.fixture(scope="class") def create_tables(self, request): self.conn = psycopg2.connect(host="127.0.0.1", - user="postgres", - password="pii_secret") + user="piiuser", + password="p11secret", + database="piidb") self.conn.autocommit = True @@ -369,10 +375,13 @@ def drop_tables(): request.addfinalizer(drop_tables) def setUp(self): - self.explorer = PostgreSQLExplorer(host="127.0.0.1", - user="postgres", - password="pii_secret", - database="postgres") + self.explorer = PostgreSQLExplorer(Namespace( + host="127.0.0.1", + user="piiuser", + password="p11secret", + database="piidb", + config_file=None + )) def tearDown(self): self.explorer.get_connection().close() @@ -412,9 +421,9 @@ def test_sqlite(self): def test_postgres(self): self.assertEqual("select c1, c2 from testSchema.t1", - PostgreSQLExplorer._get_select_query(self.schema, - self.schema.get_tables()[0], - self.schema.get_tables()[0].get_columns())) + PostgreSQLExplorer._get_select_query(self.schema, + self.schema.get_tables()[0], + self.schema.get_tables()[0].get_columns())) def test_mysql(self): self.assertEqual("select c1, c2 from testSchema.t1", @@ -432,60 +441,64 @@ def test_mssql(self): class TestDispatcher(TestCase): def test_sqlite_dispatch(self): - with mock.patch('piicatcher.db.explorer.SqliteExplorer.scan', autospec=True) as mock_scan_method: - with mock.patch('piicatcher.db.explorer.SqliteExplorer.get_tabular', autospec=True) as mock_tabular_method: - with mock.patch('piicatcher.db.explorer.tableprint', autospec=True) as MockTablePrint: - Explorer.dispatch(Namespace(host='connection', list_all=None, output_format='ascii_table', - connection_type='sqlite', scan_type=None, port=None)) + with mock.patch('piicatcher.explorer.databases.SqliteExplorer.scan', autospec=True) as mock_scan_method: + with mock.patch('piicatcher.explorer.databases.SqliteExplorer.get_tabular', autospec=True) as mock_tabular_method: + with mock.patch('piicatcher.explorer.explorer.tableprint', autospec=True) as MockTablePrint: + SqliteExplorer.dispatch(Namespace(host='connection', list_all=None, output_format='ascii_table', + connection_type='sqlite', scan_type=None, config_file=None, + port=None)) mock_scan_method.assert_called_once() mock_tabular_method.assert_called_once() MockTablePrint.table.assert_called_once() def test_mysql_dispatch(self): - with mock.patch('piicatcher.db.explorer.MySQLExplorer.scan', autospec=True) as mock_scan_method: - with mock.patch('piicatcher.db.explorer.MySQLExplorer.get_tabular', autospec=True) as mock_tabular_method: - with mock.patch('piicatcher.db.explorer.tableprint', autospec=True) as MockTablePrint: - Explorer.dispatch(Namespace(host='connection', - port=None, - list_all=None, - output_format='ascii_table', - connection_type='mysql', - scan_type='deep', - user='user', - password='pass')) + with mock.patch('piicatcher.explorer.databases.MySQLExplorer.scan', autospec=True) as mock_scan_method: + with mock.patch('piicatcher.explorer.databases.MySQLExplorer.get_tabular', autospec=True) as mock_tabular_method: + with mock.patch('piicatcher.explorer.explorer.tableprint', autospec=True) as MockTablePrint: + MSSQLExplorer.dispatch(Namespace(host='connection', + port=None, + list_all=None, + output_format='ascii_table', + connection_type='mysql', + scan_type='deep', + config_file=None, + user='user', + password='pass')) mock_scan_method.assert_called_once() mock_tabular_method.assert_called_once() MockTablePrint.table.assert_called_once() def test_postgres_dispatch(self): - with mock.patch('piicatcher.db.explorer.PostgreSQLExplorer.scan', autospec=True) as mock_scan_method: - with mock.patch('piicatcher.db.explorer.PostgreSQLExplorer.get_tabular', autospec=True) as mock_tabular_method: - with mock.patch('piicatcher.db.explorer.tableprint', autospec=True) as MockTablePrint: - Explorer.dispatch(Namespace(host='connection', - port=None, - list_all=None, - output_format='ascii_table', - connection_type='postgres', - database='public', - scan_type=None, - user='user', - password='pass')) + with mock.patch('piicatcher.explorer.databases.PostgreSQLExplorer.scan', autospec=True) as mock_scan_method: + with mock.patch('piicatcher.explorer.databases.PostgreSQLExplorer.get_tabular', autospec=True) as mock_tabular_method: + with mock.patch('piicatcher.explorer.explorer.tableprint', autospec=True) as MockTablePrint: + PostgreSQLExplorer.dispatch(Namespace(host='connection', + port=None, + list_all=None, + output_format='ascii_table', + connection_type='postgres', + database='public', + scan_type=None, + config_file=None, + user='user', + password='pass')) mock_scan_method.assert_called_once() mock_tabular_method.assert_called_once() MockTablePrint.table.assert_called_once() def test_mysql_shallow_scan(self): - with mock.patch('piicatcher.db.explorer.MySQLExplorer.shallow_scan', autospec=True) as mock_shallow_scan_method: - with mock.patch('piicatcher.db.explorer.MySQLExplorer.get_tabular', autospec=True) as mock_tabular_method: - with mock.patch('piicatcher.db.explorer.tableprint', autospec=True) as MockTablePrint: - Explorer.dispatch(Namespace(host='connection', - port=None, - list_all=None, - output_format='ascii_table', - connection_type='mysql', - user='user', - password='pass', - scan_type="shallow")) + with mock.patch('piicatcher.explorer.databases.MySQLExplorer.shallow_scan', autospec=True) as mock_shallow_scan_method: + with mock.patch('piicatcher.explorer.databases.MySQLExplorer.get_tabular', autospec=True) as mock_tabular_method: + with mock.patch('piicatcher.explorer.explorer.tableprint', autospec=True) as MockTablePrint: + MySQLExplorer.dispatch(Namespace(host='connection', + port=None, + list_all=None, + output_format='ascii_table', + connection_type='mysql', + config_file=None, + user='user', + password='pass', + scan_type="shallow")) mock_shallow_scan_method.assert_called_once() mock_tabular_method.assert_called_once() MockTablePrint.table.assert_called_once() diff --git a/tests/test_dbmetadata.py b/tests/test_dbmetadata.py index b84d53a..e62634e 100644 --- a/tests/test_dbmetadata.py +++ b/tests/test_dbmetadata.py @@ -1,5 +1,5 @@ from unittest import TestCase -from piicatcher.db.metadata import Column, Table, Schema +from piicatcher.explorer.metadata import Column, Table, Schema class DbMetadataTests(TestCase): diff --git a/tests/test_file_explorer.py b/tests/test_file_explorer.py index 752ee55..953f6e8 100644 --- a/tests/test_file_explorer.py +++ b/tests/test_file_explorer.py @@ -1,15 +1,15 @@ from unittest import TestCase, mock from argparse import Namespace -from piicatcher.files.explorer import File, FileExplorer, dispatch +from piicatcher.explorer.files import File, FileExplorer, dispatch from piicatcher.piitypes import PiiTypes class TestDispatcher(TestCase): def test_file_dispatch(self): - with mock.patch('piicatcher.files.explorer.FileExplorer.scan', autospec=True) as mock_scan_method: - with mock.patch('piicatcher.files.explorer.FileExplorer.get_tabular', autospec=True) as mock_tabular_method: - with mock.patch('piicatcher.files.explorer.tableprint', autospec=True) as MockTablePrint: + with mock.patch('piicatcher.explorer.files.FileExplorer.scan', autospec=True) as mock_scan_method: + with mock.patch('piicatcher.explorer.files.FileExplorer.get_tabular', autospec=True) as mock_tabular_method: + with mock.patch('piicatcher.explorer.files.tableprint', autospec=True) as MockTablePrint: dispatch(Namespace(path='/a/b/c', output_format='ascii_table')) mock_scan_method.assert_called_once() mock_tabular_method.assert_called_once() diff --git a/tests/test_glue.py b/tests/test_glue.py new file mode 100644 index 0000000..e374ebb --- /dev/null +++ b/tests/test_glue.py @@ -0,0 +1,198 @@ +import unittest +import datetime + +from dateutil.tz import tzlocal + +from piicatcher.store.glue import GlueStore +from tests.test_models import MockExplorer + + +class PiiTable(unittest.TestCase): + def test_no_pii(self): + pii_table = GlueStore.get_pii_table(MockExplorer.get_no_pii_table()) + self.assertEqual({}, pii_table) + + def test_partial_pii(self): + pii_table = GlueStore.get_pii_table(MockExplorer.get_partial_pii_table()) + self.assertEqual({'a': ['PiiTypes.PHONE']}, pii_table) + + def test_full_pii(self): + pii_table = GlueStore.get_pii_table(MockExplorer.get_full_pii_table()) + self.assertEqual({'a': ['PiiTypes.PHONE'], 'b': ['PiiTypes.ADDRESS', 'PiiTypes.LOCATION']}, pii_table) + + +class UpdateParameters(unittest.TestCase): + def test_empty_table(self): + columns = [{ + 'Name': 'dispatching_base_num', "Type": "string" + }, { + "Name": "pickup_datetime", "Type": "string" + }, { + "Name": "dropoff_datetime", "Type": "string" + }, { + "Name": "pulocationid", "Type": "bigint" + }, { + "Name": "dolocationid", "Type": "bigint" + }, { + "Name": "sr_flag", "Type": "bigint" + }, { + "Name": "hvfhs_license_num", "Type": "string" + } + ] + updated, is_updated = GlueStore.update_column_parameters(columns, {}) + self.assertFalse(is_updated) + self.assertEqual(columns, updated) + + def test_for_update(self): + columns = [ + {'Name': 'locationid', 'Type': 'bigint'}, {'Name': 'borough', 'Type': 'string'}, + {'Name': 'zone', 'Type': 'string'}, {'Name': 'service_zone', 'Type': 'string'} + ] + + expected = [ + {'Name': 'locationid', 'Type': 'bigint'}, + {'Name': 'borough', 'Type': 'string', 'Parameters': {'PII': 'PiiTypes.ADDRESS'}}, + {'Name': 'zone', 'Type': 'string', 'Parameters': {'PII': 'PiiTypes.ADDRESS'}}, + {'Name': 'service_zone', 'Type': 'string', 'Parameters': {'PII': 'PiiTypes.ADDRESS'}} + ] + + pii_table = { + 'borough': ['PiiTypes.ADDRESS'], + 'zone': ['PiiTypes.ADDRESS'], + 'service_zone': ['PiiTypes.ADDRESS'] + } + + updated, is_updated = GlueStore.update_column_parameters(columns, pii_table) + self.assertTrue(is_updated) + self.assertEqual(expected, columns) + + def test_param_no_update(self): + columns = [ + {'Name': 'locationid', 'Type': 'bigint', 'Parameters': {'a': 'b'}}, {'Name': 'borough', 'Type': 'string'}, + ] + + updated, is_updated = GlueStore.update_column_parameters(columns, {}) + self.assertFalse(is_updated) + self.assertEqual(columns, updated) + + def test_param_update(self): + columns = [ + {'Name': 'locationid', 'Type': 'bigint', }, {'Name': 'borough', 'Type': 'string', 'Parameters': {'a': 'b'}}, + ] + + pii_table = { + 'borough': ['PiiTypes.ADDRESS'], + } + + expected = [ + {'Name': 'locationid', 'Type': 'bigint'}, + {'Name': 'borough', 'Type': 'string', 'Parameters': {'a': 'b', 'PII': 'PiiTypes.ADDRESS'}} + ] + + updated, is_updated = GlueStore.update_column_parameters(columns, pii_table) + self.assertTrue(is_updated) + self.assertEqual(expected, columns) + + +class TableParams(unittest.TestCase): + def test_update(self): + updated_columns = [ + {'Name': 'locationid', 'Type': 'bigint'}, + {'Name': 'borough', 'Type': 'string', 'Parameters': {'PII': 'PiiTypes.ADDRESS'}}, + {'Name': 'zone', 'Type': 'string', 'Parameters': {'PII': 'PiiTypes.ADDRESS'}}, + {'Name': 'service_zone', 'Type': 'string', 'Parameters': {'PII': 'PiiTypes.ADDRESS'}} + ] + + table_params = { + 'Name': 'csv_misc', 'DatabaseName': 'taxidata', 'Owner': 'owner', + 'CreateTime': datetime.datetime(2019, 12, 9, 16, 12, 43, tzinfo=tzlocal()), + 'UpdateTime': datetime.datetime(2019, 12, 9, 16, 12, 43, tzinfo=tzlocal()), + 'LastAccessTime': datetime.datetime(2019, 12, 9, 16, 12, 43, tzinfo=tzlocal()), + 'Retention': 0, + 'StorageDescriptor': {'Columns': [ + {'Name': 'locationid', 'Type': 'bigint'}, + {'Name': 'borough', 'Type': 'string'}, + {'Name': 'zone', 'Type': 'string'}, + {'Name': 'service_zone', 'Type': 'string'} + ], + 'Location': 's3://nyc-tlc/misc/', + 'InputFormat': 'org.apache.hadoop.mapred.TextInputFormat', + 'OutputFormat': 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat', + 'Compressed': False, 'NumberOfBuckets': -1, + 'SerdeInfo': { + 'SerializationLibrary': 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe', + 'Parameters': {'field.delim': ','}}, 'BucketColumns': [], + 'SortColumns': [], + 'Parameters': { + 'CrawlerSchemaDeserializerVersion': '1.0', 'CrawlerSchemaSerializerVersion': '1.0', + 'UPDATED_BY_CRAWLER': 'TaxiCrawler', 'areColumnsQuoted': 'false', + 'averageRecordSize': '36', 'classification': 'csv', 'columnsOrdered': 'true', + 'compressionType': 'none', 'delimiter': ',', + 'exclusions': '["s3://nyc-tlc/misc/*foil*","s3://nyc-tlc/misc/shared*",' + '"s3://nyc-tlc/misc/uber*","s3://nyc-tlc/misc/*.html",' + '"s3://nyc-tlc/misc/*.zip","s3://nyc-tlc/misc/FOIL_*"]', + 'objectCount': '1', 'recordCount': '342', 'sizeKey': '12322', + 'skip.header.line.count': '1', 'typeOfData': 'file'}, + 'StoredAsSubDirectories': False + }, + 'PartitionKeys': [], 'TableType': 'EXTERNAL_TABLE', 'Parameters': { + 'CrawlerSchemaDeserializerVersion': '1.0', 'CrawlerSchemaSerializerVersion': '1.0', + 'UPDATED_BY_CRAWLER': 'TaxiCrawler', 'areColumnsQuoted': 'false', + 'averageRecordSize': '36', 'classification': 'csv', 'columnsOrdered': 'true', + 'compressionType': 'none', 'delimiter': ',', + 'exclusions': + '["s3://nyc-tlc/misc/*foil*","s3://nyc-tlc/misc/shared*","s3://nyc-tlc/misc/uber*",' + '"s3://nyc-tlc/misc/*.html","s3://nyc-tlc/misc/*.zip","s3://nyc-tlc/misc/FOIL_*"]', + 'objectCount': '1', 'recordCount': '342', 'sizeKey': '12322', + 'skip.header.line.count': '1', 'typeOfData': 'file'}, + 'CreatedBy': + 'arn:aws:sts::172965158661:assumed-role/LakeFormationWorkflowRole/AWS-Crawler', + 'IsRegisteredWithLakeFormation': False + } + + expected_table_params = { + 'Name': 'csv_misc', 'Owner': 'owner', + 'LastAccessTime': datetime.datetime(2019, 12, 9, 16, 12, 43, tzinfo=tzlocal()), + 'Retention': 0, + 'StorageDescriptor': { + 'Columns': [ + {'Name': 'locationid', 'Type': 'bigint'}, + {'Name': 'borough', 'Type': 'string', 'Parameters': {'PII': 'PiiTypes.ADDRESS'}}, + {'Name': 'zone', 'Type': 'string', 'Parameters': {'PII': 'PiiTypes.ADDRESS'}}, + {'Name': 'service_zone', 'Type': 'string', 'Parameters': {'PII': 'PiiTypes.ADDRESS'}} + ], + 'Location': 's3://nyc-tlc/misc/', 'InputFormat': 'org.apache.hadoop.mapred.TextInputFormat', + 'OutputFormat': 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat', + 'Compressed': False, 'NumberOfBuckets': -1, + 'SerdeInfo': { + 'SerializationLibrary': 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe', + 'Parameters': {'field.delim': ','}}, 'BucketColumns': [], 'SortColumns': [], + 'Parameters': { + 'CrawlerSchemaDeserializerVersion': '1.0', 'CrawlerSchemaSerializerVersion': '1.0', + 'UPDATED_BY_CRAWLER': 'TaxiCrawler', 'areColumnsQuoted': 'false', 'averageRecordSize': '36', + 'classification': 'csv', 'columnsOrdered': 'true', 'compressionType': 'none', 'delimiter': ',', + 'exclusions': '["s3://nyc-tlc/misc/*foil*","s3://nyc-tlc/misc/shared*","s3://nyc-tlc/misc/uber*",' + '"s3://nyc-tlc/misc/*.html","s3://nyc-tlc/misc/*.zip","s3://nyc-tlc/misc/FOIL_*"]', + 'objectCount': '1', 'recordCount': '342', 'sizeKey': '12322', 'skip.header.line.count': '1', + 'typeOfData': 'file' + }, + 'StoredAsSubDirectories': False + }, + 'PartitionKeys': [], 'TableType': 'EXTERNAL_TABLE', + 'Parameters': { + 'CrawlerSchemaDeserializerVersion': '1.0', 'CrawlerSchemaSerializerVersion': '1.0', + 'UPDATED_BY_CRAWLER': 'TaxiCrawler', 'areColumnsQuoted': 'false', 'averageRecordSize': '36', + 'classification': 'csv', 'columnsOrdered': 'true', 'compressionType': 'none', 'delimiter': ',', + 'exclusions': '["s3://nyc-tlc/misc/*foil*","s3://nyc-tlc/misc/shared*","s3://nyc-tlc/misc/uber*",' + '"s3://nyc-tlc/misc/*.html","s3://nyc-tlc/misc/*.zip","s3://nyc-tlc/misc/FOIL_*"]', + 'objectCount': '1', 'recordCount': '342', 'sizeKey': '12322', 'skip.header.line.count': '1', + 'typeOfData': 'file' + } + } + + updated_table_params = GlueStore.update_table_params(table_params, updated_columns) + self.assertEqual(updated_table_params, expected_table_params) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_models.py b/tests/test_models.py index 421d60f..0dd51e4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,14 +1,15 @@ +from argparse import Namespace from unittest import TestCase -from shutil import rmtree import logging import sqlite3 import pytest -from piicatcher.orm.models import * -from piicatcher.db.explorer import Explorer, SqliteExplorer -from piicatcher.db.metadata import Schema, Table, Column +from piicatcher.store.db import * +from piicatcher.explorer.databases import SqliteExplorer +from piicatcher.explorer.explorer import Explorer +from piicatcher.explorer.metadata import Schema, Table, Column from piicatcher.piitypes import PiiTypes -from piicatcher.orm.models import Store +from piicatcher.store.db import DbStore logging.basicConfig(level=logging.DEBUG) @@ -42,21 +43,8 @@ def _open_connection(self): def _get_catalog_query(self): pass - def _load_catalog(self): - pass - - def set_schema(self, schema): - self._schemas = [schema] - - -class TestStore(TestCase): - sqlite_path = 'file::memory:?cache=shared' - - @classmethod - def setUpClass(cls): - init_test(cls.sqlite_path) - schema = Schema("test_store") - + @staticmethod + def get_no_pii_table(): no_pii_table = Table("test_store", "no_pii") no_pii_a = Column("a") no_pii_b = Column("b") @@ -64,8 +52,10 @@ def setUpClass(cls): no_pii_table.add(no_pii_a) no_pii_table.add(no_pii_b) - schema.add(no_pii_table) + return no_pii_table + @staticmethod + def get_partial_pii_table(): partial_pii_table = Table("test_store", "partial_pii") partial_pii_a = Column("a") partial_pii_a.add_pii_type(PiiTypes.PHONE) @@ -74,8 +64,10 @@ def setUpClass(cls): partial_pii_table.add(partial_pii_a) partial_pii_table.add(partial_pii_b) - schema.add(partial_pii_table) + return partial_pii_table + @staticmethod + def get_full_pii_table(): full_pii_table = Table("test_store", "full_pii") full_pii_a = Column("a") full_pii_a.add_pii_type(PiiTypes.PHONE) @@ -86,12 +78,27 @@ def setUpClass(cls): full_pii_table.add(full_pii_a) full_pii_table.add(full_pii_b) - schema.add(full_pii_table) + return full_pii_table + + def _load_catalog(self): + schema = Schema("test_store") + schema.add(MockExplorer.get_no_pii_table()) + schema.add(MockExplorer.get_partial_pii_table()) + schema.add(MockExplorer.get_full_pii_table()) + + self._schemas = [schema] + + +class TestStore(TestCase): + sqlite_path = 'file::memory:?cache=shared' + + @classmethod + def setUpClass(cls): + init_test(cls.sqlite_path) - explorer = MockExplorer() - explorer.set_schema(schema) + explorer = MockExplorer(Namespace(config_file=None)) - Store.save_schemas(explorer) + DbStore.save_schemas(explorer) @classmethod def tearDownClass(cls):