Skip to content

Commit

Permalink
new: Glue support and Postgres tests
Browse files Browse the repository at this point in the history
Add support to add parameters in Glue catalog with metadata on PII.

The commit also contains the following major changes:
- Reorganize output format options to add glue.
- Rename orm to store.
- Rename models to db.
- Remove a couple of warnings in tests.
- Refactor explorers and add glue store.
- Enable postgres tests
  • Loading branch information
vrajat authored Dec 10, 2019
1 parent edaef55 commit 35aab78
Show file tree
Hide file tree
Showing 24 changed files with 757 additions and 398 deletions.
19 changes: 19 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,30 @@ jobs:
environment: # environment variables for primary container
PIPENV_VENV_IN_PROJECT: true

# Specify service dependencies here if necessary
# CircleCI maintains a library of pre-built images
# documented at https://circleci.com/docs/2.0/circleci-images/
- image: circleci/postgres:12.0-alpine-ram
environment:
POSTGRES_USER: piiuser
POSTGRES_PASSWORD: p11secret
POSTGRES_DB: piidb

working_directory: ~/repo

steps:
- checkout

- run:
name: install dockerize
command: wget https://github.com/jwilder/dockerize/releases/download/$DOCKERIZE_VERSION/dockerize-linux-amd64-$DOCKERIZE_VERSION.tar.gz && sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-$DOCKERIZE_VERSION.tar.gz && rm dockerize-linux-amd64-$DOCKERIZE_VERSION.tar.gz
environment:
DOCKERIZE_VERSION: v0.3.0

- run:
name: Wait for db
command: dockerize -wait tcp://localhost:5432 -timeout 1m

# Download and cache dependencies
- restore_cache:
key: deps9-{{ .Branch }}-{{ checksum "Pipfile.lock" }}
Expand Down
14 changes: 3 additions & 11 deletions piicatcher/command_line.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import argparse
import yaml
import logging

from piicatcher.db.aws import AthenaExplorer
from piicatcher.db.explorer import Explorer
from piicatcher.files.explorer import parser as files_parser

from piicatcher.config import set_global_config
from piicatcher.orm.models import init
from piicatcher.explorer.aws import AthenaExplorer
from piicatcher.explorer.explorer import Explorer
from piicatcher.explorer.files import parser as files_parser


def get_parser(parser_cls=argparse.ArgumentParser):
Expand All @@ -25,10 +21,6 @@ def get_parser(parser_cls=argparse.ArgumentParser):

def dispatch(ns):
logging.basicConfig(level=getattr(logging, ns.log_level.upper()))
if ns.config_file is not None:
with open(ns.config_file, 'r') as stream:
set_global_config(yaml.load(stream))
init()

ns.func(ns)

Expand Down
6 changes: 0 additions & 6 deletions piicatcher/config.py

This file was deleted.

File renamed without changes.
33 changes: 20 additions & 13 deletions piicatcher/db/aws.py → piicatcher/explorer/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import pyathena

from piicatcher.db.explorer import Explorer
from piicatcher.explorer.explorer import Explorer
from piicatcher.store.glue import GlueStore


class AthenaExplorer(Explorer):
Expand All @@ -20,18 +21,14 @@ class AthenaExplorer(Explorer):
_select_query_template = "select {column_list} from {schema_name}.{table_name} limit 10"
_count_query = "select count(*) from {schema_name}.{table_name}"

def __init__(self, access_key, secret_key, staging_dir, region_name):
super(AthenaExplorer, self).__init__()
self._access_key = access_key
self._secret_key = secret_key
self._staging_dir = staging_dir
self._region_name = region_name
def __init__(self, ns):
super(AthenaExplorer, self).__init__(ns)
self.config = ns

@classmethod
def factory(cls, ns):
logging.debug("AWS Dispatch entered")
explorer = AthenaExplorer(ns.access_key, ns.secret_key,
ns.staging_dir, ns.region)
explorer = AthenaExplorer(ns)
return explorer

@classmethod
Expand All @@ -46,15 +43,25 @@ def parser(cls, sub_parsers):
help="S3 Staging Directory for Athena results")
sub_parser.add_argument("-r", "--region", required=True,
help="AWS Region")
sub_parser.add_argument("-f", "--output-format", choices=["ascii_table", "json", "db", "glue"],
default="ascii_table",
help="Choose output format type")

cls.scan_options(sub_parser)
sub_parser.set_defaults(func=AthenaExplorer.dispatch)

@classmethod
def output(cls, ns, explorer):
if ns.output_format == "glue":
GlueStore.save_schemas(explorer)
else:
super(AthenaExplorer, cls).output(ns, explorer)

def _open_connection(self):
return pyathena.connect(aws_access_key_id=self._access_key,
aws_secret_access_key=self._secret_key,
s3_staging_dir=self._staging_dir,
region_name=self._region_name)
return pyathena.connect(aws_access_key_id=self.config.access_key,
aws_secret_access_key=self.config.secret_key,
s3_staging_dir=self.config.staging_dir,
region_name=self.config.region)

def _get_catalog_query(self):
return self._catalog_query
Expand Down
260 changes: 260 additions & 0 deletions piicatcher/explorer/databases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import sqlite3
from abc import abstractmethod

import pymysql
import psycopg2
import pymssql
import cx_Oracle

import logging

from piicatcher.explorer.explorer import Explorer


class RelDbExplorer(Explorer):
def __init__(self, ns):
super(RelDbExplorer, self).__init__(ns)
self.host = ns.host
self.user = ns.user
self.password = ns.password
self.port = int(ns.port) if 'port' in vars(ns) and ns.port is not None else self.default_port

@property
@abstractmethod
def default_port(self):
pass

@classmethod
def factory(cls, ns):
logging.debug("Relational Db Factory entered")
explorer = None
if ns.connection_type == "sqlite":
explorer = SqliteExplorer(ns)
elif ns.connection_type == "mysql":
explorer = MySQLExplorer(ns)
elif ns.connection_type == "postgres" or ns.connection_type == "redshift":
explorer = PostgreSQLExplorer(ns)
elif ns.connection_type == "sqlserver":
explorer = MSSQLExplorer(ns)
elif ns.connection_type == "oracle":
explorer = OracleExplorer(ns)
assert (explorer is not None)

return explorer


class SqliteExplorer(Explorer):
_catalog_query = """
SELECT
"" as schema_name,
m.name as table_name,
p.name as column_name,
p.type as data_type
FROM
sqlite_master AS m
JOIN
pragma_table_info(m.name) AS p
WHERE
p.type like 'text' or p.type like 'varchar%' or p.type like 'char%'
ORDER BY
m.name,
p.name
"""

_query_template = "select {column_list} from {table_name}"

class CursorContextManager:
def __init__(self, connection):
self.connection = connection

def __enter__(self):
self.cursor = self.connection.cursor()
return self.cursor

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def __init__(self, ns):
super(SqliteExplorer, self).__init__(ns)
self.host = ns.host

@classmethod
def factory(cls, ns):
logging.debug("Sqlite Factory entered")
return SqliteExplorer(ns)

def _open_connection(self):
logging.debug("Sqlite connection string '{}'".format(self.host))
return sqlite3.connect(self.host)

def _get_catalog_query(self):
return self._catalog_query

def _get_context_manager(self):
return SqliteExplorer.CursorContextManager(self.get_connection())

@classmethod
def _get_select_query(cls, schema_name, table_name, column_list):
return cls._query_template.format(
column_list=",".join([col.get_name() for col in column_list]),
table_name=table_name.get_name()
)


class MySQLExplorer(RelDbExplorer):
_catalog_query = """
SELECT
TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE
TABLE_SCHEMA NOT IN ('information_schema', 'performance_schema', 'sys', 'mysql')
AND DATA_TYPE RLIKE 'char.*|varchar.*|text'
ORDER BY table_schema, table_name, column_name
"""

def __init__(self, ns):
super(MySQLExplorer, self).__init__(ns)

@property
def default_port(self):
return 3036

def _open_connection(self):
return pymysql.connect(host=self.host,
port=self.port,
user=self.user,
password=self.password)

def _get_catalog_query(self):
return self._catalog_query


class PostgreSQLExplorer(RelDbExplorer):
_catalog_query = """
SELECT
TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE
TABLE_SCHEMA NOT IN ('information_schema', 'pg_catalog')
AND DATA_TYPE SIMILAR TO '%char%|%text%'
ORDER BY table_schema, table_name, column_name
"""

def __init__(self, ns):
super(PostgreSQLExplorer, self).__init__(ns)
self.database = self.database if ns.database is None else ns.database

@property
def default_database(self):
return "public"

@property
def default_port(self):
return 5432

def _open_connection(self):
return psycopg2.connect(host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database)

def _get_catalog_query(self):
return self._catalog_query


class MSSQLExplorer(RelDbExplorer):
_catalog_query = """
SELECT
TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE
DATA_TYPE LIKE '%char%'
ORDER BY TABLE_SCHEMA, table_name, ordinal_position
"""

_sample_query_template = "SELECT TOP 10 * FROM {schema_name}.{table_name} TABLESAMPLE (1000 ROWS)"

def __init__(self, ns):
super(MSSQLExplorer, self).__init__(ns)
self.database = self.database if ns.database is None else ns.database

@property
def default_database(self):
return "public"

@property
def default_port(self):
return 1433

def _open_connection(self):
return pymssql.connect(host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database)

def _get_catalog_query(self):
return self._catalog_query

@classmethod
def _get_sample_query(cls, schema_name, table_name, column_list):
return cls._sample_query_template.format(
column_list=",".join([col.get_name() for col in column_list]),
schema_name=schema_name.get_name(),
table_name=table_name.get_name()
)


class OracleExplorer(RelDbExplorer):
_catalog_query = """
SELECT
'{db}', TABLE_NAME, COLUMN_NAME
FROM
USER_TAB_COLUMNS
WHERE UPPER(DATA_TYPE) LIKE '%CHAR%'
ORDER BY TABLE_NAME, COLUMN_ID
"""

_sample_query_template = "select {column_list} from {table_name} sample(5)"
_select_query_template = "select {column_list} from {table_name}"
_count_query = "select count(*) from {table_name}"

def __init__(self, ns):
super(OracleExplorer, self).__init__(ns)
self.database = ns.database

@property
def default_port(self):
return 1521

def _open_connection(self):
return cx_Oracle.connect(self.user,
self.password,
"%s:%d/%s" % (self.host, self.port, self.database))

def _get_catalog_query(self):
return self._catalog_query.format(db=self.database)

@classmethod
def _get_select_query(cls, schema_name, table_name, column_list):
return cls._select_query_template.format(
column_list=",".join([col.get_name() for col in column_list]),
table_name=table_name.get_name()
)

@classmethod
def _get_sample_query(cls, schema_name, table_name, column_list):
return cls._sample_query_template.format(
column_list=",".join([col.get_name() for col in column_list]),
table_name=table_name.get_name()
)

@classmethod
def _get_count_query(cls, schema_name, table_name):
return cls._count_query.format(
table_name=table_name.get_name()
)
Loading

0 comments on commit 35aab78

Please sign in to comment.