diff --git a/orator/connections/mysql_connection.py b/orator/connections/mysql_connection.py index e37d4b06..ab3a2534 100644 --- a/orator/connections/mysql_connection.py +++ b/orator/connections/mysql_connection.py @@ -2,7 +2,7 @@ from ..utils import decode from ..utils import PY2 -from .connection import Connection +from .connection import Connection, recoverable from ..query.grammars.mysql_grammar import MySQLQueryGrammar from ..query.processors.mysql_processor import MySQLQueryProcessor from ..schema.grammars import MySQLSchemaGrammar @@ -37,6 +37,7 @@ def get_default_schema_grammar(self): def get_schema_manager(self): return MySQLSchemaManager(self) + @recoverable def begin_transaction(self): self._connection.autocommit(False) diff --git a/tests/connections/test_mysql_connection.py b/tests/connections/test_mysql_connection.py index 2c9a2f9f..f00c42ed 100644 --- a/tests/connections/test_mysql_connection.py +++ b/tests/connections/test_mysql_connection.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- +from flexmock import flexmock + from .. import OratorTestCase +from .. import mock from orator.connections.mysql_connection import MySQLConnection @@ -20,3 +23,13 @@ def test_marker_use_qmark_false(self): connection = MySQLConnection(None, "database", "", {"use_qmark": False}) self.assertIsNone(connection.get_marker()) + + def test_recover_if_caused_by_lost_connection_is_called(self): + connection = flexmock(MySQLConnection(None, "database")) + connection._connection = mock.Mock() + connection._connection.autocommit.side_effect = Exception("lost connection") + + connection.should_receive("_recover_if_caused_by_lost_connection").once() + connection.should_receive("reconnect") + + connection.begin_transaction()