From d6ec36be7c117e095449a92b86efeb2ae3d1604b Mon Sep 17 00:00:00 2001 From: Adam Ning Date: Wed, 20 Nov 2024 19:28:51 +0800 Subject: [PATCH] ENH: Refine connection exception handling (#111) Co-authored-by: qinxuye --- python/xoscar/backends/core.py | 91 +++++++++++++++++------------- python/xoscar/backends/pool.py | 40 +++++++------ python/xoscar/serialization/aio.py | 6 +- 3 files changed, 80 insertions(+), 57 deletions(-) diff --git a/python/xoscar/backends/core.py b/python/xoscar/backends/core.py index 5b79fc58..c8fb54ff 100644 --- a/python/xoscar/backends/core.py +++ b/python/xoscar/backends/core.py @@ -70,50 +70,61 @@ async def get_client(self, router: Router, dest_address: str) -> Client: return client async def _listen(self, client: Client): - while not client.closed: - try: + try: + while not client.closed: try: - message: _MessageBase = await client.recv() - except (EOFError, ConnectionError, BrokenPipeError): - # remote server closed, close client and raise ServerClosed try: - await client.close() - except (ConnectionError, BrokenPipeError): - # close failed, ignore it + message: _MessageBase = await client.recv() + except (EOFError, ConnectionError, BrokenPipeError) as e: + # AssertionError is from get_header + # remote server closed, close client and raise ServerClosed + logger.debug(f"{client.dest_address} close due to {e}") + try: + await client.close() + except (ConnectionError, BrokenPipeError): + # close failed, ignore it + pass + raise ServerClosed( + f"Remote server {client.dest_address} closed: {e}" + ) from None + future = self._client_to_message_futures[client].pop( + message.message_id + ) + if not future.done(): + future.set_result(message) + except DeserializeMessageFailed as e: + message_id = e.message_id + future = self._client_to_message_futures[client].pop(message_id) + future.set_exception(e.__cause__) # type: ignore + except Exception as e: # noqa: E722 # pylint: disable=bare-except + message_futures = self._client_to_message_futures[client] + self._client_to_message_futures[client] = dict() + for future in message_futures.values(): + future.set_exception(copy.copy(e)) + finally: + # message may have Ray ObjectRef, delete it early in case next loop doesn't run + # as soon as expected. + try: + del message + except NameError: pass - raise ServerClosed( - f"Remote server {client.dest_address} closed" - ) from None - future = self._client_to_message_futures[client].pop(message.message_id) - if not future.done(): - future.set_result(message) - except DeserializeMessageFailed as e: - message_id = e.message_id - future = self._client_to_message_futures[client].pop(message_id) - future.set_exception(e.__cause__) # type: ignore - except Exception as e: # noqa: E722 # pylint: disable=bare-except - message_futures = self._client_to_message_futures[client] - self._client_to_message_futures[client] = dict() - for future in message_futures.values(): - future.set_exception(copy.copy(e)) - finally: - # message may have Ray ObjectRef, delete it early in case next loop doesn't run - # as soon as expected. - try: - del message - except NameError: - pass - try: - del future - except NameError: - pass - await asyncio.sleep(0) + try: + del future + except NameError: + pass + await asyncio.sleep(0) - message_futures = self._client_to_message_futures[client] - self._client_to_message_futures[client] = dict() - error = ServerClosed(f"Remote server {client.dest_address} closed") - for future in message_futures.values(): - future.set_exception(copy.copy(error)) + message_futures = self._client_to_message_futures[client] + self._client_to_message_futures[client] = dict() + error = ServerClosed(f"Remote server {client.dest_address} closed") + for future in message_futures.values(): + future.set_exception(copy.copy(error)) + finally: + try: + await client.close() + except: # noqa: E722 # nosec # pylint: disable=bare-except + # ignore all error if fail to close at last + pass async def call_with_client( self, client: Client, message: _MessageBase, wait: bool = True diff --git a/python/xoscar/backends/pool.py b/python/xoscar/backends/pool.py index 52cb664e..93fd1645 100644 --- a/python/xoscar/backends/pool.py +++ b/python/xoscar/backends/pool.py @@ -551,23 +551,31 @@ async def _handle_ucx_meta_message( return False async def on_new_channel(self, channel: Channel): - while not self._stopped.is_set(): - try: - message = await channel.recv() - except EOFError: - # no data to read, check channel + try: + while not self._stopped.is_set(): try: - await channel.close() - except (ConnectionError, EOFError): - # close failed, ignore - pass - return - if await self._handle_ucx_meta_message(message, channel): - continue - asyncio.create_task(self.process_message(message, channel)) - # delete to release the reference of message - del message - await asyncio.sleep(0) + message = await channel.recv() + except (EOFError, ConnectionError, BrokenPipeError) as e: + logger.debug(f"pool: close connection due to {e}") + # no data to read, check channel + try: + await channel.close() + except (ConnectionError, EOFError): + # close failed, ignore + pass + return + if await self._handle_ucx_meta_message(message, channel): + continue + asyncio.create_task(self.process_message(message, channel)) + # delete to release the reference of message + del message + await asyncio.sleep(0) + finally: + try: + await channel.close() + except: # noqa: E722 # nosec # pylint: disable=bare-except + # ignore all error if fail to close at last + pass async def __aenter__(self): await self.start() diff --git a/python/xoscar/serialization/aio.py b/python/xoscar/serialization/aio.py index 68d4a18f..ad7d6926 100644 --- a/python/xoscar/serialization/aio.py +++ b/python/xoscar/serialization/aio.py @@ -77,7 +77,11 @@ async def run(self): def get_header_length(header_bytes: bytes): version = struct.unpack("B", header_bytes[:1])[0] # now we only have default version - assert version == DEFAULT_SERIALIZATION_VERSION, MALFORMED_MSG + if version != DEFAULT_SERIALIZATION_VERSION: + # when version not matched, + # we will immediately abort the connection + # EOFError will be captured by channel + raise EOFError(MALFORMED_MSG) # header length header_length = struct.unpack("