Skip to content

Commit

Permalink
fix:dev:Use generators in dbmetadata classes
Browse files Browse the repository at this point in the history
Previously connections and cursors were passed to DbMetadata classes
such as tables and columns. This lead to unclean interfaces and unit
test dependencies.
The move to a generator has helped in cleaning up the tech debt
  • Loading branch information
Rajat Venkatesh authored and vrajat committed Mar 29, 2019
1 parent 5574fc4 commit 9614bb5
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 73 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ install:
- pip install -e .
- python -m spacy download en_core_web_sm
script:
- pytest --cov=./
- pytest -m "not dbtest" --cov=./
after_success:
- codecov
deploy:
Expand Down
22 changes: 19 additions & 3 deletions piicatcher/dbexplorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class Explorer(ABC):
query_template = "select {column_list} from {schema_name}.{table_name}"

def __init__(self, conn_string):
self.conn_string = conn_string
self._connection = None
Expand Down Expand Up @@ -47,7 +49,7 @@ def get_columns(self, schema_name, table_name):

def scan(self):
for schema in self.get_schemas():
schema.scan(self.get_connection().cursor())
schema.scan(self._generate_rows)

def get_tabular(self):
tabular = []
Expand All @@ -59,6 +61,20 @@ def get_tabular(self):

return tabular

def _generate_rows(self, schema_name, table_name, column_list):
query = self.query_template.format(
column_list=",".join([col.get_name() for col in column_list]),
schema_name=schema_name.get_name(),
table_name=table_name.get_name()
)
logging.debug(query)
with self.get_connection().cursor() as cursor:
cursor.execute(query)
row = cursor.fetchone()
while row is not None:
yield row
row = cursor.fetchone()


class SqliteExplorer(Explorer):
pragma_query = """
Expand Down Expand Up @@ -157,7 +173,7 @@ def _load_catalog(self):
elif current_table.get_name() != row[1]:
current_schema.tables.append(current_table)
current_table = Table(current_schema, row[1])
current_table.columns.append(Column(row[2]))
current_table._columns.append(Column(row[2]))

row = cursor.fetchone()

Expand All @@ -183,7 +199,7 @@ def get_columns(self, schema_name, table_name):
tables = self.get_tables(schema_name)
for t in tables:
if t.get_name() == table_name:
return t.columns
return t.get_columns()

raise ValueError("{} table not found".format(table_name))

Expand Down
39 changes: 16 additions & 23 deletions piicatcher/dbmetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,44 +35,37 @@ def add(self, table):
def get_tables(self):
return self.tables

def scan(self, context):
def scan(self, generator):
for table in self.tables:
table.scan(context)
table.scan(generator)
logging.debug("{} has {}".format(table.get_name(), table.get_pii_types()))
[self._pii.add(p) for p in table.get_pii_types()]

logging.debug("{} has {}".format(self, self._pii))


class Table(NamedObject):
query_template = "select {column_list} from {schema_name}.{table_name}"

def __init__(self, schema, name):
super(Table, self).__init__(name)
self._schema = schema
self.columns = []
self._columns = []

def add(self, col):
self.columns.append(col)
self._columns.append(col)

def get_columns(self):
return self.columns

def scan(self, context):
query = self.query_template.format(
column_list=",".join([col.get_name() for col in self.columns]),
schema_name=self._schema.get_name(),
table_name=self.get_name()
)
logging.debug(query)
context.execute(query)
row = context.fetchone()
while row is not None:
for col, val in zip(self.columns, row):
return self._columns

def scan(self, generator):
for row in generator(
column_list=self._columns,
schema_name=self._schema,
table_name=self
):
for col, val in zip(self._columns, row):
col.scan(val)
row = context.fetchone()

for col in self.columns:
for col in self._columns:
[self._pii.add(p) for p in col.get_pii_types()]

logging.debug(self._pii)
Expand All @@ -82,8 +75,8 @@ class Column(NamedObject):
def __init__(self, name):
super(Column, self).__init__(name)

def scan(self, context):
def scan(self, data):
for scanner in [RegexScanner(), NERScanner()]:
[self._pii.add(pii) for pii in scanner.scan(context)]
[self._pii.add(pii) for pii in scanner.scan(data)]

logging.debug(self._pii)
50 changes: 4 additions & 46 deletions tests/test_dbexplorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def setUp(self):

schema = Schema('s1')
table = Table(schema, 't1')
table.columns = [col1, col2]
table._columns = [col1, col2]

schema = Schema('testSchema')
schema.tables = [table]
Expand Down Expand Up @@ -68,48 +68,6 @@ def test_tables(self):
names = [tbl.get_name() for tbl in self.explorer.get_tables(self.get_test_schema())]
self.assertEqual(sorted(['no_pii', 'partial_pii', 'full_pii']), sorted(names))

def test_negative_scan_column(self):
col = Column('col')
col.scan('abc')
self.assertFalse(col.has_pii())

def test_positive_scan_column(self):
col = Column('col')
col.scan('Jonathan Smith')
self.assertTrue(col.has_pii())

def test_no_pii_table(self):
schema = Schema(self.get_test_schema())
table = Table(schema, 'no_pii')
table.add(Column('a'))
table.add(Column('b'))

table.scan(self.explorer.get_connection().cursor())
self.assertFalse(table.has_pii())

def test_partial_pii_table(self):
schema = Schema(self.get_test_schema())
table = Table(schema, 'partial_pii')
table.add(Column('a'))
table.add(Column('b'))

table.scan(self.explorer.get_connection().cursor())
self.assertTrue(table.has_pii())
cols = table.get_columns()
self.assertTrue(cols[0].has_pii())
self.assertFalse(cols[1].has_pii())

def test_full_pii_table(self):
schema = Schema(self.get_test_schema())
table = Table(schema, 'full_pii')
table.add(Column('name'))
table.add(Column('location'))

table.scan(self.explorer.get_connection().cursor())
self.assertTrue(table.has_pii())
cols = table.get_columns()
self.assertTrue(cols[0].has_pii())
self.assertTrue(cols[1].has_pii())

def test_scan_dbexplorer(self):
self.explorer.scan()
Expand All @@ -118,7 +76,7 @@ def test_scan_dbexplorer(self):


@pytest.mark.usefixtures("temp_sqlite")
@pytest.mark.skip(reason="Old version of sqlite in Travis")
@pytest.mark.dbtest
class SqliteTest(TestCase):

@pytest.fixture(scope="class")
Expand Down Expand Up @@ -149,7 +107,7 @@ def test_schema(self):


@pytest.mark.usefixtures("create_tables")
@pytest.mark.skip(reason="TODO Setup MySQL through docker for testing")
@pytest.mark.dbtest
class MySQLExplorerTest(CommonExplorerTestCases.CommonExplorerTests):
pii_db_query = """
CREATE DATABASE IF NOT EXISTS pii_db;
Expand Down Expand Up @@ -204,7 +162,7 @@ def get_test_schema(self):


@pytest.mark.usefixtures("create_tables")
@pytest.mark.skip(reason="TODO Setup Postgres through docker for testing")
@pytest.mark.dbtest
class PostgresExplorerTest(CommonExplorerTestCases.CommonExplorerTests):
pii_db_drop = """
DROP TABLE full_pii;
Expand Down
69 changes: 69 additions & 0 deletions tests/test_dbmetadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from unittest import TestCase
from piicatcher.dbmetadata import Column, Table, Schema


class DbMetadataTests(TestCase):

data = {
"no_pii": [
('abc', 'def'),
('xsfr', 'asawe')
],
"partial_pii": [
('917-908-2234', 'plkj'),
('215-099-2234', 'sfrf')
],
"full_pii": [
('Jonathan Smith', 'Virginia'),
('Chase Ryan', 'Chennai')
]
}

@staticmethod
def data_generator(schema_name, table_name, column_list):
for row in DbMetadataTests.data[table_name.get_name()]:
yield row

def test_negative_scan_column(self):
col = Column('col')
col.scan('abc')
self.assertFalse(col.has_pii())

def test_positive_scan_column(self):
col = Column('col')
col.scan('Jonathan Smith')
self.assertTrue(col.has_pii())

def test_no_pii_table(self):
schema = Schema('public')
table = Table(schema, 'no_pii')
table.add(Column('a'))
table.add(Column('b'))

table.scan(self.data_generator)
self.assertFalse(table.has_pii())

def test_partial_pii_table(self):
schema = Schema('public')
table = Table(schema, 'partial_pii')
table.add(Column('a'))
table.add(Column('b'))

table.scan(self.data_generator)
self.assertTrue(table.has_pii())
cols = table.get_columns()
self.assertTrue(cols[0].has_pii())
self.assertFalse(cols[1].has_pii())

def test_full_pii_table(self):
schema = Schema('public')
table = Table(schema, 'full_pii')
table.add(Column('name'))
table.add(Column('location'))

table.scan(self.data_generator)
self.assertTrue(table.has_pii())

cols = table.get_columns()
self.assertTrue(cols[0].has_pii())
self.assertTrue(cols[1].has_pii())

0 comments on commit 9614bb5

Please sign in to comment.