Skip to content

Commit

Permalink
feature: sync mismatch between db and containers
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed May 26, 2024
1 parent 964c9f3 commit 1eb020a
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 0 deletions.
88 changes: 88 additions & 0 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import signal
import sys
from collections import OrderedDict, defaultdict
from collections.abc import Collection
from ipaddress import _BaseAddress as BaseIPAddress
from ipaddress import ip_network
from pathlib import Path
Expand Down Expand Up @@ -58,6 +59,7 @@
)
from ai.backend.common.logging import BraceStyleAdapter, Logger
from ai.backend.common.types import (
AgentKernelRegistryByStatus,
ClusterInfo,
CommitStatus,
HardwareMetadata,
Expand Down Expand Up @@ -477,6 +479,92 @@ async def sync_kernel_registry(
suppress_events=True,
)

@rpc_function
@collect_error
async def sync_and_get_kernels(
self,
running_kernels_from_truth: Collection[str],
terminating_kernels_from_truth: Collection[str],
) -> dict[str, Any]:
"""
Sync kernel_registry and containers to truth
and return kernel infos whose status is irreversible.
"""

actual_terminating_kernels: list[tuple[KernelId, KernelLifecycleEventReason]] = []
actual_terminated_kernels: list[tuple[KernelId, KernelLifecycleEventReason]] = []

async with self.agent.registry_lock:
all_running_kernels = [
kid
for kid in self.agent.kernel_registry
if kid not in self.agent.terminating_kernels
]

for raw_kernel_id in running_kernels_from_truth:
kernel_id = KernelId(UUID(raw_kernel_id))
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_id in self.agent.terminating_kernels:
actual_terminating_kernels.append((
kernel_id,
kernel_obj.termination_reason
or KernelLifecycleEventReason.ALREADY_TERMINATED,
))
else:
actual_terminated_kernels.append((
kernel_id,
KernelLifecycleEventReason.ALREADY_TERMINATED,
))

for raw_kernel_id in terminating_kernels_from_truth:
kernel_id = KernelId(UUID(raw_kernel_id))
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_id not in self.agent.terminating_kernels:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.DESTROY,
kernel_obj.termination_reason
or KernelLifecycleEventReason.ALREADY_TERMINATED,
suppress_events=False,
)
else:
actual_terminated_kernels.append((
kernel_id,
KernelLifecycleEventReason.ALREADY_TERMINATED,
))

for kernel_id, kernel_obj in self.agent.kernel_registry.items():
if kernel_id in terminating_kernels_from_truth:
if kernel_id not in self.agent.terminating_kernels:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.DESTROY,
kernel_obj.termination_reason
or KernelLifecycleEventReason.ALREADY_TERMINATED,
suppress_events=False,
)
elif kernel_id not in running_kernels_from_truth:
# The kernel status is not 'running' or 'terminating' in truth.
# It should be terminated.
if kernel_id not in self.agent.terminating_kernels:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.DESTROY,
kernel_obj.termination_reason
or KernelLifecycleEventReason.ALREADY_TERMINATED,
suppress_events=True,
)

result = AgentKernelRegistryByStatus(
all_running_kernels,
actual_terminating_kernels,
actual_terminated_kernels,
)
return result.to_json()

@rpc_function
@collect_error
async def create_kernels(
Expand Down
24 changes: 24 additions & 0 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,3 +1246,27 @@ def as_trafaret(cls) -> t.Trafaret:
class ModelServiceStatus(enum.Enum):
HEALTHY = "healthy"
UNHEALTHY = "unhealthy"


@dataclass
class AgentKernelRegistryByStatus(JSONSerializableMixin):
from .events import KernelLifecycleEventReason

all_running_kernels: list[KernelId]
actual_terminating_kernels: list[tuple[KernelId, KernelLifecycleEventReason]]
actual_terminated_kernels: list[tuple[KernelId, KernelLifecycleEventReason]]

def to_json(self) -> dict[str, list[KernelId]]:
return dataclasses.asdict(self)

@classmethod
def from_json(cls, obj: Mapping[str, Any]) -> AgentKernelRegistryByStatus:
return cls(**cls.as_trafaret().check(obj))

@classmethod
def as_trafaret(cls) -> t.Trafaret:
return t.Dict({
t.Key("all_running_kernels"): t.List(t.String),
t.Key("actual_terminating_kernels"): t.List(t.Tuple(t.String, t.String)),
t.Key("actual_terminated_kernels"): t.List(t.Tuple(t.String, t.String)),
})
15 changes: 15 additions & 0 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import uuid
import zlib
from collections import defaultdict
from collections.abc import Collection
from datetime import datetime
from decimal import Decimal
from io import BytesIO
Expand Down Expand Up @@ -87,6 +88,7 @@
AbuseReport,
AccessKey,
AgentId,
AgentKernelRegistryByStatus,
BinarySize,
ClusterInfo,
ClusterMode,
Expand Down Expand Up @@ -3026,6 +3028,19 @@ async def sync_agent_kernel_registry(self, agent_id: AgentId) -> None:
(str(kernel.id), str(kernel.session_id)) for kernel in grouped_kernels
])

async def sync_and_get_kernels(
self,
agent_id: AgentId,
running_kernels: Collection[KernelId],
terminating_kernels: Collection[KernelId],
) -> AgentKernelRegistryByStatus:
async with self.agent_cache.rpc_context(agent_id) as rpc:
resp: dict[str, Any] = await rpc.call.sync_and_get_kernels(
running_kernels,
terminating_kernels,
)
return AgentKernelRegistryByStatus.from_json(resp)

async def mark_kernel_terminated(
self,
kernel_id: KernelId,
Expand Down

0 comments on commit 1eb020a

Please sign in to comment.