Skip to content

Commit

Permalink
fix:dev:Process only text columns
Browse files Browse the repository at this point in the history
Also use cached catalog for Sqlite and rearrange code
test code to test all explorers.
  • Loading branch information
Rajat Venkatesh authored and vrajat committed Mar 29, 2019
1 parent 8fc1261 commit 6eb85ea
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 100 deletions.
174 changes: 77 additions & 97 deletions piicatcher/dbexplorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
class Explorer(ABC):
query_template = "select {column_list} from {schema_name}.{table_name}"

def __init__(self, conn_string):
self.conn_string = conn_string
def __init__(self):
self._connection = None
self._schemas = None
self._cache_ts = None

def __enter__(self):
return self
Expand All @@ -35,18 +35,6 @@ def close_connection(self):
if self._connection is not None:
self._connection.close()

@abstractmethod
def get_schemas(self):
pass

@abstractmethod
def get_tables(self, schema_name):
pass

@abstractmethod
def get_columns(self, schema_name, table_name):
pass

def scan(self):
for schema in self.get_schemas():
schema.scan(self._generate_rows)
Expand All @@ -61,101 +49,33 @@ def get_tabular(self):

return tabular

def _generate_rows(self, schema_name, table_name, column_list):
query = self.query_template.format(
def _get_select_query(self, schema_name, table_name, column_list):
return self.query_template.format(
column_list=",".join([col.get_name() for col in column_list]),
schema_name=schema_name.get_name(),
table_name=table_name.get_name()
)

def _generate_rows(self, schema_name, table_name, column_list):
query = self._get_select_query(schema_name, table_name, column_list)
logging.debug(query)
with self.get_connection().cursor() as cursor:
with self._get_context_manager() as cursor:
cursor.execute(query)
row = cursor.fetchone()
while row is not None:
yield row
row = cursor.fetchone()


class SqliteExplorer(Explorer):
pragma_query = """
SELECT
m.name as table_name,
p.name as column_name
FROM
sqlite_master AS m
JOIN
pragma_table_info(m.name) AS p
{where_clause}
ORDER BY
m.name,
p.cid
"""

def __init__(self, conn_string):
super(SqliteExplorer, self).__init__(conn_string)

def _open_connection(self):
logging.debug("Sqlite connection string '{}'".format(self.conn_string))
return sqlite3.connect(self.conn_string)

def get_schemas(self):
if self._schemas is None:
sch = Schema('main')
sch.tables.extend(self.get_tables(sch))
self._schemas = [sch]
return self._schemas

def get_tables(self, schema):
query = self.pragma_query.format(where_clause="")
logging.debug(query)
result_set = self.get_connection().execute(query)
tables = []
row = result_set.fetchone()
current_table = None
while row is not None:
if current_table is None:
current_table = Table(schema, row[0])
elif current_table.get_name() != row[0]:
tables.append(current_table)
current_table = Table(schema, row[0])
current_table.add(Column(row[1]))

row = result_set.fetchone()

if current_table is not None:
tables.append(current_table)

return tables

def get_columns(self, schema_name, table_name):
query = self.pragma_query.format(
where_clause="WHERE m.name = ?"
)
logging.debug(query)
logging.debug(table_name)
result_set = self.get_connection().execute(query, (table_name,))
columns = []
row = result_set.fetchone()
while row is not None:
columns.append(Column(row[1]))
row = result_set.fetchone()

return columns


class CachedExplorer(Explorer, ABC):

def __init__(self):
super(CachedExplorer, self).__init__("")
self._cache_ts = None

@abstractmethod
def _get_catalog_query(self):
pass

def _get_context_manager(self):
return self.get_connection().cursor()

def _load_catalog(self):
if self._cache_ts is None or self._cache_ts < datetime.now() - timedelta(minutes=10):
with self.get_connection().cursor() as cursor:
with self._get_context_manager() as cursor:
cursor.execute(self._get_catalog_query())
self._schemas = []

Expand All @@ -177,7 +97,7 @@ def _load_catalog(self):
elif current_table.get_name() != row[1]:
current_schema.tables.append(current_table)
current_table = Table(current_schema, row[1])
current_table._columns.append(Column(row[2]))
current_table.add(Column(row[2]))

row = cursor.fetchone()

Expand All @@ -194,6 +114,8 @@ def get_schemas(self):
def get_tables(self, schema_name):
self._load_catalog()
for s in self._schemas:
print(schema_name)
print(s.get_name())
if s.get_name() == schema_name:
return s.tables
raise ValueError("{} schema not found".format(schema_name))
Expand All @@ -208,13 +130,68 @@ def get_columns(self, schema_name, table_name):
raise ValueError("{} table not found".format(table_name))


class MySQLExplorer(CachedExplorer):
class SqliteExplorer(Explorer):
_catalog_query = """
SELECT
"" as schema_name,
m.name as table_name,
p.name as column_name,
p.type as data_type
FROM
sqlite_master AS m
JOIN
pragma_table_info(m.name) AS p
WHERE
p.type like 'text' or p.type like 'varchar%' or p.type like 'char%'
ORDER BY
m.name,
p.name
"""

_query_template = "select {column_list} from {table_name}"

class CursorContextManager():
def __init__(self, connection):
self.connection = connection

def __enter__(self):
self.cursor = self.connection.cursor()
return self.cursor

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def __init__(self, conn_string):
super(SqliteExplorer, self).__init__()
self.conn_string = conn_string

def _open_connection(self):
logging.debug("Sqlite connection string '{}'".format(self.conn_string))
return sqlite3.connect(self.conn_string)

def _get_catalog_query(self):
return self._catalog_query

def _get_context_manager(self):
return SqliteExplorer.CursorContextManager(self.get_connection())

def _get_select_query(self, schema_name, table_name, column_list):
return self._query_template.format(
column_list=",".join([col.get_name() for col in column_list]),
table_name=table_name.get_name()
)


class MySQLExplorer(Explorer):
_catalog_query = """
SELECT
TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA NOT IN ('information_schema', 'performance_schema', 'sys', 'mysql')
WHERE
TABLE_SCHEMA NOT IN ('information_schema', 'performance_schema', 'sys', 'mysql')
AND DATA_TYPE RLIKE 'char.*|varchar.*|text'
ORDER BY table_schema, table_name, column_name
"""

def __init__(self, host, user, password):
Expand All @@ -232,13 +209,16 @@ def _get_catalog_query(self):
return self._catalog_query


class PostgreSQLExplorer(CachedExplorer):
class PostgreSQLExplorer(Explorer):
_catalog_query = """
SELECT
TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA NOT IN ('information_schema', 'pg_catalog')
WHERE
TABLE_SCHEMA NOT IN ('information_schema', 'pg_catalog')
AND DATA_TYPE SIMILAR TO '%char%|%text%'
ORDER BY table_schema, table_name, column_name
"""

def __init__(self, host, user, password, database='public'):
Expand Down
Loading

0 comments on commit 6eb85ea

Please sign in to comment.