diff --git a/orator/connections/connection.py b/orator/connections/connection.py index 469b5989..91af470d 100644 --- a/orator/connections/connection.py +++ b/orator/connections/connection.py @@ -18,6 +18,17 @@ connection_logger = logging.getLogger("orator.connection") +def recoverable(wrapped): + @wraps(wrapped) + def _recoverable(self, *args, **kwargs): + self._reconnect_if_missing_connection() + try: + result = wrapped(self, *args, **kwargs) + except Exception as e: + result = self._recover_if_caused_by_lost_connection(e, wrapped, *args, **kwargs) + return result + return _recoverable + def run(wrapped): """ Special decorator encapsulating query method. @@ -356,6 +367,14 @@ def _try_again_if_caused_by_lost_connection( raise QueryException(query, bindings, e) + def _recover_if_caused_by_lost_connection(self, e, callback, *args, **kwargs): + if self._caused_by_lost_connection(e): + self.reconnect() + + return callback(self, *args, **kwargs) + + raise e + def _caused_by_lost_connection(self, e): message = str(e).lower()