From 1eb020af163bbae06eed368aab485c0bf91af2f4 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Fri, 24 May 2024 11:23:23 +0900 Subject: [PATCH] feature: sync mismatch between db and containers --- src/ai/backend/agent/server.py | 88 ++++++++++++++++++++++++++++++ src/ai/backend/common/types.py | 24 ++++++++ src/ai/backend/manager/registry.py | 15 +++++ 3 files changed, 127 insertions(+) diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 74f1a810377..81994c1b0be 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -12,6 +12,7 @@ import signal import sys from collections import OrderedDict, defaultdict +from collections.abc import Collection from ipaddress import _BaseAddress as BaseIPAddress from ipaddress import ip_network from pathlib import Path @@ -58,6 +59,7 @@ ) from ai.backend.common.logging import BraceStyleAdapter, Logger from ai.backend.common.types import ( + AgentKernelRegistryByStatus, ClusterInfo, CommitStatus, HardwareMetadata, @@ -477,6 +479,92 @@ async def sync_kernel_registry( suppress_events=True, ) + @rpc_function + @collect_error + async def sync_and_get_kernels( + self, + running_kernels_from_truth: Collection[str], + terminating_kernels_from_truth: Collection[str], + ) -> dict[str, Any]: + """ + Sync kernel_registry and containers to truth + and return kernel infos whose status is irreversible. + """ + + actual_terminating_kernels: list[tuple[KernelId, KernelLifecycleEventReason]] = [] + actual_terminated_kernels: list[tuple[KernelId, KernelLifecycleEventReason]] = [] + + async with self.agent.registry_lock: + all_running_kernels = [ + kid + for kid in self.agent.kernel_registry + if kid not in self.agent.terminating_kernels + ] + + for raw_kernel_id in running_kernels_from_truth: + kernel_id = KernelId(UUID(raw_kernel_id)) + if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None: + if kernel_id in self.agent.terminating_kernels: + actual_terminating_kernels.append(( + kernel_id, + kernel_obj.termination_reason + or KernelLifecycleEventReason.ALREADY_TERMINATED, + )) + else: + actual_terminated_kernels.append(( + kernel_id, + KernelLifecycleEventReason.ALREADY_TERMINATED, + )) + + for raw_kernel_id in terminating_kernels_from_truth: + kernel_id = KernelId(UUID(raw_kernel_id)) + if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None: + if kernel_id not in self.agent.terminating_kernels: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.DESTROY, + kernel_obj.termination_reason + or KernelLifecycleEventReason.ALREADY_TERMINATED, + suppress_events=False, + ) + else: + actual_terminated_kernels.append(( + kernel_id, + KernelLifecycleEventReason.ALREADY_TERMINATED, + )) + + for kernel_id, kernel_obj in self.agent.kernel_registry.items(): + if kernel_id in terminating_kernels_from_truth: + if kernel_id not in self.agent.terminating_kernels: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.DESTROY, + kernel_obj.termination_reason + or KernelLifecycleEventReason.ALREADY_TERMINATED, + suppress_events=False, + ) + elif kernel_id not in running_kernels_from_truth: + # The kernel status is not 'running' or 'terminating' in truth. + # It should be terminated. + if kernel_id not in self.agent.terminating_kernels: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.DESTROY, + kernel_obj.termination_reason + or KernelLifecycleEventReason.ALREADY_TERMINATED, + suppress_events=True, + ) + + result = AgentKernelRegistryByStatus( + all_running_kernels, + actual_terminating_kernels, + actual_terminated_kernels, + ) + return result.to_json() + @rpc_function @collect_error async def create_kernels( diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index e9e2a37f0e5..871829bdbc0 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -1246,3 +1246,27 @@ def as_trafaret(cls) -> t.Trafaret: class ModelServiceStatus(enum.Enum): HEALTHY = "healthy" UNHEALTHY = "unhealthy" + + +@dataclass +class AgentKernelRegistryByStatus(JSONSerializableMixin): + from .events import KernelLifecycleEventReason + + all_running_kernels: list[KernelId] + actual_terminating_kernels: list[tuple[KernelId, KernelLifecycleEventReason]] + actual_terminated_kernels: list[tuple[KernelId, KernelLifecycleEventReason]] + + def to_json(self) -> dict[str, list[KernelId]]: + return dataclasses.asdict(self) + + @classmethod + def from_json(cls, obj: Mapping[str, Any]) -> AgentKernelRegistryByStatus: + return cls(**cls.as_trafaret().check(obj)) + + @classmethod + def as_trafaret(cls) -> t.Trafaret: + return t.Dict({ + t.Key("all_running_kernels"): t.List(t.String), + t.Key("actual_terminating_kernels"): t.List(t.Tuple(t.String, t.String)), + t.Key("actual_terminated_kernels"): t.List(t.Tuple(t.String, t.String)), + }) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index c76782e8482..942f6d6d8e6 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -12,6 +12,7 @@ import uuid import zlib from collections import defaultdict +from collections.abc import Collection from datetime import datetime from decimal import Decimal from io import BytesIO @@ -87,6 +88,7 @@ AbuseReport, AccessKey, AgentId, + AgentKernelRegistryByStatus, BinarySize, ClusterInfo, ClusterMode, @@ -3026,6 +3028,19 @@ async def sync_agent_kernel_registry(self, agent_id: AgentId) -> None: (str(kernel.id), str(kernel.session_id)) for kernel in grouped_kernels ]) + async def sync_and_get_kernels( + self, + agent_id: AgentId, + running_kernels: Collection[KernelId], + terminating_kernels: Collection[KernelId], + ) -> AgentKernelRegistryByStatus: + async with self.agent_cache.rpc_context(agent_id) as rpc: + resp: dict[str, Any] = await rpc.call.sync_and_get_kernels( + running_kernels, + terminating_kernels, + ) + return AgentKernelRegistryByStatus.from_json(resp) + async def mark_kernel_terminated( self, kernel_id: KernelId,