Skip to content

Commit

Permalink
presto: Improve error handling, fix duplicate queries when things fai…
Browse files Browse the repository at this point in the history
…l, and avoid running more than one query when writing to a table.
  • Loading branch information
matthewwardrop committed Jun 1, 2024
1 parent 98c66e1 commit 2afcb64
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions omniduct/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def _init(
self.schema = schema
self.server_protocol = server_protocol
self.source = source
self.__presto = None
self.connection_fields += ("catalog", "schema")
self._requests_session = requests_session

Expand All @@ -103,18 +102,11 @@ def _connect(self):

@override
def _is_connected(self):
try:
return self.__presto is not None
except: # pylint: disable=bare-except
return False
return True

@override
def _disconnect(self):
logger.info("Disconnecting from Presto coordinator...")
try:
self.__presto.close()
except: # pylint: disable=bare-except
pass
self._sqlalchemy_engine = None
self._sqlalchemy_metadata = None
self._schemas = None # pylint: disable=attribute-defined-outside-init
Expand Down Expand Up @@ -162,10 +154,6 @@ def _execute(self, statement, cursor, wait, session_properties):
logger.progress(100, complete=True)
return cursor
except (DatabaseError, pandas.io.sql.DatabaseError) as e:
# Attempt to parse database error, before ultimately reraising the same
# exception, maintaining the full stacktrace.
exception, exception_args, traceback = sys.exc_info()

try:
message = e.args[0]
if isinstance(message, str):
Expand All @@ -191,12 +179,6 @@ def _execute(self, statement, cursor, wait, session_properties):
)
)

class ErrContext:
def __repr__(self):
return context

# logged twice so that both notebook and console users see the error context
exception_args.args = [exception_args, ErrContext()]
logger.error(context)
except: # pylint: disable=bare-except
logger.warn(
Expand All @@ -206,17 +188,14 @@ def __repr__(self):
)
)

if isinstance(exception, type):
exception = exception(exception_args)

raise exception.with_traceback(traceback)
raise

@override
def _query_to_table(self, statement, table, if_exists, **kwargs):
from pyhive.exc import DatabaseError

statements = []

if if_exists == "fail" and self.table_exists(table):
raise RuntimeError(f"Table {table} already exists!")
if if_exists == "replace":
statements.append(f"DROP TABLE IF EXISTS {table};\n")
elif if_exists == "append":
Expand All @@ -225,7 +204,16 @@ def _query_to_table(self, statement, table, if_exists, **kwargs):
)

statements.append(f"CREATE TABLE {table} AS ({statement})")
return self.execute("\n".join(statements), **kwargs)

try:
return self.execute("\n".join(statements), **kwargs)
except DatabaseError as e:
if (
isinstance(e.args, dict)
and e.args.get("errorName") == "TABLE_ALREADY_EXISTS"
):
raise RuntimeError(f"Table {table} already exists!") from e
raise

@override
def _dataframe_to_table(self, df, table, if_exists="fail", **kwargs):
Expand Down

0 comments on commit 2afcb64

Please sign in to comment.