Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Impl Resource sync interface in agent side #2529

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,11 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None:
),
)
# Notify cleanup waiters after all state updates.
if kernel_obj is not None and kernel_obj.clean_event is not None:
if (
kernel_obj is not None
and kernel_obj.clean_event is not None
and not kernel_obj.clean_event.done()
):
kernel_obj.clean_event.set_result(None)
if ev.done_future is not None and not ev.done_future.done():
ev.done_future.set_result(None)
Expand Down
112 changes: 111 additions & 1 deletion src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
ImageRegistry,
KernelCreationConfig,
KernelId,
KernelStatusCollection,
QueueSentinel,
SessionId,
aobject,
Expand All @@ -80,7 +81,7 @@
)
from .exception import ResourceError
from .monitor import AgentErrorPluginContext, AgentStatsPluginContext
from .types import AgentBackend, LifecycleEvent, VolumeInfo
from .types import AgentBackend, KernelLifecycleStatus, LifecycleEvent, VolumeInfo
from .utils import get_arch_name, get_subnet_ip

if TYPE_CHECKING:
Expand Down Expand Up @@ -478,6 +479,115 @@ async def sync_kernel_registry(
suppress_events=True,
)

@rpc_function
@collect_error
async def sync_and_get_kernels(
self,
preparing_kernels: Iterable[UUID],
pulling_kernels: Iterable[UUID],
running_kernels: Iterable[UUID],
terminating_kernels: Iterable[UUID],
) -> dict[str, Any]:
"""
Sync kernel_registry and containers to truth data
and return kernel infos whose status is irreversible.
"""

actual_terminating_kernels: list[tuple[KernelId, str]] = []
actual_terminated_kernels: list[tuple[KernelId, str]] = []

async with self.agent.registry_lock:
actual_existing_kernels = [
kid
for kid, obj in self.agent.kernel_registry.items()
if obj.state == KernelLifecycleStatus.RUNNING
]

for raw_kernel_id in running_kernels:
kernel_id = KernelId(raw_kernel_id)
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
actual_terminating_kernels.append((
kernel_id,
str(
kernel_obj.termination_reason
or KernelLifecycleEventReason.ALREADY_TERMINATED
),
))
else:
actual_terminated_kernels.append((
kernel_id,
str(KernelLifecycleEventReason.ALREADY_TERMINATED),
))

for raw_kernel_id in terminating_kernels:
kernel_id = KernelId(raw_kernel_id)
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.DESTROY,
kernel_obj.termination_reason
or KernelLifecycleEventReason.ALREADY_TERMINATED,
suppress_events=False,
)
else:
actual_terminated_kernels.append((
kernel_id,
str(KernelLifecycleEventReason.ALREADY_TERMINATED),
))

for kernel_id, kernel_obj in self.agent.kernel_registry.items():
if kernel_id in terminating_kernels:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.DESTROY,
kernel_obj.termination_reason
or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER,
suppress_events=False,
)
elif kernel_id in running_kernels:
pass
elif kernel_id in preparing_kernels:
pass
elif kernel_id in pulling_kernels:
# kernel_registry does not have `pulling` state kernels.
# Let's skip it.
pass
else:
# This kernel is not alive according to the truth data.
# The kernel should be destroyed or cleaned
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.CLEAN,
kernel_obj.termination_reason
or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER,
suppress_events=True,
)
elif kernel_id in self.agent.restarting_kernels:
pass
else:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.DESTROY,
kernel_obj.termination_reason
or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER,
suppress_events=True,
)

result = KernelStatusCollection(
actual_existing_kernels,
actual_terminating_kernels,
actual_terminated_kernels,
)
return result.to_json()

@rpc_function
@collect_error
async def create_kernels(
Expand Down
26 changes: 26 additions & 0 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,3 +1263,29 @@ class ModelServiceProfile:
),
RuntimeVariant.CMD: ModelServiceProfile(name="Predefined Image Command"),
}


@dataclass
class KernelStatusCollection(JSONSerializableMixin):
KernelTerminationInfo = tuple[KernelId, str]

actual_existing_kernels: list[KernelId]
actual_terminating_kernels: list[KernelTerminationInfo]
actual_terminated_kernels: list[KernelTerminationInfo]

def to_json(self) -> dict[str, list[KernelId]]:
return dataclasses.asdict(self)

@classmethod
def from_json(cls, obj: Mapping[str, Any]) -> KernelStatusCollection:
return cls(**cls.as_trafaret().check(obj))

@classmethod
def as_trafaret(cls) -> t.Trafaret:
from . import validators as tx

return t.Dict({
t.Key("actual_existing_kernels"): tx.ToList(tx.UUID),
t.Key("actual_terminating_kernels"): tx.ToList(t.Tuple(t.String, t.String)),
t.Key("actual_terminated_kernels"): tx.ToList(t.Tuple(t.String, t.String)),
})
13 changes: 13 additions & 0 deletions src/ai/backend/common/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,19 @@ def check_and_return(self, value: Any) -> set:
self._failure("value must be Iterable")


class ToList(t.List):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just reuse trafaret.base.List?

Copy link
Member Author

@fregataa fregataa Aug 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trafaret.base.List allows list type data only and I want to allow any iterable data type here since sequential data fetched from RPC response get deserialized into tuple.
image.png

def check_common(self, value: Any) -> None:
return super().check_common(self.check_and_return(value)) # type: ignore[misc]

def check_and_return(self, value: Any) -> list:
try:
return list(value)
except TypeError:
self._failure(
f"Cannot parse {type(value)} to list. value must be Iterable", value=value
)


class Delay(t.Trafaret):
"""
Convert a float or a tuple of 2 floats into a random generated float value
Expand Down
104 changes: 101 additions & 3 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import uuid
import zlib
from collections import defaultdict
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from datetime import datetime
from decimal import Decimal
from io import BytesIO
Expand All @@ -22,10 +23,7 @@
Dict,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
TypeAlias,
Union,
Expand Down Expand Up @@ -101,6 +99,7 @@
ImageRegistry,
KernelEnqueueingConfig,
KernelId,
KernelStatusCollection,
ModelServiceStatus,
RedisConnectionInfo,
ResourceSlot,
Expand Down Expand Up @@ -3277,6 +3276,23 @@ async def _update_session(db_session: AsyncSession) -> None:
self._kernel_actual_allocated_resources[kernel_id] = actual_allocs
await self.set_status_updatable_session(session_id)

async def _sync_agent_resource_and_get_kerenels(
self,
agent_id: AgentId,
preparing_kernels: Iterable[KernelId],
pulling_kernels: Iterable[KernelId],
running_kernels: Iterable[KernelId],
terminating_kernels: Iterable[KernelId],
) -> KernelStatusCollection:
async with self.agent_cache.rpc_context(agent_id) as rpc:
resp: dict[str, Any] = await rpc.call.sync_and_get_kernels(
preparing_kernels,
pulling_kernels,
running_kernels,
terminating_kernels,
)
return KernelStatusCollection.from_json(resp)

async def mark_kernel_terminated(
self,
kernel_id: KernelId,
Expand Down Expand Up @@ -3486,6 +3502,88 @@ async def get_status_updatable_sessions(self) -> list[SessionId]:
result.append(SessionId(msgpack.unpackb(raw_session_id)))
return result

async def sync_agent_resource(
self,
db: ExtendedAsyncSAEngine,
agent_ids: Iterable[AgentId],
) -> dict[AgentId, KernelStatusCollection | MultiAgentError]:
result: dict[AgentId, KernelStatusCollection | MultiAgentError] = {}
agent_kernel_by_status: dict[AgentId, dict[str, list[KernelId]]] = {}
stmt = (
sa.select(AgentRow)
.where(AgentRow.id.in_(agent_ids))
.options(
selectinload(
AgentRow.kernels.and_(
KernelRow.status.in_([
KernelStatus.PREPARING,
KernelStatus.PULLING,
KernelStatus.RUNNING,
KernelStatus.TERMINATING,
])
),
).options(load_only(KernelRow.id, KernelRow.status))
)
)
async with db.begin_readonly_session() as db_session:
for _agent_row in await db_session.scalars(stmt):
agent_row = cast(AgentRow, _agent_row)
preparing_kernels: list[KernelId] = []
pulling_kernels: list[KernelId] = []
running_kernels: list[KernelId] = []
terminating_kernels: list[KernelId] = []
for kernel in agent_row.kernels:
kernel_status = cast(KernelStatus, kernel.status)
match kernel_status:
case KernelStatus.PREPARING:
preparing_kernels.append(KernelId(kernel.id))
case KernelStatus.PULLING:
pulling_kernels.append(KernelId(kernel.id))
case KernelStatus.RUNNING:
running_kernels.append(KernelId(kernel.id))
case KernelStatus.TERMINATING:
terminating_kernels.append(KernelId(kernel.id))
case _:
continue
agent_kernel_by_status[AgentId(agent_row.id)] = {
"preparing_kernels": preparing_kernels,
"pulling_kernels": pulling_kernels,
"running_kernels": running_kernels,
"terminating_kernels": terminating_kernels,
}
aid_task_list: list[tuple[AgentId, asyncio.Task]] = []
async with aiotools.PersistentTaskGroup() as tg:
for agent_id in agent_ids:
task = tg.create_task(
self._sync_agent_resource_and_get_kerenels(
agent_id,
agent_kernel_by_status[agent_id]["preparing_kernels"],
agent_kernel_by_status[agent_id]["pulling_kernels"],
agent_kernel_by_status[agent_id]["running_kernels"],
agent_kernel_by_status[agent_id]["terminating_kernels"],
)
)
aid_task_list.append((agent_id, task))
for aid, task in aid_task_list:
agent_errors = []
try:
resp = await task
except aiotools.TaskGroupError as e:
agent_errors.extend(e.__errors__)
except Exception as e:
agent_errors.append(e)
if agent_errors:
result[aid] = MultiAgentError(
"agent(s) raise errors during kernel resource sync",
agent_errors,
)
else:
assert isinstance(
resp, KernelStatusCollection
), f"response should be `KernelStatusCollection`, not {type(resp)}"
result[aid] = resp
return result

async def _get_user_email(
self,
kernel: KernelRow,
Expand Down