Skip to content

Commit

Permalink
Merge branch 'main' into connect_timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Nov 20, 2024
2 parents e242f77 + d6ec36b commit 902f2e0
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 57 deletions.
91 changes: 51 additions & 40 deletions python/xoscar/backends/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,50 +70,61 @@ async def get_client(self, router: Router, dest_address: str) -> Client:
return client

async def _listen(self, client: Client):
while not client.closed:
try:
try:
while not client.closed:
try:
message: _MessageBase = await client.recv()
except (EOFError, ConnectionError, BrokenPipeError):
# remote server closed, close client and raise ServerClosed
try:
await client.close()
except (ConnectionError, BrokenPipeError):
# close failed, ignore it
message: _MessageBase = await client.recv()
except (EOFError, ConnectionError, BrokenPipeError) as e:
# AssertionError is from get_header
# remote server closed, close client and raise ServerClosed
logger.debug(f"{client.dest_address} close due to {e}")
try:
await client.close()
except (ConnectionError, BrokenPipeError):
# close failed, ignore it
pass
raise ServerClosed(
f"Remote server {client.dest_address} closed: {e}"
) from None
future = self._client_to_message_futures[client].pop(
message.message_id
)
if not future.done():
future.set_result(message)
except DeserializeMessageFailed as e:
message_id = e.message_id
future = self._client_to_message_futures[client].pop(message_id)
future.set_exception(e.__cause__) # type: ignore
except Exception as e: # noqa: E722 # pylint: disable=bare-except
message_futures = self._client_to_message_futures[client]
self._client_to_message_futures[client] = dict()
for future in message_futures.values():
future.set_exception(copy.copy(e))
finally:
# message may have Ray ObjectRef, delete it early in case next loop doesn't run
# as soon as expected.
try:
del message
except NameError:
pass
raise ServerClosed(
f"Remote server {client.dest_address} closed"
) from None
future = self._client_to_message_futures[client].pop(message.message_id)
if not future.done():
future.set_result(message)
except DeserializeMessageFailed as e:
message_id = e.message_id
future = self._client_to_message_futures[client].pop(message_id)
future.set_exception(e.__cause__) # type: ignore
except Exception as e: # noqa: E722 # pylint: disable=bare-except
message_futures = self._client_to_message_futures[client]
self._client_to_message_futures[client] = dict()
for future in message_futures.values():
future.set_exception(copy.copy(e))
finally:
# message may have Ray ObjectRef, delete it early in case next loop doesn't run
# as soon as expected.
try:
del message
except NameError:
pass
try:
del future
except NameError:
pass
await asyncio.sleep(0)
try:
del future
except NameError:
pass
await asyncio.sleep(0)

message_futures = self._client_to_message_futures[client]
self._client_to_message_futures[client] = dict()
error = ServerClosed(f"Remote server {client.dest_address} closed")
for future in message_futures.values():
future.set_exception(copy.copy(error))
message_futures = self._client_to_message_futures[client]
self._client_to_message_futures[client] = dict()
error = ServerClosed(f"Remote server {client.dest_address} closed")
for future in message_futures.values():
future.set_exception(copy.copy(error))
finally:
try:
await client.close()
except: # noqa: E722 # nosec # pylint: disable=bare-except
# ignore all error if fail to close at last
pass

async def call_with_client(
self, client: Client, message: _MessageBase, wait: bool = True
Expand Down
40 changes: 24 additions & 16 deletions python/xoscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,23 +551,31 @@ async def _handle_ucx_meta_message(
return False

async def on_new_channel(self, channel: Channel):
while not self._stopped.is_set():
try:
message = await channel.recv()
except EOFError:
# no data to read, check channel
try:
while not self._stopped.is_set():
try:
await channel.close()
except (ConnectionError, EOFError):
# close failed, ignore
pass
return
if await self._handle_ucx_meta_message(message, channel):
continue
asyncio.create_task(self.process_message(message, channel))
# delete to release the reference of message
del message
await asyncio.sleep(0)
message = await channel.recv()
except (EOFError, ConnectionError, BrokenPipeError) as e:
logger.debug(f"pool: close connection due to {e}")
# no data to read, check channel
try:
await channel.close()
except (ConnectionError, EOFError):
# close failed, ignore
pass
return
if await self._handle_ucx_meta_message(message, channel):
continue
asyncio.create_task(self.process_message(message, channel))
# delete to release the reference of message
del message
await asyncio.sleep(0)
finally:
try:
await channel.close()
except: # noqa: E722 # nosec # pylint: disable=bare-except
# ignore all error if fail to close at last
pass

async def __aenter__(self):
await self.start()
Expand Down
6 changes: 5 additions & 1 deletion python/xoscar/serialization/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ async def run(self):
def get_header_length(header_bytes: bytes):
version = struct.unpack("B", header_bytes[:1])[0]
# now we only have default version
assert version == DEFAULT_SERIALIZATION_VERSION, MALFORMED_MSG
if version != DEFAULT_SERIALIZATION_VERSION:
# when version not matched,
# we will immediately abort the connection
# EOFError will be captured by channel
raise EOFError(MALFORMED_MSG)
# header length
header_length = struct.unpack("<Q", header_bytes[1:9])[0]
# compress
Expand Down

0 comments on commit 902f2e0

Please sign in to comment.