Skip to content

Commit

Permalink
fix: hard-sync kernel_registry to real containers
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed May 28, 2024
1 parent 11f42db commit 58002f0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 21 deletions.
1 change: 1 addition & 0 deletions changes/2179.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Sync agent's kernel registry with the actual container through periodic loop.
70 changes: 49 additions & 21 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import zlib
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from collections.abc import Container as ContainerT
from decimal import Decimal
from io import SEEK_END, BytesIO
from pathlib import Path
Expand Down Expand Up @@ -1076,27 +1077,10 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None:
ev.done_future.set_exception(e)
await self.produce_error_event()
finally:
if ev.kernel_id in self.restarting_kernels:
# Don't forget as we are restarting it.
kernel_obj = self.kernel_registry.get(ev.kernel_id, None)
else:
# Forget as we are done with this kernel.
kernel_obj = self.kernel_registry.pop(ev.kernel_id, None)
kernel_obj = self.kernel_registry.get(ev.kernel_id, None)
try:
if kernel_obj is not None:
# Restore used ports to the port pool.
port_range = self.local_config["container"]["port-range"]
# Exclude out-of-range ports, because when the agent restarts
# with a different port range, existing containers' host ports
# may not belong to the new port range.
if host_ports := kernel_obj.get("host_ports"):
restored_ports = [
*filter(
lambda p: port_range[0] <= p <= port_range[1],
host_ports,
)
]
self.port_pool.update(restored_ports)
await self._restore_port_pool(kernel_obj)
await kernel_obj.close()
finally:
self.terminating_kernels.discard(ev.kernel_id)
Expand All @@ -1116,6 +1100,20 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None:
if ev.done_future is not None and not ev.done_future.done():
ev.done_future.set_result(None)

async def _restore_port_pool(self, kernel_obj: AbstractKernel) -> None:
port_range = self.local_config["container"]["port-range"]
# Exclude out-of-range ports, because when the agent restarts
# with a different port range, existing containers' host ports
# may not belong to the new port range.
if host_ports := kernel_obj.get("host_ports"):
restored_ports = [
*filter(
lambda p: port_range[0] <= p <= port_range[1],
host_ports,
)
]
self.port_pool.update(restored_ports)

async def process_lifecycle_events(self) -> None:
async def lifecycle_task_exception_handler(
exc_type: Type[Exception],
Expand Down Expand Up @@ -1260,6 +1258,8 @@ async def sync_container_lifecycles(self, interval: float) -> None:
for cases when we miss the container lifecycle events from the underlying implementation APIs
due to the agent restarts or crashes.
"""
all_detected_kernels: set[KernelId] = set()

known_kernels: Dict[KernelId, ContainerId] = {}
alive_kernels: Dict[KernelId, ContainerId] = {}
kernel_session_map: Dict[KernelId, SessionId] = {}
Expand All @@ -1270,6 +1270,7 @@ async def sync_container_lifecycles(self, interval: float) -> None:
try:
# Check if: there are dead containers
for kernel_id, container in await self.enumerate_containers(DEAD_STATUS_SET):
all_detected_kernels.add(kernel_id)
if (
kernel_id in self.restarting_kernels
or kernel_id in self.terminating_kernels
Expand All @@ -1289,6 +1290,7 @@ async def sync_container_lifecycles(self, interval: float) -> None:
KernelLifecycleEventReason.SELF_TERMINATED,
)
for kernel_id, container in await self.enumerate_containers(ACTIVE_STATUS_SET):
all_detected_kernels.add(kernel_id)
alive_kernels[kernel_id] = container.id
session_id = SessionId(UUID(container.labels["ai.backend.session-id"]))
kernel_session_map[kernel_id] = session_id
Expand Down Expand Up @@ -1323,13 +1325,41 @@ async def sync_container_lifecycles(self, interval: float) -> None:
KernelLifecycleEventReason.TERMINATED_UNKNOWN_CONTAINER,
)
finally:
await self.prune_kernel_registry(all_detected_kernels)
# Enqueue the events.
for kernel_id, ev in terminated_kernels.items():
await self.container_lifecycle_queue.put(ev)

# Set container count
await self.set_container_count(len(own_kernels.keys()))

async def prune_kernel_registry(
self, detected_kernels: ContainerT[KernelId], *, ensure_cleaned: bool = True
) -> None:
"""
Deregister containerless kernels from `kernel_registry`
since `_handle_clean_event()` does not deregister them.
"""
any_container_pruned = False
for kernel_id in [*self.kernel_registry.keys()]:
if kernel_id not in detected_kernels:
if ensure_cleaned:
# Don't need to process this through event task
# since there is no communication with any container here.
kernel_obj = self.kernel_registry[kernel_id]
kernel_obj.stats_enabled = False
if kernel_obj.runner is not None:
await kernel_obj.runner.close()
if kernel_obj.clean_event is not None and not kernel_obj.clean_event.done():
kernel_obj.clean_event.set_result(None)
await self._restore_port_pool(kernel_obj)
await kernel_obj.close()
del self.kernel_registry[kernel_id]
self.terminating_kernels.discard(kernel_id)
any_container_pruned = True
if any_container_pruned:
await self.reconstruct_resource_usage()

async def set_container_count(self, container_count: int) -> None:
await redis_helper.execute(
self.redis_stat_pool, lambda r: r.set(f"container_count.{self.id}", container_count)
Expand Down Expand Up @@ -2025,8 +2055,6 @@ async def create_kernel(
" unregistered.",
kernel_id,
)
async with self.registry_lock:
del self.kernel_registry[kernel_id]
raise
async with self.registry_lock:
self.kernel_registry[ctx.kernel_id].data.update(container_data)
Expand Down

0 comments on commit 58002f0

Please sign in to comment.